diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md new file mode 100644 index 0000000000000..4847406f995a9 --- /dev/null +++ b/.github/CONTRIBUTING.md @@ -0,0 +1,17 @@ +PRs welcome! But please file bugs first and explain the problem or +motivation. For new or changed functionality, strike up a discussion +and get agreement on the design/solution before spending too much time writing +code. + +Commit messages should [reference +bugs](https://docs.github.com/en/github/writing-on-github/autolinked-references-and-urls). + +We require [Developer Certificate of +Origin](https://en.wikipedia.org/wiki/Developer_Certificate_of_Origin) (DCO) +`Signed-off-by` lines in commits. (`git commit -s`) + +Please squash your code review edits & force push. Multiple commits in +a PR are fine, but only if they're each logically separate and all tests pass +at each stage. No fixup commits. + +See [commit-messages.md](docs/commit-messages.md) (or skim `git log`) for our commit message style. diff --git a/.github/workflows/build-and-publish-images.yml b/.github/workflows/build-and-publish-images.yml new file mode 100644 index 0000000000000..381dc88fc6179 --- /dev/null +++ b/.github/workflows/build-and-publish-images.yml @@ -0,0 +1,45 @@ +name: Publish Dev Operator + +on: + push: + tags: + - 'v*.*.*' + - 'v*.*.*-*' +jobs: + publish: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Build and publish k8s-operator image + env: + REPO: ghcr.io/${{ github.repository_owner }}/tailscale-k8s-operator + TAGS: ${{ github.ref_name }} + run: | + echo "Building and publishing k8s-operator to ${REPO} with tags ${TAGS}" + TAGS="${TAGS}" REPOS=${REPO} PLATFORM=${PLATFORM} PUSH=true TARGET=operator ./build_docker.sh + - name: Build and publish nameserver image + env: + REPO: ghcr.io/${{ github.repository_owner }}/tailscale-k8s-nameserver + TAGS: ${{ github.ref_name }} + run: | + echo "Building and publishing k8s-nameserver to ${REPO} with tags ${TAGS}" + TAGS="${TAGS}" REPOS=${REPO} PLATFORM=${PLATFORM} PUSH=true TARGET=k8s-nameserver ./build_docker.sh + - name: Build and publish client image + env: + REPO: ghcr.io/${{ github.repository_owner }}/tailscale + TAGS: ${{ github.ref_name }} + run: | + echo "Building and publishing tailscale client to ${REPO} with tags ${TAGS}" + TAGS="${TAGS}" REPOS=${REPO} PLATFORM=${PLATFORM} PUSH=true TARGET=client ./build_docker.sh diff --git a/.github/workflows/chart.yaml b/.github/workflows/chart.yaml new file mode 100644 index 0000000000000..eb0ffd5f62003 --- /dev/null +++ b/.github/workflows/chart.yaml @@ -0,0 +1,38 @@ +name: package-helm-chart + +on: + push: + tags: + - 'v*.*.*' + - 'v*.*.*-*' + workflow_dispatch: + +jobs: + package-and-push-helm-chart: + permissions: + contents: read + packages: write + + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4.2.2 + + - name: Set environment variables + id: set-variables + run: | + echo "REPOSITORY=ghcr.io/$(echo ${{ github.repository }} | tr '[:upper:]' '[:lower:]')" >> "$GITHUB_OUTPUT" + + - name: Login to GitHub Container Registry + uses: docker/login-action@v3.3.0 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ github.token }} + + - name: Build, package and push helm chart + run: | + ./tool/go run cmd/k8s-operator/generate/main.go helmcrd + ./tool/helm package --app-version=${{ github.ref_name }} --version=${{ github.ref_name }} './cmd/k8s-operator/deploy/chart' + ./tool/helm push ./tailscale-operator-${{ github.ref_name }}.tgz oci://${{ steps.set-variables.outputs.REPOSITORY }}/charts diff --git a/.github/workflows/checklocks.yml b/.github/workflows/checklocks.yml index 064797c884a60..7464524ce99e2 100644 --- a/.github/workflows/checklocks.yml +++ b/.github/workflows/checklocks.yml @@ -18,7 +18,7 @@ jobs: runs-on: [ ubuntu-latest ] steps: - name: Check out code - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Build checklocks run: ./tool/go build -o /tmp/checklocks gvisor.dev/gvisor/tools/checklocks/cmd/checklocks diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml index 4e266c6eae6ab..311f539e1978f 100644 --- a/.github/workflows/codeql-analysis.yml +++ b/.github/workflows/codeql-analysis.yml @@ -45,17 +45,17 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 # Install a more recent Go that understands modern go.mod content. - name: Install Go - uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 # v5.0.2 + uses: actions/setup-go@0aaccfd150d50ccaeb58ebd88d36e91967a5f35b # v5.4.0 with: go-version-file: go.mod # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL - uses: github/codeql-action/init@6db8d6351fd0be61f9ed8ebd12ccd35dcec51fea # v3.26.11 + uses: github/codeql-action/init@45775bd8235c68ba998cffa5171334d58593da47 # v3.28.15 with: languages: ${{ matrix.language }} # If you wish to specify custom queries, you can do so here or in a config file. @@ -66,7 +66,7 @@ jobs: # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). # If this step fails, then you should remove it and run the build manually (see below) - name: Autobuild - uses: github/codeql-action/autobuild@6db8d6351fd0be61f9ed8ebd12ccd35dcec51fea # v3.26.11 + uses: github/codeql-action/autobuild@45775bd8235c68ba998cffa5171334d58593da47 # v3.28.15 # â„šī¸ Command-line programs to run using the OS shell. # 📚 https://git.io/JvXDl @@ -80,4 +80,4 @@ jobs: # make release - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@6db8d6351fd0be61f9ed8ebd12ccd35dcec51fea # v3.26.11 + uses: github/codeql-action/analyze@45775bd8235c68ba998cffa5171334d58593da47 # v3.28.15 diff --git a/.github/workflows/docker-file-build.yml b/.github/workflows/docker-file-build.yml index c535755724391..04611e172bbea 100644 --- a/.github/workflows/docker-file-build.yml +++ b/.github/workflows/docker-file-build.yml @@ -10,6 +10,6 @@ jobs: deploy: runs-on: ubuntu-latest steps: - - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: "Build Docker image" run: docker build . diff --git a/.github/workflows/flakehub-publish-tagged.yml b/.github/workflows/flakehub-publish-tagged.yml index 60fdba91c1247..9ff12c6a3fd14 100644 --- a/.github/workflows/flakehub-publish-tagged.yml +++ b/.github/workflows/flakehub-publish-tagged.yml @@ -17,7 +17,7 @@ jobs: id-token: "write" contents: "read" steps: - - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: ref: "${{ (inputs.tag != null) && format('refs/tags/{0}', inputs.tag) || '' }}" - uses: "DeterminateSystems/nix-installer-action@main" diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 9c34debc5d2f4..04a2e042db27d 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -23,18 +23,17 @@ jobs: name: lint runs-on: ubuntu-latest steps: - - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 # v5.0.2 + - uses: actions/setup-go@0aaccfd150d50ccaeb58ebd88d36e91967a5f35b # v5.4.0 with: go-version-file: go.mod cache: false - name: golangci-lint - # Note: this is the 'v6.1.0' tag as of 2024-08-21 - uses: golangci/golangci-lint-action@aaa42aa0628b4ae2578232a66b541047968fac86 + uses: golangci/golangci-lint-action@1481404843c368bc19ca9406f87d6e0fc97bdcfd # v7.0.0 with: - version: v1.60 + version: v2.0.2 # Show only new issues if it's a pull request. only-new-issues: true diff --git a/.github/workflows/govulncheck.yml b/.github/workflows/govulncheck.yml index 4a5ad54f391e8..10269ff0bf078 100644 --- a/.github/workflows/govulncheck.yml +++ b/.github/workflows/govulncheck.yml @@ -14,7 +14,7 @@ jobs: steps: - name: Check out code into the Go module directory - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Install govulncheck run: ./tool/go install golang.org/x/vuln/cmd/govulncheck@latest @@ -24,13 +24,13 @@ jobs: - name: Post to slack if: failure() && github.event_name == 'schedule' - uses: slackapi/slack-github-action@37ebaef184d7626c5f204ab8d3baff4262dd30f0 # v1.27.0 - env: - SLACK_BOT_TOKEN: ${{ secrets.GOVULNCHECK_BOT_TOKEN }} + uses: slackapi/slack-github-action@485a9d42d3a73031f12ec201c457e2162c45d02d # v2.0.0 with: - channel-id: 'C05PXRM304B' + method: chat.postMessage + token: ${{ secrets.GOVULNCHECK_BOT_TOKEN }} payload: | { + "channel": "C08FGKZCQTW", "blocks": [ { "type": "section", diff --git a/.github/workflows/installer.yml b/.github/workflows/installer.yml index 48b29c6ec02cd..7888d9ba5d3e2 100644 --- a/.github/workflows/installer.yml +++ b/.github/workflows/installer.yml @@ -1,16 +1,20 @@ name: test installer.sh on: + schedule: + - cron: '0 15 * * *' # 10am EST (UTC-4/5) push: branches: - "main" paths: - scripts/installer.sh + - .github/workflows/installer.yml pull_request: branches: - "*" paths: - scripts/installer.sh + - .github/workflows/installer.yml jobs: test: @@ -29,13 +33,11 @@ jobs: - "debian:stable-slim" - "debian:testing-slim" - "debian:sid-slim" - - "ubuntu:18.04" - "ubuntu:20.04" - "ubuntu:22.04" - - "ubuntu:23.04" + - "ubuntu:24.04" - "elementary/docker:stable" - "elementary/docker:unstable" - - "parrotsec/core:lts-amd64" - "parrotsec/core:latest" - "kalilinux/kali-rolling" - "kalilinux/kali-dev" @@ -48,7 +50,7 @@ jobs: - "opensuse/leap:latest" - "opensuse/tumbleweed:latest" - "archlinux:latest" - - "alpine:3.14" + - "alpine:3.21" - "alpine:latest" - "alpine:edge" deps: @@ -58,10 +60,6 @@ jobs: # Check a few images with wget rather than curl. - { image: "debian:oldstable-slim", deps: "wget" } - { image: "debian:sid-slim", deps: "wget" } - - { image: "ubuntu:23.04", deps: "wget" } - # Ubuntu 16.04 also needs apt-transport-https installed. - - { image: "ubuntu:16.04", deps: "curl apt-transport-https" } - - { image: "ubuntu:16.04", deps: "wget apt-transport-https" } runs-on: ubuntu-latest container: image: ${{ matrix.image }} @@ -76,10 +74,10 @@ jobs: # tar and gzip are needed by the actions/checkout below. run: yum install -y --allowerasing tar gzip ${{ matrix.deps }} if: | - contains(matrix.image, 'centos') - || contains(matrix.image, 'oraclelinux') - || contains(matrix.image, 'fedora') - || contains(matrix.image, 'amazonlinux') + contains(matrix.image, 'centos') || + contains(matrix.image, 'oraclelinux') || + contains(matrix.image, 'fedora') || + contains(matrix.image, 'amazonlinux') - name: install dependencies (zypper) # tar and gzip are needed by the actions/checkout below. run: zypper --non-interactive install tar gzip ${{ matrix.deps }} @@ -89,16 +87,13 @@ jobs: apt-get update apt-get install -y ${{ matrix.deps }} if: | - contains(matrix.image, 'debian') - || contains(matrix.image, 'ubuntu') - || contains(matrix.image, 'elementary') - || contains(matrix.image, 'parrotsec') - || contains(matrix.image, 'kalilinux') + contains(matrix.image, 'debian') || + contains(matrix.image, 'ubuntu') || + contains(matrix.image, 'elementary') || + contains(matrix.image, 'parrotsec') || + contains(matrix.image, 'kalilinux') - name: checkout - # We cannot use v4, as it requires a newer glibc version than some of the - # tested images provide. See - # https://github.com/actions/checkout/issues/1487 - uses: actions/checkout@f43a0e5ff2bd294095638e18286ca9a3d1956744 # v3.6.0 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: run installer run: scripts/installer.sh # Package installation can fail in docker because systemd is not running @@ -107,3 +102,30 @@ jobs: continue-on-error: true - name: check tailscale version run: tailscale --version + notify-slack: + needs: test + runs-on: ubuntu-latest + steps: + - name: Notify Slack of failure on scheduled runs + if: failure() && github.event_name == 'schedule' + uses: slackapi/slack-github-action@485a9d42d3a73031f12ec201c457e2162c45d02d # v2.0.0 + with: + webhook: ${{ secrets.SLACK_WEBHOOK_URL }} + webhook-type: incoming-webhook + payload: | + { + "attachments": [{ + "title": "Tailscale installer test failed", + "title_link": "https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}", + "text": "One or more OSes in the test matrix failed. See the run for details.", + "fields": [ + { + "title": "Ref", + "value": "${{ github.ref_name }}", + "short": true + } + ], + "footer": "${{ github.workflow }} on schedule", + "color": "danger" + }] + } diff --git a/.github/workflows/kubemanifests.yaml b/.github/workflows/kubemanifests.yaml index f943ccb524f35..5b100a2763e3b 100644 --- a/.github/workflows/kubemanifests.yaml +++ b/.github/workflows/kubemanifests.yaml @@ -17,7 +17,7 @@ jobs: runs-on: [ ubuntu-latest ] steps: - name: Check out code - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Build and lint Helm chart run: | eval `./tool/go run ./cmd/mkversion` diff --git a/.github/workflows/natlab-integrationtest.yml b/.github/workflows/natlab-integrationtest.yml new file mode 100644 index 0000000000000..1de74cdaa45f8 --- /dev/null +++ b/.github/workflows/natlab-integrationtest.yml @@ -0,0 +1,27 @@ +# Run some natlab integration tests. +# See https://github.com/tailscale/tailscale/issues/13038 +name: "natlab-integrationtest" + +concurrency: + group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +on: + pull_request: + paths: + - "tstest/integration/nat/nat_test.go" +jobs: + natlab-integrationtest: + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Install qemu + run: | + sudo rm /var/lib/man-db/auto-update + sudo apt-get -y update + sudo apt-get -y remove man-db + sudo apt-get install -y qemu-system-x86 qemu-utils + - name: Run natlab integration tests + run: | + ./tool/go test -v -run=^TestEasyEasy$ -timeout=3m -count=1 ./tstest/integration/nat --run-vm-tests diff --git a/.github/workflows/ssh-integrationtest.yml b/.github/workflows/ssh-integrationtest.yml index a82696307ea4b..829d10ab8c2c8 100644 --- a/.github/workflows/ssh-integrationtest.yml +++ b/.github/workflows/ssh-integrationtest.yml @@ -17,7 +17,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Check out code - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Run SSH integration tests run: | make sshintegrationtest \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bc70040b054bf..100a7288a8236 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -50,7 +50,7 @@ jobs: - shard: '4/4' steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: build test wrapper run: ./tool/go build -o /tmp/testwrapper ./cmd/testwrapper - name: integration tests as root @@ -64,7 +64,6 @@ jobs: matrix: include: - goarch: amd64 - coverflags: "-coverprofile=/tmp/coverage.out" - goarch: amd64 buildflags: "-race" shard: '1/3' @@ -78,9 +77,9 @@ jobs: runs-on: ubuntu-22.04 steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Restore Cache - uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0 + uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: # Note: unlike the other setups, this is only grabbing the mod download # cache, rather than the whole mod directory, as the download cache @@ -119,15 +118,10 @@ jobs: - name: build test wrapper run: ./tool/go build -o /tmp/testwrapper ./cmd/testwrapper - name: test all - run: NOBASHDEBUG=true PATH=$PWD/tool:$PATH /tmp/testwrapper ${{matrix.coverflags}} ./... ${{matrix.buildflags}} + run: NOBASHDEBUG=true PATH=$PWD/tool:$PATH /tmp/testwrapper ./... ${{matrix.buildflags}} env: GOARCH: ${{ matrix.goarch }} TS_TEST_SHARD: ${{ matrix.shard }} - - name: Publish to coveralls.io - if: matrix.coverflags != '' # only publish results if we've tracked coverage - uses: shogo82148/actions-goveralls@v1 - with: - path-to-profile: /tmp/coverage.out - name: bench all run: ./tool/go test ${{matrix.buildflags}} -bench=. -benchtime=1x -run=^$ $(for x in $(git grep -l "^func Benchmark" | xargs dirname | sort | uniq); do echo "./$x"; done) env: @@ -145,21 +139,25 @@ jobs: echo "Build/test created untracked files in the repo (file names above)." exit 1 fi - + - name: Tidy cache + shell: bash + run: | + find $(go env GOCACHE) -type f -mmin +90 -delete + find $(go env GOMODCACHE)/cache -type f -mmin +90 -delete windows: runs-on: windows-2022 steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Install Go - uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 # v5.0.2 + uses: actions/setup-go@0aaccfd150d50ccaeb58ebd88d36e91967a5f35b # v5.4.0 with: go-version-file: go.mod cache: false - name: Restore Cache - uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0 + uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: # Note: unlike the other setups, this is only grabbing the mod download # cache, rather than the whole mod directory, as the download cache @@ -182,6 +180,11 @@ jobs: # Somewhere in the layers (powershell?) # the equals signs cause great confusion. run: go test ./... -bench . -benchtime 1x -run "^$" + - name: Tidy cache + shell: bash + run: | + find $(go env GOCACHE) -type f -mmin +90 -delete + find $(go env GOMODCACHE)/cache -type f -mmin +90 -delete privileged: runs-on: ubuntu-22.04 @@ -190,7 +193,7 @@ jobs: options: --privileged steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: chown run: chown -R $(id -u):$(id -g) $PWD - name: privileged tests @@ -202,7 +205,7 @@ jobs: if: github.repository == 'tailscale/tailscale' steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Run VM tests run: ./tool/go test ./tstest/integration/vms -v -no-s3 -run-vm-tests -run=TestRunUbuntu2004 env: @@ -214,7 +217,7 @@ jobs: runs-on: ubuntu-22.04 steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: build all run: ./tool/go install -race ./cmd/... - name: build tests @@ -258,9 +261,9 @@ jobs: runs-on: ubuntu-22.04 steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Restore Cache - uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0 + uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: # Note: unlike the other setups, this is only grabbing the mod download # cache, rather than the whole mod directory, as the download cache @@ -289,13 +292,18 @@ jobs: GOOS: ${{ matrix.goos }} GOARCH: ${{ matrix.goarch }} CGO_ENABLED: "0" + - name: Tidy cache + shell: bash + run: | + find $(go env GOCACHE) -type f -mmin +90 -delete + find $(go env GOMODCACHE)/cache -type f -mmin +90 -delete ios: # similar to cross above, but iOS can't build most of the repo. So, just #make it build a few smoke packages. runs-on: ubuntu-22.04 steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: build some run: ./tool/go build ./ipn/... ./wgengine/ ./types/... ./control/controlclient env: @@ -313,13 +321,19 @@ jobs: # AIX - goos: aix goarch: ppc64 + # Solaris + - goos: solaris + goarch: amd64 + # illumos + - goos: illumos + goarch: amd64 runs-on: ubuntu-22.04 steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Restore Cache - uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0 + uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: # Note: unlike the other setups, this is only grabbing the mod download # cache, rather than the whole mod directory, as the download cache @@ -342,6 +356,11 @@ jobs: GOARCH: ${{ matrix.goarch }} GOARM: ${{ matrix.goarm }} CGO_ENABLED: "0" + - name: Tidy cache + shell: bash + run: | + find $(go env GOCACHE) -type f -mmin +90 -delete + find $(go env GOMODCACHE)/cache -type f -mmin +90 -delete android: # similar to cross above, but android fails to build a few pieces of the @@ -350,7 +369,7 @@ jobs: runs-on: ubuntu-22.04 steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 # Super minimal Android build that doesn't even use CGO and doesn't build everything that's needed # and is only arm64. But it's a smoke build: it's not meant to catch everything. But it'll catch # some Android breakages early. @@ -365,9 +384,9 @@ jobs: runs-on: ubuntu-22.04 steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Restore Cache - uses: actions/cache@2cdf405574d6ef1f33a1d12acccd3ae82f47b3f2 # v4.1.0 + uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3 with: # Note: unlike the other setups, this is only grabbing the mod download # cache, rather than the whole mod directory, as the download cache @@ -394,12 +413,17 @@ jobs: run: | ./tool/go run ./cmd/tsconnect --fast-compression build ./tool/go run ./cmd/tsconnect --fast-compression build-pkg + - name: Tidy cache + shell: bash + run: | + find $(go env GOCACHE) -type f -mmin +90 -delete + find $(go env GOMODCACHE)/cache -type f -mmin +90 -delete tailscale_go: # Subset of tests that depend on our custom Go toolchain. runs-on: ubuntu-22.04 steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: test tailscale_go run: ./tool/go test -tags=tailscale_go,ts_enable_sockstats ./net/sockstats/... @@ -461,7 +485,7 @@ jobs: run: | echo "artifacts_path=$(realpath .)" >> $GITHUB_ENV - name: upload crash - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 if: steps.run.outcome != 'success' && steps.build.outcome == 'success' with: name: artifacts @@ -471,17 +495,16 @@ jobs: runs-on: ubuntu-22.04 steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: check depaware run: | - export PATH=$(./tool/go env GOROOT)/bin:$PATH - find . -name 'depaware.txt' | xargs -n1 dirname | xargs ./tool/go run github.com/tailscale/depaware --check + make depaware go_generate: runs-on: ubuntu-22.04 steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: check that 'go generate' is clean run: | pkgs=$(./tool/go list ./... | grep -Ev 'dnsfallback|k8s-operator|xdp') @@ -494,7 +517,7 @@ jobs: runs-on: ubuntu-22.04 steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: check that 'go mod tidy' is clean run: | ./tool/go mod tidy @@ -506,7 +529,7 @@ jobs: runs-on: ubuntu-22.04 steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: check licenses run: ./scripts/check_license_headers.sh . @@ -522,7 +545,7 @@ jobs: goarch: "386" steps: - name: checkout - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: install staticcheck run: GOBIN=~/.local/bin ./tool/go install honnef.co/go/tools/cmd/staticcheck - name: run staticcheck @@ -563,8 +586,10 @@ jobs: # By having the job always run, but skipping its only step as needed, we # let the CI output collapse nicely in PRs. if: failure() && github.event_name == 'push' - uses: slackapi/slack-github-action@37ebaef184d7626c5f204ab8d3baff4262dd30f0 # v1.27.0 + uses: slackapi/slack-github-action@485a9d42d3a73031f12ec201c457e2162c45d02d # v2.0.0 with: + webhook: ${{ secrets.SLACK_WEBHOOK_URL }} + webhook-type: incoming-webhook payload: | { "attachments": [{ @@ -576,9 +601,6 @@ jobs: "color": "danger" }] } - env: - SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} - SLACK_WEBHOOK_TYPE: INCOMING_WEBHOOK check_mergeability: if: always() @@ -587,7 +609,6 @@ jobs: - android - test - windows - - vm - cross - ios - wasm diff --git a/.github/workflows/update-flake.yml b/.github/workflows/update-flake.yml index f79248c1ed4e9..f695c578eca91 100644 --- a/.github/workflows/update-flake.yml +++ b/.github/workflows/update-flake.yml @@ -21,7 +21,7 @@ jobs: steps: - name: Check out code - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Run update-flakes run: ./update-flake.sh @@ -36,7 +36,7 @@ jobs: private_key: ${{ secrets.LICENSING_APP_PRIVATE_KEY }} - name: Send pull request - uses: peter-evans/create-pull-request@5e914681df9dc83aa4e4905692ca88beb2f9e91f #v7.0.5 + uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e #v7.0.8 with: token: ${{ steps.generate-token.outputs.token }} author: Flakes Updater diff --git a/.github/workflows/update-webclient-prebuilt.yml b/.github/workflows/update-webclient-prebuilt.yml index a0ae95cd77ba4..412836db78e7c 100644 --- a/.github/workflows/update-webclient-prebuilt.yml +++ b/.github/workflows/update-webclient-prebuilt.yml @@ -14,7 +14,7 @@ jobs: steps: - name: Check out code - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Run go get run: | @@ -35,7 +35,7 @@ jobs: - name: Send pull request id: pull-request - uses: peter-evans/create-pull-request@5e914681df9dc83aa4e4905692ca88beb2f9e91f #v7.0.5 + uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e #v7.0.8 with: token: ${{ steps.generate-token.outputs.token }} author: OSS Updater diff --git a/.github/workflows/webclient.yml b/.github/workflows/webclient.yml index 9afb7730d9a56..b1cfb7620f97d 100644 --- a/.github/workflows/webclient.yml +++ b/.github/workflows/webclient.yml @@ -24,7 +24,7 @@ jobs: steps: - name: Check out code - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Install deps run: ./tool/yarn --cwd client/web - name: Run lint diff --git a/.golangci.yml b/.golangci.yml index 45248de160409..917fc34bdea60 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,104 +1,114 @@ +version: "2" +# Configuration for how we run golangci-lint +# Timeout of 5m was the default in v1. +run: + timeout: 5m linters: # Don't enable any linters by default; just the ones that we explicitly # enable in the list below. - disable-all: true + default: none enable: - bidichk - - gofmt - - goimports - govet - misspell - revive - -# Configuration for how we run golangci-lint -run: - timeout: 5m - -issues: - # Excluding configuration per-path, per-linter, per-text and per-source - exclude-rules: - # These are forks of an upstream package and thus are exempt from stylistic - # changes that would make pulling in upstream changes harder. - - path: tempfork/.*\.go - text: "File is not `gofmt`-ed with `-s` `-r 'interface{} -> any'`" - - path: util/singleflight/.*\.go - text: "File is not `gofmt`-ed with `-s` `-r 'interface{} -> any'`" - -# Per-linter settings are contained in this top-level key -linters-settings: - # Enable all rules by default; we don't use invisible unicode runes. - bidichk: - - gofmt: - rewrite-rules: - - pattern: 'interface{}' - replacement: 'any' - - goimports: - - govet: + settings: # Matches what we use in corp as of 2023-12-07 - enable: - - asmdecl - - assign - - atomic - - bools - - buildtag - - cgocall - - copylocks - - deepequalerrors - - errorsas - - framepointer - - httpresponse - - ifaceassert - - loopclosure - - lostcancel - - nilfunc - - nilness - - printf - - reflectvaluecompare - - shift - - sigchanyzer - - sortslice - - stdmethods - - stringintconv - - structtag - - testinggoroutine - - tests - - unmarshal - - unreachable - - unsafeptr - - unusedresult - settings: - printf: - # List of print function names to check (in addition to default) - funcs: - - github.com/tailscale/tailscale/types/logger.Discard - # NOTE(andrew-d): this doesn't currently work because the printf - # analyzer doesn't support type declarations - #- github.com/tailscale/tailscale/types/logger.Logf - - misspell: - - revive: - enable-all-rules: false - ignore-generated-header: true + govet: + enable: + - asmdecl + - assign + - atomic + - bools + - buildtag + - cgocall + - copylocks + - deepequalerrors + - errorsas + - framepointer + - httpresponse + - ifaceassert + - loopclosure + - lostcancel + - nilfunc + - nilness + - printf + - reflectvaluecompare + - shift + - sigchanyzer + - sortslice + - stdmethods + - stringintconv + - structtag + - testinggoroutine + - tests + - unmarshal + - unreachable + - unsafeptr + - unusedresult + settings: + printf: + # List of print function names to check (in addition to default) + funcs: + - github.com/tailscale/tailscale/types/logger.Discard + # NOTE(andrew-d): this doesn't currently work because the printf + # analyzer doesn't support type declarations + #- github.com/tailscale/tailscale/types/logger.Logf + revive: + enable-all-rules: false + rules: + - name: atomic + - name: context-keys-type + - name: defer + arguments: [[ + # Calling 'recover' at the time a defer is registered (i.e. "defer recover()") has no effect. + "immediate-recover", + # Calling 'recover' outside of a deferred function has no effect + "recover", + # Returning values from a deferred function has no effect + "return", + ]] + - name: duplicated-imports + - name: errorf + - name: string-of-int + - name: time-equal + - name: unconditional-recursion + - name: useless-break + - name: waitgroup-by-value + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling rules: - - name: atomic - - name: context-keys-type - - name: defer - arguments: [[ - # Calling 'recover' at the time a defer is registered (i.e. "defer recover()") has no effect. - "immediate-recover", - # Calling 'recover' outside of a deferred function has no effect - "recover", - # Returning values from a deferred function has no effect - "return", - ]] - - name: duplicated-imports - - name: errorf - - name: string-of-int - - name: time-equal - - name: unconditional-recursion - - name: useless-break - - name: waitgroup-by-value + # These are forks of an upstream package and thus are exempt from stylistic + # changes that would make pulling in upstream changes harder. + - path: tempfork/.*\.go + text: File is not `gofmt`-ed with `-s` `-r 'interface{} -> any'` + - path: util/singleflight/.*\.go + text: File is not `gofmt`-ed with `-s` `-r 'interface{} -> any'` + paths: + - third_party$ + - builtin$ + - examples$ +formatters: + enable: + - gofmt + - goimports + settings: + gofmt: + rewrite-rules: + - pattern: interface{} + replacement: any + exclusions: + generated: lax + paths: + # the gofmt formatter would replace quite some `interface{}` types with + # `any` in the subdirs of "tempfork" which does not play well with the + # test in "sync_to_upstream_test.go" + - "tempfork/.*$" + - third_party$ + - builtin$ + - examples$ diff --git a/ALPINE.txt b/ALPINE.txt index 55b698c77f5d2..318956c3d51e2 100644 --- a/ALPINE.txt +++ b/ALPINE.txt @@ -1 +1 @@ -3.18 \ No newline at end of file +3.19 \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 4ad3d88d9577a..015022e49fc28 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,7 +27,7 @@ # $ docker exec tailscaled tailscale status -FROM golang:1.23-alpine AS build-env +FROM golang:1.24-alpine AS build-env WORKDIR /go/src/tailscale @@ -62,8 +62,10 @@ RUN GOARCH=$TARGETARCH go install -ldflags="\ -X tailscale.com/version.gitCommitStamp=$VERSION_GIT_HASH" \ -v ./cmd/tailscale ./cmd/tailscaled ./cmd/containerboot -FROM alpine:3.18 +FROM alpine:3.19 RUN apk add --no-cache ca-certificates iptables iproute2 ip6tables +RUN rm /sbin/iptables && ln -s /sbin/iptables-legacy /sbin/iptables +RUN rm /sbin/ip6tables && ln -s /sbin/ip6tables-legacy /sbin/ip6tables COPY --from=build-env /go/bin/* /usr/local/bin/ # For compat with the previous run.sh, although ideally you should be diff --git a/Dockerfile.base b/Dockerfile.base index eb4f0a02a8b75..b7e79a43c6fdf 100644 --- a/Dockerfile.base +++ b/Dockerfile.base @@ -1,5 +1,12 @@ # Copyright (c) Tailscale Inc & AUTHORS # SPDX-License-Identifier: BSD-3-Clause -FROM alpine:3.18 -RUN apk add --no-cache ca-certificates iptables iproute2 ip6tables iputils +FROM alpine:3.19 +RUN apk add --no-cache ca-certificates iptables iptables-legacy iproute2 ip6tables iputils +# Alpine 3.19 replaces legacy iptables with nftables based implementation. We +# can't be certain that all hosts that run Tailscale containers currently +# suppport nftables, so link back to legacy for backwards compatibility reasons. +# TODO(irbekrm): add some way how to determine if we still run on nodes that +# don't support nftables, so that we can eventually remove these symlinks. +RUN rm /sbin/iptables && ln -s /sbin/iptables-legacy /sbin/iptables +RUN rm /sbin/ip6tables && ln -s /sbin/ip6tables-legacy /sbin/ip6tables diff --git a/Makefile b/Makefile index 98c3d36cc1c9e..c30818c965b77 100644 --- a/Makefile +++ b/Makefile @@ -17,22 +17,26 @@ lint: ## Run golangci-lint updatedeps: ## Update depaware deps # depaware (via x/tools/go/packages) shells back to "go", so make sure the "go" # it finds in its $$PATH is the right one. - PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --update \ + PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --update --internal \ tailscale.com/cmd/tailscaled \ tailscale.com/cmd/tailscale \ tailscale.com/cmd/derper \ tailscale.com/cmd/k8s-operator \ tailscale.com/cmd/stund + PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --update -goos=linux,darwin,windows,android,ios --internal \ + tailscale.com/tsnet depaware: ## Run depaware checks # depaware (via x/tools/go/packages) shells back to "go", so make sure the "go" # it finds in its $$PATH is the right one. - PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --check \ + PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --check --internal \ tailscale.com/cmd/tailscaled \ tailscale.com/cmd/tailscale \ tailscale.com/cmd/derper \ tailscale.com/cmd/k8s-operator \ tailscale.com/cmd/stund + PATH="$$(./tool/go env GOROOT)/bin:$$PATH" ./tool/go run github.com/tailscale/depaware --check --goos=linux,darwin,windows,android,ios --internal \ + tailscale.com/tsnet buildwindows: ## Build tailscale CLI for windows/amd64 GOOS=windows GOARCH=amd64 ./tool/go install tailscale.com/cmd/tailscale tailscale.com/cmd/tailscaled @@ -100,7 +104,7 @@ publishdevoperator: ## Build and publish k8s-operator image to location specifie @test "${REPO}" != "ghcr.io/tailscale/tailscale" || (echo "REPO=... must not be ghcr.io/tailscale/tailscale" && exit 1) @test "${REPO}" != "tailscale/k8s-operator" || (echo "REPO=... must not be tailscale/k8s-operator" && exit 1) @test "${REPO}" != "ghcr.io/tailscale/k8s-operator" || (echo "REPO=... must not be ghcr.io/tailscale/k8s-operator" && exit 1) - TAGS="${TAGS}" REPOS=${REPO} PLATFORM=${PLATFORM} PUSH=true TARGET=operator ./build_docker.sh + TAGS="${TAGS}" REPOS=${REPO} PLATFORM=${PLATFORM} PUSH=true TARGET=k8s-operator ./build_docker.sh publishdevnameserver: ## Build and publish k8s-nameserver image to location specified by ${REPO} @test -n "${REPO}" || (echo "REPO=... required; e.g. REPO=ghcr.io/${USER}/tailscale" && exit 1) @@ -116,7 +120,6 @@ sshintegrationtest: ## Run the SSH integration tests in various Docker container GOOS=linux GOARCH=amd64 ./tool/go build -o ssh/tailssh/testcontainers/tailscaled ./cmd/tailscaled && \ echo "Testing on ubuntu:focal" && docker build --build-arg="BASE=ubuntu:focal" -t ssh-ubuntu-focal ssh/tailssh/testcontainers && \ echo "Testing on ubuntu:jammy" && docker build --build-arg="BASE=ubuntu:jammy" -t ssh-ubuntu-jammy ssh/tailssh/testcontainers && \ - echo "Testing on ubuntu:mantic" && docker build --build-arg="BASE=ubuntu:mantic" -t ssh-ubuntu-mantic ssh/tailssh/testcontainers && \ echo "Testing on ubuntu:noble" && docker build --build-arg="BASE=ubuntu:noble" -t ssh-ubuntu-noble ssh/tailssh/testcontainers && \ echo "Testing on alpine:latest" && docker build --build-arg="BASE=alpine:latest" -t ssh-alpine-latest ssh/tailssh/testcontainers diff --git a/README.md b/README.md index 4627d9780f0b5..2c9713a6f339c 100644 --- a/README.md +++ b/README.md @@ -71,8 +71,7 @@ We require [Developer Certificate of Origin](https://en.wikipedia.org/wiki/Developer_Certificate_of_Origin) `Signed-off-by` lines in commits. -See `git log` for our commit message style. It's basically the same as -[Go's style](https://github.com/golang/go/wiki/CommitMessage). +See [commit-messages.md](docs/commit-messages.md) (or skim `git log`) for our commit message style. ## About Us diff --git a/VERSION.txt b/VERSION.txt index 79e15fd49370a..209084d73d622 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -1.77.0 +1.84.2 diff --git a/appc/appconnector.go b/appc/appconnector.go index 671ced9534406..89c6c9aeb9aa7 100644 --- a/appc/appconnector.go +++ b/appc/appconnector.go @@ -18,7 +18,6 @@ import ( "sync" "time" - xmaps "golang.org/x/exp/maps" "golang.org/x/net/dns/dnsmessage" "tailscale.com/types/logger" "tailscale.com/types/views" @@ -290,12 +289,14 @@ func (e *AppConnector) updateDomains(domains []string) { toRemove = append(toRemove, netip.PrefixFrom(a, a.BitLen())) } } - if err := e.routeAdvertiser.UnadvertiseRoute(toRemove...); err != nil { - e.logf("failed to unadvertise routes on domain removal: %v: %v: %v", xmaps.Keys(oldDomains), toRemove, err) - } + e.queue.Add(func() { + if err := e.routeAdvertiser.UnadvertiseRoute(toRemove...); err != nil { + e.logf("failed to unadvertise routes on domain removal: %v: %v: %v", slicesx.MapKeys(oldDomains), toRemove, err) + } + }) } - e.logf("handling domains: %v and wildcards: %v", xmaps.Keys(e.domains), e.wildcards) + e.logf("handling domains: %v and wildcards: %v", slicesx.MapKeys(e.domains), e.wildcards) } // updateRoutes merges the supplied routes into the currently configured routes. The routes supplied @@ -311,11 +312,6 @@ func (e *AppConnector) updateRoutes(routes []netip.Prefix) { return } - if err := e.routeAdvertiser.AdvertiseRoute(routes...); err != nil { - e.logf("failed to advertise routes: %v: %v", routes, err) - return - } - var toRemove []netip.Prefix // If we're storing routes and know e.controlRoutes is a good @@ -339,9 +335,14 @@ nextRoute: } } - if err := e.routeAdvertiser.UnadvertiseRoute(toRemove...); err != nil { - e.logf("failed to unadvertise routes: %v: %v", toRemove, err) - } + e.queue.Add(func() { + if err := e.routeAdvertiser.AdvertiseRoute(routes...); err != nil { + e.logf("failed to advertise routes: %v: %v", routes, err) + } + if err := e.routeAdvertiser.UnadvertiseRoute(toRemove...); err != nil { + e.logf("failed to unadvertise routes: %v: %v", toRemove, err) + } + }) e.controlRoutes = routes if err := e.storeRoutesLocked(); err != nil { @@ -354,7 +355,7 @@ func (e *AppConnector) Domains() views.Slice[string] { e.mu.Lock() defer e.mu.Unlock() - return views.SliceOf(xmaps.Keys(e.domains)) + return views.SliceOf(slicesx.MapKeys(e.domains)) } // DomainRoutes returns a map of domains to resolved IP @@ -375,13 +376,13 @@ func (e *AppConnector) DomainRoutes() map[string][]netip.Addr { // response is being returned over the PeerAPI. The response is parsed and // matched against the configured domains, if matched the routeAdvertiser is // advised to advertise the discovered route. -func (e *AppConnector) ObserveDNSResponse(res []byte) { +func (e *AppConnector) ObserveDNSResponse(res []byte) error { var p dnsmessage.Parser if _, err := p.Start(res); err != nil { - return + return err } if err := p.SkipAllQuestions(); err != nil { - return + return err } // cnameChain tracks a chain of CNAMEs for a given query in order to reverse @@ -400,12 +401,12 @@ func (e *AppConnector) ObserveDNSResponse(res []byte) { break } if err != nil { - return + return err } if h.Class != dnsmessage.ClassINET { if err := p.SkipAnswer(); err != nil { - return + return err } continue } @@ -414,7 +415,7 @@ func (e *AppConnector) ObserveDNSResponse(res []byte) { case dnsmessage.TypeCNAME, dnsmessage.TypeA, dnsmessage.TypeAAAA: default: if err := p.SkipAnswer(); err != nil { - return + return err } continue @@ -428,7 +429,7 @@ func (e *AppConnector) ObserveDNSResponse(res []byte) { if h.Type == dnsmessage.TypeCNAME { res, err := p.CNAMEResource() if err != nil { - return + return err } cname := strings.TrimSuffix(strings.ToLower(res.CNAME.String()), ".") if len(cname) == 0 { @@ -442,20 +443,20 @@ func (e *AppConnector) ObserveDNSResponse(res []byte) { case dnsmessage.TypeA: r, err := p.AResource() if err != nil { - return + return err } addr := netip.AddrFrom4(r.A) mak.Set(&addressRecords, domain, append(addressRecords[domain], addr)) case dnsmessage.TypeAAAA: r, err := p.AAAAResource() if err != nil { - return + return err } addr := netip.AddrFrom16(r.AAAA) mak.Set(&addressRecords, domain, append(addressRecords[domain], addr)) default: if err := p.SkipAnswer(); err != nil { - return + return err } continue } @@ -486,6 +487,7 @@ func (e *AppConnector) ObserveDNSResponse(res []byte) { e.scheduleAdvertisement(domain, toAdvertise...) } } + return nil } // starting from the given domain that resolved to an address, find it, or any diff --git a/appc/appconnector_test.go b/appc/appconnector_test.go index 7dba8cebd9e8c..c13835f39ed9a 100644 --- a/appc/appconnector_test.go +++ b/appc/appconnector_test.go @@ -8,16 +8,17 @@ import ( "net/netip" "reflect" "slices" + "sync/atomic" "testing" "time" - xmaps "golang.org/x/exp/maps" "golang.org/x/net/dns/dnsmessage" "tailscale.com/appc/appctest" "tailscale.com/tstest" "tailscale.com/util/clientmetric" "tailscale.com/util/mak" "tailscale.com/util/must" + "tailscale.com/util/slicesx" ) func fakeStoreRoutes(*RouteInfo) error { return nil } @@ -50,7 +51,7 @@ func TestUpdateDomains(t *testing.T) { // domains are explicitly downcased on set. a.UpdateDomains([]string{"UP.EXAMPLE.COM"}) a.Wait(ctx) - if got, want := xmaps.Keys(a.domains), []string{"up.example.com"}; !slices.Equal(got, want) { + if got, want := slicesx.MapKeys(a.domains), []string{"up.example.com"}; !slices.Equal(got, want) { t.Errorf("got %v; want %v", got, want) } } @@ -69,7 +70,9 @@ func TestUpdateRoutes(t *testing.T) { a.updateDomains([]string{"*.example.com"}) // This route should be collapsed into the range - a.ObserveDNSResponse(dnsResponse("a.example.com.", "192.0.2.1")) + if err := a.ObserveDNSResponse(dnsResponse("a.example.com.", "192.0.2.1")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) if !slices.Equal(rc.Routes(), []netip.Prefix{netip.MustParsePrefix("192.0.2.1/32")}) { @@ -77,11 +80,14 @@ func TestUpdateRoutes(t *testing.T) { } // This route should not be collapsed or removed - a.ObserveDNSResponse(dnsResponse("b.example.com.", "192.0.0.1")) + if err := a.ObserveDNSResponse(dnsResponse("b.example.com.", "192.0.0.1")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) routes := []netip.Prefix{netip.MustParsePrefix("192.0.2.0/24"), netip.MustParsePrefix("192.0.0.1/32")} a.updateRoutes(routes) + a.Wait(ctx) slices.SortFunc(rc.Routes(), prefixCompare) rc.SetRoutes(slices.Compact(rc.Routes())) @@ -101,6 +107,7 @@ func TestUpdateRoutes(t *testing.T) { } func TestUpdateRoutesUnadvertisesContainedRoutes(t *testing.T) { + ctx := context.Background() for _, shouldStore := range []bool{false, true} { rc := &appctest.RouteCollector{} var a *AppConnector @@ -113,6 +120,7 @@ func TestUpdateRoutesUnadvertisesContainedRoutes(t *testing.T) { rc.SetRoutes([]netip.Prefix{netip.MustParsePrefix("192.0.2.1/32")}) routes := []netip.Prefix{netip.MustParsePrefix("192.0.2.0/24")} a.updateRoutes(routes) + a.Wait(ctx) if !slices.EqualFunc(routes, rc.Routes(), prefixEqual) { t.Fatalf("got %v, want %v", rc.Routes(), routes) @@ -130,7 +138,9 @@ func TestDomainRoutes(t *testing.T) { a = NewAppConnector(t.Logf, rc, nil, nil) } a.updateDomains([]string{"example.com"}) - a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")) + if err := a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(context.Background()) want := map[string][]netip.Addr{ @@ -155,7 +165,9 @@ func TestObserveDNSResponse(t *testing.T) { } // a has no domains configured, so it should not advertise any routes - a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")) + if err := a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } if got, want := rc.Routes(), ([]netip.Prefix)(nil); !slices.Equal(got, want) { t.Errorf("got %v; want %v", got, want) } @@ -163,7 +175,9 @@ func TestObserveDNSResponse(t *testing.T) { wantRoutes := []netip.Prefix{netip.MustParsePrefix("192.0.0.8/32")} a.updateDomains([]string{"example.com"}) - a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")) + if err := a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) { t.Errorf("got %v; want %v", got, want) @@ -172,7 +186,9 @@ func TestObserveDNSResponse(t *testing.T) { // a CNAME record chain should result in a route being added if the chain // matches a routed domain. a.updateDomains([]string{"www.example.com", "example.com"}) - a.ObserveDNSResponse(dnsCNAMEResponse("192.0.0.9", "www.example.com.", "chain.example.com.", "example.com.")) + if err := a.ObserveDNSResponse(dnsCNAMEResponse("192.0.0.9", "www.example.com.", "chain.example.com.", "example.com.")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) wantRoutes = append(wantRoutes, netip.MustParsePrefix("192.0.0.9/32")) if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) { @@ -181,7 +197,9 @@ func TestObserveDNSResponse(t *testing.T) { // a CNAME record chain should result in a route being added if the chain // even if only found in the middle of the chain - a.ObserveDNSResponse(dnsCNAMEResponse("192.0.0.10", "outside.example.org.", "www.example.com.", "example.org.")) + if err := a.ObserveDNSResponse(dnsCNAMEResponse("192.0.0.10", "outside.example.org.", "www.example.com.", "example.org.")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) wantRoutes = append(wantRoutes, netip.MustParsePrefix("192.0.0.10/32")) if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) { @@ -190,14 +208,18 @@ func TestObserveDNSResponse(t *testing.T) { wantRoutes = append(wantRoutes, netip.MustParsePrefix("2001:db8::1/128")) - a.ObserveDNSResponse(dnsResponse("example.com.", "2001:db8::1")) + if err := a.ObserveDNSResponse(dnsResponse("example.com.", "2001:db8::1")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) if got, want := rc.Routes(), wantRoutes; !slices.Equal(got, want) { t.Errorf("got %v; want %v", got, want) } // don't re-advertise routes that have already been advertised - a.ObserveDNSResponse(dnsResponse("example.com.", "2001:db8::1")) + if err := a.ObserveDNSResponse(dnsResponse("example.com.", "2001:db8::1")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) if !slices.Equal(rc.Routes(), wantRoutes) { t.Errorf("rc.Routes(): got %v; want %v", rc.Routes(), wantRoutes) @@ -207,7 +229,9 @@ func TestObserveDNSResponse(t *testing.T) { pfx := netip.MustParsePrefix("192.0.2.0/24") a.updateRoutes([]netip.Prefix{pfx}) wantRoutes = append(wantRoutes, pfx) - a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.2.1")) + if err := a.ObserveDNSResponse(dnsResponse("example.com.", "192.0.2.1")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) if !slices.Equal(rc.Routes(), wantRoutes) { t.Errorf("rc.Routes(): got %v; want %v", rc.Routes(), wantRoutes) @@ -230,7 +254,9 @@ func TestWildcardDomains(t *testing.T) { } a.updateDomains([]string{"*.example.com"}) - a.ObserveDNSResponse(dnsResponse("foo.example.com.", "192.0.0.8")) + if err := a.ObserveDNSResponse(dnsResponse("foo.example.com.", "192.0.0.8")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } a.Wait(ctx) if got, want := rc.Routes(), []netip.Prefix{netip.MustParsePrefix("192.0.0.8/32")}; !slices.Equal(got, want) { t.Errorf("routes: got %v; want %v", got, want) @@ -438,10 +464,16 @@ func TestUpdateDomainRouteRemoval(t *testing.T) { // adding domains doesn't immediately cause any routes to be advertised assertRoutes("update domains", []netip.Prefix{}, []netip.Prefix{}) - a.ObserveDNSResponse(dnsResponse("a.example.com.", "1.2.3.1")) - a.ObserveDNSResponse(dnsResponse("a.example.com.", "1.2.3.2")) - a.ObserveDNSResponse(dnsResponse("b.example.com.", "1.2.3.3")) - a.ObserveDNSResponse(dnsResponse("b.example.com.", "1.2.3.4")) + for _, res := range [][]byte{ + dnsResponse("a.example.com.", "1.2.3.1"), + dnsResponse("a.example.com.", "1.2.3.2"), + dnsResponse("b.example.com.", "1.2.3.3"), + dnsResponse("b.example.com.", "1.2.3.4"), + } { + if err := a.ObserveDNSResponse(res); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } + } a.Wait(ctx) // observing dns responses causes routes to be advertised assertRoutes("observed dns", prefixes("1.2.3.1/32", "1.2.3.2/32", "1.2.3.3/32", "1.2.3.4/32"), []netip.Prefix{}) @@ -487,10 +519,16 @@ func TestUpdateWildcardRouteRemoval(t *testing.T) { // adding domains doesn't immediately cause any routes to be advertised assertRoutes("update domains", []netip.Prefix{}, []netip.Prefix{}) - a.ObserveDNSResponse(dnsResponse("a.example.com.", "1.2.3.1")) - a.ObserveDNSResponse(dnsResponse("a.example.com.", "1.2.3.2")) - a.ObserveDNSResponse(dnsResponse("1.b.example.com.", "1.2.3.3")) - a.ObserveDNSResponse(dnsResponse("2.b.example.com.", "1.2.3.4")) + for _, res := range [][]byte{ + dnsResponse("a.example.com.", "1.2.3.1"), + dnsResponse("a.example.com.", "1.2.3.2"), + dnsResponse("1.b.example.com.", "1.2.3.3"), + dnsResponse("2.b.example.com.", "1.2.3.4"), + } { + if err := a.ObserveDNSResponse(res); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } + } a.Wait(ctx) // observing dns responses causes routes to be advertised assertRoutes("observed dns", prefixes("1.2.3.1/32", "1.2.3.2/32", "1.2.3.3/32", "1.2.3.4/32"), []netip.Prefix{}) @@ -602,3 +640,57 @@ func TestMetricBucketsAreSorted(t *testing.T) { t.Errorf("metricStoreRoutesNBuckets must be in order") } } + +// TestUpdateRoutesDeadlock is a regression test for a deadlock in +// LocalBackend<->AppConnector interaction. When using real LocalBackend as the +// routeAdvertiser, calls to Advertise/UnadvertiseRoutes can end up calling +// back into AppConnector via authReconfig. If everything is called +// synchronously, this results in a deadlock on AppConnector.mu. +func TestUpdateRoutesDeadlock(t *testing.T) { + ctx := context.Background() + rc := &appctest.RouteCollector{} + a := NewAppConnector(t.Logf, rc, &RouteInfo{}, fakeStoreRoutes) + + advertiseCalled := new(atomic.Bool) + unadvertiseCalled := new(atomic.Bool) + rc.AdvertiseCallback = func() { + // Call something that requires a.mu to be held. + a.DomainRoutes() + advertiseCalled.Store(true) + } + rc.UnadvertiseCallback = func() { + // Call something that requires a.mu to be held. + a.DomainRoutes() + unadvertiseCalled.Store(true) + } + + a.updateDomains([]string{"example.com"}) + a.Wait(ctx) + + // Trigger rc.AdveriseRoute. + a.updateRoutes( + []netip.Prefix{ + netip.MustParsePrefix("127.0.0.1/32"), + netip.MustParsePrefix("127.0.0.2/32"), + }, + ) + a.Wait(ctx) + // Trigger rc.UnadveriseRoute. + a.updateRoutes( + []netip.Prefix{ + netip.MustParsePrefix("127.0.0.1/32"), + }, + ) + a.Wait(ctx) + + if !advertiseCalled.Load() { + t.Error("AdvertiseRoute was not called") + } + if !unadvertiseCalled.Load() { + t.Error("UnadvertiseRoute was not called") + } + + if want := []netip.Prefix{netip.MustParsePrefix("127.0.0.1/32")}; !slices.Equal(slices.Compact(rc.Routes()), want) { + t.Fatalf("got %v, want %v", rc.Routes(), want) + } +} diff --git a/appc/appctest/appctest.go b/appc/appctest/appctest.go index aa77bc3b41044..9726a2b97d72b 100644 --- a/appc/appctest/appctest.go +++ b/appc/appctest/appctest.go @@ -11,12 +11,22 @@ import ( // RouteCollector is a test helper that collects the list of routes advertised type RouteCollector struct { + // AdvertiseCallback (optional) is called synchronously from + // AdvertiseRoute. + AdvertiseCallback func() + // UnadvertiseCallback (optional) is called synchronously from + // UnadvertiseRoute. + UnadvertiseCallback func() + routes []netip.Prefix removedRoutes []netip.Prefix } func (rc *RouteCollector) AdvertiseRoute(pfx ...netip.Prefix) error { rc.routes = append(rc.routes, pfx...) + if rc.AdvertiseCallback != nil { + rc.AdvertiseCallback() + } return nil } @@ -30,6 +40,9 @@ func (rc *RouteCollector) UnadvertiseRoute(toRemove ...netip.Prefix) error { rc.removedRoutes = append(rc.removedRoutes, r) } } + if rc.UnadvertiseCallback != nil { + rc.UnadvertiseCallback() + } return nil } diff --git a/atomicfile/atomicfile.go b/atomicfile/atomicfile.go index 5c18e85a896eb..b3c8c93da2af9 100644 --- a/atomicfile/atomicfile.go +++ b/atomicfile/atomicfile.go @@ -15,8 +15,9 @@ import ( ) // WriteFile writes data to filename+some suffix, then renames it into filename. -// The perm argument is ignored on Windows. If the target filename already -// exists but is not a regular file, WriteFile returns an error. +// The perm argument is ignored on Windows, but if the target filename already +// exists then the target file's attributes and ACLs are preserved. If the target +// filename already exists but is not a regular file, WriteFile returns an error. func WriteFile(filename string, data []byte, perm os.FileMode) (err error) { fi, err := os.Stat(filename) if err == nil && !fi.Mode().IsRegular() { @@ -47,5 +48,5 @@ func WriteFile(filename string, data []byte, perm os.FileMode) (err error) { if err := f.Close(); err != nil { return err } - return os.Rename(tmpName, filename) + return rename(tmpName, filename) } diff --git a/atomicfile/atomicfile_notwindows.go b/atomicfile/atomicfile_notwindows.go new file mode 100644 index 0000000000000..1ce2bb8acda7a --- /dev/null +++ b/atomicfile/atomicfile_notwindows.go @@ -0,0 +1,14 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package atomicfile + +import ( + "os" +) + +func rename(srcFile, destFile string) error { + return os.Rename(srcFile, destFile) +} diff --git a/atomicfile/atomicfile_windows.go b/atomicfile/atomicfile_windows.go new file mode 100644 index 0000000000000..c67762df2b56c --- /dev/null +++ b/atomicfile/atomicfile_windows.go @@ -0,0 +1,33 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package atomicfile + +import ( + "os" + + "golang.org/x/sys/windows" +) + +func rename(srcFile, destFile string) error { + // Use replaceFile when possible to preserve the original file's attributes and ACLs. + if err := replaceFile(destFile, srcFile); err == nil || err != windows.ERROR_FILE_NOT_FOUND { + return err + } + // destFile doesn't exist. Just do a normal rename. + return os.Rename(srcFile, destFile) +} + +func replaceFile(destFile, srcFile string) error { + destFile16, err := windows.UTF16PtrFromString(destFile) + if err != nil { + return err + } + + srcFile16, err := windows.UTF16PtrFromString(srcFile) + if err != nil { + return err + } + + return replaceFileW(destFile16, srcFile16, nil, 0, nil, nil) +} diff --git a/atomicfile/atomicfile_windows_test.go b/atomicfile/atomicfile_windows_test.go new file mode 100644 index 0000000000000..4dec1493e0224 --- /dev/null +++ b/atomicfile/atomicfile_windows_test.go @@ -0,0 +1,146 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package atomicfile + +import ( + "os" + "testing" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _SECURITY_RESOURCE_MANAGER_AUTHORITY = windows.SidIdentifierAuthority{[6]byte{0, 0, 0, 0, 0, 9}} + +// makeRandomSID generates a SID derived from a v4 GUID. +// This is basically the same algorithm used by browser sandboxes for generating +// random SIDs. +func makeRandomSID() (*windows.SID, error) { + guid, err := windows.GenerateGUID() + if err != nil { + return nil, err + } + + rids := *((*[4]uint32)(unsafe.Pointer(&guid))) + + var pSID *windows.SID + if err := windows.AllocateAndInitializeSid(&_SECURITY_RESOURCE_MANAGER_AUTHORITY, 4, rids[0], rids[1], rids[2], rids[3], 0, 0, 0, 0, &pSID); err != nil { + return nil, err + } + defer windows.FreeSid(pSID) + + // Make a copy that lives on the Go heap + return pSID.Copy() +} + +func getExistingFileSD(name string) (*windows.SECURITY_DESCRIPTOR, error) { + const infoFlags = windows.DACL_SECURITY_INFORMATION + return windows.GetNamedSecurityInfo(name, windows.SE_FILE_OBJECT, infoFlags) +} + +func getExistingFileDACL(name string) (*windows.ACL, error) { + sd, err := getExistingFileSD(name) + if err != nil { + return nil, err + } + + dacl, _, err := sd.DACL() + return dacl, err +} + +func addDenyACEForRandomSID(dacl *windows.ACL) (*windows.ACL, error) { + randomSID, err := makeRandomSID() + if err != nil { + return nil, err + } + + randomSIDTrustee := windows.TRUSTEE{nil, windows.NO_MULTIPLE_TRUSTEE, + windows.TRUSTEE_IS_SID, windows.TRUSTEE_IS_UNKNOWN, + windows.TrusteeValueFromSID(randomSID)} + + entries := []windows.EXPLICIT_ACCESS{ + { + windows.GENERIC_ALL, + windows.DENY_ACCESS, + windows.NO_INHERITANCE, + randomSIDTrustee, + }, + } + + return windows.ACLFromEntries(entries, dacl) +} + +func setExistingFileDACL(name string, dacl *windows.ACL) error { + return windows.SetNamedSecurityInfo(name, windows.SE_FILE_OBJECT, + windows.DACL_SECURITY_INFORMATION, nil, nil, dacl, nil) +} + +// makeOrigFileWithCustomDACL creates a new, temporary file with a custom +// DACL that we can check for later. It returns the name of the temporary +// file and the security descriptor for the file in SDDL format. +func makeOrigFileWithCustomDACL() (name, sddl string, err error) { + f, err := os.CreateTemp("", "foo*.tmp") + if err != nil { + return "", "", err + } + name = f.Name() + if err := f.Close(); err != nil { + return "", "", err + } + f = nil + defer func() { + if err != nil { + os.Remove(name) + } + }() + + dacl, err := getExistingFileDACL(name) + if err != nil { + return "", "", err + } + + // Add a harmless, deny-only ACE for a random SID that isn't used for anything + // (but that we can check for later). + dacl, err = addDenyACEForRandomSID(dacl) + if err != nil { + return "", "", err + } + + if err := setExistingFileDACL(name, dacl); err != nil { + return "", "", err + } + + sd, err := getExistingFileSD(name) + if err != nil { + return "", "", err + } + + return name, sd.String(), nil +} + +func TestPreserveSecurityInfo(t *testing.T) { + // Make a test file with a custom ACL. + origFileName, want, err := makeOrigFileWithCustomDACL() + if err != nil { + t.Fatalf("makeOrigFileWithCustomDACL returned %v", err) + } + t.Cleanup(func() { + os.Remove(origFileName) + }) + + if err := WriteFile(origFileName, []byte{}, 0); err != nil { + t.Fatalf("WriteFile returned %v", err) + } + + // We expect origFileName's security descriptor to be unchanged despite + // the WriteFile call. + sd, err := getExistingFileSD(origFileName) + if err != nil { + t.Fatalf("getExistingFileSD(%q) returned %v", origFileName, err) + } + + if got := sd.String(); got != want { + t.Errorf("security descriptor comparison failed: got %q, want %q", got, want) + } +} diff --git a/atomicfile/mksyscall.go b/atomicfile/mksyscall.go new file mode 100644 index 0000000000000..d8951a77c5ac6 --- /dev/null +++ b/atomicfile/mksyscall.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package atomicfile + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go + +//sys replaceFileW(replaced *uint16, replacement *uint16, backup *uint16, flags uint32, exclude unsafe.Pointer, reserved unsafe.Pointer) (err error) [int32(failretval)==0] = kernel32.ReplaceFileW diff --git a/atomicfile/zsyscall_windows.go b/atomicfile/zsyscall_windows.go new file mode 100644 index 0000000000000..f2f0b6d08cbb7 --- /dev/null +++ b/atomicfile/zsyscall_windows.go @@ -0,0 +1,52 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package atomicfile + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + + procReplaceFileW = modkernel32.NewProc("ReplaceFileW") +) + +func replaceFileW(replaced *uint16, replacement *uint16, backup *uint16, flags uint32, exclude unsafe.Pointer, reserved unsafe.Pointer) (err error) { + r1, _, e1 := syscall.Syscall6(procReplaceFileW.Addr(), 6, uintptr(unsafe.Pointer(replaced)), uintptr(unsafe.Pointer(replacement)), uintptr(unsafe.Pointer(backup)), uintptr(flags), uintptr(exclude), uintptr(reserved)) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} diff --git a/build_dist.sh b/build_dist.sh index 66afa8f745c4c..fed37c2646175 100755 --- a/build_dist.sh +++ b/build_dist.sh @@ -28,18 +28,26 @@ EOF exit 0 fi -tags="" +tags="${TAGS:-}" ldflags="-X tailscale.com/version.longStamp=${VERSION_LONG} -X tailscale.com/version.shortStamp=${VERSION_SHORT}" # build_dist.sh arguments must precede go build arguments. while [ "$#" -gt 1 ]; do case "$1" in --extra-small) + if [ ! -z "${TAGS:-}" ]; then + echo "set either --extra-small or \$TAGS, but not both" + exit 1 + fi shift ldflags="$ldflags -w -s" - tags="${tags:+$tags,}ts_omit_aws,ts_omit_bird,ts_omit_tap,ts_omit_kube,ts_omit_completion" + tags="${tags:+$tags,}ts_omit_aws,ts_omit_bird,ts_omit_tap,ts_omit_kube,ts_omit_completion,ts_omit_ssh,ts_omit_wakeonlan,ts_omit_capture,ts_omit_relayserver,ts_omit_taildrop,ts_omit_tpm" ;; --box) + if [ ! -z "${TAGS:-}" ]; then + echo "set either --box or \$TAGS, but not both" + exit 1 + fi shift tags="${tags:+$tags,}ts_include_cli" ;; diff --git a/build_docker.sh b/build_docker.sh index e8b1c8f28f450..15105c2ef5541 100755 --- a/build_docker.sh +++ b/build_docker.sh @@ -16,13 +16,21 @@ eval "$(./build_dist.sh shellvars)" DEFAULT_TARGET="client" DEFAULT_TAGS="v${VERSION_SHORT},v${VERSION_MINOR}" -DEFAULT_BASE="tailscale/alpine-base:3.18" +DEFAULT_BASE="tailscale/alpine-base:3.19" +# Set a few pre-defined OCI annotations. The source annotation is used by tools such as Renovate that scan the linked +# Github repo to find release notes for any new image tags. Note that for official Tailscale images the default +# annotations defined here will be overriden by release scripts that call this script. +# https://github.com/opencontainers/image-spec/blob/main/annotations.md#pre-defined-annotation-keys +DEFAULT_ANNOTATIONS="org.opencontainers.image.source=https://github.com/tailscale/tailscale/blob/main/build_docker.sh,org.opencontainers.image.vendor=Tailscale" PUSH="${PUSH:-false}" TARGET="${TARGET:-${DEFAULT_TARGET}}" TAGS="${TAGS:-${DEFAULT_TAGS}}" BASE="${BASE:-${DEFAULT_BASE}}" PLATFORM="${PLATFORM:-}" # default to all platforms +# OCI annotations that will be added to the image. +# https://github.com/opencontainers/image-spec/blob/main/annotations.md +ANNOTATIONS="${ANNOTATIONS:-${DEFAULT_ANNOTATIONS}}" case "$TARGET" in client) @@ -43,9 +51,10 @@ case "$TARGET" in --repos="${REPOS}" \ --push="${PUSH}" \ --target="${PLATFORM}" \ + --annotations="${ANNOTATIONS}" \ /usr/local/bin/containerboot ;; - operator) + k8s-operator) DEFAULT_REPOS="tailscale/k8s-operator" REPOS="${REPOS:-${DEFAULT_REPOS}}" go run github.com/tailscale/mkctr \ @@ -60,6 +69,7 @@ case "$TARGET" in --repos="${REPOS}" \ --push="${PUSH}" \ --target="${PLATFORM}" \ + --annotations="${ANNOTATIONS}" \ /usr/local/bin/operator ;; k8s-nameserver) @@ -77,6 +87,7 @@ case "$TARGET" in --repos="${REPOS}" \ --push="${PUSH}" \ --target="${PLATFORM}" \ + --annotations="${ANNOTATIONS}" \ /usr/local/bin/k8s-nameserver ;; *) diff --git a/client/tailscale/localclient.go b/client/local/local.go similarity index 78% rename from client/tailscale/localclient.go rename to client/local/local.go index df51dc1cab52c..0e4d495d3dd18 100644 --- a/client/tailscale/localclient.go +++ b/client/local/local.go @@ -1,15 +1,17 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build go1.19 +//go:build go1.22 -package tailscale +// Package local contains a Go client for the Tailscale LocalAPI. +package local import ( "bytes" "cmp" "context" "crypto/tls" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -40,13 +42,14 @@ import ( "tailscale.com/types/dnstype" "tailscale.com/types/key" "tailscale.com/types/tkatype" + "tailscale.com/util/syspolicy/setting" ) -// defaultLocalClient is the default LocalClient when using the legacy +// defaultClient is the default Client when using the legacy // package-level functions. -var defaultLocalClient LocalClient +var defaultClient Client -// LocalClient is a client to Tailscale's "LocalAPI", communicating with the +// Client is a client to Tailscale's "LocalAPI", communicating with the // Tailscale daemon on the local machine. Its API is not necessarily stable and // subject to changes between releases. Some API calls have stricter // compatibility guarantees, once they've been widely adopted. See method docs @@ -56,11 +59,17 @@ var defaultLocalClient LocalClient // // Any exported fields should be set before using methods on the type // and not changed thereafter. -type LocalClient struct { +type Client struct { // Dial optionally specifies an alternate func that connects to the local // machine's tailscaled or equivalent. If nil, a default is used. Dial func(ctx context.Context, network, addr string) (net.Conn, error) + // Transport optionally specifies an alternate [http.RoundTripper] + // used to execute HTTP requests. If nil, a default [http.Transport] is used, + // potentially with custom dialing logic from [Dial]. + // It is primarily used for testing. + Transport http.RoundTripper + // Socket specifies an alternate path to the local Tailscale socket. // If empty, a platform-specific default is used. Socket string @@ -84,21 +93,21 @@ type LocalClient struct { tsClientOnce sync.Once } -func (lc *LocalClient) socket() string { +func (lc *Client) socket() string { if lc.Socket != "" { return lc.Socket } return paths.DefaultTailscaledSocket() } -func (lc *LocalClient) dialer() func(ctx context.Context, network, addr string) (net.Conn, error) { +func (lc *Client) dialer() func(ctx context.Context, network, addr string) (net.Conn, error) { if lc.Dial != nil { return lc.Dial } return lc.defaultDialer } -func (lc *LocalClient) defaultDialer(ctx context.Context, network, addr string) (net.Conn, error) { +func (lc *Client) defaultDialer(ctx context.Context, network, addr string) (net.Conn, error) { if addr != "local-tailscaled.sock:80" { return nil, fmt.Errorf("unexpected URL address %q", addr) } @@ -124,13 +133,13 @@ func (lc *LocalClient) defaultDialer(ctx context.Context, network, addr string) // authenticating to the local Tailscale daemon vary by platform. // // DoLocalRequest may mutate the request to add Authorization headers. -func (lc *LocalClient) DoLocalRequest(req *http.Request) (*http.Response, error) { +func (lc *Client) DoLocalRequest(req *http.Request) (*http.Response, error) { req.Header.Set("Tailscale-Cap", strconv.Itoa(int(tailcfg.CurrentCapabilityVersion))) lc.tsClientOnce.Do(func() { lc.tsClient = &http.Client{ - Transport: &http.Transport{ - DialContext: lc.dialer(), - }, + Transport: cmp.Or(lc.Transport, http.RoundTripper( + &http.Transport{DialContext: lc.dialer()}), + ), } }) if !lc.OmitAuth { @@ -141,7 +150,7 @@ func (lc *LocalClient) DoLocalRequest(req *http.Request) (*http.Response, error) return lc.tsClient.Do(req) } -func (lc *LocalClient) doLocalRequestNiceError(req *http.Request) (*http.Response, error) { +func (lc *Client) doLocalRequestNiceError(req *http.Request) (*http.Response, error) { res, err := lc.DoLocalRequest(req) if err == nil { if server := res.Header.Get("Tailscale-Version"); server != "" && server != envknob.IPCVersion() && onVersionMismatch != nil { @@ -230,12 +239,17 @@ func SetVersionMismatchHandler(f func(clientVer, serverVer string)) { onVersionMismatch = f } -func (lc *LocalClient) send(ctx context.Context, method, path string, wantStatus int, body io.Reader) ([]byte, error) { - slurp, _, err := lc.sendWithHeaders(ctx, method, path, wantStatus, body, nil) +func (lc *Client) send(ctx context.Context, method, path string, wantStatus int, body io.Reader) ([]byte, error) { + var headers http.Header + if reason := apitype.RequestReasonKey.Value(ctx); reason != "" { + reasonBase64 := base64.StdEncoding.EncodeToString([]byte(reason)) + headers = http.Header{apitype.RequestReasonHeader: {reasonBase64}} + } + slurp, _, err := lc.sendWithHeaders(ctx, method, path, wantStatus, body, headers) return slurp, err } -func (lc *LocalClient) sendWithHeaders( +func (lc *Client) sendWithHeaders( ctx context.Context, method, path string, @@ -274,15 +288,15 @@ type httpStatusError struct { HTTPStatus int } -func (lc *LocalClient) get200(ctx context.Context, path string) ([]byte, error) { +func (lc *Client) get200(ctx context.Context, path string) ([]byte, error) { return lc.send(ctx, "GET", path, 200, nil) } // WhoIs returns the owner of the remoteAddr, which must be an IP or IP:port. // -// Deprecated: use LocalClient.WhoIs. +// Deprecated: use [Client.WhoIs]. func WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { - return defaultLocalClient.WhoIs(ctx, remoteAddr) + return defaultClient.WhoIs(ctx, remoteAddr) } func decodeJSON[T any](b []byte) (ret T, err error) { @@ -295,12 +309,12 @@ func decodeJSON[T any](b []byte) (ret T, err error) { // WhoIs returns the owner of the remoteAddr, which must be an IP or IP:port. // -// If not found, the error is ErrPeerNotFound. +// If not found, the error is [ErrPeerNotFound]. // // For connections proxied by tailscaled, this looks up the owner of the given // address as TCP first, falling back to UDP; if you want to only check a // specific address family, use WhoIsProto. -func (lc *LocalClient) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { +func (lc *Client) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { body, err := lc.get200(ctx, "/localapi/v0/whois?addr="+url.QueryEscape(remoteAddr)) if err != nil { if hs, ok := err.(httpStatusError); ok && hs.HTTPStatus == http.StatusNotFound { @@ -311,13 +325,14 @@ func (lc *LocalClient) WhoIs(ctx context.Context, remoteAddr string) (*apitype.W return decodeJSON[*apitype.WhoIsResponse](body) } -// ErrPeerNotFound is returned by WhoIs and WhoIsNodeKey when a peer is not found. +// ErrPeerNotFound is returned by [Client.WhoIs], [Client.WhoIsNodeKey] and +// [Client.WhoIsProto] when a peer is not found. var ErrPeerNotFound = errors.New("peer not found") // WhoIsNodeKey returns the owner of the given wireguard public key. // // If not found, the error is ErrPeerNotFound. -func (lc *LocalClient) WhoIsNodeKey(ctx context.Context, key key.NodePublic) (*apitype.WhoIsResponse, error) { +func (lc *Client) WhoIsNodeKey(ctx context.Context, key key.NodePublic) (*apitype.WhoIsResponse, error) { body, err := lc.get200(ctx, "/localapi/v0/whois?addr="+url.QueryEscape(key.String())) if err != nil { if hs, ok := err.(httpStatusError); ok && hs.HTTPStatus == http.StatusNotFound { @@ -331,8 +346,8 @@ func (lc *LocalClient) WhoIsNodeKey(ctx context.Context, key key.NodePublic) (*a // WhoIsProto returns the owner of the remoteAddr, which must be an IP or // IP:port, for the given protocol (tcp or udp). // -// If not found, the error is ErrPeerNotFound. -func (lc *LocalClient) WhoIsProto(ctx context.Context, proto, remoteAddr string) (*apitype.WhoIsResponse, error) { +// If not found, the error is [ErrPeerNotFound]. +func (lc *Client) WhoIsProto(ctx context.Context, proto, remoteAddr string) (*apitype.WhoIsResponse, error) { body, err := lc.get200(ctx, "/localapi/v0/whois?proto="+url.QueryEscape(proto)+"&addr="+url.QueryEscape(remoteAddr)) if err != nil { if hs, ok := err.(httpStatusError); ok && hs.HTTPStatus == http.StatusNotFound { @@ -344,19 +359,19 @@ func (lc *LocalClient) WhoIsProto(ctx context.Context, proto, remoteAddr string) } // Goroutines returns a dump of the Tailscale daemon's current goroutines. -func (lc *LocalClient) Goroutines(ctx context.Context) ([]byte, error) { +func (lc *Client) Goroutines(ctx context.Context) ([]byte, error) { return lc.get200(ctx, "/localapi/v0/goroutines") } // DaemonMetrics returns the Tailscale daemon's metrics in // the Prometheus text exposition format. -func (lc *LocalClient) DaemonMetrics(ctx context.Context) ([]byte, error) { +func (lc *Client) DaemonMetrics(ctx context.Context) ([]byte, error) { return lc.get200(ctx, "/localapi/v0/metrics") } // UserMetrics returns the user metrics in // the Prometheus text exposition format. -func (lc *LocalClient) UserMetrics(ctx context.Context) ([]byte, error) { +func (lc *Client) UserMetrics(ctx context.Context) ([]byte, error) { return lc.get200(ctx, "/localapi/v0/usermetrics") } @@ -365,7 +380,7 @@ func (lc *LocalClient) UserMetrics(ctx context.Context) ([]byte, error) { // metric is created and initialized to delta. // // IncrementCounter does not support gauge metrics or negative delta values. -func (lc *LocalClient) IncrementCounter(ctx context.Context, name string, delta int) error { +func (lc *Client) IncrementCounter(ctx context.Context, name string, delta int) error { type metricUpdate struct { Name string `json:"name"` Type string `json:"type"` @@ -384,7 +399,7 @@ func (lc *LocalClient) IncrementCounter(ctx context.Context, name string, delta // TailDaemonLogs returns a stream the Tailscale daemon's logs as they arrive. // Close the context to stop the stream. -func (lc *LocalClient) TailDaemonLogs(ctx context.Context) (io.Reader, error) { +func (lc *Client) TailDaemonLogs(ctx context.Context) (io.Reader, error) { req, err := http.NewRequestWithContext(ctx, "GET", "http://"+apitype.LocalAPIHost+"/localapi/v0/logtap", nil) if err != nil { return nil, err @@ -400,7 +415,7 @@ func (lc *LocalClient) TailDaemonLogs(ctx context.Context) (io.Reader, error) { } // Pprof returns a pprof profile of the Tailscale daemon. -func (lc *LocalClient) Pprof(ctx context.Context, pprofType string, sec int) ([]byte, error) { +func (lc *Client) Pprof(ctx context.Context, pprofType string, sec int) ([]byte, error) { var secArg string if sec < 0 || sec > 300 { return nil, errors.New("duration out of range") @@ -433,7 +448,7 @@ type BugReportOpts struct { // // The opts type specifies options to pass to the Tailscale daemon when // generating this bug report. -func (lc *LocalClient) BugReportWithOpts(ctx context.Context, opts BugReportOpts) (string, error) { +func (lc *Client) BugReportWithOpts(ctx context.Context, opts BugReportOpts) (string, error) { qparams := make(url.Values) if opts.Note != "" { qparams.Set("note", opts.Note) @@ -476,15 +491,15 @@ func (lc *LocalClient) BugReportWithOpts(ctx context.Context, opts BugReportOpts // BugReport logs and returns a log marker that can be shared by the user with support. // -// This is the same as calling BugReportWithOpts and only specifying the Note +// This is the same as calling [Client.BugReportWithOpts] and only specifying the Note // field. -func (lc *LocalClient) BugReport(ctx context.Context, note string) (string, error) { +func (lc *Client) BugReport(ctx context.Context, note string) (string, error) { return lc.BugReportWithOpts(ctx, BugReportOpts{Note: note}) } // DebugAction invokes a debug action, such as "rebind" or "restun". // These are development tools and subject to change or removal over time. -func (lc *LocalClient) DebugAction(ctx context.Context, action string) error { +func (lc *Client) DebugAction(ctx context.Context, action string) error { body, err := lc.send(ctx, "POST", "/localapi/v0/debug?action="+url.QueryEscape(action), 200, nil) if err != nil { return fmt.Errorf("error %w: %s", err, body) @@ -492,9 +507,20 @@ func (lc *LocalClient) DebugAction(ctx context.Context, action string) error { return nil } +// DebugActionBody invokes a debug action with a body parameter, such as +// "debug-force-prefer-derp". +// These are development tools and subject to change or removal over time. +func (lc *Client) DebugActionBody(ctx context.Context, action string, rbody io.Reader) error { + body, err := lc.send(ctx, "POST", "/localapi/v0/debug?action="+url.QueryEscape(action), 200, rbody) + if err != nil { + return fmt.Errorf("error %w: %s", err, body) + } + return nil +} + // DebugResultJSON invokes a debug action and returns its result as something JSON-able. // These are development tools and subject to change or removal over time. -func (lc *LocalClient) DebugResultJSON(ctx context.Context, action string) (any, error) { +func (lc *Client) DebugResultJSON(ctx context.Context, action string) (any, error) { body, err := lc.send(ctx, "POST", "/localapi/v0/debug?action="+url.QueryEscape(action), 200, nil) if err != nil { return nil, fmt.Errorf("error %w: %s", err, body) @@ -506,7 +532,7 @@ func (lc *LocalClient) DebugResultJSON(ctx context.Context, action string) (any, return x, nil } -// DebugPortmapOpts contains options for the DebugPortmap command. +// DebugPortmapOpts contains options for the [Client.DebugPortmap] command. type DebugPortmapOpts struct { // Duration is how long the mapping should be created for. It defaults // to 5 seconds if not set. @@ -537,7 +563,7 @@ type DebugPortmapOpts struct { // process. // // opts can be nil; if so, default values will be used. -func (lc *LocalClient) DebugPortmap(ctx context.Context, opts *DebugPortmapOpts) (io.ReadCloser, error) { +func (lc *Client) DebugPortmap(ctx context.Context, opts *DebugPortmapOpts) (io.ReadCloser, error) { vals := make(url.Values) if opts == nil { opts = &DebugPortmapOpts{} @@ -572,7 +598,7 @@ func (lc *LocalClient) DebugPortmap(ctx context.Context, opts *DebugPortmapOpts) // SetDevStoreKeyValue set a statestore key/value. It's only meant for development. // The schema (including when keys are re-read) is not a stable interface. -func (lc *LocalClient) SetDevStoreKeyValue(ctx context.Context, key, value string) error { +func (lc *Client) SetDevStoreKeyValue(ctx context.Context, key, value string) error { body, err := lc.send(ctx, "POST", "/localapi/v0/dev-set-state-store?"+(url.Values{ "key": {key}, "value": {value}, @@ -586,7 +612,7 @@ func (lc *LocalClient) SetDevStoreKeyValue(ctx context.Context, key, value strin // SetComponentDebugLogging sets component's debug logging enabled for // the provided duration. If the duration is in the past, the debug logging // is disabled. -func (lc *LocalClient) SetComponentDebugLogging(ctx context.Context, component string, d time.Duration) error { +func (lc *Client) SetComponentDebugLogging(ctx context.Context, component string, d time.Duration) error { body, err := lc.send(ctx, "POST", fmt.Sprintf("/localapi/v0/component-debug-logging?component=%s&secs=%d", url.QueryEscape(component), int64(d.Seconds())), 200, nil) @@ -607,25 +633,25 @@ func (lc *LocalClient) SetComponentDebugLogging(ctx context.Context, component s // Status returns the Tailscale daemon's status. func Status(ctx context.Context) (*ipnstate.Status, error) { - return defaultLocalClient.Status(ctx) + return defaultClient.Status(ctx) } // Status returns the Tailscale daemon's status. -func (lc *LocalClient) Status(ctx context.Context) (*ipnstate.Status, error) { +func (lc *Client) Status(ctx context.Context) (*ipnstate.Status, error) { return lc.status(ctx, "") } // StatusWithoutPeers returns the Tailscale daemon's status, without the peer info. func StatusWithoutPeers(ctx context.Context) (*ipnstate.Status, error) { - return defaultLocalClient.StatusWithoutPeers(ctx) + return defaultClient.StatusWithoutPeers(ctx) } // StatusWithoutPeers returns the Tailscale daemon's status, without the peer info. -func (lc *LocalClient) StatusWithoutPeers(ctx context.Context) (*ipnstate.Status, error) { +func (lc *Client) StatusWithoutPeers(ctx context.Context) (*ipnstate.Status, error) { return lc.status(ctx, "?peers=false") } -func (lc *LocalClient) status(ctx context.Context, queryString string) (*ipnstate.Status, error) { +func (lc *Client) status(ctx context.Context, queryString string) (*ipnstate.Status, error) { body, err := lc.get200(ctx, "/localapi/v0/status"+queryString) if err != nil { return nil, err @@ -636,7 +662,7 @@ func (lc *LocalClient) status(ctx context.Context, queryString string) (*ipnstat // IDToken is a request to get an OIDC ID token for an audience. // The token can be presented to any resource provider which offers OIDC // Federation. -func (lc *LocalClient) IDToken(ctx context.Context, aud string) (*tailcfg.TokenResponse, error) { +func (lc *Client) IDToken(ctx context.Context, aud string) (*tailcfg.TokenResponse, error) { body, err := lc.get200(ctx, "/localapi/v0/id-token?aud="+url.QueryEscape(aud)) if err != nil { return nil, err @@ -648,14 +674,14 @@ func (lc *LocalClient) IDToken(ctx context.Context, aud string) (*tailcfg.TokenR // received by the Tailscale daemon in its staging/cache directory but not yet // transferred by the user's CLI or GUI client and written to a user's home // directory somewhere. -func (lc *LocalClient) WaitingFiles(ctx context.Context) ([]apitype.WaitingFile, error) { +func (lc *Client) WaitingFiles(ctx context.Context) ([]apitype.WaitingFile, error) { return lc.AwaitWaitingFiles(ctx, 0) } -// AwaitWaitingFiles is like WaitingFiles but takes a duration to await for an answer. +// AwaitWaitingFiles is like [Client.WaitingFiles] but takes a duration to await for an answer. // If the duration is 0, it will return immediately. The duration is respected at second // granularity only. If no files are available, it returns (nil, nil). -func (lc *LocalClient) AwaitWaitingFiles(ctx context.Context, d time.Duration) ([]apitype.WaitingFile, error) { +func (lc *Client) AwaitWaitingFiles(ctx context.Context, d time.Duration) ([]apitype.WaitingFile, error) { path := "/localapi/v0/files/?waitsec=" + fmt.Sprint(int(d.Seconds())) body, err := lc.get200(ctx, path) if err != nil { @@ -664,12 +690,12 @@ func (lc *LocalClient) AwaitWaitingFiles(ctx context.Context, d time.Duration) ( return decodeJSON[[]apitype.WaitingFile](body) } -func (lc *LocalClient) DeleteWaitingFile(ctx context.Context, baseName string) error { +func (lc *Client) DeleteWaitingFile(ctx context.Context, baseName string) error { _, err := lc.send(ctx, "DELETE", "/localapi/v0/files/"+url.PathEscape(baseName), http.StatusNoContent, nil) return err } -func (lc *LocalClient) GetWaitingFile(ctx context.Context, baseName string) (rc io.ReadCloser, size int64, err error) { +func (lc *Client) GetWaitingFile(ctx context.Context, baseName string) (rc io.ReadCloser, size int64, err error) { req, err := http.NewRequestWithContext(ctx, "GET", "http://"+apitype.LocalAPIHost+"/localapi/v0/files/"+url.PathEscape(baseName), nil) if err != nil { return nil, 0, err @@ -690,7 +716,7 @@ func (lc *LocalClient) GetWaitingFile(ctx context.Context, baseName string) (rc return res.Body, res.ContentLength, nil } -func (lc *LocalClient) FileTargets(ctx context.Context) ([]apitype.FileTarget, error) { +func (lc *Client) FileTargets(ctx context.Context) ([]apitype.FileTarget, error) { body, err := lc.get200(ctx, "/localapi/v0/file-targets") if err != nil { return nil, err @@ -702,7 +728,7 @@ func (lc *LocalClient) FileTargets(ctx context.Context) ([]apitype.FileTarget, e // // A size of -1 means unknown. // The name parameter is the original filename, not escaped. -func (lc *LocalClient) PushFile(ctx context.Context, target tailcfg.StableNodeID, size int64, name string, r io.Reader) error { +func (lc *Client) PushFile(ctx context.Context, target tailcfg.StableNodeID, size int64, name string, r io.Reader) error { req, err := http.NewRequestWithContext(ctx, "PUT", "http://"+apitype.LocalAPIHost+"/localapi/v0/file-put/"+string(target)+"/"+url.PathEscape(name), r) if err != nil { return err @@ -725,7 +751,7 @@ func (lc *LocalClient) PushFile(ctx context.Context, target tailcfg.StableNodeID // CheckIPForwarding asks the local Tailscale daemon whether it looks like the // machine is properly configured to forward IP packets as a subnet router // or exit node. -func (lc *LocalClient) CheckIPForwarding(ctx context.Context) error { +func (lc *Client) CheckIPForwarding(ctx context.Context) error { body, err := lc.get200(ctx, "/localapi/v0/check-ip-forwarding") if err != nil { return err @@ -745,7 +771,7 @@ func (lc *LocalClient) CheckIPForwarding(ctx context.Context) error { // CheckUDPGROForwarding asks the local Tailscale daemon whether it looks like // the machine is optimally configured to forward UDP packets as a subnet router // or exit node. -func (lc *LocalClient) CheckUDPGROForwarding(ctx context.Context) error { +func (lc *Client) CheckUDPGROForwarding(ctx context.Context) error { body, err := lc.get200(ctx, "/localapi/v0/check-udp-gro-forwarding") if err != nil { return err @@ -766,7 +792,7 @@ func (lc *LocalClient) CheckUDPGROForwarding(ctx context.Context) error { // node. This can be done to improve performance of tailnet nodes acting as exit // nodes or subnet routers. // See https://tailscale.com/kb/1320/performance-best-practices#linux-optimizations-for-subnet-routers-and-exit-nodes -func (lc *LocalClient) SetUDPGROForwarding(ctx context.Context) error { +func (lc *Client) SetUDPGROForwarding(ctx context.Context) error { body, err := lc.get200(ctx, "/localapi/v0/set-udp-gro-forwarding") if err != nil { return err @@ -789,12 +815,12 @@ func (lc *LocalClient) SetUDPGROForwarding(ctx context.Context) error { // work. Currently (2022-04-18) this only checks for SSH server compatibility. // Note that EditPrefs does the same validation as this, so call CheckPrefs before // EditPrefs is not necessary. -func (lc *LocalClient) CheckPrefs(ctx context.Context, p *ipn.Prefs) error { +func (lc *Client) CheckPrefs(ctx context.Context, p *ipn.Prefs) error { _, err := lc.send(ctx, "POST", "/localapi/v0/check-prefs", http.StatusOK, jsonBody(p)) return err } -func (lc *LocalClient) GetPrefs(ctx context.Context) (*ipn.Prefs, error) { +func (lc *Client) GetPrefs(ctx context.Context) (*ipn.Prefs, error) { body, err := lc.get200(ctx, "/localapi/v0/prefs") if err != nil { return nil, err @@ -806,7 +832,12 @@ func (lc *LocalClient) GetPrefs(ctx context.Context) (*ipn.Prefs, error) { return &p, nil } -func (lc *LocalClient) EditPrefs(ctx context.Context, mp *ipn.MaskedPrefs) (*ipn.Prefs, error) { +// EditPrefs updates the [ipn.Prefs] of the current Tailscale profile, applying the changes in mp. +// It returns an error if the changes cannot be applied, such as due to the caller's access rights +// or a policy restriction. An optional reason or justification for the request can be +// provided as a context value using [apitype.RequestReasonKey]. If permitted by policy, +// access may be granted, and the reason will be logged for auditing purposes. +func (lc *Client) EditPrefs(ctx context.Context, mp *ipn.MaskedPrefs) (*ipn.Prefs, error) { body, err := lc.send(ctx, "PATCH", "/localapi/v0/prefs", http.StatusOK, jsonBody(mp)) if err != nil { return nil, err @@ -814,9 +845,36 @@ func (lc *LocalClient) EditPrefs(ctx context.Context, mp *ipn.MaskedPrefs) (*ipn return decodeJSON[*ipn.Prefs](body) } +// GetEffectivePolicy returns the effective policy for the specified scope. +func (lc *Client) GetEffectivePolicy(ctx context.Context, scope setting.PolicyScope) (*setting.Snapshot, error) { + scopeID, err := scope.MarshalText() + if err != nil { + return nil, err + } + body, err := lc.get200(ctx, "/localapi/v0/policy/"+string(scopeID)) + if err != nil { + return nil, err + } + return decodeJSON[*setting.Snapshot](body) +} + +// ReloadEffectivePolicy reloads the effective policy for the specified scope +// by reading and merging policy settings from all applicable policy sources. +func (lc *Client) ReloadEffectivePolicy(ctx context.Context, scope setting.PolicyScope) (*setting.Snapshot, error) { + scopeID, err := scope.MarshalText() + if err != nil { + return nil, err + } + body, err := lc.send(ctx, "POST", "/localapi/v0/policy/"+string(scopeID), 200, http.NoBody) + if err != nil { + return nil, err + } + return decodeJSON[*setting.Snapshot](body) +} + // GetDNSOSConfig returns the system DNS configuration for the current device. // That is, it returns the DNS configuration that the system would use if Tailscale weren't being used. -func (lc *LocalClient) GetDNSOSConfig(ctx context.Context) (*apitype.DNSOSConfig, error) { +func (lc *Client) GetDNSOSConfig(ctx context.Context) (*apitype.DNSOSConfig, error) { body, err := lc.get200(ctx, "/localapi/v0/dns-osconfig") if err != nil { return nil, err @@ -831,7 +889,7 @@ func (lc *LocalClient) GetDNSOSConfig(ctx context.Context) (*apitype.DNSOSConfig // QueryDNS executes a DNS query for a name (`google.com.`) and query type (`CNAME`). // It returns the raw DNS response bytes and the resolvers that were used to answer the query // (often just one, but can be more if we raced multiple resolvers). -func (lc *LocalClient) QueryDNS(ctx context.Context, name string, queryType string) (bytes []byte, resolvers []*dnstype.Resolver, err error) { +func (lc *Client) QueryDNS(ctx context.Context, name string, queryType string) (bytes []byte, resolvers []*dnstype.Resolver, err error) { body, err := lc.get200(ctx, fmt.Sprintf("/localapi/v0/dns-query?name=%s&type=%s", url.QueryEscape(name), queryType)) if err != nil { return nil, nil, err @@ -844,20 +902,20 @@ func (lc *LocalClient) QueryDNS(ctx context.Context, name string, queryType stri } // StartLoginInteractive starts an interactive login. -func (lc *LocalClient) StartLoginInteractive(ctx context.Context) error { +func (lc *Client) StartLoginInteractive(ctx context.Context) error { _, err := lc.send(ctx, "POST", "/localapi/v0/login-interactive", http.StatusNoContent, nil) return err } // Start applies the configuration specified in opts, and starts the // state machine. -func (lc *LocalClient) Start(ctx context.Context, opts ipn.Options) error { +func (lc *Client) Start(ctx context.Context, opts ipn.Options) error { _, err := lc.send(ctx, "POST", "/localapi/v0/start", http.StatusNoContent, jsonBody(opts)) return err } // Logout logs out the current node. -func (lc *LocalClient) Logout(ctx context.Context) error { +func (lc *Client) Logout(ctx context.Context) error { _, err := lc.send(ctx, "POST", "/localapi/v0/logout", http.StatusNoContent, nil) return err } @@ -876,7 +934,7 @@ func (lc *LocalClient) Logout(ctx context.Context) error { // This is a low-level interface; it's expected that most Tailscale // users use a higher level interface to getting/using TLS // certificates. -func (lc *LocalClient) SetDNS(ctx context.Context, name, value string) error { +func (lc *Client) SetDNS(ctx context.Context, name, value string) error { v := url.Values{} v.Set("name", name) v.Set("value", value) @@ -889,8 +947,8 @@ func (lc *LocalClient) SetDNS(ctx context.Context, name, value string) error { // The host may be a base DNS name (resolved from the netmap inside // tailscaled), a FQDN, or an IP address. // -// The ctx is only used for the duration of the call, not the lifetime of the net.Conn. -func (lc *LocalClient) DialTCP(ctx context.Context, host string, port uint16) (net.Conn, error) { +// The ctx is only used for the duration of the call, not the lifetime of the [net.Conn]. +func (lc *Client) DialTCP(ctx context.Context, host string, port uint16) (net.Conn, error) { return lc.UserDial(ctx, "tcp", host, port) } @@ -900,8 +958,8 @@ func (lc *LocalClient) DialTCP(ctx context.Context, host string, port uint16) (n // a FQDN, or an IP address. // // The ctx is only used for the duration of the call, not the lifetime of the -// net.Conn. -func (lc *LocalClient) UserDial(ctx context.Context, network, host string, port uint16) (net.Conn, error) { +// [net.Conn]. +func (lc *Client) UserDial(ctx context.Context, network, host string, port uint16) (net.Conn, error) { connCh := make(chan net.Conn, 1) trace := httptrace.ClientTrace{ GotConn: func(info httptrace.GotConnInfo) { @@ -952,7 +1010,7 @@ func (lc *LocalClient) UserDial(ctx context.Context, network, host string, port // CurrentDERPMap returns the current DERPMap that is being used by the local tailscaled. // It is intended to be used with netcheck to see availability of DERPs. -func (lc *LocalClient) CurrentDERPMap(ctx context.Context) (*tailcfg.DERPMap, error) { +func (lc *Client) CurrentDERPMap(ctx context.Context) (*tailcfg.DERPMap, error) { var derpMap tailcfg.DERPMap res, err := lc.send(ctx, "GET", "/localapi/v0/derpmap", 200, nil) if err != nil { @@ -968,9 +1026,9 @@ func (lc *LocalClient) CurrentDERPMap(ctx context.Context) (*tailcfg.DERPMap, er // // It returns a cached certificate from disk if it's still valid. // -// Deprecated: use LocalClient.CertPair. +// Deprecated: use [Client.CertPair]. func CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err error) { - return defaultLocalClient.CertPair(ctx, domain) + return defaultClient.CertPair(ctx, domain) } // CertPair returns a cert and private key for the provided DNS domain. @@ -978,7 +1036,7 @@ func CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err e // It returns a cached certificate from disk if it's still valid. // // API maturity: this is considered a stable API. -func (lc *LocalClient) CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err error) { +func (lc *Client) CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err error) { return lc.CertPairWithValidity(ctx, domain, 0) } @@ -991,7 +1049,7 @@ func (lc *LocalClient) CertPair(ctx context.Context, domain string) (certPEM, ke // valid, but for less than minValidity, it will be synchronously renewed. // // API maturity: this is considered a stable API. -func (lc *LocalClient) CertPairWithValidity(ctx context.Context, domain string, minValidity time.Duration) (certPEM, keyPEM []byte, err error) { +func (lc *Client) CertPairWithValidity(ctx context.Context, domain string, minValidity time.Duration) (certPEM, keyPEM []byte, err error) { res, err := lc.send(ctx, "GET", fmt.Sprintf("/localapi/v0/cert/%s?type=pair&min_validity=%s", domain, minValidity), 200, nil) if err != nil { return nil, nil, err @@ -1015,11 +1073,11 @@ func (lc *LocalClient) CertPairWithValidity(ctx context.Context, domain string, // It returns a cached certificate from disk if it's still valid. // // It's the right signature to use as the value of -// tls.Config.GetCertificate. +// [tls.Config.GetCertificate]. // -// Deprecated: use LocalClient.GetCertificate. +// Deprecated: use [Client.GetCertificate]. func GetCertificate(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { - return defaultLocalClient.GetCertificate(hi) + return defaultClient.GetCertificate(hi) } // GetCertificate fetches a TLS certificate for the TLS ClientHello in hi. @@ -1027,10 +1085,10 @@ func GetCertificate(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { // It returns a cached certificate from disk if it's still valid. // // It's the right signature to use as the value of -// tls.Config.GetCertificate. +// [tls.Config.GetCertificate]. // // API maturity: this is considered a stable API. -func (lc *LocalClient) GetCertificate(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { +func (lc *Client) GetCertificate(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { if hi == nil || hi.ServerName == "" { return nil, errors.New("no SNI ServerName") } @@ -1056,13 +1114,13 @@ func (lc *LocalClient) GetCertificate(hi *tls.ClientHelloInfo) (*tls.Certificate // ExpandSNIName expands bare label name into the most likely actual TLS cert name. // -// Deprecated: use LocalClient.ExpandSNIName. +// Deprecated: use [Client.ExpandSNIName]. func ExpandSNIName(ctx context.Context, name string) (fqdn string, ok bool) { - return defaultLocalClient.ExpandSNIName(ctx, name) + return defaultClient.ExpandSNIName(ctx, name) } // ExpandSNIName expands bare label name into the most likely actual TLS cert name. -func (lc *LocalClient) ExpandSNIName(ctx context.Context, name string) (fqdn string, ok bool) { +func (lc *Client) ExpandSNIName(ctx context.Context, name string) (fqdn string, ok bool) { st, err := lc.StatusWithoutPeers(ctx) if err != nil { return "", false @@ -1090,7 +1148,7 @@ type PingOpts struct { // Ping sends a ping of the provided type to the provided IP and waits // for its response. The opts type specifies additional options. -func (lc *LocalClient) PingWithOpts(ctx context.Context, ip netip.Addr, pingtype tailcfg.PingType, opts PingOpts) (*ipnstate.PingResult, error) { +func (lc *Client) PingWithOpts(ctx context.Context, ip netip.Addr, pingtype tailcfg.PingType, opts PingOpts) (*ipnstate.PingResult, error) { v := url.Values{} v.Set("ip", ip.String()) v.Set("size", strconv.Itoa(opts.Size)) @@ -1104,12 +1162,12 @@ func (lc *LocalClient) PingWithOpts(ctx context.Context, ip netip.Addr, pingtype // Ping sends a ping of the provided type to the provided IP and waits // for its response. -func (lc *LocalClient) Ping(ctx context.Context, ip netip.Addr, pingtype tailcfg.PingType) (*ipnstate.PingResult, error) { +func (lc *Client) Ping(ctx context.Context, ip netip.Addr, pingtype tailcfg.PingType) (*ipnstate.PingResult, error) { return lc.PingWithOpts(ctx, ip, pingtype, PingOpts{}) } // NetworkLockStatus fetches information about the tailnet key authority, if one is configured. -func (lc *LocalClient) NetworkLockStatus(ctx context.Context) (*ipnstate.NetworkLockStatus, error) { +func (lc *Client) NetworkLockStatus(ctx context.Context) (*ipnstate.NetworkLockStatus, error) { body, err := lc.send(ctx, "GET", "/localapi/v0/tka/status", 200, nil) if err != nil { return nil, fmt.Errorf("error: %w", err) @@ -1120,7 +1178,7 @@ func (lc *LocalClient) NetworkLockStatus(ctx context.Context) (*ipnstate.Network // NetworkLockInit initializes the tailnet key authority. // // TODO(tom): Plumb through disablement secrets. -func (lc *LocalClient) NetworkLockInit(ctx context.Context, keys []tka.Key, disablementValues [][]byte, supportDisablement []byte) (*ipnstate.NetworkLockStatus, error) { +func (lc *Client) NetworkLockInit(ctx context.Context, keys []tka.Key, disablementValues [][]byte, supportDisablement []byte) (*ipnstate.NetworkLockStatus, error) { var b bytes.Buffer type initRequest struct { Keys []tka.Key @@ -1141,7 +1199,7 @@ func (lc *LocalClient) NetworkLockInit(ctx context.Context, keys []tka.Key, disa // NetworkLockWrapPreauthKey wraps a pre-auth key with information to // enable unattended bringup in the locked tailnet. -func (lc *LocalClient) NetworkLockWrapPreauthKey(ctx context.Context, preauthKey string, tkaKey key.NLPrivate) (string, error) { +func (lc *Client) NetworkLockWrapPreauthKey(ctx context.Context, preauthKey string, tkaKey key.NLPrivate) (string, error) { encodedPrivate, err := tkaKey.MarshalText() if err != nil { return "", err @@ -1164,7 +1222,7 @@ func (lc *LocalClient) NetworkLockWrapPreauthKey(ctx context.Context, preauthKey } // NetworkLockModify adds and/or removes key(s) to the tailnet key authority. -func (lc *LocalClient) NetworkLockModify(ctx context.Context, addKeys, removeKeys []tka.Key) error { +func (lc *Client) NetworkLockModify(ctx context.Context, addKeys, removeKeys []tka.Key) error { var b bytes.Buffer type modifyRequest struct { AddKeys []tka.Key @@ -1183,7 +1241,7 @@ func (lc *LocalClient) NetworkLockModify(ctx context.Context, addKeys, removeKey // NetworkLockSign signs the specified node-key and transmits that signature to the control plane. // rotationPublic, if specified, must be an ed25519 public key. -func (lc *LocalClient) NetworkLockSign(ctx context.Context, nodeKey key.NodePublic, rotationPublic []byte) error { +func (lc *Client) NetworkLockSign(ctx context.Context, nodeKey key.NodePublic, rotationPublic []byte) error { var b bytes.Buffer type signRequest struct { NodeKey key.NodePublic @@ -1201,7 +1259,7 @@ func (lc *LocalClient) NetworkLockSign(ctx context.Context, nodeKey key.NodePubl } // NetworkLockAffectedSigs returns all signatures signed by the specified keyID. -func (lc *LocalClient) NetworkLockAffectedSigs(ctx context.Context, keyID tkatype.KeyID) ([]tkatype.MarshaledSignature, error) { +func (lc *Client) NetworkLockAffectedSigs(ctx context.Context, keyID tkatype.KeyID) ([]tkatype.MarshaledSignature, error) { body, err := lc.send(ctx, "POST", "/localapi/v0/tka/affected-sigs", 200, bytes.NewReader(keyID)) if err != nil { return nil, fmt.Errorf("error: %w", err) @@ -1210,7 +1268,7 @@ func (lc *LocalClient) NetworkLockAffectedSigs(ctx context.Context, keyID tkatyp } // NetworkLockLog returns up to maxEntries number of changes to network-lock state. -func (lc *LocalClient) NetworkLockLog(ctx context.Context, maxEntries int) ([]ipnstate.NetworkLockUpdate, error) { +func (lc *Client) NetworkLockLog(ctx context.Context, maxEntries int) ([]ipnstate.NetworkLockUpdate, error) { v := url.Values{} v.Set("limit", fmt.Sprint(maxEntries)) body, err := lc.send(ctx, "GET", "/localapi/v0/tka/log?"+v.Encode(), 200, nil) @@ -1221,7 +1279,7 @@ func (lc *LocalClient) NetworkLockLog(ctx context.Context, maxEntries int) ([]ip } // NetworkLockForceLocalDisable forcibly shuts down network lock on this node. -func (lc *LocalClient) NetworkLockForceLocalDisable(ctx context.Context) error { +func (lc *Client) NetworkLockForceLocalDisable(ctx context.Context) error { // This endpoint expects an empty JSON stanza as the payload. var b bytes.Buffer if err := json.NewEncoder(&b).Encode(struct{}{}); err != nil { @@ -1236,7 +1294,7 @@ func (lc *LocalClient) NetworkLockForceLocalDisable(ctx context.Context) error { // NetworkLockVerifySigningDeeplink verifies the network lock deeplink contained // in url and returns information extracted from it. -func (lc *LocalClient) NetworkLockVerifySigningDeeplink(ctx context.Context, url string) (*tka.DeeplinkValidationResult, error) { +func (lc *Client) NetworkLockVerifySigningDeeplink(ctx context.Context, url string) (*tka.DeeplinkValidationResult, error) { vr := struct { URL string }{url} @@ -1250,7 +1308,7 @@ func (lc *LocalClient) NetworkLockVerifySigningDeeplink(ctx context.Context, url } // NetworkLockGenRecoveryAUM generates an AUM for recovering from a tailnet-lock key compromise. -func (lc *LocalClient) NetworkLockGenRecoveryAUM(ctx context.Context, removeKeys []tkatype.KeyID, forkFrom tka.AUMHash) ([]byte, error) { +func (lc *Client) NetworkLockGenRecoveryAUM(ctx context.Context, removeKeys []tkatype.KeyID, forkFrom tka.AUMHash) ([]byte, error) { vr := struct { Keys []tkatype.KeyID ForkFrom string @@ -1265,7 +1323,7 @@ func (lc *LocalClient) NetworkLockGenRecoveryAUM(ctx context.Context, removeKeys } // NetworkLockCosignRecoveryAUM co-signs a recovery AUM using the node's tailnet lock key. -func (lc *LocalClient) NetworkLockCosignRecoveryAUM(ctx context.Context, aum tka.AUM) ([]byte, error) { +func (lc *Client) NetworkLockCosignRecoveryAUM(ctx context.Context, aum tka.AUM) ([]byte, error) { r := bytes.NewReader(aum.Serialize()) body, err := lc.send(ctx, "POST", "/localapi/v0/tka/cosign-recovery-aum", 200, r) if err != nil { @@ -1276,7 +1334,7 @@ func (lc *LocalClient) NetworkLockCosignRecoveryAUM(ctx context.Context, aum tka } // NetworkLockSubmitRecoveryAUM submits a recovery AUM to the control plane. -func (lc *LocalClient) NetworkLockSubmitRecoveryAUM(ctx context.Context, aum tka.AUM) error { +func (lc *Client) NetworkLockSubmitRecoveryAUM(ctx context.Context, aum tka.AUM) error { r := bytes.NewReader(aum.Serialize()) _, err := lc.send(ctx, "POST", "/localapi/v0/tka/submit-recovery-aum", 200, r) if err != nil { @@ -1287,7 +1345,7 @@ func (lc *LocalClient) NetworkLockSubmitRecoveryAUM(ctx context.Context, aum tka // SetServeConfig sets or replaces the serving settings. // If config is nil, settings are cleared and serving is disabled. -func (lc *LocalClient) SetServeConfig(ctx context.Context, config *ipn.ServeConfig) error { +func (lc *Client) SetServeConfig(ctx context.Context, config *ipn.ServeConfig) error { h := make(http.Header) if config != nil { h.Set("If-Match", config.ETag) @@ -1299,8 +1357,19 @@ func (lc *LocalClient) SetServeConfig(ctx context.Context, config *ipn.ServeConf return nil } +// DisconnectControl shuts down all connections to control, thus making control consider this node inactive. This can be +// run on HA subnet router or app connector replicas before shutting them down to ensure peers get told to switch over +// to another replica whilst there is still some grace period for the existing connections to terminate. +func (lc *Client) DisconnectControl(ctx context.Context) error { + _, _, err := lc.sendWithHeaders(ctx, "POST", "/localapi/v0/disconnect-control", 200, nil, nil) + if err != nil { + return fmt.Errorf("error disconnecting control: %w", err) + } + return nil +} + // NetworkLockDisable shuts down network-lock across the tailnet. -func (lc *LocalClient) NetworkLockDisable(ctx context.Context, secret []byte) error { +func (lc *Client) NetworkLockDisable(ctx context.Context, secret []byte) error { if _, err := lc.send(ctx, "POST", "/localapi/v0/tka/disable", 200, bytes.NewReader(secret)); err != nil { return fmt.Errorf("error: %w", err) } @@ -1310,7 +1379,7 @@ func (lc *LocalClient) NetworkLockDisable(ctx context.Context, secret []byte) er // GetServeConfig return the current serve config. // // If the serve config is empty, it returns (nil, nil). -func (lc *LocalClient) GetServeConfig(ctx context.Context) (*ipn.ServeConfig, error) { +func (lc *Client) GetServeConfig(ctx context.Context) (*ipn.ServeConfig, error) { body, h, err := lc.sendWithHeaders(ctx, "GET", "/localapi/v0/serve-config", 200, nil, nil) if err != nil { return nil, fmt.Errorf("getting serve config: %w", err) @@ -1385,7 +1454,7 @@ func (r jsonReader) Read(p []byte) (n int, err error) { } // ProfileStatus returns the current profile and the list of all profiles. -func (lc *LocalClient) ProfileStatus(ctx context.Context) (current ipn.LoginProfile, all []ipn.LoginProfile, err error) { +func (lc *Client) ProfileStatus(ctx context.Context) (current ipn.LoginProfile, all []ipn.LoginProfile, err error) { body, err := lc.send(ctx, "GET", "/localapi/v0/profiles/current", 200, nil) if err != nil { return @@ -1403,7 +1472,7 @@ func (lc *LocalClient) ProfileStatus(ctx context.Context) (current ipn.LoginProf } // ReloadConfig reloads the config file, if possible. -func (lc *LocalClient) ReloadConfig(ctx context.Context) (ok bool, err error) { +func (lc *Client) ReloadConfig(ctx context.Context) (ok bool, err error) { body, err := lc.send(ctx, "POST", "/localapi/v0/reload-config", 200, nil) if err != nil { return @@ -1421,22 +1490,22 @@ func (lc *LocalClient) ReloadConfig(ctx context.Context) (ok bool, err error) { // SwitchToEmptyProfile creates and switches to a new unnamed profile. The new // profile is not assigned an ID until it is persisted after a successful login. // In order to login to the new profile, the user must call LoginInteractive. -func (lc *LocalClient) SwitchToEmptyProfile(ctx context.Context) error { +func (lc *Client) SwitchToEmptyProfile(ctx context.Context) error { _, err := lc.send(ctx, "PUT", "/localapi/v0/profiles/", http.StatusCreated, nil) return err } // SwitchProfile switches to the given profile. -func (lc *LocalClient) SwitchProfile(ctx context.Context, profile ipn.ProfileID) error { +func (lc *Client) SwitchProfile(ctx context.Context, profile ipn.ProfileID) error { _, err := lc.send(ctx, "POST", "/localapi/v0/profiles/"+url.PathEscape(string(profile)), 204, nil) return err } // DeleteProfile removes the profile with the given ID. // If the profile is the current profile, an empty profile -// will be selected as if SwitchToEmptyProfile was called. -func (lc *LocalClient) DeleteProfile(ctx context.Context, profile ipn.ProfileID) error { - _, err := lc.send(ctx, "DELETE", "/localapi/v0/profiles"+url.PathEscape(string(profile)), http.StatusNoContent, nil) +// will be selected as if [Client.SwitchToEmptyProfile] was called. +func (lc *Client) DeleteProfile(ctx context.Context, profile ipn.ProfileID) error { + _, err := lc.send(ctx, "DELETE", "/localapi/v0/profiles/"+url.PathEscape(string(profile)), http.StatusNoContent, nil) return err } @@ -1452,7 +1521,7 @@ func (lc *LocalClient) DeleteProfile(ctx context.Context, profile ipn.ProfileID) // to block until the feature has been enabled. // // 2023-08-09: Valid feature values are "serve" and "funnel". -func (lc *LocalClient) QueryFeature(ctx context.Context, feature string) (*tailcfg.QueryFeatureResponse, error) { +func (lc *Client) QueryFeature(ctx context.Context, feature string) (*tailcfg.QueryFeatureResponse, error) { v := url.Values{"feature": {feature}} body, err := lc.send(ctx, "POST", "/localapi/v0/query-feature?"+v.Encode(), 200, nil) if err != nil { @@ -1461,7 +1530,7 @@ func (lc *LocalClient) QueryFeature(ctx context.Context, feature string) (*tailc return decodeJSON[*tailcfg.QueryFeatureResponse](body) } -func (lc *LocalClient) DebugDERPRegion(ctx context.Context, regionIDOrCode string) (*ipnstate.DebugDERPRegionReport, error) { +func (lc *Client) DebugDERPRegion(ctx context.Context, regionIDOrCode string) (*ipnstate.DebugDERPRegionReport, error) { v := url.Values{"region": {regionIDOrCode}} body, err := lc.send(ctx, "POST", "/localapi/v0/debug-derp-region?"+v.Encode(), 200, nil) if err != nil { @@ -1471,7 +1540,7 @@ func (lc *LocalClient) DebugDERPRegion(ctx context.Context, regionIDOrCode strin } // DebugPacketFilterRules returns the packet filter rules for the current device. -func (lc *LocalClient) DebugPacketFilterRules(ctx context.Context) ([]tailcfg.FilterRule, error) { +func (lc *Client) DebugPacketFilterRules(ctx context.Context) ([]tailcfg.FilterRule, error) { body, err := lc.send(ctx, "POST", "/localapi/v0/debug-packet-filter-rules", 200, nil) if err != nil { return nil, fmt.Errorf("error %w: %s", err, body) @@ -1482,7 +1551,7 @@ func (lc *LocalClient) DebugPacketFilterRules(ctx context.Context) ([]tailcfg.Fi // DebugSetExpireIn marks the current node key to expire in d. // // This is meant primarily for debug and testing. -func (lc *LocalClient) DebugSetExpireIn(ctx context.Context, d time.Duration) error { +func (lc *Client) DebugSetExpireIn(ctx context.Context, d time.Duration) error { v := url.Values{"expiry": {fmt.Sprint(time.Now().Add(d).Unix())}} _, err := lc.send(ctx, "POST", "/localapi/v0/set-expiry-sooner?"+v.Encode(), 200, nil) return err @@ -1491,8 +1560,8 @@ func (lc *LocalClient) DebugSetExpireIn(ctx context.Context, d time.Duration) er // StreamDebugCapture streams a pcap-formatted packet capture. // // The provided context does not determine the lifetime of the -// returned io.ReadCloser. -func (lc *LocalClient) StreamDebugCapture(ctx context.Context) (io.ReadCloser, error) { +// returned [io.ReadCloser]. +func (lc *Client) StreamDebugCapture(ctx context.Context) (io.ReadCloser, error) { req, err := http.NewRequestWithContext(ctx, "POST", "http://"+apitype.LocalAPIHost+"/localapi/v0/debug-capture", nil) if err != nil { return nil, err @@ -1514,11 +1583,11 @@ func (lc *LocalClient) StreamDebugCapture(ctx context.Context) (io.ReadCloser, e // The context is used for the life of the watch, not just the call to // WatchIPNBus. // -// The returned IPNBusWatcher's Close method must be called when done to release +// The returned [IPNBusWatcher]'s Close method must be called when done to release // resources. // // A default set of ipn.Notify messages are returned but the set can be modified by mask. -func (lc *LocalClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*IPNBusWatcher, error) { +func (lc *Client) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*IPNBusWatcher, error) { req, err := http.NewRequestWithContext(ctx, "GET", "http://"+apitype.LocalAPIHost+"/localapi/v0/watch-ipn-bus?mask="+fmt.Sprint(mask), nil) @@ -1541,10 +1610,10 @@ func (lc *LocalClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) }, nil } -// CheckUpdate returns a tailcfg.ClientVersion indicating whether or not an update is available +// CheckUpdate returns a [*tailcfg.ClientVersion] indicating whether or not an update is available // to be installed via the LocalAPI. In case the LocalAPI can't install updates, it returns a // ClientVersion that says that we are up to date. -func (lc *LocalClient) CheckUpdate(ctx context.Context) (*tailcfg.ClientVersion, error) { +func (lc *Client) CheckUpdate(ctx context.Context) (*tailcfg.ClientVersion, error) { body, err := lc.get200(ctx, "/localapi/v0/update/check") if err != nil { return nil, err @@ -1560,7 +1629,7 @@ func (lc *LocalClient) CheckUpdate(ctx context.Context) (*tailcfg.ClientVersion, // To turn it on, there must have been a previously used exit node. // The most previously used one is reused. // This is a convenience method for GUIs. To select an actual one, update the prefs. -func (lc *LocalClient) SetUseExitNode(ctx context.Context, on bool) error { +func (lc *Client) SetUseExitNode(ctx context.Context, on bool) error { _, err := lc.send(ctx, "POST", "/localapi/v0/set-use-exit-node-enabled?enabled="+strconv.FormatBool(on), http.StatusOK, nil) return err } @@ -1568,7 +1637,7 @@ func (lc *LocalClient) SetUseExitNode(ctx context.Context, on bool) error { // DriveSetServerAddr instructs Taildrive to use the server at addr to access // the filesystem. This is used on platforms like Windows and MacOS to let // Taildrive know to use the file server running in the GUI app. -func (lc *LocalClient) DriveSetServerAddr(ctx context.Context, addr string) error { +func (lc *Client) DriveSetServerAddr(ctx context.Context, addr string) error { _, err := lc.send(ctx, "PUT", "/localapi/v0/drive/fileserver-address", http.StatusCreated, strings.NewReader(addr)) return err } @@ -1576,14 +1645,14 @@ func (lc *LocalClient) DriveSetServerAddr(ctx context.Context, addr string) erro // DriveShareSet adds or updates the given share in the list of shares that // Taildrive will serve to remote nodes. If a share with the same name already // exists, the existing share is replaced/updated. -func (lc *LocalClient) DriveShareSet(ctx context.Context, share *drive.Share) error { +func (lc *Client) DriveShareSet(ctx context.Context, share *drive.Share) error { _, err := lc.send(ctx, "PUT", "/localapi/v0/drive/shares", http.StatusCreated, jsonBody(share)) return err } // DriveShareRemove removes the share with the given name from the list of // shares that Taildrive will serve to remote nodes. -func (lc *LocalClient) DriveShareRemove(ctx context.Context, name string) error { +func (lc *Client) DriveShareRemove(ctx context.Context, name string) error { _, err := lc.send( ctx, "DELETE", @@ -1594,7 +1663,7 @@ func (lc *LocalClient) DriveShareRemove(ctx context.Context, name string) error } // DriveShareRename renames the share from old to new name. -func (lc *LocalClient) DriveShareRename(ctx context.Context, oldName, newName string) error { +func (lc *Client) DriveShareRename(ctx context.Context, oldName, newName string) error { _, err := lc.send( ctx, "POST", @@ -1606,7 +1675,7 @@ func (lc *LocalClient) DriveShareRename(ctx context.Context, oldName, newName st // DriveShareList returns the list of shares that drive is currently serving // to remote nodes. -func (lc *LocalClient) DriveShareList(ctx context.Context) ([]*drive.Share, error) { +func (lc *Client) DriveShareList(ctx context.Context) ([]*drive.Share, error) { result, err := lc.get200(ctx, "/localapi/v0/drive/shares") if err != nil { return nil, err @@ -1617,7 +1686,7 @@ func (lc *LocalClient) DriveShareList(ctx context.Context) ([]*drive.Share, erro } // IPNBusWatcher is an active subscription (watch) of the local tailscaled IPN bus. -// It's returned by LocalClient.WatchIPNBus. +// It's returned by [Client.WatchIPNBus]. // // It must be closed when done. type IPNBusWatcher struct { @@ -1641,7 +1710,7 @@ func (w *IPNBusWatcher) Close() error { } // Next returns the next ipn.Notify from the stream. -// If the context from LocalClient.WatchIPNBus is done, that error is returned. +// If the context from Client.WatchIPNBus is done, that error is returned. func (w *IPNBusWatcher) Next() (ipn.Notify, error) { var n ipn.Notify if err := w.dec.Decode(&n); err != nil { @@ -1654,7 +1723,7 @@ func (w *IPNBusWatcher) Next() (ipn.Notify, error) { } // SuggestExitNode requests an exit node suggestion and returns the exit node's details. -func (lc *LocalClient) SuggestExitNode(ctx context.Context) (apitype.ExitNodeSuggestionResponse, error) { +func (lc *Client) SuggestExitNode(ctx context.Context) (apitype.ExitNodeSuggestionResponse, error) { body, err := lc.get200(ctx, "/localapi/v0/suggest-exit-node") if err != nil { return apitype.ExitNodeSuggestionResponse{}, err diff --git a/client/tailscale/localclient_test.go b/client/local/local_test.go similarity index 87% rename from client/tailscale/localclient_test.go rename to client/local/local_test.go index 950a22f474c32..0e01e74cd1813 100644 --- a/client/tailscale/localclient_test.go +++ b/client/local/local_test.go @@ -3,16 +3,16 @@ //go:build go1.19 -package tailscale +package local import ( "context" "net" "net/http" - "net/http/httptest" "testing" "tailscale.com/tstest/deptest" + "tailscale.com/tstest/nettest" "tailscale.com/types/key" ) @@ -36,15 +36,15 @@ func TestGetServeConfigFromJSON(t *testing.T) { } func TestWhoIsPeerNotFound(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + nw := nettest.GetNetwork(t) + ts := nettest.NewHTTPServer(nw, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(404) })) defer ts.Close() - lc := &LocalClient{ + lc := &Client{ Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { - var std net.Dialer - return std.DialContext(ctx, network, ts.Listener.Addr().(*net.TCPAddr).String()) + return nw.Dial(ctx, network, ts.Listener.Addr().String()) }, } var k key.NodePublic diff --git a/client/systray/logo.go b/client/systray/logo.go new file mode 100644 index 0000000000000..3467d1b741f93 --- /dev/null +++ b/client/systray/logo.go @@ -0,0 +1,327 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build cgo || !darwin + +package systray + +import ( + "bytes" + "context" + "image" + "image/color" + "image/png" + "runtime" + "sync" + "time" + + "fyne.io/systray" + ico "github.com/Kodeworks/golang-image-ico" + "github.com/fogleman/gg" +) + +// tsLogo represents the Tailscale logo displayed as the systray icon. +type tsLogo struct { + // dots represents the state of the 3x3 dot grid in the logo. + // A 0 represents a gray dot, any other value is a white dot. + dots [9]byte + + // dotMask returns an image mask to be used when rendering the logo dots. + dotMask func(dc *gg.Context, borderUnits int, radius int) *image.Alpha + + // overlay is called after the dots are rendered to draw an additional overlay. + overlay func(dc *gg.Context, borderUnits int, radius int) +} + +var ( + // disconnected is all gray dots + disconnected = tsLogo{dots: [9]byte{ + 0, 0, 0, + 0, 0, 0, + 0, 0, 0, + }} + + // connected is the normal Tailscale logo + connected = tsLogo{dots: [9]byte{ + 0, 0, 0, + 1, 1, 1, + 0, 1, 0, + }} + + // loading is a special tsLogo value that is not meant to be rendered directly, + // but indicates that the loading animation should be shown. + loading = tsLogo{dots: [9]byte{'l', 'o', 'a', 'd', 'i', 'n', 'g'}} + + // loadingIcons are shown in sequence as an animated loading icon. + loadingLogos = []tsLogo{ + {dots: [9]byte{ + 0, 1, 1, + 1, 0, 1, + 0, 0, 1, + }}, + {dots: [9]byte{ + 0, 1, 1, + 0, 0, 1, + 0, 1, 0, + }}, + {dots: [9]byte{ + 0, 1, 1, + 0, 0, 0, + 0, 0, 1, + }}, + {dots: [9]byte{ + 0, 0, 1, + 0, 1, 0, + 0, 0, 0, + }}, + {dots: [9]byte{ + 0, 1, 0, + 0, 0, 0, + 0, 0, 0, + }}, + {dots: [9]byte{ + 0, 0, 0, + 0, 0, 1, + 0, 0, 0, + }}, + {dots: [9]byte{ + 0, 0, 0, + 0, 0, 0, + 0, 0, 0, + }}, + {dots: [9]byte{ + 0, 0, 1, + 0, 0, 0, + 0, 0, 0, + }}, + {dots: [9]byte{ + 0, 0, 0, + 0, 0, 0, + 1, 0, 0, + }}, + {dots: [9]byte{ + 0, 0, 0, + 0, 0, 0, + 1, 1, 0, + }}, + {dots: [9]byte{ + 0, 0, 0, + 1, 0, 0, + 1, 1, 0, + }}, + {dots: [9]byte{ + 0, 0, 0, + 1, 1, 0, + 0, 1, 0, + }}, + {dots: [9]byte{ + 0, 0, 0, + 1, 1, 0, + 0, 1, 1, + }}, + {dots: [9]byte{ + 0, 0, 0, + 1, 1, 1, + 0, 0, 1, + }}, + {dots: [9]byte{ + 0, 1, 0, + 0, 1, 1, + 1, 0, 1, + }}, + } + + // exitNodeOnline is the Tailscale logo with an additional arrow overlay in the corner. + exitNodeOnline = tsLogo{ + dots: [9]byte{ + 0, 0, 0, + 1, 1, 1, + 0, 1, 0, + }, + // draw an arrow mask in the bottom right corner with a reasonably thick line width. + dotMask: func(dc *gg.Context, borderUnits int, radius int) *image.Alpha { + bu, r := float64(borderUnits), float64(radius) + + x1 := r * (bu + 3.5) + y := r * (bu + 7) + x2 := x1 + (r * 5) + + mc := gg.NewContext(dc.Width(), dc.Height()) + mc.DrawLine(x1, y, x2, y) // arrow center line + mc.DrawLine(x2-(1.5*r), y-(1.5*r), x2, y) // top of arrow tip + mc.DrawLine(x2-(1.5*r), y+(1.5*r), x2, y) // bottom of arrow tip + mc.SetLineWidth(r * 3) + mc.Stroke() + return mc.AsMask() + }, + // draw an arrow in the bottom right corner over the masked area. + overlay: func(dc *gg.Context, borderUnits int, radius int) { + bu, r := float64(borderUnits), float64(radius) + + x1 := r * (bu + 3.5) + y := r * (bu + 7) + x2 := x1 + (r * 5) + + dc.DrawLine(x1, y, x2, y) // arrow center line + dc.DrawLine(x2-(1.5*r), y-(1.5*r), x2, y) // top of arrow tip + dc.DrawLine(x2-(1.5*r), y+(1.5*r), x2, y) // bottom of arrow tip + dc.SetColor(fg) + dc.SetLineWidth(r) + dc.Stroke() + }, + } + + // exitNodeOffline is the Tailscale logo with a red "x" in the corner. + exitNodeOffline = tsLogo{ + dots: [9]byte{ + 0, 0, 0, + 1, 1, 1, + 0, 1, 0, + }, + // Draw a square that hides the four dots in the bottom right corner, + dotMask: func(dc *gg.Context, borderUnits int, radius int) *image.Alpha { + bu, r := float64(borderUnits), float64(radius) + x := r * (bu + 3) + + mc := gg.NewContext(dc.Width(), dc.Height()) + mc.DrawRectangle(x, x, r*6, r*6) + mc.Fill() + return mc.AsMask() + }, + // draw a red "x" over the bottom right corner. + overlay: func(dc *gg.Context, borderUnits int, radius int) { + bu, r := float64(borderUnits), float64(radius) + + x1 := r * (bu + 4) + x2 := x1 + (r * 3.5) + dc.DrawLine(x1, x1, x2, x2) // top-left to bottom-right stroke + dc.DrawLine(x1, x2, x2, x1) // bottom-left to top-right stroke + dc.SetColor(red) + dc.SetLineWidth(r) + dc.Stroke() + }, + } +) + +var ( + bg = color.NRGBA{0, 0, 0, 255} + fg = color.NRGBA{255, 255, 255, 255} + gray = color.NRGBA{255, 255, 255, 102} + red = color.NRGBA{229, 111, 74, 255} +) + +// render returns a PNG image of the logo. +func (logo tsLogo) render() *bytes.Buffer { + const borderUnits = 1 + return logo.renderWithBorder(borderUnits) +} + +// renderWithBorder returns a PNG image of the logo with the specified border width. +// One border unit is equal to the radius of a tailscale logo dot. +func (logo tsLogo) renderWithBorder(borderUnits int) *bytes.Buffer { + const radius = 25 + dim := radius * (8 + borderUnits*2) + + dc := gg.NewContext(dim, dim) + dc.DrawRectangle(0, 0, float64(dim), float64(dim)) + dc.SetColor(bg) + dc.Fill() + + if logo.dotMask != nil { + mask := logo.dotMask(dc, borderUnits, radius) + dc.SetMask(mask) + dc.InvertMask() + } + + for y := 0; y < 3; y++ { + for x := 0; x < 3; x++ { + px := (borderUnits + 1 + 3*x) * radius + py := (borderUnits + 1 + 3*y) * radius + col := fg + if logo.dots[y*3+x] == 0 { + col = gray + } + dc.DrawCircle(float64(px), float64(py), radius) + dc.SetColor(col) + dc.Fill() + } + } + + if logo.overlay != nil { + dc.ResetClip() + logo.overlay(dc, borderUnits, radius) + } + + b := bytes.NewBuffer(nil) + + // Encode as ICO format on Windows, PNG on all other platforms. + if runtime.GOOS == "windows" { + _ = ico.Encode(b, dc.Image()) + } else { + _ = png.Encode(b, dc.Image()) + } + return b +} + +// setAppIcon renders logo and sets it as the systray icon. +func setAppIcon(icon tsLogo) { + if icon.dots == loading.dots { + startLoadingAnimation() + } else { + stopLoadingAnimation() + systray.SetIcon(icon.render().Bytes()) + } +} + +var ( + loadingMu sync.Mutex // protects loadingCancel + + // loadingCancel stops the loading animation in the systray icon. + // This is nil if the animation is not currently active. + loadingCancel func() +) + +// startLoadingAnimation starts the animated loading icon in the system tray. +// The animation continues until [stopLoadingAnimation] is called. +// If the loading animation is already active, this func does nothing. +func startLoadingAnimation() { + loadingMu.Lock() + defer loadingMu.Unlock() + + if loadingCancel != nil { + // loading icon already displayed + return + } + + ctx := context.Background() + ctx, loadingCancel = context.WithCancel(ctx) + + go func() { + t := time.NewTicker(500 * time.Millisecond) + var i int + for { + select { + case <-ctx.Done(): + return + case <-t.C: + systray.SetIcon(loadingLogos[i].render().Bytes()) + i++ + if i >= len(loadingLogos) { + i = 0 + } + } + } + }() +} + +// stopLoadingAnimation stops the animated loading icon in the system tray. +// If the loading animation is not currently active, this func does nothing. +func stopLoadingAnimation() { + loadingMu.Lock() + defer loadingMu.Unlock() + + if loadingCancel != nil { + loadingCancel() + loadingCancel = nil + } +} diff --git a/client/systray/systray.go b/client/systray/systray.go new file mode 100644 index 0000000000000..195a157fb1386 --- /dev/null +++ b/client/systray/systray.go @@ -0,0 +1,769 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build cgo || !darwin + +// Package systray provides a minimal Tailscale systray application. +package systray + +import ( + "bytes" + "context" + "errors" + "fmt" + "image" + "io" + "log" + "net/http" + "os" + "os/signal" + "runtime" + "slices" + "strings" + "sync" + "syscall" + "time" + + "fyne.io/systray" + ico "github.com/Kodeworks/golang-image-ico" + "github.com/atotto/clipboard" + dbus "github.com/godbus/dbus/v5" + "github.com/toqueteos/webbrowser" + "tailscale.com/client/local" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" + "tailscale.com/util/slicesx" + "tailscale.com/util/stringsx" +) + +var ( + // newMenuDelay is the amount of time to sleep after creating a new menu, + // but before adding items to it. This works around a bug in some dbus implementations. + newMenuDelay time.Duration + + // if true, treat all mullvad exit node countries as single-city. + // Instead of rendering a submenu with cities, just select the highest-priority peer. + hideMullvadCities bool +) + +// Run starts the systray menu and blocks until the menu exits. +func (menu *Menu) Run() { + menu.updateState() + + // exit cleanly on SIGINT and SIGTERM + go func() { + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, syscall.SIGINT, syscall.SIGTERM) + select { + case <-interrupt: + menu.onExit() + case <-menu.bgCtx.Done(): + } + }() + go menu.lc.IncrementCounter(menu.bgCtx, "systray_start", 1) + + systray.Run(menu.onReady, menu.onExit) +} + +// Menu represents the systray menu, its items, and the current Tailscale state. +type Menu struct { + mu sync.Mutex // protects the entire Menu + + lc local.Client + status *ipnstate.Status + curProfile ipn.LoginProfile + allProfiles []ipn.LoginProfile + + // readonly is whether the systray app is running in read-only mode. + // This is set if LocalAPI returns a permission error, + // typically because the user needs to run `tailscale set --operator=$USER`. + readonly bool + + bgCtx context.Context // ctx for background tasks not involving menu item clicks + bgCancel context.CancelFunc + + // Top-level menu items + connect *systray.MenuItem + disconnect *systray.MenuItem + self *systray.MenuItem + exitNodes *systray.MenuItem + more *systray.MenuItem + rebuildMenu *systray.MenuItem + quit *systray.MenuItem + + rebuildCh chan struct{} // triggers a menu rebuild + accountsCh chan ipn.ProfileID + exitNodeCh chan tailcfg.StableNodeID // ID of selected exit node + + eventCancel context.CancelFunc // cancel eventLoop + + notificationIcon *os.File // icon used for desktop notifications +} + +func (menu *Menu) init() { + if menu.bgCtx != nil { + // already initialized + return + } + + menu.rebuildCh = make(chan struct{}, 1) + menu.accountsCh = make(chan ipn.ProfileID) + menu.exitNodeCh = make(chan tailcfg.StableNodeID) + + // dbus wants a file path for notification icons, so copy to a temp file. + menu.notificationIcon, _ = os.CreateTemp("", "tailscale-systray.png") + io.Copy(menu.notificationIcon, connected.renderWithBorder(3)) + + menu.bgCtx, menu.bgCancel = context.WithCancel(context.Background()) + go menu.watchIPNBus() +} + +func init() { + if runtime.GOOS != "linux" { + // so far, these tweaks are only needed on Linux + return + } + + desktop := strings.ToLower(os.Getenv("XDG_CURRENT_DESKTOP")) + switch desktop { + case "gnome": + // GNOME expands submenus downward in the main menu, rather than flyouts to the side. + // Either as a result of that or another limitation, there seems to be a maximum depth of submenus. + // Mullvad countries that have a city submenu are not being rendered, and so can't be selected. + // Handle this by simply treating all mullvad countries as single-city and select the best peer. + hideMullvadCities = true + case "kde": + // KDE doesn't need a delay, and actually won't render submenus + // if we delay for more than about 400Âĩs. + newMenuDelay = 0 + default: + // Add a slight delay to ensure the menu is created before adding items. + // + // Systray implementations that use libdbusmenu sometimes process messages out of order, + // resulting in errors such as: + // (waybar:153009): LIBDBUSMENU-GTK-WARNING **: 18:07:11.551: Children but no menu, someone's been naughty with their 'children-display' property: 'submenu' + // + // See also: https://github.com/fyne-io/systray/issues/12 + newMenuDelay = 10 * time.Millisecond + } +} + +// onReady is called by the systray package when the menu is ready to be built. +func (menu *Menu) onReady() { + log.Printf("starting") + setAppIcon(disconnected) + menu.rebuild() +} + +// updateState updates the Menu state from the Tailscale local client. +func (menu *Menu) updateState() { + menu.mu.Lock() + defer menu.mu.Unlock() + menu.init() + + menu.readonly = false + + var err error + menu.status, err = menu.lc.Status(menu.bgCtx) + if err != nil { + log.Print(err) + } + menu.curProfile, menu.allProfiles, err = menu.lc.ProfileStatus(menu.bgCtx) + if err != nil { + if local.IsAccessDeniedError(err) { + menu.readonly = true + } + log.Print(err) + } +} + +// rebuild the systray menu based on the current Tailscale state. +// +// We currently rebuild the entire menu because it is not easy to update the existing menu. +// You cannot iterate over the items in a menu, nor can you remove some items like separators. +// So for now we rebuild the whole thing, and can optimize this later if needed. +func (menu *Menu) rebuild() { + menu.mu.Lock() + defer menu.mu.Unlock() + menu.init() + + if menu.eventCancel != nil { + menu.eventCancel() + } + ctx := context.Background() + ctx, menu.eventCancel = context.WithCancel(ctx) + + systray.ResetMenu() + + if menu.readonly { + const readonlyMsg = "No permission to manage Tailscale.\nSee tailscale.com/s/cli-operator" + m := systray.AddMenuItem(readonlyMsg, "") + onClick(ctx, m, func(_ context.Context) { + webbrowser.Open("https://tailscale.com/s/cli-operator") + }) + systray.AddSeparator() + } + + menu.connect = systray.AddMenuItem("Connect", "") + menu.disconnect = systray.AddMenuItem("Disconnect", "") + menu.disconnect.Hide() + systray.AddSeparator() + + // delay to prevent race setting icon on first start + time.Sleep(newMenuDelay) + + // Set systray menu icon and title. + // Also adjust connect/disconnect menu items if needed. + var backendState string + if menu.status != nil { + backendState = menu.status.BackendState + } + switch backendState { + case ipn.Running.String(): + if menu.status.ExitNodeStatus != nil && !menu.status.ExitNodeStatus.ID.IsZero() { + if menu.status.ExitNodeStatus.Online { + setTooltip("Using exit node") + setAppIcon(exitNodeOnline) + } else { + setTooltip("Exit node offline") + setAppIcon(exitNodeOffline) + } + } else { + setTooltip(fmt.Sprintf("Connected to %s", menu.status.CurrentTailnet.Name)) + setAppIcon(connected) + } + menu.connect.SetTitle("Connected") + menu.connect.Disable() + menu.disconnect.Show() + menu.disconnect.Enable() + case ipn.Starting.String(): + setTooltip("Connecting") + setAppIcon(loading) + default: + setTooltip("Disconnected") + setAppIcon(disconnected) + } + + if menu.readonly { + menu.connect.Disable() + menu.disconnect.Disable() + } + + account := "Account" + if pt := profileTitle(menu.curProfile); pt != "" { + account = pt + } + if !menu.readonly { + accounts := systray.AddMenuItem(account, "") + setRemoteIcon(accounts, menu.curProfile.UserProfile.ProfilePicURL) + time.Sleep(newMenuDelay) + for _, profile := range menu.allProfiles { + title := profileTitle(profile) + var item *systray.MenuItem + if profile.ID == menu.curProfile.ID { + item = accounts.AddSubMenuItemCheckbox(title, "", true) + } else { + item = accounts.AddSubMenuItem(title, "") + } + setRemoteIcon(item, profile.UserProfile.ProfilePicURL) + onClick(ctx, item, func(ctx context.Context) { + select { + case <-ctx.Done(): + case menu.accountsCh <- profile.ID: + } + }) + } + } + + if menu.status != nil && menu.status.Self != nil && len(menu.status.Self.TailscaleIPs) > 0 { + title := fmt.Sprintf("This Device: %s (%s)", menu.status.Self.HostName, menu.status.Self.TailscaleIPs[0]) + menu.self = systray.AddMenuItem(title, "") + } else { + menu.self = systray.AddMenuItem("This Device: not connected", "") + menu.self.Disable() + } + systray.AddSeparator() + + if !menu.readonly { + menu.rebuildExitNodeMenu(ctx) + } + + if menu.status != nil { + menu.more = systray.AddMenuItem("More settings", "") + onClick(ctx, menu.more, func(_ context.Context) { + webbrowser.Open("http://100.100.100.100/") + }) + } + + // TODO(#15528): this menu item shouldn't be necessary at all, + // but is at least more discoverable than having users switch profiles or exit nodes. + menu.rebuildMenu = systray.AddMenuItem("Rebuild menu", "Fix missing menu items") + onClick(ctx, menu.rebuildMenu, func(ctx context.Context) { + select { + case <-ctx.Done(): + case menu.rebuildCh <- struct{}{}: + } + }) + menu.rebuildMenu.Enable() + + menu.quit = systray.AddMenuItem("Quit", "Quit the app") + menu.quit.Enable() + + go menu.eventLoop(ctx) +} + +// profileTitle returns the title string for a profile menu item. +func profileTitle(profile ipn.LoginProfile) string { + title := profile.Name + if profile.NetworkProfile.DomainName != "" { + if runtime.GOOS == "windows" || runtime.GOOS == "darwin" { + // windows and mac don't support multi-line menu + title += " (" + profile.NetworkProfile.DomainName + ")" + } else { + title += "\n" + profile.NetworkProfile.DomainName + } + } + return title +} + +var ( + cacheMu sync.Mutex + httpCache = map[string][]byte{} // URL => response body +) + +// setRemoteIcon sets the icon for menu to the specified remote image. +// Remote images are fetched as needed and cached. +func setRemoteIcon(menu *systray.MenuItem, urlStr string) { + if menu == nil || urlStr == "" { + return + } + + cacheMu.Lock() + b, ok := httpCache[urlStr] + if !ok { + resp, err := http.Get(urlStr) + if err == nil && resp.StatusCode == http.StatusOK { + b, _ = io.ReadAll(resp.Body) + + // Convert image to ICO format on Windows + if runtime.GOOS == "windows" { + im, _, err := image.Decode(bytes.NewReader(b)) + if err != nil { + return + } + buf := bytes.NewBuffer(nil) + if err := ico.Encode(buf, im); err != nil { + return + } + b = buf.Bytes() + } + + httpCache[urlStr] = b + resp.Body.Close() + } + } + cacheMu.Unlock() + + if len(b) > 0 { + menu.SetIcon(b) + } +} + +// setTooltip sets the tooltip text for the systray icon. +func setTooltip(text string) { + if runtime.GOOS == "darwin" || runtime.GOOS == "windows" { + systray.SetTooltip(text) + } else { + // on Linux, SetTitle actually sets the tooltip + systray.SetTitle(text) + } +} + +// eventLoop is the main event loop for handling click events on menu items +// and responding to Tailscale state changes. +// This method does not return until ctx.Done is closed. +func (menu *Menu) eventLoop(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-menu.rebuildCh: + menu.updateState() + menu.rebuild() + case <-menu.connect.ClickedCh: + _, err := menu.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + WantRunning: true, + }, + WantRunningSet: true, + }) + if err != nil { + log.Printf("error connecting: %v", err) + } + + case <-menu.disconnect.ClickedCh: + _, err := menu.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + WantRunning: false, + }, + WantRunningSet: true, + }) + if err != nil { + log.Printf("error disconnecting: %v", err) + } + + case <-menu.self.ClickedCh: + menu.copyTailscaleIP(menu.status.Self) + + case id := <-menu.accountsCh: + if err := menu.lc.SwitchProfile(ctx, id); err != nil { + log.Printf("error switching to profile ID %v: %v", id, err) + } + + case exitNode := <-menu.exitNodeCh: + if exitNode.IsZero() { + log.Print("disable exit node") + if err := menu.lc.SetUseExitNode(ctx, false); err != nil { + log.Printf("error disabling exit node: %v", err) + } + } else { + log.Printf("enable exit node: %v", exitNode) + mp := &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + ExitNodeID: exitNode, + }, + ExitNodeIDSet: true, + } + if _, err := menu.lc.EditPrefs(ctx, mp); err != nil { + log.Printf("error setting exit node: %v", err) + } + } + + case <-menu.quit.ClickedCh: + systray.Quit() + } + } +} + +// onClick registers a click handler for a menu item. +func onClick(ctx context.Context, item *systray.MenuItem, fn func(ctx context.Context)) { + go func() { + for { + select { + case <-ctx.Done(): + return + case <-item.ClickedCh: + fn(ctx) + } + } + }() +} + +// watchIPNBus subscribes to the tailscale event bus and sends state updates to chState. +// This method does not return. +func (menu *Menu) watchIPNBus() { + for { + if err := menu.watchIPNBusInner(); err != nil { + log.Println(err) + if errors.Is(err, context.Canceled) { + // If the context got canceled, we will never be able to + // reconnect to IPN bus, so exit the process. + log.Fatalf("watchIPNBus: %v", err) + } + } + // If our watch connection breaks, wait a bit before reconnecting. No + // reason to spam the logs if e.g. tailscaled is restarting or goes + // down. + time.Sleep(3 * time.Second) + } +} + +func (menu *Menu) watchIPNBusInner() error { + watcher, err := menu.lc.WatchIPNBus(menu.bgCtx, ipn.NotifyNoPrivateKeys) + if err != nil { + return fmt.Errorf("watching ipn bus: %w", err) + } + defer watcher.Close() + for { + select { + case <-menu.bgCtx.Done(): + return nil + default: + n, err := watcher.Next() + if err != nil { + return fmt.Errorf("ipnbus error: %w", err) + } + var rebuild bool + if n.State != nil { + log.Printf("new state: %v", n.State) + rebuild = true + } + if n.Prefs != nil { + rebuild = true + } + if rebuild { + menu.rebuildCh <- struct{}{} + } + } + } +} + +// copyTailscaleIP copies the first Tailscale IP of the given device to the clipboard +// and sends a notification with the copied value. +func (menu *Menu) copyTailscaleIP(device *ipnstate.PeerStatus) { + if device == nil || len(device.TailscaleIPs) == 0 { + return + } + name := strings.Split(device.DNSName, ".")[0] + ip := device.TailscaleIPs[0].String() + err := clipboard.WriteAll(ip) + if err != nil { + log.Printf("clipboard error: %v", err) + } + + menu.sendNotification(fmt.Sprintf("Copied Address for %v", name), ip) +} + +// sendNotification sends a desktop notification with the given title and content. +func (menu *Menu) sendNotification(title, content string) { + conn, err := dbus.SessionBus() + if err != nil { + log.Printf("dbus: %v", err) + return + } + timeout := 3 * time.Second + obj := conn.Object("org.freedesktop.Notifications", "/org/freedesktop/Notifications") + call := obj.Call("org.freedesktop.Notifications.Notify", 0, "Tailscale", uint32(0), + menu.notificationIcon.Name(), title, content, []string{}, map[string]dbus.Variant{}, int32(timeout.Milliseconds())) + if call.Err != nil { + log.Printf("dbus: %v", call.Err) + } +} + +func (menu *Menu) rebuildExitNodeMenu(ctx context.Context) { + if menu.status == nil { + return + } + + status := menu.status + menu.exitNodes = systray.AddMenuItem("Exit Nodes", "") + time.Sleep(newMenuDelay) + + // register a click handler for a menu item to set nodeID as the exit node. + setExitNodeOnClick := func(item *systray.MenuItem, nodeID tailcfg.StableNodeID) { + onClick(ctx, item, func(ctx context.Context) { + select { + case <-ctx.Done(): + case menu.exitNodeCh <- nodeID: + } + }) + } + + noExitNodeMenu := menu.exitNodes.AddSubMenuItemCheckbox("None", "", status.ExitNodeStatus == nil) + setExitNodeOnClick(noExitNodeMenu, "") + + // Show recommended exit node if available. + if status.Self.CapMap.Contains(tailcfg.NodeAttrSuggestExitNodeUI) { + sugg, err := menu.lc.SuggestExitNode(ctx) + if err == nil { + title := "Recommended: " + if loc := sugg.Location; loc.Valid() && loc.Country() != "" { + flag := countryFlag(loc.CountryCode()) + title += fmt.Sprintf("%s %s: %s", flag, loc.Country(), loc.City()) + } else { + title += strings.Split(sugg.Name, ".")[0] + } + menu.exitNodes.AddSeparator() + rm := menu.exitNodes.AddSubMenuItemCheckbox(title, "", false) + setExitNodeOnClick(rm, sugg.ID) + if status.ExitNodeStatus != nil && sugg.ID == status.ExitNodeStatus.ID { + rm.Check() + } + } + } + + // Add tailnet exit nodes if present. + var tailnetExitNodes []*ipnstate.PeerStatus + for _, ps := range status.Peer { + if ps.ExitNodeOption && ps.Location == nil { + tailnetExitNodes = append(tailnetExitNodes, ps) + } + } + if len(tailnetExitNodes) > 0 { + menu.exitNodes.AddSeparator() + menu.exitNodes.AddSubMenuItem("Tailnet Exit Nodes", "").Disable() + for _, ps := range status.Peer { + if !ps.ExitNodeOption || ps.Location != nil { + continue + } + name := strings.Split(ps.DNSName, ".")[0] + if !ps.Online { + name += " (offline)" + } + sm := menu.exitNodes.AddSubMenuItemCheckbox(name, "", false) + if !ps.Online { + sm.Disable() + } + if status.ExitNodeStatus != nil && ps.ID == status.ExitNodeStatus.ID { + sm.Check() + } + setExitNodeOnClick(sm, ps.ID) + } + } + + // Add mullvad exit nodes if present. + var mullvadExitNodes mullvadPeers + if status.Self.CapMap.Contains("mullvad") { + mullvadExitNodes = newMullvadPeers(status) + } + if len(mullvadExitNodes.countries) > 0 { + menu.exitNodes.AddSeparator() + menu.exitNodes.AddSubMenuItem("Location-based Exit Nodes", "").Disable() + mullvadMenu := menu.exitNodes.AddSubMenuItemCheckbox("Mullvad VPN", "", false) + + for _, country := range mullvadExitNodes.sortedCountries() { + flag := countryFlag(country.code) + countryMenu := mullvadMenu.AddSubMenuItemCheckbox(flag+" "+country.name, "", false) + + // single-city country, no submenu + if len(country.cities) == 1 || hideMullvadCities { + setExitNodeOnClick(countryMenu, country.best.ID) + if status.ExitNodeStatus != nil { + for _, city := range country.cities { + for _, ps := range city.peers { + if status.ExitNodeStatus.ID == ps.ID { + mullvadMenu.Check() + countryMenu.Check() + } + } + } + } + continue + } + + // multi-city country, build submenu with "best available" option and cities. + time.Sleep(newMenuDelay) + bm := countryMenu.AddSubMenuItemCheckbox("Best Available", "", false) + setExitNodeOnClick(bm, country.best.ID) + countryMenu.AddSeparator() + + for _, city := range country.sortedCities() { + cityMenu := countryMenu.AddSubMenuItemCheckbox(city.name, "", false) + setExitNodeOnClick(cityMenu, city.best.ID) + if status.ExitNodeStatus != nil { + for _, ps := range city.peers { + if status.ExitNodeStatus.ID == ps.ID { + mullvadMenu.Check() + countryMenu.Check() + cityMenu.Check() + } + } + } + } + } + } + + // TODO: "Allow Local Network Access" and "Run Exit Node" menu items +} + +// mullvadPeers contains all mullvad peer nodes, sorted by country and city. +type mullvadPeers struct { + countries map[string]*mvCountry // country code (uppercase) => country +} + +// sortedCountries returns countries containing mullvad nodes, sorted by name. +func (mp mullvadPeers) sortedCountries() []*mvCountry { + countries := slicesx.MapValues(mp.countries) + slices.SortFunc(countries, func(a, b *mvCountry) int { + return stringsx.CompareFold(a.name, b.name) + }) + return countries +} + +type mvCountry struct { + code string + name string + best *ipnstate.PeerStatus // highest priority peer in the country + cities map[string]*mvCity // city code => city +} + +// sortedCities returns cities containing mullvad nodes, sorted by name. +func (mc *mvCountry) sortedCities() []*mvCity { + cities := slicesx.MapValues(mc.cities) + slices.SortFunc(cities, func(a, b *mvCity) int { + return stringsx.CompareFold(a.name, b.name) + }) + return cities +} + +// countryFlag takes a 2-character ASCII string and returns the corresponding emoji flag. +// It returns the empty string on error. +func countryFlag(code string) string { + if len(code) != 2 { + return "" + } + runes := make([]rune, 0, 2) + for i := range 2 { + b := code[i] | 32 // lowercase + if b < 'a' || b > 'z' { + return "" + } + // https://en.wikipedia.org/wiki/Regional_indicator_symbol + runes = append(runes, 0x1F1E6+rune(b-'a')) + } + return string(runes) +} + +type mvCity struct { + name string + best *ipnstate.PeerStatus // highest priority peer in the city + peers []*ipnstate.PeerStatus +} + +func newMullvadPeers(status *ipnstate.Status) mullvadPeers { + countries := make(map[string]*mvCountry) + for _, ps := range status.Peer { + if !ps.ExitNodeOption || ps.Location == nil { + continue + } + loc := ps.Location + country, ok := countries[loc.CountryCode] + if !ok { + country = &mvCountry{ + code: loc.CountryCode, + name: loc.Country, + cities: make(map[string]*mvCity), + } + countries[loc.CountryCode] = country + } + city, ok := countries[loc.CountryCode].cities[loc.CityCode] + if !ok { + city = &mvCity{ + name: loc.City, + } + countries[loc.CountryCode].cities[loc.CityCode] = city + } + city.peers = append(city.peers, ps) + if city.best == nil || ps.Location.Priority > city.best.Location.Priority { + city.best = ps + } + if country.best == nil || ps.Location.Priority > country.best.Location.Priority { + country.best = ps + } + } + return mullvadPeers{countries} +} + +// onExit is called by the systray package when the menu is exiting. +func (menu *Menu) onExit() { + log.Printf("exiting") + if menu.bgCancel != nil { + menu.bgCancel() + } + if menu.eventCancel != nil { + menu.eventCancel() + } + + os.Remove(menu.notificationIcon.Name()) +} diff --git a/client/tailscale/acl.go b/client/tailscale/acl.go index 8d8bdfc86baf1..929ec2b3b1ca9 100644 --- a/client/tailscale/acl.go +++ b/client/tailscale/acl.go @@ -12,6 +12,7 @@ import ( "fmt" "net/http" "net/netip" + "net/url" ) // ACLRow defines a rule that grants access by a set of users or groups to a set @@ -83,7 +84,7 @@ func (c *Client) ACL(ctx context.Context) (acl *ACL, err error) { } }() - path := fmt.Sprintf("%s/api/v2/tailnet/%s/acl", c.baseURL(), c.tailnet) + path := c.BuildTailnetURL("acl") req, err := http.NewRequestWithContext(ctx, "GET", path, nil) if err != nil { return nil, err @@ -97,7 +98,7 @@ func (c *Client) ACL(ctx context.Context) (acl *ACL, err error) { // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } // Otherwise, try to decode the response. @@ -126,7 +127,7 @@ func (c *Client) ACLHuJSON(ctx context.Context) (acl *ACLHuJSON, err error) { } }() - path := fmt.Sprintf("%s/api/v2/tailnet/%s/acl?details=1", c.baseURL(), c.tailnet) + path := c.BuildTailnetURL("acl", url.Values{"details": {"1"}}) req, err := http.NewRequestWithContext(ctx, "GET", path, nil) if err != nil { return nil, err @@ -138,7 +139,7 @@ func (c *Client) ACLHuJSON(ctx context.Context) (acl *ACLHuJSON, err error) { } if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } data := struct { @@ -146,7 +147,7 @@ func (c *Client) ACLHuJSON(ctx context.Context) (acl *ACLHuJSON, err error) { Warnings []string `json:"warnings"` }{} if err := json.Unmarshal(b, &data); err != nil { - return nil, err + return nil, fmt.Errorf("json.Unmarshal %q: %w", b, err) } acl = &ACLHuJSON{ @@ -184,7 +185,7 @@ func (e ACLTestError) Error() string { } func (c *Client) aclPOSTRequest(ctx context.Context, body []byte, avoidCollisions bool, etag, acceptHeader string) ([]byte, string, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/acl", c.baseURL(), c.tailnet) + path := c.BuildTailnetURL("acl") req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(body)) if err != nil { return nil, "", err @@ -328,7 +329,7 @@ type ACLPreview struct { } func (c *Client) previewACLPostRequest(ctx context.Context, body []byte, previewType string, previewFor string) (res *ACLPreviewResponse, err error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/acl/preview", c.baseURL(), c.tailnet) + path := c.BuildTailnetURL("acl", "preview") req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(body)) if err != nil { return nil, err @@ -350,7 +351,7 @@ func (c *Client) previewACLPostRequest(ctx context.Context, body []byte, preview // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } if err = json.Unmarshal(b, &res); err != nil { return nil, err @@ -488,7 +489,7 @@ func (c *Client) ValidateACLJSON(ctx context.Context, source, dest string) (test return nil, err } - path := fmt.Sprintf("%s/api/v2/tailnet/%s/acl/validate", c.baseURL(), c.tailnet) + path := c.BuildTailnetURL("acl", "validate") req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewBuffer(postData)) if err != nil { return nil, err diff --git a/client/tailscale/apitype/apitype.go b/client/tailscale/apitype/apitype.go index b1c273a4fd462..58cdcecc78d4f 100644 --- a/client/tailscale/apitype/apitype.go +++ b/client/tailscale/apitype/apitype.go @@ -7,11 +7,29 @@ package apitype import ( "tailscale.com/tailcfg" "tailscale.com/types/dnstype" + "tailscale.com/util/ctxkey" ) // LocalAPIHost is the Host header value used by the LocalAPI. const LocalAPIHost = "local-tailscaled.sock" +// RequestReasonHeader is the header used to pass justification for a LocalAPI request, +// such as when a user wants to perform an action they don't have permission for, +// and a policy allows it with justification. As of 2025-01-29, it is only used to +// allow a user to disconnect Tailscale when the "always-on" mode is enabled. +// +// The header value is base64-encoded using the standard encoding defined in RFC 4648. +// +// See tailscale/corp#26146. +const RequestReasonHeader = "X-Tailscale-Reason" + +// RequestReasonKey is the context key used to pass the request reason +// when making a LocalAPI request via [local.Client]. +// It's value is a raw string. An empty string means no reason was provided. +// +// See tailscale/corp#26146. +var RequestReasonKey = ctxkey.New(RequestReasonHeader, "") + // WhoIsResponse is the JSON type returned by tailscaled debug server's /whois?ip=$IP handler. // In successful whois responses, Node and UserProfile are never nil. type WhoIsResponse struct { diff --git a/client/tailscale/devices.go b/client/tailscale/devices.go index 9008d4d0d0d54..0664f9e63edb1 100644 --- a/client/tailscale/devices.go +++ b/client/tailscale/devices.go @@ -79,6 +79,13 @@ type Device struct { // Tailscale have attempted to collect this from the device but it has not // opted in, PostureIdentity will have Disabled=true. PostureIdentity *DevicePostureIdentity `json:"postureIdentity"` + + // TailnetLockKey is the tailnet lock public key of the node as a hex string. + TailnetLockKey string `json:"tailnetLockKey,omitempty"` + + // TailnetLockErr indicates an issue with the tailnet lock node-key signature + // on this device. This field is only populated when tailnet lock is enabled. + TailnetLockErr string `json:"tailnetLockError,omitempty"` } type DevicePostureIdentity struct { @@ -131,7 +138,7 @@ func (c *Client) Devices(ctx context.Context, fields *DeviceFieldsOpts) (deviceL } }() - path := fmt.Sprintf("%s/api/v2/tailnet/%s/devices", c.baseURL(), c.tailnet) + path := c.BuildTailnetURL("devices") req, err := http.NewRequestWithContext(ctx, "GET", path, nil) if err != nil { return nil, err @@ -149,7 +156,7 @@ func (c *Client) Devices(ctx context.Context, fields *DeviceFieldsOpts) (deviceL // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } var devices GetDevicesResponse @@ -188,7 +195,7 @@ func (c *Client) Device(ctx context.Context, deviceID string, fields *DeviceFiel // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } err = json.Unmarshal(b, &device) @@ -221,7 +228,7 @@ func (c *Client) DeleteDevice(ctx context.Context, deviceID string) (err error) // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return handleErrorResponse(b, resp) + return HandleErrorResponse(b, resp) } return nil } @@ -253,7 +260,7 @@ func (c *Client) SetAuthorized(ctx context.Context, deviceID string, authorized // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return handleErrorResponse(b, resp) + return HandleErrorResponse(b, resp) } return nil @@ -281,7 +288,7 @@ func (c *Client) SetTags(ctx context.Context, deviceID string, tags []string) er // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return handleErrorResponse(b, resp) + return HandleErrorResponse(b, resp) } return nil diff --git a/client/tailscale/dns.go b/client/tailscale/dns.go index f198742b3ca51..bbdc7c56c65f7 100644 --- a/client/tailscale/dns.go +++ b/client/tailscale/dns.go @@ -44,7 +44,7 @@ type DNSPreferences struct { } func (c *Client) dnsGETRequest(ctx context.Context, endpoint string) ([]byte, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/dns/%s", c.baseURL(), c.tailnet, endpoint) + path := c.BuildTailnetURL("dns", endpoint) req, err := http.NewRequestWithContext(ctx, "GET", path, nil) if err != nil { return nil, err @@ -57,14 +57,14 @@ func (c *Client) dnsGETRequest(ctx context.Context, endpoint string) ([]byte, er // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } return b, nil } func (c *Client) dnsPOSTRequest(ctx context.Context, endpoint string, postData any) ([]byte, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/dns/%s", c.baseURL(), c.tailnet, endpoint) + path := c.BuildTailnetURL("dns", endpoint) data, err := json.Marshal(&postData) if err != nil { return nil, err @@ -84,7 +84,7 @@ func (c *Client) dnsPOSTRequest(ctx context.Context, endpoint string, postData a // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } return b, nil diff --git a/client/tailscale/example/servetls/servetls.go b/client/tailscale/example/servetls/servetls.go index f48e90d163527..0ade420887634 100644 --- a/client/tailscale/example/servetls/servetls.go +++ b/client/tailscale/example/servetls/servetls.go @@ -11,13 +11,14 @@ import ( "log" "net/http" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" ) func main() { + var lc local.Client s := &http.Server{ TLSConfig: &tls.Config{ - GetCertificate: tailscale.GetCertificate, + GetCertificate: lc.GetCertificate, }, Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { io.WriteString(w, "

Hello from Tailscale!

It works.") diff --git a/client/tailscale/keys.go b/client/tailscale/keys.go index 84bcdfae6aeeb..79e19e99880f7 100644 --- a/client/tailscale/keys.go +++ b/client/tailscale/keys.go @@ -40,7 +40,7 @@ type KeyDeviceCreateCapabilities struct { // Keys returns the list of keys for the current user. func (c *Client) Keys(ctx context.Context) ([]string, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys", c.baseURL(), c.tailnet) + path := c.BuildTailnetURL("keys") req, err := http.NewRequestWithContext(ctx, "GET", path, nil) if err != nil { return nil, err @@ -51,7 +51,7 @@ func (c *Client) Keys(ctx context.Context) ([]string, error) { return nil, err } if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } var keys struct { @@ -99,7 +99,7 @@ func (c *Client) CreateKeyWithExpiry(ctx context.Context, caps KeyCapabilities, return "", nil, err } - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys", c.baseURL(), c.tailnet) + path := c.BuildTailnetURL("keys") req, err := http.NewRequestWithContext(ctx, "POST", path, bytes.NewReader(bs)) if err != nil { return "", nil, err @@ -110,7 +110,7 @@ func (c *Client) CreateKeyWithExpiry(ctx context.Context, caps KeyCapabilities, return "", nil, err } if resp.StatusCode != http.StatusOK { - return "", nil, handleErrorResponse(b, resp) + return "", nil, HandleErrorResponse(b, resp) } var key struct { @@ -126,7 +126,7 @@ func (c *Client) CreateKeyWithExpiry(ctx context.Context, caps KeyCapabilities, // Key returns the metadata for the given key ID. Currently, capabilities are // only returned for auth keys, API keys only return general metadata. func (c *Client) Key(ctx context.Context, id string) (*Key, error) { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys/%s", c.baseURL(), c.tailnet, id) + path := c.BuildTailnetURL("keys", id) req, err := http.NewRequestWithContext(ctx, "GET", path, nil) if err != nil { return nil, err @@ -137,7 +137,7 @@ func (c *Client) Key(ctx context.Context, id string) (*Key, error) { return nil, err } if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } var key Key @@ -149,7 +149,7 @@ func (c *Client) Key(ctx context.Context, id string) (*Key, error) { // DeleteKey deletes the key with the given ID. func (c *Client) DeleteKey(ctx context.Context, id string) error { - path := fmt.Sprintf("%s/api/v2/tailnet/%s/keys/%s", c.baseURL(), c.tailnet, id) + path := c.BuildTailnetURL("keys", id) req, err := http.NewRequestWithContext(ctx, "DELETE", path, nil) if err != nil { return err @@ -160,7 +160,7 @@ func (c *Client) DeleteKey(ctx context.Context, id string) error { return err } if resp.StatusCode != http.StatusOK { - return handleErrorResponse(b, resp) + return HandleErrorResponse(b, resp) } return nil } diff --git a/client/tailscale/localclient_aliases.go b/client/tailscale/localclient_aliases.go new file mode 100644 index 0000000000000..2b53906b71ae4 --- /dev/null +++ b/client/tailscale/localclient_aliases.go @@ -0,0 +1,106 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailscale + +import ( + "context" + "crypto/tls" + + "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/ipn/ipnstate" +) + +// ErrPeerNotFound is an alias for [tailscale.com/client/local.ErrPeerNotFound]. +// +// Deprecated: import [tailscale.com/client/local] instead. +var ErrPeerNotFound = local.ErrPeerNotFound + +// LocalClient is an alias for [tailscale.com/client/local.Client]. +// +// Deprecated: import [tailscale.com/client/local] instead. +type LocalClient = local.Client + +// IPNBusWatcher is an alias for [tailscale.com/client/local.IPNBusWatcher]. +// +// Deprecated: import [tailscale.com/client/local] instead. +type IPNBusWatcher = local.IPNBusWatcher + +// BugReportOpts is an alias for [tailscale.com/client/local.BugReportOpts]. +// +// Deprecated: import [tailscale.com/client/local] instead. +type BugReportOpts = local.BugReportOpts + +// DebugPortmapOpts is an alias for [tailscale.com/client/local.DebugPortmapOpts]. +// +// Deprecated: import [tailscale.com/client/local] instead. +type DebugPortmapOpts = local.DebugPortmapOpts + +// PingOpts is an alias for [tailscale.com/client/local.PingOpts]. +// +// Deprecated: import [tailscale.com/client/local] instead. +type PingOpts = local.PingOpts + +// GetCertificate is an alias for [tailscale.com/client/local.GetCertificate]. +// +// Deprecated: import [tailscale.com/client/local] instead and use [local.Client.GetCertificate]. +func GetCertificate(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { + return local.GetCertificate(hi) +} + +// SetVersionMismatchHandler is an alias for [tailscale.com/client/local.SetVersionMismatchHandler]. +// +// Deprecated: import [tailscale.com/client/local] instead. +func SetVersionMismatchHandler(f func(clientVer, serverVer string)) { + local.SetVersionMismatchHandler(f) +} + +// IsAccessDeniedError is an alias for [tailscale.com/client/local.IsAccessDeniedError]. +// +// Deprecated: import [tailscale.com/client/local] instead. +func IsAccessDeniedError(err error) bool { + return local.IsAccessDeniedError(err) +} + +// IsPreconditionsFailedError is an alias for [tailscale.com/client/local.IsPreconditionsFailedError]. +// +// Deprecated: import [tailscale.com/client/local] instead. +func IsPreconditionsFailedError(err error) bool { + return local.IsPreconditionsFailedError(err) +} + +// WhoIs is an alias for [tailscale.com/client/local.WhoIs]. +// +// Deprecated: import [tailscale.com/client/local] instead and use [local.Client.WhoIs]. +func WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { + return local.WhoIs(ctx, remoteAddr) +} + +// Status is an alias for [tailscale.com/client/local.Status]. +// +// Deprecated: import [tailscale.com/client/local] instead. +func Status(ctx context.Context) (*ipnstate.Status, error) { + return local.Status(ctx) +} + +// StatusWithoutPeers is an alias for [tailscale.com/client/local.StatusWithoutPeers]. +// +// Deprecated: import [tailscale.com/client/local] instead. +func StatusWithoutPeers(ctx context.Context) (*ipnstate.Status, error) { + return local.StatusWithoutPeers(ctx) +} + +// CertPair is an alias for [tailscale.com/client/local.CertPair]. +// +// Deprecated: import [tailscale.com/client/local] instead and use [local.Client.CertPair]. +func CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err error) { + return local.CertPair(ctx, domain) +} + +// ExpandSNIName is an alias for [tailscale.com/client/local.ExpandSNIName]. +// +// Deprecated: import [tailscale.com/client/local] instead and use [local.Client.ExpandSNIName]. +func ExpandSNIName(ctx context.Context, name string) (fqdn string, ok bool) { + return local.ExpandSNIName(ctx, name) +} diff --git a/client/tailscale/routes.go b/client/tailscale/routes.go index 5912fc46c09a6..b72f2743ff9fb 100644 --- a/client/tailscale/routes.go +++ b/client/tailscale/routes.go @@ -44,7 +44,7 @@ func (c *Client) Routes(ctx context.Context, deviceID string) (routes *Routes, e // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } var sr Routes @@ -84,7 +84,7 @@ func (c *Client) SetRoutes(ctx context.Context, deviceID string, subnets []netip // If status code was not successful, return the error. // TODO: Change the check for the StatusCode to include other 2XX success codes. if resp.StatusCode != http.StatusOK { - return nil, handleErrorResponse(b, resp) + return nil, HandleErrorResponse(b, resp) } var srr *Routes diff --git a/client/tailscale/tailnet.go b/client/tailscale/tailnet.go index 2539e7f235b0e..9453962c908c8 100644 --- a/client/tailscale/tailnet.go +++ b/client/tailscale/tailnet.go @@ -9,7 +9,6 @@ import ( "context" "fmt" "net/http" - "net/url" "tailscale.com/util/httpm" ) @@ -22,7 +21,7 @@ func (c *Client) TailnetDeleteRequest(ctx context.Context, tailnetID string) (er } }() - path := fmt.Sprintf("%s/api/v2/tailnet/%s", c.baseURL(), url.PathEscape(string(tailnetID))) + path := c.BuildTailnetURL("tailnet") req, err := http.NewRequestWithContext(ctx, httpm.DELETE, path, nil) if err != nil { return err @@ -35,7 +34,7 @@ func (c *Client) TailnetDeleteRequest(ctx context.Context, tailnetID string) (er } if resp.StatusCode != http.StatusOK { - return handleErrorResponse(b, resp) + return HandleErrorResponse(b, resp) } return nil diff --git a/client/tailscale/tailscale.go b/client/tailscale/tailscale.go index 8533b47129e01..76e44454b2fc2 100644 --- a/client/tailscale/tailscale.go +++ b/client/tailscale/tailscale.go @@ -3,11 +3,12 @@ //go:build go1.19 -// Package tailscale contains Go clients for the Tailscale LocalAPI and -// Tailscale control plane API. +// Package tailscale contains a Go client for the Tailscale control plane API. // -// Warning: this package is in development and makes no API compatibility -// promises as of 2022-04-29. It is subject to change at any time. +// This package is only intended for internal and transitional use. +// +// Deprecated: the official control plane client is available at +// [tailscale.com/client/tailscale/v2]. package tailscale import ( @@ -16,13 +17,12 @@ import ( "fmt" "io" "net/http" + "net/url" + "path" ) // I_Acknowledge_This_API_Is_Unstable must be set true to use this package -// for now. It was added 2022-04-29 when it was moved to this git repo -// and will be removed when the public API has settled. -// -// TODO(bradfitz): remove this after the we're happy with the public API. +// for now. This package is being replaced by [tailscale.com/client/tailscale/v2]. var I_Acknowledge_This_API_Is_Unstable = false // TODO: use url.PathEscape() for deviceID and tailnets when constructing requests. @@ -34,8 +34,10 @@ const maxReadSize = 10 << 20 // Client makes API calls to the Tailscale control plane API server. // -// Use NewClient to instantiate one. Exported fields should be set before +// Use [NewClient] to instantiate one. Exported fields should be set before // the client is used and not changed thereafter. +// +// Deprecated: use [tailscale.com/client/tailscale/v2] instead. type Client struct { // tailnet is the globally unique identifier for a Tailscale network, such // as "example.com" or "user@gmail.com". @@ -49,7 +51,7 @@ type Client struct { BaseURL string // HTTPClient optionally specifies an alternate HTTP client to use. - // If nil, http.DefaultClient is used. + // If nil, [http.DefaultClient] is used. HTTPClient *http.Client // UserAgent optionally specifies an alternate User-Agent header @@ -63,6 +65,46 @@ func (c *Client) httpClient() *http.Client { return http.DefaultClient } +// BuildURL builds a url to http(s):///api/v2/ +// using the given pathElements. It url escapes each path element, so the +// caller doesn't need to worry about that. The last item of pathElements can +// be of type url.Values to add a query string to the URL. +// +// For example, BuildURL(devices, 5) with the default server URL would result in +// https://api.tailscale.com/api/v2/devices/5. +func (c *Client) BuildURL(pathElements ...any) string { + elem := make([]string, 1, len(pathElements)+1) + elem[0] = "/api/v2" + var query string + for i, pathElement := range pathElements { + if uv, ok := pathElement.(url.Values); ok && i == len(pathElements)-1 { + query = uv.Encode() + } else { + elem = append(elem, url.PathEscape(fmt.Sprint(pathElement))) + } + } + url := c.baseURL() + path.Join(elem...) + if query != "" { + url += "?" + query + } + return url +} + +// BuildTailnetURL builds a url to http(s):///api/v2/tailnet// +// using the given pathElements. It url escapes each path element, so the +// caller doesn't need to worry about that. The last item of pathElements can +// be of type url.Values to add a query string to the URL. +// +// For example, BuildTailnetURL(policy, validate) with the default server URL and a tailnet of "example.com" +// would result in https://api.tailscale.com/api/v2/tailnet/example.com/policy/validate. +func (c *Client) BuildTailnetURL(pathElements ...any) string { + allElements := make([]any, 2, len(pathElements)+2) + allElements[0] = "tailnet" + allElements[1] = c.tailnet + allElements = append(allElements, pathElements...) + return c.BuildURL(allElements...) +} + func (c *Client) baseURL() string { if c.BaseURL != "" { return c.BaseURL @@ -77,7 +119,7 @@ type AuthMethod interface { modifyRequest(req *http.Request) } -// APIKey is an AuthMethod for NewClient that authenticates requests +// APIKey is an [AuthMethod] for [NewClient] that authenticates requests // using an authkey. type APIKey string @@ -91,13 +133,15 @@ func (c *Client) setAuth(r *http.Request) { } } -// NewClient is a convenience method for instantiating a new Client. +// NewClient is a convenience method for instantiating a new [Client]. // // tailnet is the globally unique identifier for a Tailscale network, such // as "example.com" or "user@gmail.com". -// If httpClient is nil, then http.DefaultClient is used. +// If httpClient is nil, then [http.DefaultClient] is used. // "api.tailscale.com" is set as the BaseURL for the returned client // and can be changed manually by the user. +// +// Deprecated: use [tailscale.com/client/tailscale/v2] instead. func NewClient(tailnet string, auth AuthMethod) *Client { return &Client{ tailnet: tailnet, @@ -148,12 +192,14 @@ func (e ErrResponse) Error() string { return fmt.Sprintf("Status: %d, Message: %q", e.Status, e.Message) } -// handleErrorResponse decodes the error message from the server and returns -// an ErrResponse from it. -func handleErrorResponse(b []byte, resp *http.Response) error { +// HandleErrorResponse decodes the error message from the server and returns +// an [ErrResponse] from it. +// +// Deprecated: use [tailscale.com/client/tailscale/v2] instead. +func HandleErrorResponse(b []byte, resp *http.Response) error { var errResp ErrResponse if err := json.Unmarshal(b, &errResp); err != nil { - return err + return fmt.Errorf("json.Unmarshal %q: %w", b, err) } errResp.Status = resp.StatusCode return errResp diff --git a/client/tailscale/tailscale_test.go b/client/tailscale/tailscale_test.go new file mode 100644 index 0000000000000..67379293bd580 --- /dev/null +++ b/client/tailscale/tailscale_test.go @@ -0,0 +1,86 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailscale + +import ( + "net/url" + "testing" +) + +func TestClientBuildURL(t *testing.T) { + c := Client{BaseURL: "http://127.0.0.1:1234"} + for _, tt := range []struct { + desc string + elements []any + want string + }{ + { + desc: "single-element", + elements: []any{"devices"}, + want: "http://127.0.0.1:1234/api/v2/devices", + }, + { + desc: "multiple-elements", + elements: []any{"tailnet", "example.com"}, + want: "http://127.0.0.1:1234/api/v2/tailnet/example.com", + }, + { + desc: "escape-element", + elements: []any{"tailnet", "example dot com?foo=bar"}, + want: `http://127.0.0.1:1234/api/v2/tailnet/example%20dot%20com%3Ffoo=bar`, + }, + { + desc: "url.Values", + elements: []any{"tailnet", "example.com", "acl", url.Values{"details": {"1"}}}, + want: `http://127.0.0.1:1234/api/v2/tailnet/example.com/acl?details=1`, + }, + } { + t.Run(tt.desc, func(t *testing.T) { + got := c.BuildURL(tt.elements...) + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestClientBuildTailnetURL(t *testing.T) { + c := Client{ + BaseURL: "http://127.0.0.1:1234", + tailnet: "example.com", + } + for _, tt := range []struct { + desc string + elements []any + want string + }{ + { + desc: "single-element", + elements: []any{"devices"}, + want: "http://127.0.0.1:1234/api/v2/tailnet/example.com/devices", + }, + { + desc: "multiple-elements", + elements: []any{"devices", 123}, + want: "http://127.0.0.1:1234/api/v2/tailnet/example.com/devices/123", + }, + { + desc: "escape-element", + elements: []any{"foo bar?baz=qux"}, + want: `http://127.0.0.1:1234/api/v2/tailnet/example.com/foo%20bar%3Fbaz=qux`, + }, + { + desc: "url.Values", + elements: []any{"acl", url.Values{"details": {"1"}}}, + want: `http://127.0.0.1:1234/api/v2/tailnet/example.com/acl?details=1`, + }, + } { + t.Run(tt.desc, func(t *testing.T) { + got := c.BuildTailnetURL(tt.elements...) + if got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} diff --git a/client/web/package.json b/client/web/package.json index 4b3afb1df6ef5..c45f7d6a867ec 100644 --- a/client/web/package.json +++ b/client/web/package.json @@ -3,7 +3,7 @@ "version": "0.0.1", "license": "BSD-3-Clause", "engines": { - "node": "18.20.4", + "node": "22.14.0", "yarn": "1.22.19" }, "type": "module", @@ -20,7 +20,7 @@ "zustand": "^4.4.7" }, "devDependencies": { - "@types/node": "^18.16.1", + "@types/node": "^22.14.0", "@types/react": "^18.0.20", "@types/react-dom": "^18.0.6", "@vitejs/plugin-react-swc": "^3.6.0", diff --git a/client/web/src/components/views/login-view.tsx b/client/web/src/components/views/login-view.tsx index b2868bb46c991..f8c15b16dbcaa 100644 --- a/client/web/src/components/views/login-view.tsx +++ b/client/web/src/components/views/login-view.tsx @@ -1,13 +1,11 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -import React, { useState } from "react" +import React from "react" import { useAPI } from "src/api" import TailscaleIcon from "src/assets/icons/tailscale-icon.svg?react" import { NodeData } from "src/types" import Button from "src/ui/button" -import Collapsible from "src/ui/collapsible" -import Input from "src/ui/input" /** * LoginView is rendered when the client is not authenticated @@ -15,8 +13,6 @@ import Input from "src/ui/input" */ export default function LoginView({ data }: { data: NodeData }) { const api = useAPI() - const [controlURL, setControlURL] = useState("") - const [authKey, setAuthKey] = useState("") return (
@@ -88,8 +84,6 @@ export default function LoginView({ data }: { data: NodeData }) { action: "up", data: { Reauthenticate: true, - ControlURL: controlURL, - AuthKey: authKey, }, }) } @@ -98,34 +92,6 @@ export default function LoginView({ data }: { data: NodeData }) { > Log In - -

Auth Key

-

- Connect with a pre-authenticated key.{" "} - - Learn more → - -

- setAuthKey(e.target.value)} - placeholder="tskey-auth-XXX" - /> -

Server URL

-

Base URL of control server.

- setControlURL(e.target.value)} - placeholder="https://login.tailscale.com/" - /> -
)}
diff --git a/client/web/web.go b/client/web/web.go index 04ba2d086334a..6eccdadcfdb65 100644 --- a/client/web/web.go +++ b/client/web/web.go @@ -22,10 +22,11 @@ import ( "time" "github.com/gorilla/csrf" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/client/tailscale/apitype" "tailscale.com/clientupdate" "tailscale.com/envknob" + "tailscale.com/envknob/featureknob" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" @@ -49,7 +50,7 @@ type Server struct { mode ServerMode logf logger.Logf - lc *tailscale.LocalClient + lc *local.Client timeNow func() time.Time // devMode indicates that the server run with frontend assets @@ -88,8 +89,8 @@ type Server struct { type ServerMode string const ( - // LoginServerMode serves a readonly login client for logging a - // node into a tailnet, and viewing a readonly interface of the + // LoginServerMode serves a read-only login client for logging a + // node into a tailnet, and viewing a read-only interface of the // node's current Tailscale settings. // // In this mode, API calls are authenticated via platform auth. @@ -109,7 +110,7 @@ const ( // This mode restricts the app to only being assessible over Tailscale, // and API calls are authenticated via browser sessions associated with // the source's Tailscale identity. If the source browser does not have - // a valid session, a readonly version of the app is displayed. + // a valid session, a read-only version of the app is displayed. ManageServerMode ServerMode = "manage" ) @@ -124,9 +125,9 @@ type ServerOpts struct { // PathPrefix is the URL prefix added to requests by CGI or reverse proxy. PathPrefix string - // LocalClient is the tailscale.LocalClient to use for this web server. + // LocalClient is the local.Client to use for this web server. // If nil, a new one will be created. - LocalClient *tailscale.LocalClient + LocalClient *local.Client // TimeNow optionally provides a time function. // time.Now is used as default. @@ -165,7 +166,7 @@ func NewServer(opts ServerOpts) (s *Server, err error) { return nil, fmt.Errorf("invalid Mode provided") } if opts.LocalClient == nil { - opts.LocalClient = &tailscale.LocalClient{} + opts.LocalClient = &local.Client{} } s = &Server{ mode: opts.Mode, @@ -202,25 +203,9 @@ func NewServer(opts ServerOpts) (s *Server, err error) { } s.assetsHandler, s.assetsCleanup = assetsHandler(s.devMode) - var metric string // clientmetric to report on startup - - // Create handler for "/api" requests with CSRF protection. - // We don't require secure cookies, since the web client is regularly used - // on network appliances that are served on local non-https URLs. - // The client is secured by limiting the interface it listens on, - // or by authenticating requests before they reach the web client. - csrfProtect := csrf.Protect(s.csrfKey(), csrf.Secure(false)) - switch s.mode { - case LoginServerMode: - s.apiHandler = csrfProtect(http.HandlerFunc(s.serveLoginAPI)) - metric = "web_login_client_initialization" - case ReadOnlyServerMode: - s.apiHandler = csrfProtect(http.HandlerFunc(s.serveLoginAPI)) - metric = "web_readonly_client_initialization" - case ManageServerMode: - s.apiHandler = csrfProtect(http.HandlerFunc(s.serveAPI)) - metric = "web_client_initialization" - } + var metric string + s.apiHandler, metric = s.modeAPIHandler(s.mode) + s.apiHandler = s.withCSRF(s.apiHandler) // Don't block startup on reporting metric. // Report in separate go routine with 5 second timeout. @@ -233,6 +218,39 @@ func NewServer(opts ServerOpts) (s *Server, err error) { return s, nil } +func (s *Server) withCSRF(h http.Handler) http.Handler { + csrfProtect := csrf.Protect(s.csrfKey(), csrf.Secure(false)) + + // ref https://github.com/tailscale/tailscale/pull/14822 + // signal to the CSRF middleware that the request is being served over + // plaintext HTTP to skip TLS-only header checks. + withSetPlaintext := func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r = csrf.PlaintextHTTPRequest(r) + h.ServeHTTP(w, r) + }) + } + + // NB: the order of the withSetPlaintext and csrfProtect calls is important + // to ensure that we signal to the CSRF middleware that the request is being + // served over plaintext HTTP and not over TLS as it presumes by default. + return withSetPlaintext(csrfProtect(h)) +} + +func (s *Server) modeAPIHandler(mode ServerMode) (http.Handler, string) { + switch mode { + case LoginServerMode: + return http.HandlerFunc(s.serveLoginAPI), "web_login_client_initialization" + case ReadOnlyServerMode: + return http.HandlerFunc(s.serveLoginAPI), "web_readonly_client_initialization" + case ManageServerMode: + return http.HandlerFunc(s.serveAPI), "web_client_initialization" + default: // invalid mode + log.Fatalf("invalid mode: %v", mode) + } + return nil, "" +} + func (s *Server) Shutdown() { s.logf("web.Server: shutting down") if s.assetsCleanup != nil { @@ -317,7 +335,8 @@ func (s *Server) requireTailscaleIP(w http.ResponseWriter, r *http.Request) (han ipv6ServiceHost = "[" + tsaddr.TailscaleServiceIPv6String + "]" ) // allow requests on quad-100 (or ipv6 equivalent) - if r.Host == ipv4ServiceHost || r.Host == ipv6ServiceHost { + host := strings.TrimSuffix(r.Host, ":80") + if host == ipv4ServiceHost || host == ipv6ServiceHost { return false } @@ -694,16 +713,16 @@ func (s *Server) serveAPIAuth(w http.ResponseWriter, r *http.Request) { switch { case sErr != nil && errors.Is(sErr, errNotUsingTailscale): s.lc.IncrementCounter(r.Context(), "web_client_viewing_local", 1) - resp.Authorized = false // restricted to the readonly view + resp.Authorized = false // restricted to the read-only view case sErr != nil && errors.Is(sErr, errNotOwner): s.lc.IncrementCounter(r.Context(), "web_client_viewing_not_owner", 1) - resp.Authorized = false // restricted to the readonly view + resp.Authorized = false // restricted to the read-only view case sErr != nil && errors.Is(sErr, errTaggedLocalSource): s.lc.IncrementCounter(r.Context(), "web_client_viewing_local_tag", 1) - resp.Authorized = false // restricted to the readonly view + resp.Authorized = false // restricted to the read-only view case sErr != nil && errors.Is(sErr, errTaggedRemoteSource): s.lc.IncrementCounter(r.Context(), "web_client_viewing_remote_tag", 1) - resp.Authorized = false // restricted to the readonly view + resp.Authorized = false // restricted to the read-only view case sErr != nil && !errors.Is(sErr, errNoSession): // Any other error. http.Error(w, sErr.Error(), http.StatusInternalServerError) @@ -803,8 +822,8 @@ type nodeData struct { DeviceName string TailnetName string // TLS cert name DomainName string - IPv4 string - IPv6 string + IPv4 netip.Addr + IPv6 netip.Addr OS string IPNVersion string @@ -863,10 +882,14 @@ func (s *Server) serveGetNodeData(w http.ResponseWriter, r *http.Request) { return } filterRules, _ := s.lc.DebugPacketFilterRules(r.Context()) + ipv4, ipv6 := s.selfNodeAddresses(r, st) + data := &nodeData{ ID: st.Self.ID, Status: st.BackendState, DeviceName: strings.Split(st.Self.DNSName, ".")[0], + IPv4: ipv4, + IPv6: ipv6, OS: st.Self.OS, IPNVersion: strings.Split(st.Version, "-")[0], Profile: st.User[st.Self.UserID], @@ -886,10 +909,6 @@ func (s *Server) serveGetNodeData(w http.ResponseWriter, r *http.Request) { ACLAllowsAnyIncomingTraffic: s.aclsAllowAccess(filterRules), } - ipv4, ipv6 := s.selfNodeAddresses(r, st) - data.IPv4 = ipv4.String() - data.IPv6 = ipv6.String() - if hostinfo.GetEnvType() == hostinfo.HomeAssistantAddOn && data.URLPrefix == "" { // X-Ingress-Path is the path prefix in use for Home Assistant // https://developers.home-assistant.io/docs/add-ons/presentation#ingress @@ -960,37 +979,16 @@ func (s *Server) serveGetNodeData(w http.ResponseWriter, r *http.Request) { } func availableFeatures() map[string]bool { - env := hostinfo.GetEnvType() features := map[string]bool{ "advertise-exit-node": true, // available on all platforms "advertise-routes": true, // available on all platforms - "use-exit-node": canUseExitNode(env) == nil, - "ssh": envknob.CanRunTailscaleSSH() == nil, + "use-exit-node": featureknob.CanUseExitNode() == nil, + "ssh": featureknob.CanRunTailscaleSSH() == nil, "auto-update": version.IsUnstableBuild() && clientupdate.CanAutoUpdate(), } - if env == hostinfo.HomeAssistantAddOn { - // Setting SSH on Home Assistant causes trouble on startup - // (since the flag is not being passed to `tailscale up`). - // Although Tailscale SSH does work here, - // it's not terribly useful since it's running in a separate container. - features["ssh"] = false - } return features } -func canUseExitNode(env hostinfo.EnvType) error { - switch dist := distro.Get(); dist { - case distro.Synology, // see https://github.com/tailscale/tailscale/issues/1995 - distro.QNAP, - distro.Unraid: - return fmt.Errorf("Tailscale exit nodes cannot be used on %s.", dist) - } - if env == hostinfo.HomeAssistantAddOn { - return errors.New("Tailscale exit nodes cannot be used on Home Assistant.") - } - return nil -} - // aclsAllowAccess returns whether tailnet ACLs (as expressed in the provided filter rules) // permit any devices to access the local web client. // This does not currently check whether a specific device can connect, just any device. diff --git a/client/web/web_test.go b/client/web/web_test.go index 3c5543c12014c..2a6bc787ac396 100644 --- a/client/web/web_test.go +++ b/client/web/web_test.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "net/http" + "net/http/cookiejar" "net/http/httptest" "net/netip" "net/url" @@ -20,12 +21,14 @@ import ( "time" "github.com/google/go-cmp/cmp" - "tailscale.com/client/tailscale" + "github.com/gorilla/csrf" + "tailscale.com/client/local" "tailscale.com/client/tailscale/apitype" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/net/memnet" "tailscale.com/tailcfg" + "tailscale.com/tstest/nettest" "tailscale.com/types/views" "tailscale.com/util/httpm" ) @@ -120,7 +123,7 @@ func TestServeAPI(t *testing.T) { s := &Server{ mode: ManageServerMode, - lc: &tailscale.LocalClient{Dial: lal.Dial}, + lc: &local.Client{Dial: lal.Dial}, timeNow: time.Now, } @@ -288,7 +291,7 @@ func TestGetTailscaleBrowserSession(t *testing.T) { s := &Server{ timeNow: time.Now, - lc: &tailscale.LocalClient{Dial: lal.Dial}, + lc: &local.Client{Dial: lal.Dial}, } // Add some browser sessions to cache state. @@ -457,7 +460,7 @@ func TestAuthorizeRequest(t *testing.T) { s := &Server{ mode: ManageServerMode, - lc: &tailscale.LocalClient{Dial: lal.Dial}, + lc: &local.Client{Dial: lal.Dial}, timeNow: time.Now, } validCookie := "ts-cookie" @@ -572,7 +575,7 @@ func TestServeAuth(t *testing.T) { s := &Server{ mode: ManageServerMode, - lc: &tailscale.LocalClient{Dial: lal.Dial}, + lc: &local.Client{Dial: lal.Dial}, timeNow: func() time.Time { return timeNow }, newAuthURL: mockNewAuthURL, waitAuthURL: mockWaitAuthURL, @@ -914,7 +917,7 @@ func TestServeAPIAuthMetricLogging(t *testing.T) { s := &Server{ mode: ManageServerMode, - lc: &tailscale.LocalClient{Dial: lal.Dial}, + lc: &local.Client{Dial: lal.Dial}, timeNow: func() time.Time { return timeNow }, newAuthURL: mockNewAuthURL, waitAuthURL: mockWaitAuthURL, @@ -1126,7 +1129,7 @@ func TestRequireTailscaleIP(t *testing.T) { s := &Server{ mode: ManageServerMode, - lc: &tailscale.LocalClient{Dial: lal.Dial}, + lc: &local.Client{Dial: lal.Dial}, timeNow: time.Now, logf: t.Logf, } @@ -1175,6 +1178,16 @@ func TestRequireTailscaleIP(t *testing.T) { target: "http://[fd7a:115c:a1e0::53]/", wantHandled: false, }, + { + name: "quad-100:80", + target: "http://100.100.100.100:80/", + wantHandled: false, + }, + { + name: "ipv6-service-addr:80", + target: "http://[fd7a:115c:a1e0::53]:80/", + wantHandled: false, + }, } for _, tt := range tests { @@ -1477,3 +1490,83 @@ func mockWaitAuthURL(_ context.Context, id string, src tailcfg.NodeID) (*tailcfg return nil, errors.New("unknown id") } } + +func TestCSRFProtect(t *testing.T) { + s := &Server{} + + mux := http.NewServeMux() + mux.HandleFunc("GET /test/csrf-token", func(w http.ResponseWriter, r *http.Request) { + token := csrf.Token(r) + _, err := io.WriteString(w, token) + if err != nil { + t.Fatal(err) + } + }) + mux.HandleFunc("POST /test/csrf-protected", func(w http.ResponseWriter, r *http.Request) { + _, err := io.WriteString(w, "ok") + if err != nil { + t.Fatal(err) + } + }) + h := s.withCSRF(mux) + ser := nettest.NewHTTPServer(nettest.GetNetwork(t), h) + defer ser.Close() + + jar, err := cookiejar.New(nil) + if err != nil { + t.Fatalf("unable to construct cookie jar: %v", err) + } + + client := ser.Client() + client.Jar = jar + + // make GET request to populate cookie jar + resp, err := client.Get(ser.URL + "/test/csrf-token") + if err != nil { + t.Fatalf("unable to make request: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %v", resp.Status) + } + tokenBytes, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unable to read body: %v", err) + } + + csrfToken := strings.TrimSpace(string(tokenBytes)) + if csrfToken == "" { + t.Fatal("empty csrf token") + } + + // make a POST request without the CSRF header; ensure it fails + resp, err = client.Post(ser.URL+"/test/csrf-protected", "text/plain", nil) + if err != nil { + t.Fatalf("unable to make request: %v", err) + } + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("unexpected status: %v", resp.Status) + } + + // make a POST request with the CSRF header; ensure it succeeds + req, err := http.NewRequest("POST", ser.URL+"/test/csrf-protected", nil) + if err != nil { + t.Fatalf("error building request: %v", err) + } + req.Header.Set("X-CSRF-Token", csrfToken) + resp, err = client.Do(req) + if err != nil { + t.Fatalf("unable to make request: %v", err) + } + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %v", resp.Status) + } + defer resp.Body.Close() + out, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("unable to read body: %v", err) + } + if string(out) != "ok" { + t.Fatalf("unexpected body: %q", out) + } +} diff --git a/client/web/yarn.lock b/client/web/yarn.lock index 2c8fca5e53e9d..a9b2ae8767b99 100644 --- a/client/web/yarn.lock +++ b/client/web/yarn.lock @@ -1880,12 +1880,12 @@ resolved "https://registry.yarnpkg.com/@types/json5/-/json5-0.0.29.tgz#ee28707ae94e11d2b827bcbe5270bcea7f3e71ee" integrity sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ== -"@types/node@^18.16.1": - version "18.19.18" - resolved "https://registry.yarnpkg.com/@types/node/-/node-18.19.18.tgz#7526471b28828d1fef1f7e4960fb9477e6e4369c" - integrity sha512-80CP7B8y4PzZF0GWx15/gVWRrB5y/bIjNI84NK3cmQJu0WZwvmj2WMA5LcofQFVfLqqCSp545+U2LsrVzX36Zg== +"@types/node@^22.14.0": + version "22.14.0" + resolved "https://registry.yarnpkg.com/@types/node/-/node-22.14.0.tgz#d3bfa3936fef0dbacd79ea3eb17d521c628bb47e" + integrity sha512-Kmpl+z84ILoG+3T/zQFyAJsU6EPTmOCj8/2+83fSN6djd6I4o7uOuGIH6vq3PrjY5BGitSbFuMN18j3iknubbA== dependencies: - undici-types "~5.26.4" + undici-types "~6.21.0" "@types/parse-json@^4.0.0": version "4.0.2" @@ -5124,10 +5124,10 @@ unbox-primitive@^1.0.2: has-symbols "^1.0.3" which-boxed-primitive "^1.0.2" -undici-types@~5.26.4: - version "5.26.5" - resolved "https://registry.yarnpkg.com/undici-types/-/undici-types-5.26.5.tgz#bcd539893d00b56e964fd2657a4866b221a65617" - integrity sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA== +undici-types@~6.21.0: + version "6.21.0" + resolved "https://registry.yarnpkg.com/undici-types/-/undici-types-6.21.0.tgz#691d00af3909be93a7faa13be61b3a5b50ef12cb" + integrity sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ== unicode-canonical-property-names-ecmascript@^2.0.0: version "2.0.0" diff --git a/clientupdate/clientupdate.go b/clientupdate/clientupdate.go index 7fa84d67f9d8a..ffd3fb03bb80d 100644 --- a/clientupdate/clientupdate.go +++ b/clientupdate/clientupdate.go @@ -27,6 +27,8 @@ import ( "strconv" "strings" + "tailscale.com/hostinfo" + "tailscale.com/types/lazy" "tailscale.com/types/logger" "tailscale.com/util/cmpver" "tailscale.com/version" @@ -169,6 +171,12 @@ func NewUpdater(args Arguments) (*Updater, error) { type updateFunction func() error func (up *Updater) getUpdateFunction() (fn updateFunction, canAutoUpdate bool) { + hi := hostinfo.New() + // We don't know how to update custom tsnet binaries, it's up to the user. + if hi.Package == "tsnet" { + return nil, false + } + switch runtime.GOOS { case "windows": return up.updateWindows, true @@ -242,9 +250,13 @@ func (up *Updater) getUpdateFunction() (fn updateFunction, canAutoUpdate bool) { return nil, false } +var canAutoUpdateCache lazy.SyncValue[bool] + // CanAutoUpdate reports whether auto-updating via the clientupdate package // is supported for the current os/distro. -func CanAutoUpdate() bool { +func CanAutoUpdate() bool { return canAutoUpdateCache.Get(canAutoUpdateUncached) } + +func canAutoUpdateUncached() bool { if version.IsMacSysExt() { // Macsys uses Sparkle for auto-updates, which doesn't have an update // function in this package. diff --git a/clientupdate/clientupdate_windows.go b/clientupdate/clientupdate_windows.go index 9737229745332..b79d447ad4d30 100644 --- a/clientupdate/clientupdate_windows.go +++ b/clientupdate/clientupdate_windows.go @@ -16,6 +16,7 @@ import ( "path/filepath" "runtime" "strings" + "time" "github.com/google/uuid" "golang.org/x/sys/windows" @@ -34,6 +35,12 @@ const ( // It is used to re-launch the GUI process (tailscale-ipn.exe) after // install is complete. winExePathEnv = "TS_UPDATE_WIN_EXE_PATH" + // winVersionEnv is the environment variable that is set along with + // winMSIEnv and carries the version of tailscale that is being installed. + // It is used for logging purposes. + winVersionEnv = "TS_UPDATE_WIN_VERSION" + // updaterPrefix is the prefix for the temporary executable created by [makeSelfCopy]. + updaterPrefix = "tailscale-updater" ) func makeSelfCopy() (origPathExe, tmpPathExe string, err error) { @@ -46,7 +53,7 @@ func makeSelfCopy() (origPathExe, tmpPathExe string, err error) { return "", "", err } defer f.Close() - f2, err := os.CreateTemp("", "tailscale-updater-*.exe") + f2, err := os.CreateTemp("", updaterPrefix+"-*.exe") if err != nil { return "", "", err } @@ -137,7 +144,7 @@ you can run the command prompt as Administrator one of these ways: up.Logf("authenticode verification succeeded") up.Logf("making tailscale.exe copy to switch to...") - up.cleanupOldDownloads(filepath.Join(os.TempDir(), "tailscale-updater-*.exe")) + up.cleanupOldDownloads(filepath.Join(os.TempDir(), updaterPrefix+"-*.exe")) selfOrig, selfCopy, err := makeSelfCopy() if err != nil { return err @@ -146,7 +153,7 @@ you can run the command prompt as Administrator one of these ways: up.Logf("running tailscale.exe copy for final install...") cmd := exec.Command(selfCopy, "update") - cmd.Env = append(os.Environ(), winMSIEnv+"="+msiTarget, winExePathEnv+"="+selfOrig) + cmd.Env = append(os.Environ(), winMSIEnv+"="+msiTarget, winExePathEnv+"="+selfOrig, winVersionEnv+"="+ver) cmd.Stdout = up.Stderr cmd.Stderr = up.Stderr cmd.Stdin = os.Stdin @@ -162,23 +169,62 @@ you can run the command prompt as Administrator one of these ways: func (up *Updater) installMSI(msi string) error { var err error for tries := 0; tries < 2; tries++ { - cmd := exec.Command("msiexec.exe", "/i", filepath.Base(msi), "/quiet", "/norestart", "/qn") + // msiexec.exe requires exclusive access to the log file, so create a dedicated one for each run. + installLogPath := up.startNewLogFile("tailscale-installer", os.Getenv(winVersionEnv)) + up.Logf("Install log: %s", installLogPath) + cmd := exec.Command("msiexec.exe", "/i", filepath.Base(msi), "/quiet", "/norestart", "/qn", "/L*v", installLogPath) cmd.Dir = filepath.Dir(msi) cmd.Stdout = up.Stdout cmd.Stderr = up.Stderr cmd.Stdin = os.Stdin err = cmd.Run() - if err == nil { - break + switch err := err.(type) { + case nil: + // Success. + return nil + case *exec.ExitError: + // For possible error codes returned by Windows Installer, see + // https://web.archive.org/web/20250409144914/https://learn.microsoft.com/en-us/windows/win32/msi/error-codes + switch windows.Errno(err.ExitCode()) { + case windows.ERROR_SUCCESS_REBOOT_REQUIRED: + // In most cases, updating Tailscale should not require a reboot. + // If it does, it might be because we failed to close the GUI + // and the installer couldn't replace tailscale-ipn.exe. + // The old GUI will continue to run until the next reboot. + // Not ideal, but also not a retryable error. + up.Logf("[unexpected] reboot required") + return nil + case windows.ERROR_SUCCESS_REBOOT_INITIATED: + // Same as above, but perhaps the device is configured to prompt + // the user to reboot and the user has chosen to reboot now. + up.Logf("[unexpected] reboot initiated") + return nil + case windows.ERROR_INSTALL_ALREADY_RUNNING: + // The Windows Installer service is currently busy. + // It could be our own install initiated by user/MDM/GP, another MSI install or perhaps a Windows Update install. + // Anyway, we can't do anything about it right now. The user (or tailscaled) can retry later. + // Retrying now will likely fail, and is risky since we might uninstall the current version + // and then fail to install the new one, leaving the user with no Tailscale at all. + // + // TODO(nickkhyl,awly): should we check if this is actually a downgrade before uninstalling the current version? + // Also, maybe keep retrying the install longer if we uninstalled the current version due to a failed install attempt? + up.Logf("another installation is already in progress") + return err + } + default: + // Everything else is a retryable error. } + up.Logf("Install attempt failed: %v", err) uninstallVersion := up.currentVersion if v := os.Getenv("TS_DEBUG_UNINSTALL_VERSION"); v != "" { uninstallVersion = v } + uninstallLogPath := up.startNewLogFile("tailscale-uninstaller", uninstallVersion) // Assume it's a downgrade, which msiexec won't permit. Uninstall our current version first. up.Logf("Uninstalling current version %q for downgrade...", uninstallVersion) - cmd = exec.Command("msiexec.exe", "/x", msiUUIDForVersion(uninstallVersion), "/norestart", "/qn") + up.Logf("Uninstall log: %s", uninstallLogPath) + cmd = exec.Command("msiexec.exe", "/x", msiUUIDForVersion(uninstallVersion), "/norestart", "/qn", "/L*v", uninstallLogPath) cmd.Stdout = up.Stdout cmd.Stderr = up.Stderr cmd.Stdin = os.Stdin @@ -205,12 +251,14 @@ func (up *Updater) switchOutputToFile() (io.Closer, error) { var logFilePath string exePath, err := os.Executable() if err != nil { - logFilePath = filepath.Join(os.TempDir(), "tailscale-updater.log") + logFilePath = up.startNewLogFile(updaterPrefix, os.Getenv(winVersionEnv)) } else { - logFilePath = strings.TrimSuffix(exePath, ".exe") + ".log" + // Use the same suffix as the self-copy executable. + suffix := strings.TrimSuffix(strings.TrimPrefix(filepath.Base(exePath), updaterPrefix), ".exe") + logFilePath = up.startNewLogFile(updaterPrefix, os.Getenv(winVersionEnv)+suffix) } - up.Logf("writing update output to %q", logFilePath) + up.Logf("writing update output to: %s", logFilePath) logFile, err := os.Create(logFilePath) if err != nil { return nil, err @@ -223,3 +271,20 @@ func (up *Updater) switchOutputToFile() (io.Closer, error) { up.Stderr = logFile return logFile, nil } + +// startNewLogFile returns a name for a new log file. +// It cleans up any old log files with the same baseNamePrefix. +func (up *Updater) startNewLogFile(baseNamePrefix, baseNameSuffix string) string { + baseName := fmt.Sprintf("%s-%s-%s.log", baseNamePrefix, + time.Now().Format("20060102T150405"), baseNameSuffix) + + dir := filepath.Join(os.Getenv("ProgramData"), "Tailscale", "Logs") + if err := os.MkdirAll(dir, 0700); err != nil { + up.Logf("failed to create log directory: %v", err) + return filepath.Join(os.TempDir(), baseName) + } + + // TODO(nickkhyl): preserve up to N old log files? + up.cleanupOldDownloads(filepath.Join(dir, baseNamePrefix+"-*.log")) + return filepath.Join(dir, baseName) +} diff --git a/cmd/addlicense/main.go b/cmd/addlicense/main.go index a8fd9dd4ab96a..1cd1b0f19354a 100644 --- a/cmd/addlicense/main.go +++ b/cmd/addlicense/main.go @@ -18,12 +18,12 @@ var ( ) func usage() { - fmt.Fprintf(os.Stderr, ` + fmt.Fprint(os.Stderr, ` usage: addlicense -file FILE `[1:]) flag.PrintDefaults() - fmt.Fprintf(os.Stderr, ` + fmt.Fprint(os.Stderr, ` addlicense adds a Tailscale license to the beginning of file. It is intended for use with 'go generate', so it also runs a subcommand, diff --git a/cmd/checkmetrics/checkmetrics.go b/cmd/checkmetrics/checkmetrics.go new file mode 100644 index 0000000000000..fb9e8ab4c61ec --- /dev/null +++ b/cmd/checkmetrics/checkmetrics.go @@ -0,0 +1,131 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// checkmetrics validates that all metrics in the tailscale client-metrics +// are documented in a given path or URL. +package main + +import ( + "context" + "flag" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "os" + "strings" + "time" + + "tailscale.com/ipn/store/mem" + "tailscale.com/tsnet" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/util/httpm" +) + +var ( + kbPath = flag.String("kb-path", "", "filepath to the client-metrics knowledge base") + kbUrl = flag.String("kb-url", "", "URL to the client-metrics knowledge base page") +) + +func main() { + flag.Parse() + if *kbPath == "" && *kbUrl == "" { + log.Fatalf("either -kb-path or -kb-url must be set") + } + + var control testcontrol.Server + ts := httptest.NewServer(&control) + defer ts.Close() + + td, err := os.MkdirTemp("", "testcontrol") + if err != nil { + log.Fatal(err) + } + defer os.RemoveAll(td) + + // tsnet is used not used as a Tailscale client, but as a way to + // boot up Tailscale, have all the metrics registered, and then + // verifiy that all the metrics are documented. + tsn := &tsnet.Server{ + Dir: td, + Store: new(mem.Store), + UserLogf: log.Printf, + Ephemeral: true, + ControlURL: ts.URL, + } + if err := tsn.Start(); err != nil { + log.Fatal(err) + } + defer tsn.Close() + + log.Printf("checking that all metrics are documented, looking for: %s", tsn.Sys().UserMetricsRegistry().MetricNames()) + + if *kbPath != "" { + kb, err := readKB(*kbPath) + if err != nil { + log.Fatalf("reading kb: %v", err) + } + missing := undocumentedMetrics(kb, tsn.Sys().UserMetricsRegistry().MetricNames()) + + if len(missing) > 0 { + log.Fatalf("found undocumented metrics in %q: %v", *kbPath, missing) + } + } + + if *kbUrl != "" { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + kb, err := getKB(ctx, *kbUrl) + if err != nil { + log.Fatalf("getting kb: %v", err) + } + missing := undocumentedMetrics(kb, tsn.Sys().UserMetricsRegistry().MetricNames()) + + if len(missing) > 0 { + log.Fatalf("found undocumented metrics in %q: %v", *kbUrl, missing) + } + } +} + +func readKB(path string) (string, error) { + b, err := os.ReadFile(path) + if err != nil { + return "", fmt.Errorf("reading file: %w", err) + } + + return string(b), nil +} + +func getKB(ctx context.Context, url string) (string, error) { + req, err := http.NewRequestWithContext(ctx, httpm.GET, url, nil) + if err != nil { + return "", fmt.Errorf("creating request: %w", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("getting kb page: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + b, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("reading body: %w", err) + } + return string(b), nil +} + +func undocumentedMetrics(b string, metrics []string) []string { + var missing []string + for _, metric := range metrics { + if !strings.Contains(b, metric) { + missing = append(missing, metric) + } + } + return missing +} diff --git a/cmd/containerboot/certs.go b/cmd/containerboot/certs.go new file mode 100644 index 0000000000000..504ef7988072b --- /dev/null +++ b/cmd/containerboot/certs.go @@ -0,0 +1,156 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package main + +import ( + "context" + "fmt" + "log" + "net" + "sync" + "time" + + "tailscale.com/ipn" + "tailscale.com/util/goroutines" + "tailscale.com/util/mak" +) + +// certManager is responsible for issuing certificates for known domains and for +// maintaining a loop that re-attempts issuance daily. +// Currently cert manager logic is only run on ingress ProxyGroup replicas that are responsible for managing certs for +// HA Ingress HTTPS endpoints ('write' replicas). +type certManager struct { + lc localClient + tracker goroutines.Tracker // tracks running goroutines + mu sync.Mutex // guards the following + // certLoops contains a map of DNS names, for which we currently need to + // manage certs to cancel functions that allow stopping a goroutine when + // we no longer need to manage certs for the DNS name. + certLoops map[string]context.CancelFunc +} + +// ensureCertLoops ensures that, for all currently managed Service HTTPS +// endpoints, there is a cert loop responsible for issuing and ensuring the +// renewal of the TLS certs. +// ServeConfig must not be nil. +func (cm *certManager) ensureCertLoops(ctx context.Context, sc *ipn.ServeConfig) error { + if sc == nil { + return fmt.Errorf("[unexpected] ensureCertLoops called with nil ServeConfig") + } + currentDomains := make(map[string]bool) + const httpsPort = "443" + for _, service := range sc.Services { + for hostPort := range service.Web { + domain, port, err := net.SplitHostPort(string(hostPort)) + if err != nil { + return fmt.Errorf("[unexpected] unable to parse HostPort %s", hostPort) + } + if port != httpsPort { // HA Ingress' HTTP endpoint + continue + } + currentDomains[domain] = true + } + } + cm.mu.Lock() + defer cm.mu.Unlock() + for domain := range currentDomains { + if _, exists := cm.certLoops[domain]; !exists { + cancelCtx, cancel := context.WithCancel(ctx) + mak.Set(&cm.certLoops, domain, cancel) + // Note that most of the issuance anyway happens + // serially because the cert client has a shared lock + // that's held during any issuance. + cm.tracker.Go(func() { cm.runCertLoop(cancelCtx, domain) }) + } + } + + // Stop goroutines for domain names that are no longer in the config. + for domain, cancel := range cm.certLoops { + if !currentDomains[domain] { + cancel() + delete(cm.certLoops, domain) + } + } + return nil +} + +// runCertLoop: +// - calls localAPI certificate endpoint to ensure that certs are issued for the +// given domain name +// - calls localAPI certificate endpoint daily to ensure that certs are renewed +// - if certificate issuance failed retries after an exponential backoff period +// starting at 1 minute and capped at 24 hours. Reset the backoff once issuance succeeds. +// Note that renewal check also happens when the node receives an HTTPS request and it is possible that certs get +// renewed at that point. Renewal here is needed to prevent the shared certs from expiry in edge cases where the 'write' +// replica does not get any HTTPS requests. +// https://letsencrypt.org/docs/integration-guide/#retrying-failures +func (cm *certManager) runCertLoop(ctx context.Context, domain string) { + const ( + normalInterval = 24 * time.Hour // regular renewal check + initialRetry = 1 * time.Minute // initial backoff after a failure + maxRetryInterval = 24 * time.Hour // max backoff period + ) + timer := time.NewTimer(0) // fire off timer immediately + defer timer.Stop() + retryCount := 0 + for { + select { + case <-ctx.Done(): + return + case <-timer.C: + // We call the certificate endpoint, but don't do anything + // with the returned certs here. + // The call to the certificate endpoint will ensure that + // certs are issued/renewed as needed and stored in the + // relevant state store. For example, for HA Ingress + // 'write' replica, the cert and key will be stored in a + // Kubernetes Secret named after the domain for which we + // are issuing. + // Note that renewals triggered by the call to the + // certificates endpoint here and by renewal check + // triggered during a call to node's HTTPS endpoint + // share the same state/renewal lock mechanism, so we + // should not run into redundant issuances during + // concurrent renewal checks. + // TODO(irbekrm): maybe it is worth adding a new + // issuance endpoint that explicitly only triggers + // issuance and stores certs in the relevant store, but + // does not return certs to the caller? + + // An issuance holds a shared lock, so we need to avoid + // a situation where other services cannot issue certs + // because a single one is holding the lock. + ctxT, cancel := context.WithTimeout(ctx, time.Second*300) + defer cancel() + _, _, err := cm.lc.CertPair(ctxT, domain) + if err != nil { + log.Printf("error refreshing certificate for %s: %v", domain, err) + } + var nextInterval time.Duration + // TODO(irbekrm): distinguish between LE rate limit + // errors and other error types like transient network + // errors. + if err == nil { + retryCount = 0 + nextInterval = normalInterval + } else { + retryCount++ + // Calculate backoff: initialRetry * 2^(retryCount-1) + // For retryCount=1: 1min * 2^0 = 1min + // For retryCount=2: 1min * 2^1 = 2min + // For retryCount=3: 1min * 2^2 = 4min + backoff := initialRetry * time.Duration(1<<(retryCount-1)) + if backoff > maxRetryInterval { + backoff = maxRetryInterval + } + nextInterval = backoff + log.Printf("Error refreshing certificate for %s (retry %d): %v. Will retry in %v\n", + domain, retryCount, err, nextInterval) + } + timer.Reset(nextInterval) + } + } +} diff --git a/cmd/containerboot/certs_test.go b/cmd/containerboot/certs_test.go new file mode 100644 index 0000000000000..577311ea36a64 --- /dev/null +++ b/cmd/containerboot/certs_test.go @@ -0,0 +1,229 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package main + +import ( + "context" + "testing" + "time" + + "tailscale.com/ipn" + "tailscale.com/tailcfg" +) + +// TestEnsureCertLoops tests that the certManager correctly starts and stops +// update loops for certs when the serve config changes. It tracks goroutine +// count and uses that as a validator that the expected number of cert loops are +// running. +func TestEnsureCertLoops(t *testing.T) { + tests := []struct { + name string + initialConfig *ipn.ServeConfig + updatedConfig *ipn.ServeConfig + initialGoroutines int64 // after initial serve config is applied + updatedGoroutines int64 // after updated serve config is applied + wantErr bool + }{ + { + name: "empty_serve_config", + initialConfig: &ipn.ServeConfig{}, + initialGoroutines: 0, + }, + { + name: "nil_serve_config", + initialConfig: nil, + initialGoroutines: 0, + wantErr: true, + }, + { + name: "empty_to_one_service", + initialConfig: &ipn.ServeConfig{}, + updatedConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + initialGoroutines: 0, + updatedGoroutines: 1, + }, + { + name: "single_service", + initialConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + initialGoroutines: 1, + }, + { + name: "multiple_services", + initialConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + "svc:my-other-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-other-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + initialGoroutines: 2, // one loop per domain across all services + }, + { + name: "ignore_non_https_ports", + initialConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + "my-app.tailnetxyz.ts.net:80": {}, + }, + }, + }, + }, + initialGoroutines: 1, // only one loop for the 443 endpoint + }, + { + name: "remove_domain", + initialConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + "svc:my-other-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-other-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + updatedConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + initialGoroutines: 2, // initially two loops (one per service) + updatedGoroutines: 1, // one loop after removing service2 + }, + { + name: "add_domain", + initialConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + updatedConfig: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:my-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-app.tailnetxyz.ts.net:443": {}, + }, + }, + "svc:my-other-app": { + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "my-other-app.tailnetxyz.ts.net:443": {}, + }, + }, + }, + }, + initialGoroutines: 1, + updatedGoroutines: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + cm := &certManager{ + lc: &fakeLocalClient{}, + certLoops: make(map[string]context.CancelFunc), + } + + allDone := make(chan bool, 1) + defer cm.tracker.AddDoneCallback(func() { + cm.mu.Lock() + defer cm.mu.Unlock() + if cm.tracker.RunningGoroutines() > 0 { + return + } + select { + case allDone <- true: + default: + } + })() + + err := cm.ensureCertLoops(ctx, tt.initialConfig) + if (err != nil) != tt.wantErr { + t.Fatalf("ensureCertLoops() error = %v", err) + } + + if got := cm.tracker.RunningGoroutines(); got != tt.initialGoroutines { + t.Errorf("after initial config: got %d running goroutines, want %d", got, tt.initialGoroutines) + } + + if tt.updatedConfig != nil { + if err := cm.ensureCertLoops(ctx, tt.updatedConfig); err != nil { + t.Fatalf("ensureCertLoops() error on update = %v", err) + } + + // Although starting goroutines and cancelling + // the context happens in the main goroutine, it + // the actual goroutine exit when a context is + // cancelled does not- so wait for a bit for the + // running goroutine count to reach the expected + // number. + deadline := time.After(5 * time.Second) + for { + if got := cm.tracker.RunningGoroutines(); got == tt.updatedGoroutines { + break + } + select { + case <-deadline: + t.Fatalf("timed out waiting for goroutine count to reach %d, currently at %d", + tt.updatedGoroutines, cm.tracker.RunningGoroutines()) + case <-time.After(10 * time.Millisecond): + continue + } + } + } + + if tt.updatedGoroutines == 0 { + return // no goroutines to wait for + } + // cancel context to make goroutines exit + cancel() + select { + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for goroutine to finish") + case <-allDone: + } + }) + } +} diff --git a/cmd/containerboot/egressservices.go b/cmd/containerboot/egressservices.go new file mode 100644 index 0000000000000..71141f17a9bb6 --- /dev/null +++ b/cmd/containerboot/egressservices.go @@ -0,0 +1,766 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net/http" + "net/netip" + "os" + "path/filepath" + "reflect" + "strconv" + "strings" + "time" + + "github.com/fsnotify/fsnotify" + "tailscale.com/client/local" + "tailscale.com/ipn" + "tailscale.com/kube/egressservices" + "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" + "tailscale.com/syncs" + "tailscale.com/tailcfg" + "tailscale.com/util/httpm" + "tailscale.com/util/linuxfw" + "tailscale.com/util/mak" +) + +const tailscaleTunInterface = "tailscale0" + +// Modified using a build flag to speed up tests. +var testSleepDuration string + +// This file contains functionality to run containerboot as a proxy that can +// route cluster traffic to one or more tailnet targets, based on portmapping +// rules read from a configfile. Currently (9/2024) this is only used for the +// Kubernetes operator egress proxies. + +// egressProxy knows how to configure firewall rules to route cluster traffic to +// one or more tailnet services. +type egressProxy struct { + cfgPath string // path to a directory with egress services config files + + nfr linuxfw.NetfilterRunner // never nil + + kc kubeclient.Client // never nil + stateSecret string // name of the kube state Secret + + tsClient *local.Client // never nil + + netmapChan chan ipn.Notify // chan to receive netmap updates on + + podIPv4 string // never empty string, currently only IPv4 is supported + + // tailnetFQDNs is the egress service FQDN to tailnet IP mappings that + // were last used to configure firewall rules for this proxy. + // TODO(irbekrm): target addresses are also stored in the state Secret. + // Evaluate whether we should retrieve them from there and not store in + // memory at all. + targetFQDNs map[string][]netip.Prefix + + tailnetAddrs []netip.Prefix // tailnet IPs of this tailnet device + + // shortSleep is the backoff sleep between healthcheck endpoint calls - can be overridden in tests. + shortSleep time.Duration + // longSleep is the time to sleep after the routing rules are updated to increase the chance that kube + // proxies on all nodes have updated their routing configuration. It can be configured to 0 in + // tests. + longSleep time.Duration + // client is a client that can send HTTP requests. + client httpClient +} + +// httpClient is a client that can send HTTP requests and can be mocked in tests. +type httpClient interface { + Do(*http.Request) (*http.Response, error) +} + +// run configures egress proxy firewall rules and ensures that the firewall rules are reconfigured when: +// - the mounted egress config has changed +// - the proxy's tailnet IP addresses have changed +// - tailnet IPs have changed for any backend targets specified by tailnet FQDN +func (ep *egressProxy) run(ctx context.Context, n ipn.Notify, opts egressProxyRunOpts) error { + ep.configure(opts) + var tickChan <-chan time.Time + var eventChan <-chan fsnotify.Event + // TODO (irbekrm): take a look if this can be pulled into a single func + // shared with serve config loader. + if w, err := fsnotify.NewWatcher(); err != nil { + log.Printf("failed to create fsnotify watcher, timer-only mode: %v", err) + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + tickChan = ticker.C + } else { + defer w.Close() + if err := w.Add(ep.cfgPath); err != nil { + return fmt.Errorf("failed to add fsnotify watch: %w", err) + } + eventChan = w.Events + } + + if err := ep.sync(ctx, n); err != nil { + return err + } + for { + select { + case <-ctx.Done(): + return nil + case <-tickChan: + log.Printf("periodic sync, ensuring firewall config is up to date...") + case <-eventChan: + log.Printf("config file change detected, ensuring firewall config is up to date...") + case n = <-ep.netmapChan: + shouldResync := ep.shouldResync(n) + if !shouldResync { + continue + } + log.Printf("netmap change detected, ensuring firewall config is up to date...") + } + if err := ep.sync(ctx, n); err != nil { + return fmt.Errorf("error syncing egress service config: %w", err) + } + } +} + +type egressProxyRunOpts struct { + cfgPath string + nfr linuxfw.NetfilterRunner + kc kubeclient.Client + tsClient *local.Client + stateSecret string + netmapChan chan ipn.Notify + podIPv4 string + tailnetAddrs []netip.Prefix +} + +// applyOpts configures egress proxy using the provided options. +func (ep *egressProxy) configure(opts egressProxyRunOpts) { + ep.cfgPath = opts.cfgPath + ep.nfr = opts.nfr + ep.kc = opts.kc + ep.tsClient = opts.tsClient + ep.stateSecret = opts.stateSecret + ep.netmapChan = opts.netmapChan + ep.podIPv4 = opts.podIPv4 + ep.tailnetAddrs = opts.tailnetAddrs + ep.client = &http.Client{} // default HTTP client + sleepDuration := time.Second + if d, err := time.ParseDuration(testSleepDuration); err == nil && d > 0 { + log.Printf("using test sleep duration %v", d) + sleepDuration = d + } + ep.shortSleep = sleepDuration + ep.longSleep = sleepDuration * 10 +} + +// sync triggers an egress proxy config resync. The resync calculates the diff between config and status to determine if +// any firewall rules need to be updated. Currently using status in state Secret as a reference for what is the current +// firewall configuration is good enough because - the status is keyed by the Pod IP - we crash the Pod on errors such +// as failed firewall update +func (ep *egressProxy) sync(ctx context.Context, n ipn.Notify) error { + cfgs, err := ep.getConfigs() + if err != nil { + return fmt.Errorf("error retrieving egress service configs: %w", err) + } + status, err := ep.getStatus(ctx) + if err != nil { + return fmt.Errorf("error retrieving current egress proxy status: %w", err) + } + newStatus, err := ep.syncEgressConfigs(cfgs, status, n) + if err != nil { + return fmt.Errorf("error syncing egress service configs: %w", err) + } + if !servicesStatusIsEqual(newStatus, status) { + if err := ep.setStatus(ctx, newStatus, n); err != nil { + return fmt.Errorf("error setting egress proxy status: %w", err) + } + } + return nil +} + +// addrsHaveChanged returns true if the provided netmap update contains tailnet address change for this proxy node. +// Netmap must not be nil. +func (ep *egressProxy) addrsHaveChanged(n ipn.Notify) bool { + return !reflect.DeepEqual(ep.tailnetAddrs, n.NetMap.SelfNode.Addresses()) +} + +// syncEgressConfigs adds and deletes firewall rules to match the desired +// configuration. It uses the provided status to determine what is currently +// applied and updates the status after a successful sync. +func (ep *egressProxy) syncEgressConfigs(cfgs *egressservices.Configs, status *egressservices.Status, n ipn.Notify) (*egressservices.Status, error) { + if !(wantsServicesConfigured(cfgs) || hasServicesConfigured(status)) { + return nil, nil + } + + // Delete unnecessary services. + if err := ep.deleteUnnecessaryServices(cfgs, status); err != nil { + return nil, fmt.Errorf("error deleting services: %w", err) + + } + newStatus := &egressservices.Status{} + if !wantsServicesConfigured(cfgs) { + return newStatus, nil + } + + // Add new services, update rules for any that have changed. + rulesPerSvcToAdd := make(map[string][]rule, 0) + rulesPerSvcToDelete := make(map[string][]rule, 0) + for svcName, cfg := range *cfgs { + tailnetTargetIPs, err := ep.tailnetTargetIPsForSvc(cfg, n) + if err != nil { + return nil, fmt.Errorf("error determining tailnet target IPs: %w", err) + } + rulesToAdd, rulesToDelete, err := updatesForCfg(svcName, cfg, status, tailnetTargetIPs) + if err != nil { + return nil, fmt.Errorf("error validating service changes: %v", err) + } + log.Printf("syncegressservices: looking at svc %s rulesToAdd %d rulesToDelete %d", svcName, len(rulesToAdd), len(rulesToDelete)) + if len(rulesToAdd) != 0 { + mak.Set(&rulesPerSvcToAdd, svcName, rulesToAdd) + } + if len(rulesToDelete) != 0 { + mak.Set(&rulesPerSvcToDelete, svcName, rulesToDelete) + } + if len(rulesToAdd) != 0 || ep.addrsHaveChanged(n) { + // For each tailnet target, set up SNAT from the local tailnet device address of the matching + // family. + for _, t := range tailnetTargetIPs { + var local netip.Addr + for _, pfx := range n.NetMap.SelfNode.Addresses().All() { + if !pfx.IsSingleIP() { + continue + } + if pfx.Addr().Is4() != t.Is4() { + continue + } + local = pfx.Addr() + break + } + if !local.IsValid() { + return nil, fmt.Errorf("no valid local IP: %v", local) + } + if err := ep.nfr.EnsureSNATForDst(local, t); err != nil { + return nil, fmt.Errorf("error setting up SNAT rule: %w", err) + } + } + } + // Update the status. Status will be written back to the state Secret by the caller. + mak.Set(&newStatus.Services, svcName, &egressservices.ServiceStatus{TailnetTargetIPs: tailnetTargetIPs, TailnetTarget: cfg.TailnetTarget, Ports: cfg.Ports}) + } + + // Actually apply the firewall rules. + if err := ensureRulesAdded(rulesPerSvcToAdd, ep.nfr); err != nil { + return nil, fmt.Errorf("error adding rules: %w", err) + } + if err := ensureRulesDeleted(rulesPerSvcToDelete, ep.nfr); err != nil { + return nil, fmt.Errorf("error deleting rules: %w", err) + } + + return newStatus, nil +} + +// updatesForCfg calculates any rules that need to be added or deleted for an individucal egress service config. +func updatesForCfg(svcName string, cfg egressservices.Config, status *egressservices.Status, tailnetTargetIPs []netip.Addr) ([]rule, []rule, error) { + rulesToAdd := make([]rule, 0) + rulesToDelete := make([]rule, 0) + currentConfig, ok := lookupCurrentConfig(svcName, status) + + // If no rules for service are present yet, add them all. + if !ok { + for _, t := range tailnetTargetIPs { + for ports := range cfg.Ports { + log.Printf("syncegressservices: svc %s adding port %v", svcName, ports) + rulesToAdd = append(rulesToAdd, rule{tailnetPort: ports.TargetPort, containerPort: ports.MatchPort, protocol: ports.Protocol, tailnetIP: t}) + } + } + return rulesToAdd, rulesToDelete, nil + } + + // If there are no backend targets available, delete any currently configured rules. + if len(tailnetTargetIPs) == 0 { + log.Printf("tailnet target for egress service %s does not have any backend addresses, deleting all rules", svcName) + for _, ip := range currentConfig.TailnetTargetIPs { + for ports := range currentConfig.Ports { + rulesToDelete = append(rulesToDelete, rule{tailnetPort: ports.TargetPort, containerPort: ports.MatchPort, protocol: ports.Protocol, tailnetIP: ip}) + } + } + return rulesToAdd, rulesToDelete, nil + } + + // If there are rules present for backend targets that no longer match, delete them. + for _, ip := range currentConfig.TailnetTargetIPs { + var found bool + for _, wantsIP := range tailnetTargetIPs { + if reflect.DeepEqual(ip, wantsIP) { + found = true + break + } + } + if !found { + for ports := range currentConfig.Ports { + rulesToDelete = append(rulesToDelete, rule{tailnetPort: ports.TargetPort, containerPort: ports.MatchPort, protocol: ports.Protocol, tailnetIP: ip}) + } + } + } + + // Sync rules for the currently wanted backend targets. + for _, ip := range tailnetTargetIPs { + + // If the backend target is not yet present in status, add all rules. + var found bool + for _, gotIP := range currentConfig.TailnetTargetIPs { + if reflect.DeepEqual(ip, gotIP) { + found = true + break + } + } + if !found { + for ports := range cfg.Ports { + rulesToAdd = append(rulesToAdd, rule{tailnetPort: ports.TargetPort, containerPort: ports.MatchPort, protocol: ports.Protocol, tailnetIP: ip}) + } + continue + } + + // If the backend target is present in status, check that the + // currently applied rules are up to date. + + // Delete any current portmappings that are no longer present in config. + for port := range currentConfig.Ports { + if _, ok := cfg.Ports[port]; ok { + continue + } + rulesToDelete = append(rulesToDelete, rule{tailnetPort: port.TargetPort, containerPort: port.MatchPort, protocol: port.Protocol, tailnetIP: ip}) + } + + // Add any new portmappings. + for port := range cfg.Ports { + if _, ok := currentConfig.Ports[port]; ok { + continue + } + rulesToAdd = append(rulesToAdd, rule{tailnetPort: port.TargetPort, containerPort: port.MatchPort, protocol: port.Protocol, tailnetIP: ip}) + } + } + return rulesToAdd, rulesToDelete, nil +} + +// deleteUnneccessaryServices ensure that any services found on status, but not +// present in config are deleted. +func (ep *egressProxy) deleteUnnecessaryServices(cfgs *egressservices.Configs, status *egressservices.Status) error { + if !hasServicesConfigured(status) { + return nil + } + if !wantsServicesConfigured(cfgs) { + for svcName, svc := range status.Services { + log.Printf("service %s is no longer required, deleting", svcName) + if err := ensureServiceDeleted(svcName, svc, ep.nfr); err != nil { + return fmt.Errorf("error deleting service %s: %w", svcName, err) + } + } + return nil + } + + for svcName, svc := range status.Services { + if _, ok := (*cfgs)[svcName]; !ok { + log.Printf("service %s is no longer required, deleting", svcName) + if err := ensureServiceDeleted(svcName, svc, ep.nfr); err != nil { + return fmt.Errorf("error deleting service %s: %w", svcName, err) + } + // TODO (irbekrm): also delete the SNAT rule here + } + } + return nil +} + +// getConfigs gets the mounted egress service configuration. +func (ep *egressProxy) getConfigs() (*egressservices.Configs, error) { + svcsCfg := filepath.Join(ep.cfgPath, egressservices.KeyEgressServices) + j, err := os.ReadFile(svcsCfg) + if os.IsNotExist(err) { + return nil, nil + } + if err != nil { + return nil, err + } + if len(j) == 0 || string(j) == "" { + return nil, nil + } + cfg := &egressservices.Configs{} + if err := json.Unmarshal(j, &cfg); err != nil { + return nil, err + } + return cfg, nil +} + +// getStatus gets the current status of the configured firewall. The current +// status is stored in state Secret. Returns nil status if no status that +// applies to the current proxy Pod was found. Uses the Pod IP to determine if a +// status found in the state Secret applies to this proxy Pod. +func (ep *egressProxy) getStatus(ctx context.Context) (*egressservices.Status, error) { + secret, err := ep.kc.GetSecret(ctx, ep.stateSecret) + if err != nil { + return nil, fmt.Errorf("error retrieving state secret: %w", err) + } + status := &egressservices.Status{} + raw, ok := secret.Data[egressservices.KeyEgressServices] + if !ok { + return nil, nil + } + if err := json.Unmarshal([]byte(raw), status); err != nil { + return nil, fmt.Errorf("error unmarshalling previous config: %w", err) + } + if reflect.DeepEqual(status.PodIPv4, ep.podIPv4) { + return status, nil + } + return nil, nil +} + +// setStatus writes egress proxy's currently configured firewall to the state +// Secret and updates proxy's tailnet addresses. +func (ep *egressProxy) setStatus(ctx context.Context, status *egressservices.Status, n ipn.Notify) error { + // Pod IP is used to determine if a stored status applies to THIS proxy Pod. + if status == nil { + status = &egressservices.Status{} + } + status.PodIPv4 = ep.podIPv4 + secret, err := ep.kc.GetSecret(ctx, ep.stateSecret) + if err != nil { + return fmt.Errorf("error retrieving state Secret: %w", err) + } + bs, err := json.Marshal(status) + if err != nil { + return fmt.Errorf("error marshalling service config: %w", err) + } + secret.Data[egressservices.KeyEgressServices] = bs + patch := kubeclient.JSONPatch{ + Op: "replace", + Path: fmt.Sprintf("/data/%s", egressservices.KeyEgressServices), + Value: bs, + } + if err := ep.kc.JSONPatchResource(ctx, ep.stateSecret, kubeclient.TypeSecrets, []kubeclient.JSONPatch{patch}); err != nil { + return fmt.Errorf("error patching state Secret: %w", err) + } + ep.tailnetAddrs = n.NetMap.SelfNode.Addresses().AsSlice() + return nil +} + +// tailnetTargetIPsForSvc returns the tailnet IPs to which traffic for this +// egress service should be proxied. The egress service can be configured by IP +// or by FQDN. If it's configured by IP, just return that. If it's configured by +// FQDN, resolve the FQDN and return the resolved IPs. It checks if the +// netfilter runner supports IPv6 NAT and skips any IPv6 addresses if it +// doesn't. +func (ep *egressProxy) tailnetTargetIPsForSvc(svc egressservices.Config, n ipn.Notify) (addrs []netip.Addr, err error) { + if svc.TailnetTarget.IP != "" { + addr, err := netip.ParseAddr(svc.TailnetTarget.IP) + if err != nil { + return nil, fmt.Errorf("error parsing tailnet target IP: %w", err) + } + if addr.Is6() && !ep.nfr.HasIPV6NAT() { + log.Printf("tailnet target is an IPv6 address, but this host does not support IPv6 in the chosen firewall mode. This will probably not work.") + return addrs, nil + } + return []netip.Addr{addr}, nil + } + + if svc.TailnetTarget.FQDN == "" { + return nil, errors.New("unexpected egress service config- neither tailnet target IP nor FQDN is set") + } + if n.NetMap == nil { + log.Printf("netmap is not available, unable to determine backend addresses for %s", svc.TailnetTarget.FQDN) + return addrs, nil + } + var ( + node tailcfg.NodeView + nodeFound bool + ) + for _, nn := range n.NetMap.Peers { + if equalFQDNs(nn.Name(), svc.TailnetTarget.FQDN) { + node = nn + nodeFound = true + break + } + } + if nodeFound { + for _, addr := range node.Addresses().AsSlice() { + if addr.Addr().Is6() && !ep.nfr.HasIPV6NAT() { + log.Printf("tailnet target %v is an IPv6 address, but this host does not support IPv6 in the chosen firewall mode, skipping.", addr.Addr().String()) + continue + } + addrs = append(addrs, addr.Addr()) + } + // Egress target endpoints configured via FQDN are stored, so + // that we can determine if a netmap update should trigger a + // resync. + mak.Set(&ep.targetFQDNs, svc.TailnetTarget.FQDN, node.Addresses().AsSlice()) + } + return addrs, nil +} + +// shouldResync parses netmap update and returns true if the update contains +// changes for which the egress proxy's firewall should be reconfigured. +func (ep *egressProxy) shouldResync(n ipn.Notify) bool { + if n.NetMap == nil { + return false + } + + // If proxy's tailnet addresses have changed, resync. + if !reflect.DeepEqual(n.NetMap.SelfNode.Addresses().AsSlice(), ep.tailnetAddrs) { + log.Printf("node addresses have changed, trigger egress config resync") + ep.tailnetAddrs = n.NetMap.SelfNode.Addresses().AsSlice() + return true + } + + // If the IPs for any of the egress services configured via FQDN have + // changed, resync. + for fqdn, ips := range ep.targetFQDNs { + for _, nn := range n.NetMap.Peers { + if equalFQDNs(nn.Name(), fqdn) { + if !reflect.DeepEqual(ips, nn.Addresses().AsSlice()) { + log.Printf("backend addresses for egress target %q have changed old IPs %v, new IPs %v trigger egress config resync", nn.Name(), ips, nn.Addresses().AsSlice()) + return true + } + break + } + } + } + return false +} + +// ensureServiceDeleted ensures that any rules for an egress service are removed +// from the firewall configuration. +func ensureServiceDeleted(svcName string, svc *egressservices.ServiceStatus, nfr linuxfw.NetfilterRunner) error { + + // Note that the portmap is needed for iptables based firewall only. + // Nftables group rules for a service in a chain, so there is no need to + // specify individual portmapping based rules. + pms := make([]linuxfw.PortMap, 0) + for pm := range svc.Ports { + pms = append(pms, linuxfw.PortMap{MatchPort: pm.MatchPort, TargetPort: pm.TargetPort, Protocol: pm.Protocol}) + } + + if err := nfr.DeleteSvc(svcName, tailscaleTunInterface, svc.TailnetTargetIPs, pms); err != nil { + return fmt.Errorf("error deleting service %s: %w", svcName, err) + } + return nil +} + +// ensureRulesAdded ensures that all portmapping rules are added to the firewall +// configuration. For any rules that already exist, calling this function is a +// no-op. In case of nftables, a service consists of one or two (one per IP +// family) chains that conain the portmapping rules for the service and the +// chains as needed when this function is called. +func ensureRulesAdded(rulesPerSvc map[string][]rule, nfr linuxfw.NetfilterRunner) error { + for svc, rules := range rulesPerSvc { + for _, rule := range rules { + log.Printf("ensureRulesAdded svc %s tailnetTarget %s container port %d tailnet port %d protocol %s", svc, rule.tailnetIP, rule.containerPort, rule.tailnetPort, rule.protocol) + if err := nfr.EnsurePortMapRuleForSvc(svc, tailscaleTunInterface, rule.tailnetIP, linuxfw.PortMap{MatchPort: rule.containerPort, TargetPort: rule.tailnetPort, Protocol: rule.protocol}); err != nil { + return fmt.Errorf("error ensuring rule: %w", err) + } + } + } + return nil +} + +// ensureRulesDeleted ensures that the given rules are deleted from the firewall +// configuration. For any rules that do not exist, calling this funcion is a +// no-op. +func ensureRulesDeleted(rulesPerSvc map[string][]rule, nfr linuxfw.NetfilterRunner) error { + for svc, rules := range rulesPerSvc { + for _, rule := range rules { + log.Printf("ensureRulesDeleted svc %s tailnetTarget %s container port %d tailnet port %d protocol %s", svc, rule.tailnetIP, rule.containerPort, rule.tailnetPort, rule.protocol) + if err := nfr.DeletePortMapRuleForSvc(svc, tailscaleTunInterface, rule.tailnetIP, linuxfw.PortMap{MatchPort: rule.containerPort, TargetPort: rule.tailnetPort, Protocol: rule.protocol}); err != nil { + return fmt.Errorf("error deleting rule: %w", err) + } + } + } + return nil +} + +func lookupCurrentConfig(svcName string, status *egressservices.Status) (*egressservices.ServiceStatus, bool) { + if status == nil || len(status.Services) == 0 { + return nil, false + } + c, ok := status.Services[svcName] + return c, ok +} + +func equalFQDNs(s, s1 string) bool { + s, _ = strings.CutSuffix(s, ".") + s1, _ = strings.CutSuffix(s1, ".") + return strings.EqualFold(s, s1) +} + +// rule contains configuration for an egress proxy firewall rule. +type rule struct { + containerPort uint16 // port to match incoming traffic + tailnetPort uint16 // tailnet service port + tailnetIP netip.Addr // tailnet service IP + protocol string +} + +func wantsServicesConfigured(cfgs *egressservices.Configs) bool { + return cfgs != nil && len(*cfgs) != 0 +} + +func hasServicesConfigured(status *egressservices.Status) bool { + return status != nil && len(status.Services) != 0 +} + +func servicesStatusIsEqual(st, st1 *egressservices.Status) bool { + if st == nil && st1 == nil { + return true + } + if st == nil || st1 == nil { + return false + } + st.PodIPv4 = "" + st1.PodIPv4 = "" + return reflect.DeepEqual(*st, *st1) +} + +// registerHandlers adds a new handler to the provided ServeMux that can be called as a Kubernetes prestop hook to +// delay shutdown till it's safe to do so. +func (ep *egressProxy) registerHandlers(mux *http.ServeMux) { + mux.Handle(fmt.Sprintf("GET %s", kubetypes.EgessServicesPreshutdownEP), ep) +} + +// ServeHTTP serves /internal-egress-services-preshutdown endpoint, when it receives a request, it periodically polls +// the configured health check endpoint for each egress service till it the health check endpoint no longer hits this +// proxy Pod. It uses the Pod-IPv4 header to verify if health check response is received from this Pod. +func (ep *egressProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + cfgs, err := ep.getConfigs() + if err != nil { + http.Error(w, fmt.Sprintf("error retrieving egress services configs: %v", err), http.StatusInternalServerError) + return + } + if cfgs == nil { + if _, err := w.Write([]byte("safe to terminate")); err != nil { + http.Error(w, fmt.Sprintf("error writing termination status: %v", err), http.StatusInternalServerError) + } + return + } + hp, err := ep.getHEPPings() + if err != nil { + http.Error(w, fmt.Sprintf("error determining the number of times health check endpoint should be pinged: %v", err), http.StatusInternalServerError) + return + } + ep.waitTillSafeToShutdown(r.Context(), cfgs, hp) +} + +// waitTillSafeToShutdown looks up all egress targets configured to be proxied via this instance and, for each target +// whose configuration includes a healthcheck endpoint, pings the endpoint till none of the responses +// are returned by this instance or till the HTTP request times out. In practice, the endpoint will be a Kubernetes Service for whom one of the backends +// would normally be this Pod. When this Pod is being deleted, the operator should have removed it from the Service +// backends and eventually kube proxy routing rules should be updated to no longer route traffic for the Service to this +// Pod. +func (ep *egressProxy) waitTillSafeToShutdown(ctx context.Context, cfgs *egressservices.Configs, hp int) { + if cfgs == nil || len(*cfgs) == 0 { // avoid sleeping if no services are configured + return + } + log.Printf("Ensuring that cluster traffic for egress targets is no longer routed via this Pod...") + wg := syncs.WaitGroup{} + + for s, cfg := range *cfgs { + hep := cfg.HealthCheckEndpoint + if hep == "" { + log.Printf("Tailnet target %q does not have a cluster healthcheck specified, unable to verify if cluster traffic for the target is still routed via this Pod", s) + continue + } + svc := s + wg.Go(func() { + log.Printf("Ensuring that cluster traffic is no longer routed to %q via this Pod...", svc) + for { + if ctx.Err() != nil { // kubelet's HTTP request timeout + log.Printf("Cluster traffic for %s did not stop being routed to this Pod.", svc) + return + } + found, err := lookupPodRoute(ctx, hep, ep.podIPv4, hp, ep.client) + if err != nil { + log.Printf("unable to reach endpoint %q, assuming the routing rules for this Pod have been deleted: %v", hep, err) + break + } + if !found { + log.Printf("service %q is no longer routed through this Pod", svc) + break + } + log.Printf("service %q is still routed through this Pod, waiting...", svc) + time.Sleep(ep.shortSleep) + } + }) + } + wg.Wait() + // The check above really only checked that the routing rules are updated on this node. Sleep for a bit to + // ensure that the routing rules are updated on other nodes. TODO(irbekrm): this may or may not be good enough. + // If it's not good enough, we'd probably want to do something more complex, where the proxies check each other. + log.Printf("Sleeping for %s before shutdown to ensure that kube proxies on all nodes have updated routing configuration", ep.longSleep) + time.Sleep(ep.longSleep) +} + +// lookupPodRoute calls the healthcheck endpoint repeat times and returns true if the endpoint returns with the podIP +// header at least once. +func lookupPodRoute(ctx context.Context, hep, podIP string, repeat int, client httpClient) (bool, error) { + for range repeat { + f, err := lookup(ctx, hep, podIP, client) + if err != nil { + return false, err + } + if f { + return true, nil + } + } + return false, nil +} + +// lookup calls the healthcheck endpoint and returns true if the response contains the podIP header. +func lookup(ctx context.Context, hep, podIP string, client httpClient) (bool, error) { + req, err := http.NewRequestWithContext(ctx, httpm.GET, hep, nil) + if err != nil { + return false, fmt.Errorf("error creating new HTTP request: %v", err) + } + + // Close the TCP connection to ensure that the next request is routed to a different backend. + req.Close = true + + resp, err := client.Do(req) + if err != nil { + log.Printf("Endpoint %q can not be reached: %v, likely because there are no (more) healthy backends", hep, err) + return true, nil + } + defer resp.Body.Close() + gotIP := resp.Header.Get(kubetypes.PodIPv4Header) + return strings.EqualFold(podIP, gotIP), nil +} + +// getHEPPings gets the number of pings that should be sent to a health check endpoint to ensure that each configured +// backend is hit. This assumes that a health check endpoint is a Kubernetes Service and traffic to backend Pods is +// round robin load balanced. +func (ep *egressProxy) getHEPPings() (int, error) { + hepPingsPath := filepath.Join(ep.cfgPath, egressservices.KeyHEPPings) + j, err := os.ReadFile(hepPingsPath) + if os.IsNotExist(err) { + return 0, nil + } + if err != nil { + return -1, err + } + if len(j) == 0 || string(j) == "" { + return 0, nil + } + hp, err := strconv.Atoi(string(j)) + if err != nil { + return -1, fmt.Errorf("error parsing hep pings as int: %v", err) + } + if hp < 0 { + log.Printf("[unexpected] hep pings is negative: %d", hp) + return 0, nil + } + return hp, nil +} diff --git a/cmd/containerboot/services_test.go b/cmd/containerboot/egressservices_test.go similarity index 61% rename from cmd/containerboot/services_test.go rename to cmd/containerboot/egressservices_test.go index 46f6db1cf6d0e..724626b072c2b 100644 --- a/cmd/containerboot/services_test.go +++ b/cmd/containerboot/egressservices_test.go @@ -6,11 +6,18 @@ package main import ( + "context" + "fmt" + "io" + "net/http" "net/netip" "reflect" + "strings" + "sync" "testing" "tailscale.com/kube/egressservices" + "tailscale.com/kube/kubetypes" ) func Test_updatesForSvc(t *testing.T) { @@ -173,3 +180,145 @@ func Test_updatesForSvc(t *testing.T) { }) } } + +// A failure of this test will most likely look like a timeout. +func TestWaitTillSafeToShutdown(t *testing.T) { + podIP := "10.0.0.1" + anotherIP := "10.0.0.2" + + tests := []struct { + name string + // services is a map of service name to the number of calls to make to the healthcheck endpoint before + // returning a response that does NOT contain this Pod's IP in headers. + services map[string]int + replicas int + healthCheckSet bool + }{ + { + name: "no_configs", + }, + { + name: "one_service_immediately_safe_to_shutdown", + services: map[string]int{ + "svc1": 0, + }, + replicas: 2, + healthCheckSet: true, + }, + { + name: "multiple_services_immediately_safe_to_shutdown", + services: map[string]int{ + "svc1": 0, + "svc2": 0, + "svc3": 0, + }, + replicas: 2, + healthCheckSet: true, + }, + { + name: "multiple_services_no_healthcheck_endpoints", + services: map[string]int{ + "svc1": 0, + "svc2": 0, + "svc3": 0, + }, + replicas: 2, + }, + { + name: "one_service_eventually_safe_to_shutdown", + services: map[string]int{ + "svc1": 3, // After 3 calls to health check endpoint, no longer returns this Pod's IP + }, + replicas: 2, + healthCheckSet: true, + }, + { + name: "multiple_services_eventually_safe_to_shutdown", + services: map[string]int{ + "svc1": 1, // After 1 call to health check endpoint, no longer returns this Pod's IP + "svc2": 3, // After 3 calls to health check endpoint, no longer returns this Pod's IP + "svc3": 5, // After 5 calls to the health check endpoint, no longer returns this Pod's IP + }, + replicas: 2, + healthCheckSet: true, + }, + { + name: "multiple_services_eventually_safe_to_shutdown_with_higher_replica_count", + services: map[string]int{ + "svc1": 7, + "svc2": 10, + }, + replicas: 5, + healthCheckSet: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfgs := &egressservices.Configs{} + switches := make(map[string]int) + + for svc, callsToSwitch := range tt.services { + endpoint := fmt.Sprintf("http://%s.local", svc) + if tt.healthCheckSet { + (*cfgs)[svc] = egressservices.Config{ + HealthCheckEndpoint: endpoint, + } + } + switches[endpoint] = callsToSwitch + } + + ep := &egressProxy{ + podIPv4: podIP, + client: &mockHTTPClient{ + podIP: podIP, + anotherIP: anotherIP, + switches: switches, + }, + } + + ep.waitTillSafeToShutdown(context.Background(), cfgs, tt.replicas) + }) + } +} + +// mockHTTPClient is a client that receives an HTTP call for an egress service endpoint and returns a response with an +// IP address in a 'Pod-IPv4' header. It can be configured to return one IP address for N calls, then switch to another +// IP address to simulate a scenario where an IP is eventually no longer a backend for an endpoint. +// TODO(irbekrm): to test this more thoroughly, we should have the client take into account the number of replicas and +// return as if traffic was round robin load balanced across different Pods. +type mockHTTPClient struct { + // podIP - initial IP address to return, that matches the current proxy's IP address. + podIP string + anotherIP string + // after how many calls to an endpoint, the client should start returning 'anotherIP' instead of 'podIP. + switches map[string]int + mu sync.Mutex // protects the following + // calls tracks the number of calls received. + calls map[string]int +} + +func (m *mockHTTPClient) Do(req *http.Request) (*http.Response, error) { + m.mu.Lock() + if m.calls == nil { + m.calls = make(map[string]int) + } + + endpoint := req.URL.String() + m.calls[endpoint]++ + calls := m.calls[endpoint] + m.mu.Unlock() + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader("")), + } + + if calls <= m.switches[endpoint] { + resp.Header.Set(kubetypes.PodIPv4Header, m.podIP) // Pod is still routable + } else { + resp.Header.Set(kubetypes.PodIPv4Header, m.anotherIP) // Pod is no longer routable + } + return resp, nil +} diff --git a/cmd/containerboot/healthz.go b/cmd/containerboot/healthz.go index fb7fccd968816..d6a64a37c4ac5 100644 --- a/cmd/containerboot/healthz.go +++ b/cmd/containerboot/healthz.go @@ -6,10 +6,12 @@ package main import ( + "fmt" "log" - "net" "net/http" "sync" + + "tailscale.com/kube/kubetypes" ) // healthz is a simple health check server, if enabled it returns 200 OK if @@ -18,34 +20,38 @@ import ( type healthz struct { sync.Mutex hasAddrs bool + podIPv4 string } func (h *healthz) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.Lock() defer h.Unlock() + if h.hasAddrs { - w.Write([]byte("ok")) + w.Header().Add(kubetypes.PodIPv4Header, h.podIPv4) + if _, err := w.Write([]byte("ok")); err != nil { + http.Error(w, fmt.Sprintf("error writing status: %v", err), http.StatusInternalServerError) + } } else { - http.Error(w, "node currently has no tailscale IPs", http.StatusInternalServerError) + http.Error(w, "node currently has no tailscale IPs", http.StatusServiceUnavailable) } } -// runHealthz runs a simple HTTP health endpoint on /healthz, listening on the -// provided address. A containerized tailscale instance is considered healthy if -// it has at least one tailnet IP address. -func runHealthz(addr string, h *healthz) { - lis, err := net.Listen("tcp", addr) - if err != nil { - log.Fatalf("error listening on the provided health endpoint address %q: %v", addr, err) +func (h *healthz) update(healthy bool) { + h.Lock() + defer h.Unlock() + + if h.hasAddrs != healthy { + log.Println("Setting healthy", healthy) } - mux := http.NewServeMux() - mux.Handle("/healthz", h) - log.Printf("Running healthcheck endpoint at %s/healthz", addr) - hs := &http.Server{Handler: mux} - - go func() { - if err := hs.Serve(lis); err != nil { - log.Fatalf("failed running health endpoint: %v", err) - } - }() + h.hasAddrs = healthy +} + +// registerHealthHandlers registers a simple health handler at /healthz. +// A containerized tailscale instance is considered healthy if +// it has at least one tailnet IP address. +func registerHealthHandlers(mux *http.ServeMux, podIPv4 string) *healthz { + h := &healthz{podIPv4: podIPv4} + mux.Handle("GET /healthz", h) + return h } diff --git a/cmd/containerboot/ingressservices.go b/cmd/containerboot/ingressservices.go new file mode 100644 index 0000000000000..1a2da95675f4e --- /dev/null +++ b/cmd/containerboot/ingressservices.go @@ -0,0 +1,331 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package main + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/netip" + "os" + "path/filepath" + "reflect" + "time" + + "github.com/fsnotify/fsnotify" + "tailscale.com/kube/ingressservices" + "tailscale.com/kube/kubeclient" + "tailscale.com/util/linuxfw" + "tailscale.com/util/mak" +) + +// ingressProxy corresponds to a Kubernetes Operator's network layer ingress +// proxy. It configures firewall rules (iptables or nftables) to proxy tailnet +// traffic to Kubernetes Services. Currently this is only used for network +// layer proxies in HA mode. +type ingressProxy struct { + cfgPath string // path to ingress configfile. + + // nfr is the netfilter runner used to configure firewall rules. + // This is going to be either iptables or nftables based runner. + // Never nil. + nfr linuxfw.NetfilterRunner + + kc kubeclient.Client // never nil + stateSecret string // Secret that holds Tailscale state + + // Pod's IP addresses are used as an identifier of this particular Pod. + podIPv4 string // empty if Pod does not have IPv4 address + podIPv6 string // empty if Pod does not have IPv6 address +} + +// run starts the ingress proxy and ensures that firewall rules are set on start +// and refreshed as ingress config changes. +func (p *ingressProxy) run(ctx context.Context, opts ingressProxyOpts) error { + log.Printf("starting ingress proxy...") + p.configure(opts) + var tickChan <-chan time.Time + var eventChan <-chan fsnotify.Event + if w, err := fsnotify.NewWatcher(); err != nil { + log.Printf("failed to create fsnotify watcher, timer-only mode: %v", err) + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + tickChan = ticker.C + } else { + defer w.Close() + dir := filepath.Dir(p.cfgPath) + if err := w.Add(dir); err != nil { + return fmt.Errorf("failed to add fsnotify watch for %v: %w", dir, err) + } + eventChan = w.Events + } + + if err := p.sync(ctx); err != nil { + return err + } + for { + select { + case <-ctx.Done(): + return nil + case <-tickChan: + log.Printf("periodic sync, ensuring firewall config is up to date...") + case <-eventChan: + log.Printf("config file change detected, ensuring firewall config is up to date...") + } + if err := p.sync(ctx); err != nil { + return fmt.Errorf("error syncing ingress service config: %w", err) + } + } +} + +// sync reconciles proxy's firewall rules (iptables or nftables) on ingress config changes: +// - ensures that new firewall rules are added +// - ensures that old firewall rules are deleted +// - updates ingress proxy's status in the state Secret +func (p *ingressProxy) sync(ctx context.Context) error { + // 1. Get the desired firewall configuration + cfgs, err := p.getConfigs() + if err != nil { + return fmt.Errorf("ingress proxy: error retrieving configs: %w", err) + } + + // 2. Get the recorded firewall status + status, err := p.getStatus(ctx) + if err != nil { + return fmt.Errorf("ingress proxy: error retrieving current status: %w", err) + } + + // 3. Ensure that firewall configuration is up to date + if err := p.syncIngressConfigs(cfgs, status); err != nil { + return fmt.Errorf("ingress proxy: error syncing configs: %w", err) + } + var existingConfigs *ingressservices.Configs + if status != nil { + existingConfigs = &status.Configs + } + + // 4. Update the recorded firewall status + if !(ingressServicesStatusIsEqual(cfgs, existingConfigs) && p.isCurrentStatus(status)) { + if err := p.recordStatus(ctx, cfgs); err != nil { + return fmt.Errorf("ingress proxy: error setting status: %w", err) + } + } + return nil +} + +// getConfigs returns the desired ingress service configuration from the mounted +// configfile. +func (p *ingressProxy) getConfigs() (*ingressservices.Configs, error) { + j, err := os.ReadFile(p.cfgPath) + if os.IsNotExist(err) { + return nil, nil + } + if err != nil { + return nil, err + } + if len(j) == 0 || string(j) == "" { + return nil, nil + } + cfg := &ingressservices.Configs{} + if err := json.Unmarshal(j, &cfg); err != nil { + return nil, err + } + return cfg, nil +} + +// getStatus gets the recorded status of the configured firewall. The status is +// stored in the proxy's state Secret. Note that the recorded status might not +// be the current status of the firewall if it belongs to a previous Pod- we +// take that into account further down the line when determining if the desired +// rules are actually present. +func (p *ingressProxy) getStatus(ctx context.Context) (*ingressservices.Status, error) { + secret, err := p.kc.GetSecret(ctx, p.stateSecret) + if err != nil { + return nil, fmt.Errorf("error retrieving state Secret: %w", err) + } + status := &ingressservices.Status{} + raw, ok := secret.Data[ingressservices.IngressConfigKey] + if !ok { + return nil, nil + } + if err := json.Unmarshal([]byte(raw), status); err != nil { + return nil, fmt.Errorf("error unmarshalling previous config: %w", err) + } + return status, nil +} + +// syncIngressConfigs takes the desired firewall configuration and the recorded +// status and ensures that any missing rules are added and no longer needed +// rules are deleted. +func (p *ingressProxy) syncIngressConfigs(cfgs *ingressservices.Configs, status *ingressservices.Status) error { + rulesToAdd := p.getRulesToAdd(cfgs, status) + rulesToDelete := p.getRulesToDelete(cfgs, status) + + if err := ensureIngressRulesDeleted(rulesToDelete, p.nfr); err != nil { + return fmt.Errorf("error deleting ingress rules: %w", err) + } + if err := ensureIngressRulesAdded(rulesToAdd, p.nfr); err != nil { + return fmt.Errorf("error adding ingress rules: %w", err) + } + return nil +} + +// recordStatus writes the configured firewall status to the proxy's state +// Secret. This allows the Kubernetes Operator to determine whether this proxy +// Pod has setup firewall rules to route traffic for an ingress service. +func (p *ingressProxy) recordStatus(ctx context.Context, newCfg *ingressservices.Configs) error { + status := &ingressservices.Status{} + if newCfg != nil { + status.Configs = *newCfg + } + // Pod IPs are used to determine if recorded status applies to THIS proxy Pod. + status.PodIPv4 = p.podIPv4 + status.PodIPv6 = p.podIPv6 + secret, err := p.kc.GetSecret(ctx, p.stateSecret) + if err != nil { + return fmt.Errorf("error retrieving state Secret: %w", err) + } + bs, err := json.Marshal(status) + if err != nil { + return fmt.Errorf("error marshalling status: %w", err) + } + secret.Data[ingressservices.IngressConfigKey] = bs + patch := kubeclient.JSONPatch{ + Op: "replace", + Path: fmt.Sprintf("/data/%s", ingressservices.IngressConfigKey), + Value: bs, + } + if err := p.kc.JSONPatchResource(ctx, p.stateSecret, kubeclient.TypeSecrets, []kubeclient.JSONPatch{patch}); err != nil { + return fmt.Errorf("error patching state Secret: %w", err) + } + return nil +} + +// getRulesToAdd takes the desired firewall configuration and the recorded +// firewall status and returns a map of missing Tailscale Services and rules. +func (p *ingressProxy) getRulesToAdd(cfgs *ingressservices.Configs, status *ingressservices.Status) map[string]ingressservices.Config { + if cfgs == nil { + return nil + } + var rulesToAdd map[string]ingressservices.Config + for tsSvc, wantsCfg := range *cfgs { + if status == nil || !p.isCurrentStatus(status) { + mak.Set(&rulesToAdd, tsSvc, wantsCfg) + continue + } + gotCfg := status.Configs.GetConfig(tsSvc) + if gotCfg == nil || !reflect.DeepEqual(wantsCfg, *gotCfg) { + mak.Set(&rulesToAdd, tsSvc, wantsCfg) + } + } + return rulesToAdd +} + +// getRulesToDelete takes the desired firewall configuration and the recorded +// status and returns a map of Tailscale Services and rules that need to be deleted. +func (p *ingressProxy) getRulesToDelete(cfgs *ingressservices.Configs, status *ingressservices.Status) map[string]ingressservices.Config { + if status == nil || !p.isCurrentStatus(status) { + return nil + } + var rulesToDelete map[string]ingressservices.Config + for tsSvc, gotCfg := range status.Configs { + if cfgs == nil { + mak.Set(&rulesToDelete, tsSvc, gotCfg) + continue + } + wantsCfg := cfgs.GetConfig(tsSvc) + if wantsCfg != nil && reflect.DeepEqual(*wantsCfg, gotCfg) { + continue + } + mak.Set(&rulesToDelete, tsSvc, gotCfg) + } + return rulesToDelete +} + +// ensureIngressRulesAdded takes a map of Tailscale Services and rules and ensures that the firewall rules are added. +func ensureIngressRulesAdded(cfgs map[string]ingressservices.Config, nfr linuxfw.NetfilterRunner) error { + for serviceName, cfg := range cfgs { + if cfg.IPv4Mapping != nil { + if err := addDNATRuleForSvc(nfr, serviceName, cfg.IPv4Mapping.TailscaleServiceIP, cfg.IPv4Mapping.ClusterIP); err != nil { + return fmt.Errorf("error adding ingress rule for %s: %w", serviceName, err) + } + } + if cfg.IPv6Mapping != nil { + if err := addDNATRuleForSvc(nfr, serviceName, cfg.IPv6Mapping.TailscaleServiceIP, cfg.IPv6Mapping.ClusterIP); err != nil { + return fmt.Errorf("error adding ingress rule for %s: %w", serviceName, err) + } + } + } + return nil +} + +func addDNATRuleForSvc(nfr linuxfw.NetfilterRunner, serviceName string, tsIP, clusterIP netip.Addr) error { + log.Printf("adding DNAT rule for Tailscale Service %s with IP %s to Kubernetes Service IP %s", serviceName, tsIP, clusterIP) + return nfr.EnsureDNATRuleForSvc(serviceName, tsIP, clusterIP) +} + +// ensureIngressRulesDeleted takes a map of Tailscale Services and rules and ensures that the firewall rules are deleted. +func ensureIngressRulesDeleted(cfgs map[string]ingressservices.Config, nfr linuxfw.NetfilterRunner) error { + for serviceName, cfg := range cfgs { + if cfg.IPv4Mapping != nil { + if err := deleteDNATRuleForSvc(nfr, serviceName, cfg.IPv4Mapping.TailscaleServiceIP, cfg.IPv4Mapping.ClusterIP); err != nil { + return fmt.Errorf("error deleting ingress rule for %s: %w", serviceName, err) + } + } + if cfg.IPv6Mapping != nil { + if err := deleteDNATRuleForSvc(nfr, serviceName, cfg.IPv6Mapping.TailscaleServiceIP, cfg.IPv6Mapping.ClusterIP); err != nil { + return fmt.Errorf("error deleting ingress rule for %s: %w", serviceName, err) + } + } + } + return nil +} + +func deleteDNATRuleForSvc(nfr linuxfw.NetfilterRunner, serviceName string, tsIP, clusterIP netip.Addr) error { + log.Printf("deleting DNAT rule for Tailscale Service %s with IP %s to Kubernetes Service IP %s", serviceName, tsIP, clusterIP) + return nfr.DeleteDNATRuleForSvc(serviceName, tsIP, clusterIP) +} + +// isCurrentStatus returns true if the status of an ingress proxy as read from +// the proxy's state Secret is the status of the current proxy Pod. We use +// Pod's IP addresses to determine that the status is for this Pod. +func (p *ingressProxy) isCurrentStatus(status *ingressservices.Status) bool { + if status == nil { + return true + } + return status.PodIPv4 == p.podIPv4 && status.PodIPv6 == p.podIPv6 +} + +type ingressProxyOpts struct { + cfgPath string + nfr linuxfw.NetfilterRunner // never nil + kc kubeclient.Client // never nil + stateSecret string + podIPv4 string + podIPv6 string +} + +// configure sets the ingress proxy's configuration. It is called once on start +// so we don't care about concurrent access to fields. +func (p *ingressProxy) configure(opts ingressProxyOpts) { + p.cfgPath = opts.cfgPath + p.nfr = opts.nfr + p.kc = opts.kc + p.stateSecret = opts.stateSecret + p.podIPv4 = opts.podIPv4 + p.podIPv6 = opts.podIPv6 +} + +func ingressServicesStatusIsEqual(st, st1 *ingressservices.Configs) bool { + if st == nil && st1 == nil { + return true + } + if st == nil || st1 == nil { + return false + } + return reflect.DeepEqual(*st, *st1) +} diff --git a/cmd/containerboot/ingressservices_test.go b/cmd/containerboot/ingressservices_test.go new file mode 100644 index 0000000000000..228bbb159f463 --- /dev/null +++ b/cmd/containerboot/ingressservices_test.go @@ -0,0 +1,223 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package main + +import ( + "net/netip" + "testing" + + "tailscale.com/kube/ingressservices" + "tailscale.com/util/linuxfw" +) + +func TestSyncIngressConfigs(t *testing.T) { + tests := []struct { + name string + currentConfigs *ingressservices.Configs + currentStatus *ingressservices.Status + wantServices map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + } + }{ + { + name: "add_new_rules_when_no_existing_config", + currentConfigs: &ingressservices.Configs{ + "svc:foo": makeServiceConfig("100.64.0.1", "10.0.0.1", "", ""), + }, + currentStatus: nil, + wantServices: map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{ + "svc:foo": makeWantService("100.64.0.1", "10.0.0.1"), + }, + }, + { + name: "add_multiple_services", + currentConfigs: &ingressservices.Configs{ + "svc:foo": makeServiceConfig("100.64.0.1", "10.0.0.1", "", ""), + "svc:bar": makeServiceConfig("100.64.0.2", "10.0.0.2", "", ""), + "svc:baz": makeServiceConfig("100.64.0.3", "10.0.0.3", "", ""), + }, + currentStatus: nil, + wantServices: map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{ + "svc:foo": makeWantService("100.64.0.1", "10.0.0.1"), + "svc:bar": makeWantService("100.64.0.2", "10.0.0.2"), + "svc:baz": makeWantService("100.64.0.3", "10.0.0.3"), + }, + }, + { + name: "add_both_ipv4_and_ipv6_rules", + currentConfigs: &ingressservices.Configs{ + "svc:foo": makeServiceConfig("100.64.0.1", "10.0.0.1", "2001:db8::1", "2001:db8::2"), + }, + currentStatus: nil, + wantServices: map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{ + "svc:foo": makeWantService("2001:db8::1", "2001:db8::2"), + }, + }, + { + name: "add_ipv6_only_rules", + currentConfigs: &ingressservices.Configs{ + "svc:ipv6": makeServiceConfig("", "", "2001:db8::10", "2001:db8::20"), + }, + currentStatus: nil, + wantServices: map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{ + "svc:ipv6": makeWantService("2001:db8::10", "2001:db8::20"), + }, + }, + { + name: "delete_all_rules_when_config_removed", + currentConfigs: nil, + currentStatus: &ingressservices.Status{ + Configs: ingressservices.Configs{ + "svc:foo": makeServiceConfig("100.64.0.1", "10.0.0.1", "", ""), + "svc:bar": makeServiceConfig("100.64.0.2", "10.0.0.2", "", ""), + }, + PodIPv4: "10.0.0.2", // Current pod IPv4 + PodIPv6: "2001:db8::2", // Current pod IPv6 + }, + wantServices: map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{}, + }, + { + name: "add_remove_modify", + currentConfigs: &ingressservices.Configs{ + "svc:foo": makeServiceConfig("100.64.0.1", "10.0.0.2", "", ""), // Changed cluster IP + "svc:new": makeServiceConfig("100.64.0.4", "10.0.0.4", "", ""), + }, + currentStatus: &ingressservices.Status{ + Configs: ingressservices.Configs{ + "svc:foo": makeServiceConfig("100.64.0.1", "10.0.0.1", "", ""), + "svc:bar": makeServiceConfig("100.64.0.2", "10.0.0.2", "", ""), + "svc:baz": makeServiceConfig("100.64.0.3", "10.0.0.3", "", ""), + }, + PodIPv4: "10.0.0.2", // Current pod IPv4 + PodIPv6: "2001:db8::2", // Current pod IPv6 + }, + wantServices: map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{ + "svc:foo": makeWantService("100.64.0.1", "10.0.0.2"), + "svc:new": makeWantService("100.64.0.4", "10.0.0.4"), + }, + }, + { + name: "update_with_outdated_status", + currentConfigs: &ingressservices.Configs{ + "svc:web": makeServiceConfig("100.64.0.10", "10.0.0.10", "", ""), + "svc:web-ipv6": { + IPv6Mapping: &ingressservices.Mapping{ + TailscaleServiceIP: netip.MustParseAddr("2001:db8::10"), + ClusterIP: netip.MustParseAddr("2001:db8::20"), + }, + }, + "svc:api": makeServiceConfig("100.64.0.20", "10.0.0.20", "", ""), + }, + currentStatus: &ingressservices.Status{ + Configs: ingressservices.Configs{ + "svc:web": makeServiceConfig("100.64.0.10", "10.0.0.10", "", ""), + "svc:web-ipv6": { + IPv6Mapping: &ingressservices.Mapping{ + TailscaleServiceIP: netip.MustParseAddr("2001:db8::10"), + ClusterIP: netip.MustParseAddr("2001:db8::20"), + }, + }, + "svc:old": makeServiceConfig("100.64.0.30", "10.0.0.30", "", ""), + }, + PodIPv4: "10.0.0.1", // Outdated pod IP + PodIPv6: "2001:db8::1", // Outdated pod IP + }, + wantServices: map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{ + "svc:web": makeWantService("100.64.0.10", "10.0.0.10"), + "svc:web-ipv6": makeWantService("2001:db8::10", "2001:db8::20"), + "svc:api": makeWantService("100.64.0.20", "10.0.0.20"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var nfr linuxfw.NetfilterRunner = linuxfw.NewFakeNetfilterRunner() + + ep := &ingressProxy{ + nfr: nfr, + podIPv4: "10.0.0.2", // Current pod IPv4 + podIPv6: "2001:db8::2", // Current pod IPv6 + } + + err := ep.syncIngressConfigs(tt.currentConfigs, tt.currentStatus) + if err != nil { + t.Fatalf("syncIngressConfigs failed: %v", err) + } + + fake := nfr.(*linuxfw.FakeNetfilterRunner) + gotServices := fake.GetServiceState() + if len(gotServices) != len(tt.wantServices) { + t.Errorf("got %d services, want %d", len(gotServices), len(tt.wantServices)) + } + for svc, want := range tt.wantServices { + got, ok := gotServices[svc] + if !ok { + t.Errorf("service %s not found", svc) + continue + } + if got.TailscaleServiceIP != want.TailscaleServiceIP { + t.Errorf("service %s: got TailscaleServiceIP %v, want %v", svc, got.TailscaleServiceIP, want.TailscaleServiceIP) + } + if got.ClusterIP != want.ClusterIP { + t.Errorf("service %s: got ClusterIP %v, want %v", svc, got.ClusterIP, want.ClusterIP) + } + } + }) + } +} + +func makeServiceConfig(tsIP, clusterIP string, tsIP6, clusterIP6 string) ingressservices.Config { + cfg := ingressservices.Config{} + if tsIP != "" && clusterIP != "" { + cfg.IPv4Mapping = &ingressservices.Mapping{ + TailscaleServiceIP: netip.MustParseAddr(tsIP), + ClusterIP: netip.MustParseAddr(clusterIP), + } + } + if tsIP6 != "" && clusterIP6 != "" { + cfg.IPv6Mapping = &ingressservices.Mapping{ + TailscaleServiceIP: netip.MustParseAddr(tsIP6), + ClusterIP: netip.MustParseAddr(clusterIP6), + } + } + return cfg +} + +func makeWantService(tsIP, clusterIP string) struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr +} { + return struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{ + TailscaleServiceIP: netip.MustParseAddr(tsIP), + ClusterIP: netip.MustParseAddr(clusterIP), + } +} diff --git a/cmd/containerboot/kube.go b/cmd/containerboot/kube.go index 908cc01efc25a..0a2dfa1bf342f 100644 --- a/cmd/containerboot/kube.go +++ b/cmd/containerboot/kube.go @@ -8,31 +8,64 @@ package main import ( "context" "encoding/json" + "errors" "fmt" "log" "net/http" "net/netip" "os" + "strings" + "time" + "tailscale.com/ipn" "tailscale.com/kube/kubeapi" "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" + "tailscale.com/logtail/backoff" "tailscale.com/tailcfg" + "tailscale.com/types/logger" ) -// storeDeviceID writes deviceID to 'device_id' data field of the named -// Kubernetes Secret. -func storeDeviceID(ctx context.Context, secretName string, deviceID tailcfg.StableNodeID) error { +// kubeClient is a wrapper around Tailscale's internal kube client that knows how to talk to the kube API server. We use +// this rather than any of the upstream Kubernetes client libaries to avoid extra imports. +type kubeClient struct { + kubeclient.Client + stateSecret string + canPatch bool // whether the client has permissions to patch Kubernetes Secrets +} + +func newKubeClient(root string, stateSecret string) (*kubeClient, error) { + if root != "/" { + // If we are running in a test, we need to set the root path to the fake + // service account directory. + kubeclient.SetRootPathForTesting(root) + } + var err error + kc, err := kubeclient.New("tailscale-container") + if err != nil { + return nil, fmt.Errorf("Error creating kube client: %w", err) + } + if (root != "/") || os.Getenv("TS_KUBERNETES_READ_API_SERVER_ADDRESS_FROM_ENV") == "true" { + // Derive the API server address from the environment variables + // Used to set http server in tests, or optionally enabled by flag + kc.SetURL(fmt.Sprintf("https://%s:%s", os.Getenv("KUBERNETES_SERVICE_HOST"), os.Getenv("KUBERNETES_SERVICE_PORT_HTTPS"))) + } + return &kubeClient{Client: kc, stateSecret: stateSecret}, nil +} + +// storeDeviceID writes deviceID to 'device_id' data field of the client's state Secret. +func (kc *kubeClient) storeDeviceID(ctx context.Context, deviceID tailcfg.StableNodeID) error { s := &kubeapi.Secret{ Data: map[string][]byte{ - "device_id": []byte(deviceID), + kubetypes.KeyDeviceID: []byte(deviceID), }, } - return kc.StrategicMergePatchSecret(ctx, secretName, s, "tailscale-container") + return kc.StrategicMergePatchSecret(ctx, kc.stateSecret, s, "tailscale-container") } -// storeDeviceEndpoints writes device's tailnet IPs and MagicDNS name to fields -// 'device_ips', 'device_fqdn' of the named Kubernetes Secret. -func storeDeviceEndpoints(ctx context.Context, secretName string, fqdn string, addresses []netip.Prefix) error { +// storeDeviceEndpoints writes device's tailnet IPs and MagicDNS name to fields 'device_ips', 'device_fqdn' of client's +// state Secret. +func (kc *kubeClient) storeDeviceEndpoints(ctx context.Context, fqdn string, addresses []netip.Prefix) error { var ips []string for _, addr := range addresses { ips = append(ips, addr.Addr().String()) @@ -44,16 +77,28 @@ func storeDeviceEndpoints(ctx context.Context, secretName string, fqdn string, a s := &kubeapi.Secret{ Data: map[string][]byte{ - "device_fqdn": []byte(fqdn), - "device_ips": deviceIPs, + kubetypes.KeyDeviceFQDN: []byte(fqdn), + kubetypes.KeyDeviceIPs: deviceIPs, }, } - return kc.StrategicMergePatchSecret(ctx, secretName, s, "tailscale-container") + return kc.StrategicMergePatchSecret(ctx, kc.stateSecret, s, "tailscale-container") +} + +// storeHTTPSEndpoint writes an HTTPS endpoint exposed by this device via 'tailscale serve' to the client's state +// Secret. In practice this will be the same value that gets written to 'device_fqdn', but this should only be called +// when the serve config has been successfully set up. +func (kc *kubeClient) storeHTTPSEndpoint(ctx context.Context, ep string) error { + s := &kubeapi.Secret{ + Data: map[string][]byte{ + kubetypes.KeyHTTPSEndpoint: []byte(ep), + }, + } + return kc.StrategicMergePatchSecret(ctx, kc.stateSecret, s, "tailscale-container") } // deleteAuthKey deletes the 'authkey' field of the given kube // secret. No-op if there is no authkey in the secret. -func deleteAuthKey(ctx context.Context, secretName string) error { +func (kc *kubeClient) deleteAuthKey(ctx context.Context) error { // m is a JSON Patch data structure, see https://jsonpatch.com/ or RFC 6902. m := []kubeclient.JSONPatch{ { @@ -61,7 +106,7 @@ func deleteAuthKey(ctx context.Context, secretName string) error { Path: "/data/authkey", }, } - if err := kc.JSONPatchSecret(ctx, secretName, m); err != nil { + if err := kc.JSONPatchResource(ctx, kc.stateSecret, kubeclient.TypeSecrets, m); err != nil { if s, ok := err.(*kubeapi.Status); ok && s.Code == http.StatusUnprocessableEntity { // This is kubernetes-ese for "the field you asked to // delete already doesn't exist", aka no-op. @@ -72,22 +117,78 @@ func deleteAuthKey(ctx context.Context, secretName string) error { return nil } -var kc kubeclient.Client - -func initKubeClient(root string) { - if root != "/" { - // If we are running in a test, we need to set the root path to the fake - // service account directory. - kubeclient.SetRootPathForTesting(root) +// storeCapVerUID stores the current capability version of tailscale and, if provided, UID of the Pod in the tailscale +// state Secret. +// These two fields are used by the Kubernetes Operator to observe the current capability version of tailscaled running in this container. +func (kc *kubeClient) storeCapVerUID(ctx context.Context, podUID string) error { + capVerS := fmt.Sprintf("%d", tailcfg.CurrentCapabilityVersion) + d := map[string][]byte{ + kubetypes.KeyCapVer: []byte(capVerS), } - var err error - kc, err = kubeclient.New() - if err != nil { - log.Fatalf("Error creating kube client: %v", err) + if podUID != "" { + d[kubetypes.KeyPodUID] = []byte(podUID) } - if (root != "/") || os.Getenv("TS_KUBERNETES_READ_API_SERVER_ADDRESS_FROM_ENV") == "true" { - // Derive the API server address from the environment variables - // Used to set http server in tests, or optionally enabled by flag - kc.SetURL(fmt.Sprintf("https://%s:%s", os.Getenv("KUBERNETES_SERVICE_HOST"), os.Getenv("KUBERNETES_SERVICE_PORT_HTTPS"))) + s := &kubeapi.Secret{ + Data: d, } + return kc.StrategicMergePatchSecret(ctx, kc.stateSecret, s, "tailscale-container") +} + +// waitForConsistentState waits for tailscaled to finish writing state if it +// looks like it's started. It is designed to reduce the likelihood that +// tailscaled gets shut down in the window between authenticating to control +// and finishing writing state. However, it's not bullet proof because we can't +// atomically authenticate and write state. +func (kc *kubeClient) waitForConsistentState(ctx context.Context) error { + var logged bool + + bo := backoff.NewBackoff("", logger.Discard, 2*time.Second) + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + secret, err := kc.GetSecret(ctx, kc.stateSecret) + if ctx.Err() != nil || kubeclient.IsNotFoundErr(err) { + return nil + } + if err != nil { + return fmt.Errorf("getting Secret %q: %v", kc.stateSecret, err) + } + + if hasConsistentState(secret.Data) { + return nil + } + + if !logged { + log.Printf("Waiting for tailscaled to finish writing state to Secret %q", kc.stateSecret) + logged = true + } + bo.BackOff(ctx, errors.New("")) // Fake error to trigger actual sleep. + } +} + +// hasConsistentState returns true is there is either no state or the full set +// of expected keys are present. +func hasConsistentState(d map[string][]byte) bool { + var ( + _, hasCurrent = d[string(ipn.CurrentProfileStateKey)] + _, hasKnown = d[string(ipn.KnownProfilesStateKey)] + _, hasMachine = d[string(ipn.MachineKeyStateKey)] + hasProfile bool + ) + + for k := range d { + if strings.HasPrefix(k, "profile-") { + if hasProfile { + return false // We only expect one profile. + } + hasProfile = true + } + } + + // Approximate check, we don't want to reimplement all of profileManager. + return (hasCurrent && hasKnown && hasMachine && hasProfile) || + (!hasCurrent && !hasKnown && !hasMachine && !hasProfile) } diff --git a/cmd/containerboot/kube_test.go b/cmd/containerboot/kube_test.go index 1a5730548838f..413971bc6df23 100644 --- a/cmd/containerboot/kube_test.go +++ b/cmd/containerboot/kube_test.go @@ -9,8 +9,10 @@ import ( "context" "errors" "testing" + "time" "github.com/google/go-cmp/cmp" + "tailscale.com/ipn" "tailscale.com/kube/kubeapi" "tailscale.com/kube/kubeclient" ) @@ -21,7 +23,7 @@ func TestSetupKube(t *testing.T) { cfg *settings wantErr bool wantCfg *settings - kc kubeclient.Client + kc *kubeClient }{ { name: "TS_AUTHKEY set, state Secret exists", @@ -29,14 +31,14 @@ func TestSetupKube(t *testing.T) { AuthKey: "foo", KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return nil, nil }, - }, + }}, wantCfg: &settings{ AuthKey: "foo", KubeSecret: "foo", @@ -48,14 +50,14 @@ func TestSetupKube(t *testing.T) { AuthKey: "foo", KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, true, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return nil, &kubeapi.Status{Code: 404} }, - }, + }}, wantCfg: &settings{ AuthKey: "foo", KubeSecret: "foo", @@ -67,14 +69,14 @@ func TestSetupKube(t *testing.T) { AuthKey: "foo", KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return nil, &kubeapi.Status{Code: 404} }, - }, + }}, wantCfg: &settings{ AuthKey: "foo", KubeSecret: "foo", @@ -87,14 +89,14 @@ func TestSetupKube(t *testing.T) { AuthKey: "foo", KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return nil, &kubeapi.Status{Code: 403} }, - }, + }}, wantCfg: &settings{ AuthKey: "foo", KubeSecret: "foo", @@ -111,11 +113,11 @@ func TestSetupKube(t *testing.T) { AuthKey: "foo", KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, errors.New("broken") }, - }, + }}, wantErr: true, }, { @@ -127,14 +129,14 @@ func TestSetupKube(t *testing.T) { wantCfg: &settings{ KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, true, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return nil, &kubeapi.Status{Code: 404} }, - }, + }}, }, { // Interactive login using URL in Pod logs @@ -145,28 +147,28 @@ func TestSetupKube(t *testing.T) { wantCfg: &settings{ KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return &kubeapi.Secret{}, nil }, - }, + }}, }, { name: "TS_AUTHKEY not set, state Secret contains auth key, we do not have RBAC to patch it", cfg: &settings{ KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return false, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return &kubeapi.Secret{Data: map[string][]byte{"authkey": []byte("foo")}}, nil }, - }, + }}, wantCfg: &settings{ KubeSecret: "foo", }, @@ -177,14 +179,14 @@ func TestSetupKube(t *testing.T) { cfg: &settings{ KubeSecret: "foo", }, - kc: &kubeclient.FakeClient{ + kc: &kubeClient{stateSecret: "foo", Client: &kubeclient.FakeClient{ CheckSecretPermissionsImpl: func(context.Context, string) (bool, bool, error) { return true, false, nil }, GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { return &kubeapi.Secret{Data: map[string][]byte{"authkey": []byte("foo")}}, nil }, - }, + }}, wantCfg: &settings{ KubeSecret: "foo", AuthKey: "foo", @@ -194,9 +196,9 @@ func TestSetupKube(t *testing.T) { } for _, tt := range tests { - kc = tt.kc + kc := tt.kc t.Run(tt.name, func(t *testing.T) { - if err := tt.cfg.setupKube(context.Background()); (err != nil) != tt.wantErr { + if err := tt.cfg.setupKube(context.Background(), kc); (err != nil) != tt.wantErr { t.Errorf("settings.setupKube() error = %v, wantErr %v", err, tt.wantErr) } if diff := cmp.Diff(*tt.cfg, *tt.wantCfg); diff != "" { @@ -205,3 +207,34 @@ func TestSetupKube(t *testing.T) { }) } } + +func TestWaitForConsistentState(t *testing.T) { + data := map[string][]byte{ + // Missing _current-profile. + string(ipn.KnownProfilesStateKey): []byte(""), + string(ipn.MachineKeyStateKey): []byte(""), + "profile-foo": []byte(""), + } + kc := &kubeClient{ + Client: &kubeclient.FakeClient{ + GetSecretImpl: func(context.Context, string) (*kubeapi.Secret, error) { + return &kubeapi.Secret{ + Data: data, + }, nil + }, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := kc.waitForConsistentState(ctx); err != context.DeadlineExceeded { + t.Fatalf("expected DeadlineExceeded, got %v", err) + } + + ctx, cancel = context.WithTimeout(context.Background(), time.Second) + defer cancel() + data[string(ipn.CurrentProfileStateKey)] = []byte("") + if err := kc.waitForConsistentState(ctx); err != nil { + t.Fatalf("expected nil, got %v", err) + } +} diff --git a/cmd/containerboot/main.go b/cmd/containerboot/main.go index 4c8ba58073c69..9543308975b79 100644 --- a/cmd/containerboot/main.go +++ b/cmd/containerboot/main.go @@ -52,11 +52,17 @@ // ${TS_CERT_DOMAIN}, it will be replaced with the value of the available FQDN. // It cannot be used in conjunction with TS_DEST_IP. The file is watched for changes, // and will be re-applied when it changes. -// - TS_HEALTHCHECK_ADDR_PORT: if specified, an HTTP health endpoint will be -// served at /healthz at the provided address, which should be in form [
]:. -// If not set, no health check will be run. If set to :, addr will default to 0.0.0.0 -// The health endpoint will return 200 OK if this node has at least one tailnet IP address, -// otherwise returns 503. +// - TS_HEALTHCHECK_ADDR_PORT: deprecated, use TS_ENABLE_HEALTH_CHECK instead and optionally +// set TS_LOCAL_ADDR_PORT. Will be removed in 1.82.0. +// - TS_LOCAL_ADDR_PORT: the address and port to serve local metrics and health +// check endpoints if enabled via TS_ENABLE_METRICS and/or TS_ENABLE_HEALTH_CHECK. +// Defaults to [::]:9002, serving on all available interfaces. +// - TS_ENABLE_METRICS: if true, a metrics endpoint will be served at /metrics on +// the address specified by TS_LOCAL_ADDR_PORT. See https://tailscale.com/kb/1482/client-metrics +// for more information on the metrics exposed. +// - TS_ENABLE_HEALTH_CHECK: if true, a health check endpoint will be served at /healthz on +// the address specified by TS_LOCAL_ADDR_PORT. The health endpoint will return 200 +// OK if this node has at least one tailnet IP address, otherwise returns 503. // NB: the health criteria might change in the future. // - TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR: if specified, a path to a // directory that containers tailscaled config in file. The config file needs to be @@ -99,10 +105,10 @@ import ( "log" "math" "net" + "net/http" "net/netip" "os" "os/signal" - "path" "path/filepath" "slices" "strings" @@ -115,6 +121,7 @@ import ( "tailscale.com/client/tailscale" "tailscale.com/ipn" kubeutils "tailscale.com/k8s-operator" + "tailscale.com/kube/kubetypes" "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/types/ptr" @@ -130,55 +137,126 @@ func newNetfilterRunner(logf logger.Logf) (linuxfw.NetfilterRunner, error) { } func main() { + if err := run(); err != nil && !errors.Is(err, context.Canceled) { + log.Fatal(err) + } +} + +func run() error { log.SetPrefix("boot: ") tailscale.I_Acknowledge_This_API_Is_Unstable = true cfg, err := configFromEnv() if err != nil { - log.Fatalf("invalid configuration: %v", err) + return fmt.Errorf("invalid configuration: %w", err) } if !cfg.UserspaceMode { if err := ensureTunFile(cfg.Root); err != nil { - log.Fatalf("Unable to create tuntap device file: %v", err) + return fmt.Errorf("unable to create tuntap device file: %w", err) } if cfg.ProxyTargetIP != "" || cfg.ProxyTargetDNSName != "" || cfg.Routes != nil || cfg.TailnetTargetIP != "" || cfg.TailnetTargetFQDN != "" { if err := ensureIPForwarding(cfg.Root, cfg.ProxyTargetIP, cfg.TailnetTargetIP, cfg.TailnetTargetFQDN, cfg.Routes); err != nil { log.Printf("Failed to enable IP forwarding: %v", err) log.Printf("To run tailscale as a proxy or router container, IP forwarding must be enabled.") if cfg.InKubernetes { - log.Fatalf("You can either set the sysctls as a privileged initContainer, or run the tailscale container with privileged=true.") + return fmt.Errorf("you can either set the sysctls as a privileged initContainer, or run the tailscale container with privileged=true.") } else { - log.Fatalf("You can fix this by running the container with privileged=true, or the equivalent in your container runtime that permits access to sysctls.") + return fmt.Errorf("you can fix this by running the container with privileged=true, or the equivalent in your container runtime that permits access to sysctls.") } } } } - // Context is used for all setup stuff until we're in steady + // Root context for the whole containerboot process, used to make sure + // shutdown signals are promptly and cleanly handled. + ctx, cancel := contextWithExitSignalWatch() + defer cancel() + + // bootCtx is used for all setup stuff until we're in steady // state, so that if something is hanging we eventually time out // and crashloop the container. - bootCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + bootCtx, cancel := context.WithTimeout(ctx, 60*time.Second) defer cancel() + var kc *kubeClient if cfg.InKubernetes { - initKubeClient(cfg.Root) - if err := cfg.setupKube(bootCtx); err != nil { - log.Fatalf("error setting up for running on Kubernetes: %v", err) + kc, err = newKubeClient(cfg.Root, cfg.KubeSecret) + if err != nil { + return fmt.Errorf("error initializing kube client: %w", err) + } + if err := cfg.setupKube(bootCtx, kc); err != nil { + return fmt.Errorf("error setting up for running on Kubernetes: %w", err) } } client, daemonProcess, err := startTailscaled(bootCtx, cfg) if err != nil { - log.Fatalf("failed to bring up tailscale: %v", err) + return fmt.Errorf("failed to bring up tailscale: %w", err) } killTailscaled := func() { + // The default termination grace period for a Pod is 30s. We wait 25s at + // most so that we still reserve some of that budget for tailscaled + // to receive and react to a SIGTERM before the SIGKILL that k8s + // will send at the end of the grace period. + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second) + defer cancel() + + if err := ensureServicesNotAdvertised(ctx, client); err != nil { + log.Printf("Error ensuring services are not advertised: %v", err) + } + + if hasKubeStateStore(cfg) { + // Check we're not shutting tailscaled down while it's still writing + // state. If we authenticate and fail to write all the state, we'll + // never recover automatically. + log.Printf("Checking for consistent state") + err := kc.waitForConsistentState(ctx) + if err != nil { + log.Printf("Error waiting for consistent state on shutdown: %v", err) + } + } + log.Printf("Sending SIGTERM to tailscaled") if err := daemonProcess.Signal(unix.SIGTERM); err != nil { log.Fatalf("error shutting tailscaled down: %v", err) } } defer killTailscaled() + var healthCheck *healthz + ep := &egressProxy{} + if cfg.HealthCheckAddrPort != "" { + mux := http.NewServeMux() + + log.Printf("Running healthcheck endpoint at %s/healthz", cfg.HealthCheckAddrPort) + healthCheck = registerHealthHandlers(mux, cfg.PodIPv4) + + close := runHTTPServer(mux, cfg.HealthCheckAddrPort) + defer close() + } + + if cfg.localMetricsEnabled() || cfg.localHealthEnabled() || cfg.egressSvcsTerminateEPEnabled() { + mux := http.NewServeMux() + + if cfg.localMetricsEnabled() { + log.Printf("Running metrics endpoint at %s/metrics", cfg.LocalAddrPort) + registerMetricsHandlers(mux, client, cfg.DebugAddrPort) + } + + if cfg.localHealthEnabled() { + log.Printf("Running healthcheck endpoint at %s/healthz", cfg.LocalAddrPort) + healthCheck = registerHealthHandlers(mux, cfg.PodIPv4) + } + + if cfg.egressSvcsTerminateEPEnabled() { + log.Printf("Running egress preshutdown hook at %s%s", cfg.LocalAddrPort, kubetypes.EgessServicesPreshutdownEP) + ep.registerHandlers(mux) + } + + close := runHTTPServer(mux, cfg.LocalAddrPort) + defer close() + } + if cfg.EnableForwardingOptimizations { if err := client.SetUDPGROForwarding(bootCtx); err != nil { log.Printf("[unexpected] error enabling UDP GRO forwarding: %v", err) @@ -187,7 +265,7 @@ func main() { w, err := client.WatchIPNBus(bootCtx, ipn.NotifyInitialNetMap|ipn.NotifyInitialPrefs|ipn.NotifyInitialState) if err != nil { - log.Fatalf("failed to watch tailscaled for updates: %v", err) + return fmt.Errorf("failed to watch tailscaled for updates: %w", err) } // Now that we've started tailscaled, we can symlink the socket to the @@ -223,18 +301,18 @@ func main() { didLogin = true w.Close() if err := tailscaleUp(bootCtx, cfg); err != nil { - return fmt.Errorf("failed to auth tailscale: %v", err) + return fmt.Errorf("failed to auth tailscale: %w", err) } w, err = client.WatchIPNBus(bootCtx, ipn.NotifyInitialNetMap|ipn.NotifyInitialState) if err != nil { - return fmt.Errorf("rewatching tailscaled for updates after auth: %v", err) + return fmt.Errorf("rewatching tailscaled for updates after auth: %w", err) } return nil } if isTwoStepConfigAlwaysAuth(cfg) { if err := authTailscale(); err != nil { - log.Fatalf("failed to auth tailscale: %v", err) + return fmt.Errorf("failed to auth tailscale: %w", err) } } @@ -242,7 +320,7 @@ authLoop: for { n, err := w.Next() if err != nil { - log.Fatalf("failed to read from tailscaled: %v", err) + return fmt.Errorf("failed to read from tailscaled: %w", err) } if n.State != nil { @@ -251,10 +329,10 @@ authLoop: if isOneStepConfig(cfg) { // This could happen if this is the first time tailscaled was run for this // device and the auth key was not passed via the configfile. - log.Fatalf("invalid state: tailscaled daemon started with a config file, but tailscale is not logged in: ensure you pass a valid auth key in the config file.") + return fmt.Errorf("invalid state: tailscaled daemon started with a config file, but tailscale is not logged in: ensure you pass a valid auth key in the config file.") } if err := authTailscale(); err != nil { - log.Fatalf("failed to auth tailscale: %v", err) + return fmt.Errorf("failed to auth tailscale: %w", err) } case ipn.NeedsMachineAuth: log.Printf("machine authorization required, please visit the admin panel") @@ -274,22 +352,25 @@ authLoop: w.Close() - ctx, cancel := contextWithExitSignalWatch() - defer cancel() - if isTwoStepConfigAuthOnce(cfg) { // Now that we are authenticated, we can set/reset any of the // settings that we need to. if err := tailscaleSet(ctx, cfg); err != nil { - log.Fatalf("failed to auth tailscale: %v", err) + return fmt.Errorf("failed to auth tailscale: %w", err) } } + // Remove any serve config and advertised HTTPS endpoint that may have been set by a previous run of + // containerboot, but only if we're providing a new one. if cfg.ServeConfigPath != "" { - // Remove any serve config that may have been set by a previous run of - // containerboot, but only if we're providing a new one. + log.Printf("serve proxy: unsetting previous config") if err := client.SetServeConfig(ctx, new(ipn.ServeConfig)); err != nil { - log.Fatalf("failed to unset serve config: %v", err) + return fmt.Errorf("failed to unset serve config: %w", err) + } + if hasKubeStateStore(cfg) { + if err := kc.storeHTTPSEndpoint(ctx, ""); err != nil { + return fmt.Errorf("failed to update HTTPS endpoint in tailscale state: %w", err) + } } } @@ -298,14 +379,26 @@ authLoop: // authkey is no longer needed. We don't strictly need to // wipe it, but it's good hygiene. log.Printf("Deleting authkey from kube secret") - if err := deleteAuthKey(ctx, cfg.KubeSecret); err != nil { - log.Fatalf("deleting authkey from kube secret: %v", err) + if err := kc.deleteAuthKey(ctx); err != nil { + return fmt.Errorf("deleting authkey from kube secret: %w", err) + } + } + + if hasKubeStateStore(cfg) { + if err := kc.storeCapVerUID(ctx, cfg.PodUID); err != nil { + return fmt.Errorf("storing capability version and UID: %w", err) } } w, err = client.WatchIPNBus(ctx, ipn.NotifyInitialNetMap|ipn.NotifyInitialState) if err != nil { - log.Fatalf("rewatching tailscaled for updates after auth: %v", err) + return fmt.Errorf("rewatching tailscaled for updates after auth: %w", err) + } + + // If tailscaled config was read from a mounted file, watch the file for updates and reload. + cfgWatchErrChan := make(chan error) + if cfg.TailscaledConfigFilePath != "" { + go watchTailscaledConfigChanges(ctx, cfg.TailscaledConfigFilePath, client, cfgWatchErrChan) } var ( @@ -322,17 +415,14 @@ authLoop: certDomain = new(atomic.Pointer[string]) certDomainChanged = make(chan bool, 1) - h = &healthz{} // http server for the healthz endpoint - healthzRunner = sync.OnceFunc(func() { runHealthz(cfg.HealthCheckAddrPort, h) }) + triggerWatchServeConfigChanges sync.Once ) - if cfg.ServeConfigPath != "" { - go watchServeConfigChanges(ctx, cfg.ServeConfigPath, certDomainChanged, certDomain, client) - } + var nfr linuxfw.NetfilterRunner if isL3Proxy(cfg) { nfr, err = newNetfilterRunner(log.Printf) if err != nil { - log.Fatalf("error creating new netfilter runner: %v", err) + return fmt.Errorf("error creating new netfilter runner: %w", err) } } @@ -351,6 +441,7 @@ authLoop: // egressSvcsErrorChan will get an error sent to it if this containerboot instance is configured to expose 1+ // egress services in HA mode and errored. var egressSvcsErrorChan = make(chan error) + var ingressSvcsErrorChan = make(chan error) defer t.Stop() // resetTimer resets timer for when to next attempt to resolve the DNS // name for the proxy configured with TS_EXPERIMENTAL_DEST_DNS_NAME. The @@ -403,7 +494,9 @@ runLoop: killTailscaled() break runLoop case err := <-errChan: - log.Fatalf("failed to read from tailscaled: %v", err) + return fmt.Errorf("failed to read from tailscaled: %w", err) + case err := <-cfgWatchErrChan: + return fmt.Errorf("failed to watch tailscaled config: %w", err) case n := <-notifyChan: if n.State != nil && *n.State != ipn.Running { // Something's gone wrong and we've left the authenticated state. @@ -411,7 +504,7 @@ runLoop: // control flow required to make it work now is hard. So, just crash // the container and rely on the container runtime to restart us, // whereupon we'll go through initial auth again. - log.Fatalf("tailscaled left running state (now in state %q), exiting", *n.State) + return fmt.Errorf("tailscaled left running state (now in state %q), exiting", *n.State) } if n.NetMap != nil { addrs = n.NetMap.SelfNode.Addresses().AsSlice() @@ -428,8 +521,8 @@ runLoop: // fails. deviceID := n.NetMap.SelfNode.StableID() if hasKubeStateStore(cfg) && deephash.Update(¤tDeviceID, &deviceID) { - if err := storeDeviceID(ctx, cfg.KubeSecret, n.NetMap.SelfNode.StableID()); err != nil { - log.Fatalf("storing device ID in Kubernetes Secret: %v", err) + if err := kc.storeDeviceID(ctx, n.NetMap.SelfNode.StableID()); err != nil { + return fmt.Errorf("storing device ID in Kubernetes Secret: %w", err) } } if cfg.TailnetTargetFQDN != "" { @@ -466,12 +559,12 @@ runLoop: rulesInstalled = true log.Printf("Installing forwarding rules for destination %v", ea.String()) if err := installEgressForwardingRule(ctx, ea.String(), addrs, nfr); err != nil { - log.Fatalf("installing egress proxy rules for destination %s: %v", ea.String(), err) + return fmt.Errorf("installing egress proxy rules for destination %s: %v", ea.String(), err) } } } if !rulesInstalled { - log.Fatalf("no forwarding rules for egress addresses %v, host supports IPv6: %v", egressAddrs, nfr.HasIPV6NAT()) + return fmt.Errorf("no forwarding rules for egress addresses %v, host supports IPv6: %v", egressAddrs, nfr.HasIPV6NAT()) } } currentEgressIPs = newCurentEgressIPs @@ -479,7 +572,7 @@ runLoop: if cfg.ProxyTargetIP != "" && len(addrs) != 0 && ipsHaveChanged { log.Printf("Installing proxy rules") if err := installIngressForwardingRule(ctx, cfg.ProxyTargetIP, addrs, nfr); err != nil { - log.Fatalf("installing ingress proxy rules: %v", err) + return fmt.Errorf("installing ingress proxy rules: %w", err) } } if cfg.ProxyTargetDNSName != "" && len(addrs) != 0 && ipsHaveChanged { @@ -495,14 +588,17 @@ runLoop: if backendsHaveChanged { log.Printf("installing ingress proxy rules for backends %v", newBackendAddrs) if err := installIngressForwardingRuleForDNSTarget(ctx, newBackendAddrs, addrs, nfr); err != nil { - log.Fatalf("error installing ingress proxy rules: %v", err) + return fmt.Errorf("error installing ingress proxy rules: %w", err) } } resetTimer(false) backendAddrs = newBackendAddrs } - if cfg.ServeConfigPath != "" && len(n.NetMap.DNS.CertDomains) != 0 { - cd := n.NetMap.DNS.CertDomains[0] + if cfg.ServeConfigPath != "" { + cd := certDomainFromNetmap(n.NetMap) + if cd == "" { + cd = kubetypes.ValueNoHTTPS + } prev := certDomain.Swap(ptr.To(cd)) if prev == nil || *prev != cd { select { @@ -514,7 +610,7 @@ runLoop: if cfg.TailnetTargetIP != "" && ipsHaveChanged && len(addrs) != 0 { log.Printf("Installing forwarding rules for destination %v", cfg.TailnetTargetIP) if err := installEgressForwardingRule(ctx, cfg.TailnetTargetIP, addrs, nfr); err != nil { - log.Fatalf("installing egress proxy rules: %v", err) + return fmt.Errorf("installing egress proxy rules: %w", err) } } // If this is a L7 cluster ingress proxy (set up @@ -526,7 +622,7 @@ runLoop: if cfg.AllowProxyingClusterTrafficViaIngress && cfg.ServeConfigPath != "" && ipsHaveChanged && len(addrs) != 0 { log.Printf("installing rules to forward traffic for %s to node's tailnet IP", cfg.PodIP) if err := installTSForwardingRuleForDestination(ctx, cfg.PodIP, addrs, nfr); err != nil { - log.Fatalf("installing rules to forward traffic to node's tailnet IP: %v", err) + return fmt.Errorf("installing rules to forward traffic to node's tailnet IP: %w", err) } } currentIPs = newCurrentIPs @@ -544,17 +640,21 @@ runLoop: // TODO (irbekrm): instead of using the IP and FQDN, have some other mechanism for the proxy signal that it is 'Ready'. deviceEndpoints := []any{n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses()} if hasKubeStateStore(cfg) && deephash.Update(¤tDeviceEndpoints, &deviceEndpoints) { - if err := storeDeviceEndpoints(ctx, cfg.KubeSecret, n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses().AsSlice()); err != nil { - log.Fatalf("storing device IPs and FQDN in Kubernetes Secret: %v", err) + if err := kc.storeDeviceEndpoints(ctx, n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses().AsSlice()); err != nil { + return fmt.Errorf("storing device IPs and FQDN in Kubernetes Secret: %w", err) } } - if cfg.HealthCheckAddrPort != "" { - h.Lock() - h.hasAddrs = len(addrs) != 0 - h.Unlock() - healthzRunner() + if healthCheck != nil { + healthCheck.update(len(addrs) != 0) } + + if cfg.ServeConfigPath != "" { + triggerWatchServeConfigChanges.Do(func() { + go watchServeConfigChanges(ctx, certDomainChanged, certDomain, client, kc, cfg) + }) + } + if egressSvcsNotify != nil { egressSvcsNotify <- n } @@ -576,24 +676,42 @@ runLoop: // will then continuously monitor the config file and netmap updates and // reconfigure the firewall rules as needed. If any of its operations fail, it // will crash this node. - if cfg.EgressSvcsCfgPath != "" { - log.Printf("configuring egress proxy using configuration file at %s", cfg.EgressSvcsCfgPath) + if cfg.EgressProxiesCfgPath != "" { + log.Printf("configuring egress proxy using configuration file at %s", cfg.EgressProxiesCfgPath) egressSvcsNotify = make(chan ipn.Notify) - ep := egressProxy{ - cfgPath: cfg.EgressSvcsCfgPath, + opts := egressProxyRunOpts{ + cfgPath: cfg.EgressProxiesCfgPath, nfr: nfr, kc: kc, + tsClient: client, stateSecret: cfg.KubeSecret, netmapChan: egressSvcsNotify, podIPv4: cfg.PodIPv4, tailnetAddrs: addrs, } go func() { - if err := ep.run(ctx, n); err != nil { + if err := ep.run(ctx, n, opts); err != nil { egressSvcsErrorChan <- err } }() } + ip := ingressProxy{} + if cfg.IngressProxiesCfgPath != "" { + log.Printf("configuring ingress proxy using configuration file at %s", cfg.IngressProxiesCfgPath) + opts := ingressProxyOpts{ + cfgPath: cfg.IngressProxiesCfgPath, + nfr: nfr, + kc: kc, + stateSecret: cfg.KubeSecret, + podIPv4: cfg.PodIPv4, + podIPv6: cfg.PodIPv6, + } + go func() { + if err := ip.run(ctx, opts); err != nil { + ingressSvcsErrorChan <- err + } + }() + } // Wait on tailscaled process. It won't be cleaned up by default when the // container exits as it is not PID1. TODO (irbekrm): perhaps we can replace the @@ -631,16 +749,20 @@ runLoop: if backendsHaveChanged && len(addrs) != 0 { log.Printf("Backend address change detected, installing proxy rules for backends %v", newBackendAddrs) if err := installIngressForwardingRuleForDNSTarget(ctx, newBackendAddrs, addrs, nfr); err != nil { - log.Fatalf("installing ingress proxy rules for DNS target %s: %v", cfg.ProxyTargetDNSName, err) + return fmt.Errorf("installing ingress proxy rules for DNS target %s: %v", cfg.ProxyTargetDNSName, err) } } backendAddrs = newBackendAddrs resetTimer(false) case e := <-egressSvcsErrorChan: - log.Fatalf("egress proxy failed: %v", e) + return fmt.Errorf("egress proxy failed: %v", e) + case e := <-ingressSvcsErrorChan: + return fmt.Errorf("ingress proxy failed: %v", e) } } wg.Wait() + + return nil } // ensureTunFile checks that /dev/net/tun exists, creating it if @@ -669,13 +791,13 @@ func resolveDNS(ctx context.Context, name string) ([]net.IP, error) { ip4s, err := net.DefaultResolver.LookupIP(ctx, "ip4", name) if err != nil { if e, ok := err.(*net.DNSError); !(ok && e.IsNotFound) { - return nil, fmt.Errorf("error looking up IPv4 addresses: %v", err) + return nil, fmt.Errorf("error looking up IPv4 addresses: %w", err) } } ip6s, err := net.DefaultResolver.LookupIP(ctx, "ip6", name) if err != nil { if e, ok := err.(*net.DNSError); !(ok && e.IsNotFound) { - return nil, fmt.Errorf("error looking up IPv6 addresses: %v", err) + return nil, fmt.Errorf("error looking up IPv6 addresses: %w", err) } } if len(ip4s) == 0 && len(ip6s) == 0 { @@ -688,7 +810,7 @@ func resolveDNS(ctx context.Context, name string) ([]net.IP, error) { // context that gets cancelled when a signal is received and a cancel function // that can be called to free the resources when the watch should be stopped. func contextWithExitSignalWatch() (context.Context, func()) { - closeChan := make(chan string) + closeChan := make(chan struct{}) ctx, cancel := context.WithCancel(context.Background()) signalChan := make(chan os.Signal, 1) signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM) @@ -700,8 +822,11 @@ func contextWithExitSignalWatch() (context.Context, func()) { return } }() + closeOnce := sync.Once{} f := func() { - closeChan <- "goodbye" + closeOnce.Do(func() { + close(closeChan) + }) } return ctx, f } @@ -731,7 +856,6 @@ func tailscaledConfigFilePath() string { } cv, err := kubeutils.CapVerFromFileName(e.Name()) if err != nil { - log.Printf("skipping file %q in tailscaled config directory %q: %v", e.Name(), dir, err) continue } if cv > maxCompatVer && cv <= tailcfg.CurrentCapabilityVersion { @@ -739,8 +863,32 @@ func tailscaledConfigFilePath() string { } } if maxCompatVer == -1 { - log.Fatalf("no tailscaled config file found in %q for current capability version %q", dir, tailcfg.CurrentCapabilityVersion) + log.Fatalf("no tailscaled config file found in %q for current capability version %d", dir, tailcfg.CurrentCapabilityVersion) + } + filePath := filepath.Join(dir, kubeutils.TailscaledConfigFileName(maxCompatVer)) + log.Printf("Using tailscaled config file %q to match current capability version %d", filePath, tailcfg.CurrentCapabilityVersion) + return filePath +} + +func runHTTPServer(mux *http.ServeMux, addr string) (close func() error) { + ln, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalf("failed to listen on addr %q: %v", addr, err) + } + srv := &http.Server{Handler: mux} + + go func() { + if err := srv.Serve(ln); err != nil { + if err != http.ErrServerClosed { + log.Fatalf("failed running server: %v", err) + } else { + log.Printf("HTTP server at %s closed", addr) + } + } + }() + + return func() error { + err := srv.Shutdown(context.Background()) + return errors.Join(err, ln.Close()) } - log.Printf("Using tailscaled config file %q for capability version %q", maxCompatVer, tailcfg.CurrentCapabilityVersion) - return path.Join(dir, kubeutils.TailscaledConfigFileName(maxCompatVer)) } diff --git a/cmd/containerboot/main_test.go b/cmd/containerboot/main_test.go index 5c92787ce6079..c7293c77a4afa 100644 --- a/cmd/containerboot/main_test.go +++ b/cmd/containerboot/main_test.go @@ -25,12 +25,16 @@ import ( "strconv" "strings" "sync" + "syscall" "testing" "time" "github.com/google/go-cmp/cmp" "golang.org/x/sys/unix" "tailscale.com/ipn" + "tailscale.com/kube/egressservices" + "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/types/netmap" @@ -38,69 +42,23 @@ import ( ) func TestContainerBoot(t *testing.T) { - d := t.TempDir() - - lapi := localAPI{FSRoot: d} - if err := lapi.Start(); err != nil { - t.Fatal(err) - } - defer lapi.Close() - - kube := kubeServer{FSRoot: d} - if err := kube.Start(); err != nil { - t.Fatal(err) - } - defer kube.Close() - - tailscaledConf := &ipn.ConfigVAlpha{AuthKey: ptr.To("foo"), Version: "alpha0"} - tailscaledConfBytes, err := json.Marshal(tailscaledConf) - if err != nil { - t.Fatalf("error unmarshaling tailscaled config: %v", err) + boot := filepath.Join(t.TempDir(), "containerboot") + if err := exec.Command("go", "build", "-ldflags", "-X main.testSleepDuration=1ms", "-o", boot, "tailscale.com/cmd/containerboot").Run(); err != nil { + t.Fatalf("Building containerboot: %v", err) } + egressStatus := egressSvcStatus("foo", "foo.tailnetxyz.ts.net") - dirs := []string{ - "var/lib", - "usr/bin", - "tmp", - "dev/net", - "proc/sys/net/ipv4", - "proc/sys/net/ipv6/conf/all", - "etc/tailscaled", - } - for _, path := range dirs { - if err := os.MkdirAll(filepath.Join(d, path), 0700); err != nil { - t.Fatal(err) - } + metricsURL := func(port int) string { + return fmt.Sprintf("http://127.0.0.1:%d/metrics", port) } - files := map[string][]byte{ - "usr/bin/tailscaled": fakeTailscaled, - "usr/bin/tailscale": fakeTailscale, - "usr/bin/iptables": fakeTailscale, - "usr/bin/ip6tables": fakeTailscale, - "dev/net/tun": []byte(""), - "proc/sys/net/ipv4/ip_forward": []byte("0"), - "proc/sys/net/ipv6/conf/all/forwarding": []byte("0"), - "etc/tailscaled/cap-95.hujson": tailscaledConfBytes, + healthURL := func(port int) string { + return fmt.Sprintf("http://127.0.0.1:%d/healthz", port) } - resetFiles := func() { - for path, content := range files { - // Making everything executable is a little weird, but the - // stuff that doesn't need to be executable doesn't care if we - // do make it executable. - if err := os.WriteFile(filepath.Join(d, path), content, 0700); err != nil { - t.Fatal(err) - } - } + egressSvcTerminateURL := func(port int) string { + return fmt.Sprintf("http://127.0.0.1:%d%s", port, kubetypes.EgessServicesPreshutdownEP) } - resetFiles() - boot := filepath.Join(d, "containerboot") - if err := exec.Command("go", "build", "-o", boot, "tailscale.com/cmd/containerboot").Run(); err != nil { - t.Fatalf("Building containerboot: %v", err) - } - - argFile := filepath.Join(d, "args") - runningSockPath := filepath.Join(d, "tmp/tailscaled.sock") + capver := fmt.Sprintf("%d", tailcfg.CurrentCapabilityVersion) type phase struct { // If non-nil, send this IPN bus notification (and remember it as the @@ -110,15 +68,31 @@ func TestContainerBoot(t *testing.T) { // WantCmds is the commands that containerboot should run in this phase. WantCmds []string + // WantKubeSecret is the secret keys/values that should exist in the // kube secret. WantKubeSecret map[string]string + + // Update the kube secret with these keys/values at the beginning of the + // phase (simulates our fake tailscaled doing it). + UpdateKubeSecret map[string]string + // WantFiles files that should exist in the container and their // contents. WantFiles map[string]string - // WantFatalLog is the fatal log message we expect from containerboot. - // If set for a phase, the test will finish on that phase. - WantFatalLog string + + // WantLog is a log message we expect from containerboot. + WantLog string + + // If set for a phase, the test will expect containerboot to exit with + // this error code, and the test will finish on that phase without + // waiting for the successful startup log message. + WantExitCode *int + + // The signal to send to containerboot at the start of the phase. + Signal *syscall.Signal + + EndpointStatuses map[string]int } runningNotify := &ipn.Notify{ State: ptr.To(ipn.Running), @@ -130,601 +104,966 @@ func TestContainerBoot(t *testing.T) { }).View(), }, } - tests := []struct { - Name string + type testCase struct { Env map[string]string KubeSecret map[string]string KubeDenyPatch bool Phases []phase - }{ - { - // Out of the box default: runs in userspace mode, ephemeral storage, interactive login. - Name: "no_args", - Env: nil, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + } + tests := map[string]func(env *testEnv) testCase{ + "no_args": func(env *testEnv) testCase { + return testCase{ + // Out of the box default: runs in userspace mode, ephemeral storage, interactive login. + Env: nil, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + // No metrics or health by default. + EndpointStatuses: map[string]int{ + metricsURL(9002): -1, + healthURL(9002): -1, + }, + }, + { + Notify: runningNotify, }, }, - { - Notify: runningNotify, - }, - }, + } }, - { - // Userspace mode, ephemeral storage, authkey provided on every run. - Name: "authkey", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", - }, + "authkey": func(env *testEnv) testCase { + return testCase{ + // Userspace mode, ephemeral storage, authkey provided on every run. + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", }, - { - Notify: runningNotify, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + }, + { + Notify: runningNotify, + }, }, - }, + } }, - { - // Userspace mode, ephemeral storage, authkey provided on every run. - Name: "authkey-old-flag", - Env: map[string]string{ - "TS_AUTH_KEY": "tskey-key", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", - }, + "authkey_old_flag": func(env *testEnv) testCase { + return testCase{ + // Userspace mode, ephemeral storage, authkey provided on every run. + Env: map[string]string{ + "TS_AUTH_KEY": "tskey-key", }, - { - Notify: runningNotify, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + }, + { + Notify: runningNotify, + }, }, - }, + } }, - { - Name: "authkey_disk_state", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_STATE_DIR": filepath.Join(d, "tmp"), - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", - }, + "authkey_disk_state": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_STATE_DIR": filepath.Join(env.d, "tmp"), }, - { - Notify: runningNotify, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + }, + { + Notify: runningNotify, + }, }, - }, + } }, - { - Name: "routes", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_ROUTES": "1.2.3.0/24,10.20.30.0/24", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=1.2.3.0/24,10.20.30.0/24", - }, + "routes": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_ROUTES": "1.2.3.0/24,10.20.30.0/24", }, - { - Notify: runningNotify, - WantFiles: map[string]string{ - "proc/sys/net/ipv4/ip_forward": "0", - "proc/sys/net/ipv6/conf/all/forwarding": "0", + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=1.2.3.0/24,10.20.30.0/24", + }, + }, + { + Notify: runningNotify, + WantFiles: map[string]string{ + "proc/sys/net/ipv4/ip_forward": "0", + "proc/sys/net/ipv6/conf/all/forwarding": "0", + }, }, }, - }, + } }, - { - Name: "empty routes", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_ROUTES": "", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=", - }, + "empty_routes": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_ROUTES": "", }, - { - Notify: runningNotify, - WantFiles: map[string]string{ - "proc/sys/net/ipv4/ip_forward": "0", - "proc/sys/net/ipv6/conf/all/forwarding": "0", + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=", + }, + }, + { + Notify: runningNotify, + WantFiles: map[string]string{ + "proc/sys/net/ipv4/ip_forward": "0", + "proc/sys/net/ipv6/conf/all/forwarding": "0", + }, }, }, - }, + } }, - { - Name: "routes_kernel_ipv4", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_ROUTES": "1.2.3.0/24,10.20.30.0/24", - "TS_USERSPACE": "false", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=1.2.3.0/24,10.20.30.0/24", - }, + "routes_kernel_ipv4": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_ROUTES": "1.2.3.0/24,10.20.30.0/24", + "TS_USERSPACE": "false", }, - { - Notify: runningNotify, - WantFiles: map[string]string{ - "proc/sys/net/ipv4/ip_forward": "1", - "proc/sys/net/ipv6/conf/all/forwarding": "0", + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=1.2.3.0/24,10.20.30.0/24", + }, + }, + { + Notify: runningNotify, + WantFiles: map[string]string{ + "proc/sys/net/ipv4/ip_forward": "1", + "proc/sys/net/ipv6/conf/all/forwarding": "0", + }, }, }, - }, + } }, - { - Name: "routes_kernel_ipv6", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_ROUTES": "::/64,1::/64", - "TS_USERSPACE": "false", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=::/64,1::/64", - }, + "routes_kernel_ipv6": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_ROUTES": "::/64,1::/64", + "TS_USERSPACE": "false", }, - { - Notify: runningNotify, - WantFiles: map[string]string{ - "proc/sys/net/ipv4/ip_forward": "0", - "proc/sys/net/ipv6/conf/all/forwarding": "1", + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=::/64,1::/64", + }, + }, + { + Notify: runningNotify, + WantFiles: map[string]string{ + "proc/sys/net/ipv4/ip_forward": "0", + "proc/sys/net/ipv6/conf/all/forwarding": "1", + }, }, }, - }, + } }, - { - Name: "routes_kernel_all_families", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_ROUTES": "::/64,1.2.3.0/24", - "TS_USERSPACE": "false", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=::/64,1.2.3.0/24", - }, + "routes_kernel_all_families": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_ROUTES": "::/64,1.2.3.0/24", + "TS_USERSPACE": "false", }, - { - Notify: runningNotify, - WantFiles: map[string]string{ - "proc/sys/net/ipv4/ip_forward": "1", - "proc/sys/net/ipv6/conf/all/forwarding": "1", + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key --advertise-routes=::/64,1.2.3.0/24", + }, + }, + { + Notify: runningNotify, + WantFiles: map[string]string{ + "proc/sys/net/ipv4/ip_forward": "1", + "proc/sys/net/ipv6/conf/all/forwarding": "1", + }, }, }, - }, + } }, - { - Name: "ingress proxy", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_DEST_IP": "1.2.3.4", - "TS_USERSPACE": "false", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", - }, + "ingress_proxy": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_DEST_IP": "1.2.3.4", + "TS_USERSPACE": "false", }, - { - Notify: runningNotify, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + }, + { + Notify: runningNotify, + }, }, - }, + } }, - { - Name: "egress proxy", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_TAILNET_TARGET_IP": "100.99.99.99", - "TS_USERSPACE": "false", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + "egress_proxy": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_TAILNET_TARGET_IP": "100.99.99.99", + "TS_USERSPACE": "false", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantFiles: map[string]string{ + "proc/sys/net/ipv4/ip_forward": "1", + "proc/sys/net/ipv6/conf/all/forwarding": "0", + }, }, - WantFiles: map[string]string{ - "proc/sys/net/ipv4/ip_forward": "1", - "proc/sys/net/ipv6/conf/all/forwarding": "0", + { + Notify: runningNotify, }, }, - { - Notify: runningNotify, - }, - }, + } }, - { - Name: "egress_proxy_fqdn_ipv6_target_on_ipv4_host", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_TAILNET_TARGET_FQDN": "ipv6-node.test.ts.net", // resolves to IPv6 address - "TS_USERSPACE": "false", - "TS_TEST_FAKE_NETFILTER_6": "false", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", - }, - WantFiles: map[string]string{ - "proc/sys/net/ipv4/ip_forward": "1", - "proc/sys/net/ipv6/conf/all/forwarding": "0", - }, - }, - { - Notify: &ipn.Notify{ - State: ptr.To(ipn.Running), - NetMap: &netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - StableID: tailcfg.StableNodeID("myID"), - Name: "test-node.test.ts.net", - Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, - }).View(), - Peers: []tailcfg.NodeView{ - (&tailcfg.Node{ - StableID: tailcfg.StableNodeID("ipv6ID"), - Name: "ipv6-node.test.ts.net", - Addresses: []netip.Prefix{netip.MustParsePrefix("::1/128")}, + "egress_proxy_fqdn_ipv6_target_on_ipv4_host": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_TAILNET_TARGET_FQDN": "ipv6-node.test.ts.net", // resolves to IPv6 address + "TS_USERSPACE": "false", + "TS_TEST_FAKE_NETFILTER_6": "false", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantFiles: map[string]string{ + "proc/sys/net/ipv4/ip_forward": "1", + "proc/sys/net/ipv6/conf/all/forwarding": "0", + }, + }, + { + Notify: &ipn.Notify{ + State: ptr.To(ipn.Running), + NetMap: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + StableID: tailcfg.StableNodeID("myID"), + Name: "test-node.test.ts.net", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + StableID: tailcfg.StableNodeID("ipv6ID"), + Name: "ipv6-node.test.ts.net", + Addresses: []netip.Prefix{netip.MustParsePrefix("::1/128")}, + }).View(), + }, }, }, + WantLog: "no forwarding rules for egress addresses [::1/128], host supports IPv6: false", + WantExitCode: ptr.To(1), }, - WantFatalLog: "no forwarding rules for egress addresses [::1/128], host supports IPv6: false", }, - }, + } }, - { - Name: "authkey_once", - Env: map[string]string{ - "TS_AUTHKEY": "tskey-key", - "TS_AUTH_ONCE": "true", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - }, + "authkey_once": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_AUTH_ONCE": "true", }, - { - Notify: &ipn.Notify{ - State: ptr.To(ipn.NeedsLogin), + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + }, }, - WantCmds: []string{ - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + { + Notify: &ipn.Notify{ + State: ptr.To(ipn.NeedsLogin), + }, + WantCmds: []string{ + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, }, - }, - { - Notify: runningNotify, - WantCmds: []string{ - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock set --accept-dns=false", + { + Notify: runningNotify, + WantCmds: []string{ + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock set --accept-dns=false", + }, }, }, - }, + } }, - { - Name: "kube_storage", - Env: map[string]string{ - "KUBERNETES_SERVICE_HOST": kube.Host, - "KUBERNETES_SERVICE_PORT_HTTPS": kube.Port, - }, - KubeSecret: map[string]string{ - "authkey": "tskey-key", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + "auth_key_once_extra_args_override_dns": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_AUTHKEY": "tskey-key", + "TS_AUTH_ONCE": "true", + "TS_ACCEPT_DNS": "false", + "TS_EXTRA_ARGS": "--accept-dns", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + }, }, - WantKubeSecret: map[string]string{ - "authkey": "tskey-key", + { + Notify: &ipn.Notify{ + State: ptr.To(ipn.NeedsLogin), + }, + WantCmds: []string{ + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=true --authkey=tskey-key", + }, }, - }, - { - Notify: runningNotify, - WantKubeSecret: map[string]string{ - "authkey": "tskey-key", - "device_fqdn": "test-node.test.ts.net", - "device_id": "myID", - "device_ips": `["100.64.0.1"]`, + { + Notify: runningNotify, + WantCmds: []string{ + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock set --accept-dns=true", + }, }, }, - }, + } }, - { - Name: "kube_disk_storage", - Env: map[string]string{ - "KUBERNETES_SERVICE_HOST": kube.Host, - "KUBERNETES_SERVICE_PORT_HTTPS": kube.Port, - // Explicitly set to an empty value, to override the default of "tailscale". - "TS_KUBE_SECRET": "", - "TS_STATE_DIR": filepath.Join(d, "tmp"), - "TS_AUTHKEY": "tskey-key", - }, - KubeSecret: map[string]string{}, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + "kube_storage": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": env.kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": env.kube.Port, + }, + KubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + }, + { + Notify: runningNotify, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + "device_fqdn": "test-node.test.ts.net", + "device_id": "myID", + "device_ips": `["100.64.0.1"]`, + "tailscale_capver": capver, + }, }, - WantKubeSecret: map[string]string{}, }, - { - Notify: runningNotify, - WantKubeSecret: map[string]string{}, + } + }, + "kube_disk_storage": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": env.kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": env.kube.Port, + // Explicitly set to an empty value, to override the default of "tailscale". + "TS_KUBE_SECRET": "", + "TS_STATE_DIR": filepath.Join(env.d, "tmp"), + "TS_AUTHKEY": "tskey-key", }, - }, + KubeSecret: map[string]string{}, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{}, + }, + { + Notify: runningNotify, + WantKubeSecret: map[string]string{}, + }, + }, + } }, - { - Name: "kube_storage_no_patch", - Env: map[string]string{ - "KUBERNETES_SERVICE_HOST": kube.Host, - "KUBERNETES_SERVICE_PORT_HTTPS": kube.Port, - "TS_AUTHKEY": "tskey-key", - }, - KubeSecret: map[string]string{}, - KubeDenyPatch: true, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + "kube_storage_no_patch": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": env.kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": env.kube.Port, + "TS_AUTHKEY": "tskey-key", + }, + KubeSecret: map[string]string{}, + KubeDenyPatch: true, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{}, + }, + { + Notify: runningNotify, + WantKubeSecret: map[string]string{}, }, - WantKubeSecret: map[string]string{}, }, - { - Notify: runningNotify, - WantKubeSecret: map[string]string{}, + } + }, + "kube_storage_auth_once": func(env *testEnv) testCase { + return testCase{ + // Same as previous, but deletes the authkey from the kube secret. + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": env.kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": env.kube.Port, + "TS_AUTH_ONCE": "true", }, - }, + KubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + }, + { + Notify: &ipn.Notify{ + State: ptr.To(ipn.NeedsLogin), + }, + WantCmds: []string{ + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + }, + { + Notify: runningNotify, + WantCmds: []string{ + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock set --accept-dns=false", + }, + WantKubeSecret: map[string]string{ + "device_fqdn": "test-node.test.ts.net", + "device_id": "myID", + "device_ips": `["100.64.0.1"]`, + "tailscale_capver": capver, + }, + }, + }, + } }, - { - // Same as previous, but deletes the authkey from the kube secret. - Name: "kube_storage_auth_once", - Env: map[string]string{ - "KUBERNETES_SERVICE_HOST": kube.Host, - "KUBERNETES_SERVICE_PORT_HTTPS": kube.Port, - "TS_AUTH_ONCE": "true", - }, - KubeSecret: map[string]string{ - "authkey": "tskey-key", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + "kube_storage_updates": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": env.kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": env.kube.Port, + }, + KubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + }, }, - WantKubeSecret: map[string]string{ - "authkey": "tskey-key", + { + Notify: runningNotify, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + "device_fqdn": "test-node.test.ts.net", + "device_id": "myID", + "device_ips": `["100.64.0.1"]`, + "tailscale_capver": capver, + }, }, + { + Notify: &ipn.Notify{ + State: ptr.To(ipn.Running), + NetMap: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + StableID: tailcfg.StableNodeID("newID"), + Name: "new-name.test.ts.net", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, + }).View(), + }, + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + "device_fqdn": "new-name.test.ts.net", + "device_id": "newID", + "device_ips": `["100.64.0.1"]`, + "tailscale_capver": capver, + }, + }, + }, + } + }, + "proxies": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_SOCKS5_SERVER": "localhost:1080", + "TS_OUTBOUND_HTTP_PROXY_LISTEN": "localhost:8080", }, - { - Notify: &ipn.Notify{ - State: ptr.To(ipn.NeedsLogin), + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking --socks5-server=localhost:1080 --outbound-http-proxy-listen=localhost:8080", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + }, + { + Notify: runningNotify, }, - WantCmds: []string{ - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + } + }, + "dns": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_ACCEPT_DNS": "true", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=true", + }, }, - WantKubeSecret: map[string]string{ - "authkey": "tskey-key", + { + Notify: runningNotify, }, }, - { - Notify: runningNotify, - WantCmds: []string{ - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock set --accept-dns=false", + } + }, + "extra_args": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_EXTRA_ARGS": "--widget=rotated", + "TS_TAILSCALED_EXTRA_ARGS": "--experiments=widgets", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking --experiments=widgets", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --widget=rotated", + }, + }, { + Notify: runningNotify, }, - WantKubeSecret: map[string]string{ - "device_fqdn": "test-node.test.ts.net", - "device_id": "myID", - "device_ips": `["100.64.0.1"]`, + }, + } + }, + "extra_args_accept_routes": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_EXTRA_ARGS": "--accept-routes", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --accept-routes", + }, + }, { + Notify: runningNotify, }, }, - }, + } }, - { - Name: "kube_storage_updates", - Env: map[string]string{ - "KUBERNETES_SERVICE_HOST": kube.Host, - "KUBERNETES_SERVICE_PORT_HTTPS": kube.Port, - }, - KubeSecret: map[string]string{ - "authkey": "tskey-key", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + "extra_args_accept_dns": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_EXTRA_ARGS": "--accept-dns", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=true", + }, + }, { + Notify: runningNotify, }, - WantKubeSecret: map[string]string{ - "authkey": "tskey-key", + }, + } + }, + "extra_args_accept_dns_overrides_env_var": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_ACCEPT_DNS": "true", // Overridden by TS_EXTRA_ARGS. + "TS_EXTRA_ARGS": "--accept-dns=false", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + }, { + Notify: runningNotify, }, }, - { - Notify: runningNotify, - WantKubeSecret: map[string]string{ - "authkey": "tskey-key", - "device_fqdn": "test-node.test.ts.net", - "device_id": "myID", - "device_ips": `["100.64.0.1"]`, + } + }, + "hostname": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_HOSTNAME": "my-server", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --hostname=my-server", + }, + }, { + Notify: runningNotify, }, }, - { - Notify: &ipn.Notify{ - State: ptr.To(ipn.Running), - NetMap: &netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - StableID: tailcfg.StableNodeID("newID"), - Name: "new-name.test.ts.net", - Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, - }).View(), + } + }, + "experimental_tailscaled_config_path": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR": filepath.Join(env.d, "etc/tailscaled/"), + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking --config=/etc/tailscaled/cap-95.hujson", }, + }, { + Notify: runningNotify, }, - WantKubeSecret: map[string]string{ - "authkey": "tskey-key", - "device_fqdn": "new-name.test.ts.net", - "device_id": "newID", - "device_ips": `["100.64.0.1"]`, + }, + } + }, + "metrics_enabled": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_LOCAL_ADDR_PORT": fmt.Sprintf("[::]:%d", env.localAddrPort), + "TS_ENABLE_METRICS": "true", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + EndpointStatuses: map[string]int{ + metricsURL(env.localAddrPort): 200, + healthURL(env.localAddrPort): -1, + }, + }, { + Notify: runningNotify, }, }, - }, + } }, - { - Name: "proxies", - Env: map[string]string{ - "TS_SOCKS5_SERVER": "localhost:1080", - "TS_OUTBOUND_HTTP_PROXY_LISTEN": "localhost:8080", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking --socks5-server=localhost:1080 --outbound-http-proxy-listen=localhost:8080", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + "health_enabled": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_LOCAL_ADDR_PORT": fmt.Sprintf("[::]:%d", env.localAddrPort), + "TS_ENABLE_HEALTH_CHECK": "true", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + EndpointStatuses: map[string]int{ + metricsURL(env.localAddrPort): -1, + healthURL(env.localAddrPort): 503, // Doesn't start passing until the next phase. + }, + }, { + Notify: runningNotify, + EndpointStatuses: map[string]int{ + metricsURL(env.localAddrPort): -1, + healthURL(env.localAddrPort): 200, + }, }, }, - { - Notify: runningNotify, + } + }, + "metrics_and_health_on_same_port": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_LOCAL_ADDR_PORT": fmt.Sprintf("[::]:%d", env.localAddrPort), + "TS_ENABLE_METRICS": "true", + "TS_ENABLE_HEALTH_CHECK": "true", }, - }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + EndpointStatuses: map[string]int{ + metricsURL(env.localAddrPort): 200, + healthURL(env.localAddrPort): 503, // Doesn't start passing until the next phase. + }, + }, { + Notify: runningNotify, + EndpointStatuses: map[string]int{ + metricsURL(env.localAddrPort): 200, + healthURL(env.localAddrPort): 200, + }, + }, + }, + } }, - { - Name: "dns", - Env: map[string]string{ - "TS_ACCEPT_DNS": "true", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=true", + "local_metrics_and_deprecated_health": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_LOCAL_ADDR_PORT": fmt.Sprintf("[::]:%d", env.localAddrPort), + "TS_ENABLE_METRICS": "true", + "TS_HEALTHCHECK_ADDR_PORT": fmt.Sprintf("[::]:%d", env.healthAddrPort), + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false", + }, + EndpointStatuses: map[string]int{ + metricsURL(env.localAddrPort): 200, + healthURL(env.healthAddrPort): 503, // Doesn't start passing until the next phase. + }, + }, { + Notify: runningNotify, + EndpointStatuses: map[string]int{ + metricsURL(env.localAddrPort): 200, + healthURL(env.healthAddrPort): 200, + }, }, }, - { - Notify: runningNotify, + } + }, + "serve_config_no_kube": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_SERVE_CONFIG": filepath.Join(env.d, "etc/tailscaled/serve-config.json"), + "TS_AUTHKEY": "tskey-key", }, - }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + }, + { + Notify: runningNotify, + }, + }, + } }, - { - Name: "extra_args", - Env: map[string]string{ - "TS_EXTRA_ARGS": "--widget=rotated", - "TS_TAILSCALED_EXTRA_ARGS": "--experiments=widgets", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking --experiments=widgets", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --widget=rotated", + "serve_config_kube": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": env.kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": env.kube.Port, + "TS_SERVE_CONFIG": filepath.Join(env.d, "etc/tailscaled/serve-config.json"), + }, + KubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + }, + { + Notify: runningNotify, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + "device_fqdn": "test-node.test.ts.net", + "device_id": "myID", + "device_ips": `["100.64.0.1"]`, + "https_endpoint": "no-https", + "tailscale_capver": capver, + }, }, - }, { - Notify: runningNotify, }, - }, + } }, - { - Name: "extra_args_accept_routes", - Env: map[string]string{ - "TS_EXTRA_ARGS": "--accept-routes", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --accept-routes", + "egress_svcs_config_kube": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": env.kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": env.kube.Port, + "TS_EGRESS_PROXIES_CONFIG_PATH": filepath.Join(env.d, "etc/tailscaled"), + "TS_LOCAL_ADDR_PORT": fmt.Sprintf("[::]:%d", env.localAddrPort), + }, + KubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + Phases: []phase{ + { + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + EndpointStatuses: map[string]int{ + egressSvcTerminateURL(env.localAddrPort): 200, + }, + }, + { + Notify: runningNotify, + WantKubeSecret: map[string]string{ + "egress-services": mustBase64(t, egressStatus), + "authkey": "tskey-key", + "device_fqdn": "test-node.test.ts.net", + "device_id": "myID", + "device_ips": `["100.64.0.1"]`, + "tailscale_capver": capver, + }, + EndpointStatuses: map[string]int{ + egressSvcTerminateURL(env.localAddrPort): 200, + }, }, - }, { - Notify: runningNotify, }, - }, + } }, - { - Name: "hostname", - Env: map[string]string{ - "TS_HOSTNAME": "my-server", - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking", - "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --hostname=my-server", + "egress_svcs_config_no_kube": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "TS_EGRESS_PROXIES_CONFIG_PATH": filepath.Join(env.d, "etc/tailscaled"), + "TS_AUTHKEY": "tskey-key", + }, + Phases: []phase{ + { + WantLog: "TS_EGRESS_PROXIES_CONFIG_PATH is only supported for Tailscale running on Kubernetes", + WantExitCode: ptr.To(1), }, - }, { - Notify: runningNotify, }, - }, + } }, - { - Name: "experimental tailscaled config path", - Env: map[string]string{ - "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR": filepath.Join(d, "etc/tailscaled/"), - }, - Phases: []phase{ - { - WantCmds: []string{ - "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=mem: --statedir=/tmp --tun=userspace-networking --config=/etc/tailscaled/cap-95.hujson", + "kube_shutdown_during_state_write": func(env *testEnv) testCase { + return testCase{ + Env: map[string]string{ + "KUBERNETES_SERVICE_HOST": env.kube.Host, + "KUBERNETES_SERVICE_PORT_HTTPS": env.kube.Port, + "TS_ENABLE_HEALTH_CHECK": "true", + }, + KubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + Phases: []phase{ + { + // Normal startup. + WantCmds: []string{ + "/usr/bin/tailscaled --socket=/tmp/tailscaled.sock --state=kube:tailscale --statedir=/tmp --tun=userspace-networking", + "/usr/bin/tailscale --socket=/tmp/tailscaled.sock up --accept-dns=false --authkey=tskey-key", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + }, + }, + { + // SIGTERM before state is finished writing, should wait for + // consistent state before propagating SIGTERM to tailscaled. + Signal: ptr.To(unix.SIGTERM), + UpdateKubeSecret: map[string]string{ + "_machinekey": "foo", + "_profiles": "foo", + "profile-baff": "foo", + // Missing "_current-profile" key. + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + "_machinekey": "foo", + "_profiles": "foo", + "profile-baff": "foo", + }, + WantLog: "Waiting for tailscaled to finish writing state to Secret \"tailscale\"", + }, + { + // tailscaled has finished writing state, should propagate SIGTERM. + UpdateKubeSecret: map[string]string{ + "_current-profile": "foo", + }, + WantKubeSecret: map[string]string{ + "authkey": "tskey-key", + "_machinekey": "foo", + "_profiles": "foo", + "profile-baff": "foo", + "_current-profile": "foo", + }, + WantLog: "HTTP server at [::]:9002 closed", + WantExitCode: ptr.To(0), }, - }, { - Notify: runningNotify, }, - }, + } }, } - for _, test := range tests { - t.Run(test.Name, func(t *testing.T) { - lapi.Reset() - kube.Reset() - os.Remove(argFile) - os.Remove(runningSockPath) - resetFiles() + for name, test := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + env := newTestEnv(t) + tc := test(&env) - for k, v := range test.KubeSecret { - kube.SetSecret(k, v) + for k, v := range tc.KubeSecret { + env.kube.SetSecret(k, v) } - kube.SetPatching(!test.KubeDenyPatch) + env.kube.SetPatching(!tc.KubeDenyPatch) cmd := exec.Command(boot) cmd.Env = []string{ - fmt.Sprintf("PATH=%s/usr/bin:%s", d, os.Getenv("PATH")), - fmt.Sprintf("TS_TEST_RECORD_ARGS=%s", argFile), - fmt.Sprintf("TS_TEST_SOCKET=%s", lapi.Path), - fmt.Sprintf("TS_SOCKET=%s", runningSockPath), - fmt.Sprintf("TS_TEST_ONLY_ROOT=%s", d), + fmt.Sprintf("PATH=%s/usr/bin:%s", env.d, os.Getenv("PATH")), + fmt.Sprintf("TS_TEST_RECORD_ARGS=%s", env.argFile), + fmt.Sprintf("TS_TEST_SOCKET=%s", env.lapi.Path), + fmt.Sprintf("TS_SOCKET=%s", env.runningSockPath), + fmt.Sprintf("TS_TEST_ONLY_ROOT=%s", env.d), fmt.Sprint("TS_TEST_FAKE_NETFILTER=true"), } - for k, v := range test.Env { + for k, v := range tc.Env { cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%s", k, v)) } cbOut := &lockingBuffer{} @@ -734,6 +1073,7 @@ func TestContainerBoot(t *testing.T) { } }() cmd.Stderr = cbOut + cmd.Stdout = cbOut if err := cmd.Start(); err != nil { t.Fatalf("starting containerboot: %v", err) } @@ -743,37 +1083,47 @@ func TestContainerBoot(t *testing.T) { }() var wantCmds []string - for i, p := range test.Phases { - lapi.Notify(p.Notify) - if p.WantFatalLog != "" { + for i, p := range tc.Phases { + for k, v := range p.UpdateKubeSecret { + env.kube.SetSecret(k, v) + } + env.lapi.Notify(p.Notify) + if p.Signal != nil { + cmd.Process.Signal(*p.Signal) + } + if p.WantLog != "" { err := tstest.WaitFor(2*time.Second, func() error { - state, err := cmd.Process.Wait() - if err != nil { - return err - } - if state.ExitCode() != 1 { - return fmt.Errorf("process exited with code %d but wanted %d", state.ExitCode(), 1) - } - waitLogLine(t, time.Second, cbOut, p.WantFatalLog) + waitLogLine(t, time.Second, cbOut, p.WantLog) return nil }) if err != nil { t.Fatal(err) } + } + + if p.WantExitCode != nil { + state, err := cmd.Process.Wait() + if err != nil { + t.Fatal(err) + } + if state.ExitCode() != *p.WantExitCode { + t.Fatalf("phase %d: want exit code %d, got %d", i, *p.WantExitCode, state.ExitCode()) + } // Early test return, we don't expect the successful startup log message. return } + wantCmds = append(wantCmds, p.WantCmds...) - waitArgs(t, 2*time.Second, d, argFile, strings.Join(wantCmds, "\n")) + waitArgs(t, 2*time.Second, env.d, env.argFile, strings.Join(wantCmds, "\n")) err := tstest.WaitFor(2*time.Second, func() error { if p.WantKubeSecret != nil { - got := kube.Secret() + got := env.kube.Secret() if diff := cmp.Diff(got, p.WantKubeSecret); diff != "" { return fmt.Errorf("unexpected kube secret data (-got+want):\n%s", diff) } } else { - got := kube.Secret() + got := env.kube.Secret() if len(got) > 0 { return fmt.Errorf("kube secret unexpectedly not empty, got %#v", got) } @@ -785,7 +1135,7 @@ func TestContainerBoot(t *testing.T) { } err = tstest.WaitFor(2*time.Second, func() error { for path, want := range p.WantFiles { - gotBs, err := os.ReadFile(filepath.Join(d, path)) + gotBs, err := os.ReadFile(filepath.Join(env.d, path)) if err != nil { return fmt.Errorf("reading wanted file %q: %v", path, err) } @@ -796,10 +1146,32 @@ func TestContainerBoot(t *testing.T) { return nil }) if err != nil { - t.Fatal(err) + t.Fatalf("phase %d: %v", i, err) + } + + for url, want := range p.EndpointStatuses { + err := tstest.WaitFor(2*time.Second, func() error { + resp, err := http.Get(url) + if err != nil && want != -1 { + return fmt.Errorf("GET %s: %v", url, err) + } + if want > 0 && resp.StatusCode != want { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("GET %s, want %d, got %d\n%s", url, want, resp.StatusCode, string(body)) + } + + return nil + }) + if err != nil { + t.Fatalf("phase %d: %v", i, err) + } } } waitLogLine(t, 2*time.Second, cbOut, "Startup complete, waiting for shutdown signal") + if cmd.ProcessState != nil { + t.Fatalf("containerboot should be running but exited with exit code %d", cmd.ProcessState.ExitCode()) + } }) } } @@ -927,13 +1299,6 @@ func (l *localAPI) Close() { l.srv.Close() } -func (l *localAPI) Reset() { - l.Lock() - defer l.Unlock() - l.notify = nil - l.cond.Broadcast() -} - func (l *localAPI) Notify(n *ipn.Notify) { if n == nil { return @@ -955,6 +1320,12 @@ func (l *localAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { panic(fmt.Sprintf("unsupported method %q", r.Method)) } + case "/localapi/v0/usermetrics": + if r.Method != "GET" { + panic(fmt.Sprintf("unsupported method %q", r.Method)) + } + w.Write([]byte("fake metrics")) + return default: panic(fmt.Sprintf("unsupported path %q", r.URL.Path)) } @@ -1019,24 +1390,19 @@ func (k *kubeServer) SetPatching(canPatch bool) { k.canPatch = canPatch } -func (k *kubeServer) Reset() { - k.Lock() - defer k.Unlock() +func (k *kubeServer) Start(t *testing.T) { k.secret = map[string]string{} -} - -func (k *kubeServer) Start() error { root := filepath.Join(k.FSRoot, "var/run/secrets/kubernetes.io/serviceaccount") if err := os.MkdirAll(root, 0700); err != nil { - return err + t.Fatal(err) } if err := os.WriteFile(filepath.Join(root, "namespace"), []byte("default"), 0600); err != nil { - return err + t.Fatal(err) } if err := os.WriteFile(filepath.Join(root, "token"), []byte("bearer_token"), 0600); err != nil { - return err + t.Fatal(err) } k.srv = httptest.NewTLSServer(k) @@ -1045,13 +1411,11 @@ func (k *kubeServer) Start() error { var cert bytes.Buffer if err := pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: k.srv.Certificate().Raw}); err != nil { - return err + t.Fatal(err) } if err := os.WriteFile(filepath.Join(root, "ca.crt"), cert.Bytes(), 0600); err != nil { - return err + t.Fatal(err) } - - return nil } func (k *kubeServer) Close() { @@ -1100,6 +1464,7 @@ func (k *kubeServer) serveSecret(w http.ResponseWriter, r *http.Request) { http.Error(w, fmt.Sprintf("reading request body: %v", err), http.StatusInternalServerError) return } + defer r.Body.Close() switch r.Method { case "GET": @@ -1132,13 +1497,32 @@ func (k *kubeServer) serveSecret(w http.ResponseWriter, r *http.Request) { panic(fmt.Sprintf("json decode failed: %v. Body:\n\n%s", err, string(bs))) } for _, op := range req { - if op.Op != "remove" { + switch op.Op { + case "remove": + if !strings.HasPrefix(op.Path, "/data/") { + panic(fmt.Sprintf("unsupported json-patch path %q", op.Path)) + } + delete(k.secret, strings.TrimPrefix(op.Path, "/data/")) + case "replace": + path, ok := strings.CutPrefix(op.Path, "/data/") + if !ok { + panic(fmt.Sprintf("unsupported json-patch path %q", op.Path)) + } + req := make([]kubeclient.JSONPatch, 0) + if err := json.Unmarshal(bs, &req); err != nil { + panic(fmt.Sprintf("json decode failed: %v. Body:\n\n%s", err, string(bs))) + } + + for _, patch := range req { + val, ok := patch.Value.(string) + if !ok { + panic(fmt.Sprintf("unsupported json patch value %v: cannot be converted to string", patch.Value)) + } + k.secret[path] = val + } + default: panic(fmt.Sprintf("unsupported json-patch op %q", op.Op)) } - if !strings.HasPrefix(op.Path, "/data/") { - panic(fmt.Sprintf("unsupported json-patch path %q", op.Path)) - } - delete(k.secret, strings.TrimPrefix(op.Path, "/data/")) } case "application/strategic-merge-patch+json": req := struct { @@ -1154,6 +1538,135 @@ func (k *kubeServer) serveSecret(w http.ResponseWriter, r *http.Request) { panic(fmt.Sprintf("unknown content type %q", r.Header.Get("Content-Type"))) } default: - panic(fmt.Sprintf("unhandled HTTP method %q", r.Method)) + panic(fmt.Sprintf("unhandled HTTP request %s %s", r.Method, r.URL)) + } +} + +func mustBase64(t *testing.T, v any) string { + b := mustJSON(t, v) + s := base64.StdEncoding.WithPadding('=').EncodeToString(b) + return s +} + +func mustJSON(t *testing.T, v any) []byte { + b, err := json.Marshal(v) + if err != nil { + t.Fatalf("error converting %v to json: %v", v, err) + } + return b +} + +// egress services status given one named tailnet target specified by FQDN. As written by the proxy to its state Secret. +func egressSvcStatus(name, fqdn string) egressservices.Status { + return egressservices.Status{ + Services: map[string]*egressservices.ServiceStatus{ + name: { + TailnetTarget: egressservices.TailnetTarget{ + FQDN: fqdn, + }, + }, + }, + } +} + +// egress config given one named tailnet target specified by FQDN. +func egressSvcConfig(name, fqdn string) egressservices.Configs { + return egressservices.Configs{ + name: egressservices.Config{ + TailnetTarget: egressservices.TailnetTarget{ + FQDN: fqdn, + }, + }, + } +} + +// testEnv represents the environment needed for a single sub-test so that tests +// can run in parallel. +type testEnv struct { + kube *kubeServer // Fake kube server. + lapi *localAPI // Local TS API server. + d string // Temp dir for the specific test. + argFile string // File with commands test_tailscale{,d}.sh were invoked with. + runningSockPath string // Path to the running tailscaled socket. + localAddrPort int // Port for the containerboot HTTP server. + healthAddrPort int // Port for the (deprecated) containerboot health server. +} + +func newTestEnv(t *testing.T) testEnv { + d := t.TempDir() + + lapi := localAPI{FSRoot: d} + if err := lapi.Start(); err != nil { + t.Fatal(err) + } + t.Cleanup(lapi.Close) + + kube := kubeServer{FSRoot: d} + kube.Start(t) + t.Cleanup(kube.Close) + + tailscaledConf := &ipn.ConfigVAlpha{AuthKey: ptr.To("foo"), Version: "alpha0"} + serveConf := ipn.ServeConfig{TCP: map[uint16]*ipn.TCPPortHandler{80: {HTTP: true}}} + egressCfg := egressSvcConfig("foo", "foo.tailnetxyz.ts.net") + + dirs := []string{ + "var/lib", + "usr/bin", + "tmp", + "dev/net", + "proc/sys/net/ipv4", + "proc/sys/net/ipv6/conf/all", + "etc/tailscaled", + } + for _, path := range dirs { + if err := os.MkdirAll(filepath.Join(d, path), 0700); err != nil { + t.Fatal(err) + } + } + files := map[string][]byte{ + "usr/bin/tailscaled": fakeTailscaled, + "usr/bin/tailscale": fakeTailscale, + "usr/bin/iptables": fakeTailscale, + "usr/bin/ip6tables": fakeTailscale, + "dev/net/tun": []byte(""), + "proc/sys/net/ipv4/ip_forward": []byte("0"), + "proc/sys/net/ipv6/conf/all/forwarding": []byte("0"), + "etc/tailscaled/cap-95.hujson": mustJSON(t, tailscaledConf), + "etc/tailscaled/serve-config.json": mustJSON(t, serveConf), + filepath.Join("etc/tailscaled/", egressservices.KeyEgressServices): mustJSON(t, egressCfg), + filepath.Join("etc/tailscaled/", egressservices.KeyHEPPings): []byte("4"), + } + for path, content := range files { + // Making everything executable is a little weird, but the + // stuff that doesn't need to be executable doesn't care if we + // do make it executable. + if err := os.WriteFile(filepath.Join(d, path), content, 0700); err != nil { + t.Fatal(err) + } + } + + argFile := filepath.Join(d, "args") + runningSockPath := filepath.Join(d, "tmp/tailscaled.sock") + var localAddrPort, healthAddrPort int + for _, p := range []*int{&localAddrPort, &healthAddrPort} { + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatalf("Failed to open listener: %v", err) + } + if err := ln.Close(); err != nil { + t.Fatalf("Failed to close listener: %v", err) + } + port := ln.Addr().(*net.TCPAddr).Port + *p = port + } + + return testEnv{ + kube: &kube, + lapi: &lapi, + d: d, + argFile: argFile, + runningSockPath: runningSockPath, + localAddrPort: localAddrPort, + healthAddrPort: healthAddrPort, } } diff --git a/cmd/containerboot/metrics.go b/cmd/containerboot/metrics.go new file mode 100644 index 0000000000000..bbd050de6df26 --- /dev/null +++ b/cmd/containerboot/metrics.go @@ -0,0 +1,79 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package main + +import ( + "fmt" + "io" + "net/http" + + "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" +) + +// metrics is a simple metrics HTTP server, if enabled it forwards requests to +// the tailscaled's LocalAPI usermetrics endpoint at /localapi/v0/usermetrics. +type metrics struct { + debugEndpoint string + lc *local.Client +} + +func proxy(w http.ResponseWriter, r *http.Request, url string, do func(*http.Request) (*http.Response, error)) { + req, err := http.NewRequestWithContext(r.Context(), r.Method, url, r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("failed to construct request: %s", err), http.StatusInternalServerError) + return + } + req.Header = r.Header.Clone() + + resp, err := do(req) + if err != nil { + http.Error(w, fmt.Sprintf("failed to proxy request: %s", err), http.StatusInternalServerError) + return + } + defer resp.Body.Close() + + for key, val := range resp.Header { + for _, v := range val { + w.Header().Add(key, v) + } + } + w.WriteHeader(resp.StatusCode) + if _, err := io.Copy(w, resp.Body); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func (m *metrics) handleMetrics(w http.ResponseWriter, r *http.Request) { + localAPIURL := "http://" + apitype.LocalAPIHost + "/localapi/v0/usermetrics" + proxy(w, r, localAPIURL, m.lc.DoLocalRequest) +} + +func (m *metrics) handleDebug(w http.ResponseWriter, r *http.Request) { + if m.debugEndpoint == "" { + http.Error(w, "debug endpoint not configured", http.StatusNotFound) + return + } + + debugURL := "http://" + m.debugEndpoint + r.URL.Path + proxy(w, r, debugURL, http.DefaultClient.Do) +} + +// registerMetricsHandlers registers a simple HTTP metrics handler at /metrics, forwarding +// requests to tailscaled's /localapi/v0/usermetrics API. +// +// In 1.78.x and 1.80.x, it also proxies debug paths to tailscaled's debug +// endpoint if configured to ease migration for a breaking change serving user +// metrics instead of debug metrics on the "metrics" port. +func registerMetricsHandlers(mux *http.ServeMux, lc *local.Client, debugAddrPort string) { + m := &metrics{ + lc: lc, + debugEndpoint: debugAddrPort, + } + + mux.HandleFunc("GET /metrics", m.handleMetrics) + mux.HandleFunc("/debug/", m.handleDebug) // TODO(tomhjp): Remove for 1.82.0 release. +} diff --git a/cmd/containerboot/serve.go b/cmd/containerboot/serve.go index 6c22b3eeb651e..37fd497779c75 100644 --- a/cmd/containerboot/serve.go +++ b/cmd/containerboot/serve.go @@ -17,8 +17,10 @@ import ( "time" "github.com/fsnotify/fsnotify" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/ipn" + "tailscale.com/kube/kubetypes" + "tailscale.com/types/netmap" ) // watchServeConfigChanges watches path for changes, and when it sees one, reads @@ -26,27 +28,36 @@ import ( // applies it to lc. It exits when ctx is canceled. cdChanged is a channel that // is written to when the certDomain changes, causing the serve config to be // re-read and applied. -func watchServeConfigChanges(ctx context.Context, path string, cdChanged <-chan bool, certDomainAtomic *atomic.Pointer[string], lc *tailscale.LocalClient) { +func watchServeConfigChanges(ctx context.Context, cdChanged <-chan bool, certDomainAtomic *atomic.Pointer[string], lc *local.Client, kc *kubeClient, cfg *settings) { if certDomainAtomic == nil { - panic("cd must not be nil") + panic("certDomainAtomic must not be nil") } + var tickChan <-chan time.Time var eventChan <-chan fsnotify.Event if w, err := fsnotify.NewWatcher(); err != nil { - log.Printf("failed to create fsnotify watcher, timer-only mode: %v", err) + // Creating a new fsnotify watcher would fail for example if inotify was not able to create a new file descriptor. + // See https://github.com/tailscale/tailscale/issues/15081 + log.Printf("serve proxy: failed to create fsnotify watcher, timer-only mode: %v", err) ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() tickChan = ticker.C } else { defer w.Close() - if err := w.Add(filepath.Dir(path)); err != nil { - log.Fatalf("failed to add fsnotify watch: %v", err) + if err := w.Add(filepath.Dir(cfg.ServeConfigPath)); err != nil { + log.Fatalf("serve proxy: failed to add fsnotify watch: %v", err) } eventChan = w.Events } var certDomain string var prevServeConfig *ipn.ServeConfig + var cm certManager + if cfg.CertShareMode == "rw" { + cm = certManager{ + lc: lc, + } + } for { select { case <-ctx.Done(): @@ -59,22 +70,77 @@ func watchServeConfigChanges(ctx context.Context, path string, cdChanged <-chan // k8s handles these mounts. So just re-read the file and apply it // if it's changed. } - if certDomain == "" { - continue - } - sc, err := readServeConfig(path, certDomain) + sc, err := readServeConfig(cfg.ServeConfigPath, certDomain) if err != nil { - log.Fatalf("failed to read serve config: %v", err) + log.Fatalf("serve proxy: failed to read serve config: %v", err) + } + if sc == nil { + log.Printf("serve proxy: no serve config at %q, skipping", cfg.ServeConfigPath) + continue } if prevServeConfig != nil && reflect.DeepEqual(sc, prevServeConfig) { continue } - log.Printf("Applying serve config") - if err := lc.SetServeConfig(ctx, sc); err != nil { - log.Fatalf("failed to set serve config: %v", err) + if err := updateServeConfig(ctx, sc, certDomain, lc); err != nil { + log.Fatalf("serve proxy: error updating serve config: %v", err) + } + if kc != nil && kc.canPatch { + if err := kc.storeHTTPSEndpoint(ctx, certDomain); err != nil { + log.Fatalf("serve proxy: error storing HTTPS endpoint: %v", err) + } } prevServeConfig = sc + if cfg.CertShareMode != "rw" { + continue + } + if err := cm.ensureCertLoops(ctx, sc); err != nil { + log.Fatalf("serve proxy: error ensuring cert loops: %v", err) + } + } +} + +func certDomainFromNetmap(nm *netmap.NetworkMap) string { + if len(nm.DNS.CertDomains) == 0 { + return "" + } + return nm.DNS.CertDomains[0] +} + +// localClient is a subset of [local.Client] that can be mocked for testing. +type localClient interface { + SetServeConfig(context.Context, *ipn.ServeConfig) error + CertPair(context.Context, string) ([]byte, []byte, error) +} + +func updateServeConfig(ctx context.Context, sc *ipn.ServeConfig, certDomain string, lc localClient) error { + if !isValidHTTPSConfig(certDomain, sc) { + return nil } + log.Printf("serve proxy: applying serve config") + return lc.SetServeConfig(ctx, sc) +} + +func isValidHTTPSConfig(certDomain string, sc *ipn.ServeConfig) bool { + if certDomain == kubetypes.ValueNoHTTPS && hasHTTPSEndpoint(sc) { + log.Printf( + `serve proxy: this node is configured as a proxy that exposes an HTTPS endpoint to tailnet, + (perhaps a Kubernetes operator Ingress proxy) but it is not able to issue TLS certs, so this will likely not work. + To make it work, ensure that HTTPS is enabled for your tailnet, see https://tailscale.com/kb/1153/enabling-https for more details.`) + return false + } + return true +} + +func hasHTTPSEndpoint(cfg *ipn.ServeConfig) bool { + if cfg == nil { + return false + } + for _, tcpCfg := range cfg.TCP { + if tcpCfg.HTTPS { + return true + } + } + return false } // readServeConfig reads the ipn.ServeConfig from path, replacing @@ -85,8 +151,17 @@ func readServeConfig(path, certDomain string) (*ipn.ServeConfig, error) { } j, err := os.ReadFile(path) if err != nil { + if os.IsNotExist(err) { + return nil, nil + } return nil, err } + // Serve config can be provided by users as well as the Kubernetes Operator (for its proxies). User-provided + // config could be empty for reasons. + if len(j) == 0 { + log.Printf("serve proxy: serve config file is empty, skipping") + return nil, nil + } j = bytes.ReplaceAll(j, []byte("${TS_CERT_DOMAIN}"), []byte(certDomain)) var sc ipn.ServeConfig if err := json.Unmarshal(j, &sc); err != nil { diff --git a/cmd/containerboot/serve_test.go b/cmd/containerboot/serve_test.go new file mode 100644 index 0000000000000..fc18f254dad05 --- /dev/null +++ b/cmd/containerboot/serve_test.go @@ -0,0 +1,271 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package main + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/client/local" + "tailscale.com/ipn" + "tailscale.com/kube/kubetypes" +) + +func TestUpdateServeConfig(t *testing.T) { + tests := []struct { + name string + sc *ipn.ServeConfig + certDomain string + wantCall bool + }{ + { + name: "no_https_no_cert_domain", + sc: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + }, + }, + certDomain: kubetypes.ValueNoHTTPS, // tailnet has HTTPS disabled + wantCall: true, // should set serve config as it doesn't have HTTPS endpoints + }, + { + name: "https_with_cert_domain", + sc: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: {HTTPS: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "${TS_CERT_DOMAIN}:443": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://10.0.1.100:8080"}, + }, + }, + }, + }, + certDomain: "test-node.tailnet.ts.net", + wantCall: true, + }, + { + name: "https_without_cert_domain", + sc: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: {HTTPS: true}, + }, + }, + certDomain: kubetypes.ValueNoHTTPS, + wantCall: false, // incorrect configuration- should not set serve config + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fakeLC := &fakeLocalClient{} + err := updateServeConfig(context.Background(), tt.sc, tt.certDomain, fakeLC) + if err != nil { + t.Errorf("updateServeConfig() error = %v", err) + } + if fakeLC.setServeCalled != tt.wantCall { + t.Errorf("SetServeConfig() called = %v, want %v", fakeLC.setServeCalled, tt.wantCall) + } + }) + } +} + +func TestReadServeConfig(t *testing.T) { + tests := []struct { + name string + gotSC string + certDomain string + wantSC *ipn.ServeConfig + wantErr bool + }{ + { + name: "empty_file", + }, + { + name: "valid_config_with_cert_domain_placeholder", + gotSC: `{ + "TCP": { + "443": { + "HTTPS": true + } + }, + "Web": { + "${TS_CERT_DOMAIN}:443": { + "Handlers": { + "/api": { + "Proxy": "https://10.2.3.4/api" + }}}}}`, + certDomain: "example.com", + wantSC: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: { + HTTPS: true, + }, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + ipn.HostPort("example.com:443"): { + Handlers: map[string]*ipn.HTTPHandler{ + "/api": { + Proxy: "https://10.2.3.4/api", + }, + }, + }, + }, + }, + }, + { + name: "valid_config_for_http_proxy", + gotSC: `{ + "TCP": { + "80": { + "HTTP": true + } + }}`, + wantSC: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: { + HTTP: true, + }, + }, + }, + }, + { + name: "config_without_cert_domain", + gotSC: `{ + "TCP": { + "443": { + "HTTPS": true + } + }, + "Web": { + "localhost:443": { + "Handlers": { + "/api": { + "Proxy": "https://10.2.3.4/api" + }}}}}`, + certDomain: "", + wantErr: false, + wantSC: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: { + HTTPS: true, + }, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + ipn.HostPort("localhost:443"): { + Handlers: map[string]*ipn.HTTPHandler{ + "/api": { + Proxy: "https://10.2.3.4/api", + }, + }, + }, + }, + }, + }, + { + name: "invalid_json", + gotSC: "invalid json", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "serve-config.json") + if err := os.WriteFile(path, []byte(tt.gotSC), 0644); err != nil { + t.Fatal(err) + } + + got, err := readServeConfig(path, tt.certDomain) + if (err != nil) != tt.wantErr { + t.Errorf("readServeConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !cmp.Equal(got, tt.wantSC) { + t.Errorf("readServeConfig() diff (-got +want):\n%s", cmp.Diff(got, tt.wantSC)) + } + }) + } +} + +type fakeLocalClient struct { + *local.Client + setServeCalled bool +} + +func (m *fakeLocalClient) SetServeConfig(ctx context.Context, cfg *ipn.ServeConfig) error { + m.setServeCalled = true + return nil +} + +func (m *fakeLocalClient) CertPair(ctx context.Context, domain string) (certPEM, keyPEM []byte, err error) { + return nil, nil, nil +} + +func TestHasHTTPSEndpoint(t *testing.T) { + tests := []struct { + name string + cfg *ipn.ServeConfig + want bool + }{ + { + name: "nil_config", + cfg: nil, + want: false, + }, + { + name: "empty_config", + cfg: &ipn.ServeConfig{}, + want: false, + }, + { + name: "no_https_endpoints", + cfg: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: { + HTTPS: false, + }, + }, + }, + want: false, + }, + { + name: "has_https_endpoint", + cfg: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: { + HTTPS: true, + }, + }, + }, + want: true, + }, + { + name: "mixed_endpoints", + cfg: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTPS: false}, + 443: {HTTPS: true}, + }, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := hasHTTPSEndpoint(tt.cfg) + if got != tt.want { + t.Errorf("hasHTTPSEndpoint() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cmd/containerboot/services.go b/cmd/containerboot/services.go index 4da7286b7ca0a..6079128c02b19 100644 --- a/cmd/containerboot/services.go +++ b/cmd/containerboot/services.go @@ -7,565 +7,57 @@ package main import ( "context" - "encoding/json" - "errors" "fmt" "log" - "net/netip" - "os" - "path/filepath" - "reflect" - "strings" "time" - "github.com/fsnotify/fsnotify" + "tailscale.com/client/local" "tailscale.com/ipn" - "tailscale.com/kube/egressservices" - "tailscale.com/kube/kubeclient" - "tailscale.com/tailcfg" - "tailscale.com/util/linuxfw" - "tailscale.com/util/mak" ) -const tailscaleTunInterface = "tailscale0" - -// This file contains functionality to run containerboot as a proxy that can -// route cluster traffic to one or more tailnet targets, based on portmapping -// rules read from a configfile. Currently (9/2024) this is only used for the -// Kubernetes operator egress proxies. - -// egressProxy knows how to configure firewall rules to route cluster traffic to -// one or more tailnet services. -type egressProxy struct { - cfgPath string // path to egress service config file - - nfr linuxfw.NetfilterRunner // never nil - - kc kubeclient.Client // never nil - stateSecret string // name of the kube state Secret - - netmapChan chan ipn.Notify // chan to receive netmap updates on - - podIPv4 string // never empty string, currently only IPv4 is supported - - // tailnetFQDNs is the egress service FQDN to tailnet IP mappings that - // were last used to configure firewall rules for this proxy. - // TODO(irbekrm): target addresses are also stored in the state Secret. - // Evaluate whether we should retrieve them from there and not store in - // memory at all. - targetFQDNs map[string][]netip.Prefix - - // used to configure firewall rules. - tailnetAddrs []netip.Prefix -} - -// run configures egress proxy firewall rules and ensures that the firewall rules are reconfigured when: -// - the mounted egress config has changed -// - the proxy's tailnet IP addresses have changed -// - tailnet IPs have changed for any backend targets specified by tailnet FQDN -func (ep *egressProxy) run(ctx context.Context, n ipn.Notify) error { - var tickChan <-chan time.Time - var eventChan <-chan fsnotify.Event - // TODO (irbekrm): take a look if this can be pulled into a single func - // shared with serve config loader. - if w, err := fsnotify.NewWatcher(); err != nil { - log.Printf("failed to create fsnotify watcher, timer-only mode: %v", err) - ticker := time.NewTicker(5 * time.Second) - defer ticker.Stop() - tickChan = ticker.C - } else { - defer w.Close() - if err := w.Add(filepath.Dir(ep.cfgPath)); err != nil { - return fmt.Errorf("failed to add fsnotify watch: %w", err) - } - eventChan = w.Events - } - - if err := ep.sync(ctx, n); err != nil { - return err - } - for { - var err error - select { - case <-ctx.Done(): - return nil - case <-tickChan: - err = ep.sync(ctx, n) - case <-eventChan: - log.Printf("config file change detected, ensuring firewall config is up to date...") - err = ep.sync(ctx, n) - case n = <-ep.netmapChan: - shouldResync := ep.shouldResync(n) - if shouldResync { - log.Printf("netmap change detected, ensuring firewall config is up to date...") - err = ep.sync(ctx, n) - } - } - if err != nil { - return fmt.Errorf("error syncing egress service config: %w", err) - } - } -} - -// sync triggers an egress proxy config resync. The resync calculates the diff between config and status to determine if -// any firewall rules need to be updated. Currently using status in state Secret as a reference for what is the current -// firewall configuration is good enough because - the status is keyed by the Pod IP - we crash the Pod on errors such -// as failed firewall update -func (ep *egressProxy) sync(ctx context.Context, n ipn.Notify) error { - cfgs, err := ep.getConfigs() - if err != nil { - return fmt.Errorf("error retrieving egress service configs: %w", err) - } - status, err := ep.getStatus(ctx) - if err != nil { - return fmt.Errorf("error retrieving current egress proxy status: %w", err) - } - newStatus, err := ep.syncEgressConfigs(cfgs, status, n) +// ensureServicesNotAdvertised is a function that gets called on containerboot +// termination and ensures that any currently advertised VIPServices get +// unadvertised to give clients time to switch to another node before this one +// is shut down. +func ensureServicesNotAdvertised(ctx context.Context, lc *local.Client) error { + prefs, err := lc.GetPrefs(ctx) if err != nil { - return fmt.Errorf("error syncing egress service configs: %w", err) - } - if !servicesStatusIsEqual(newStatus, status) { - if err := ep.setStatus(ctx, newStatus, n); err != nil { - return fmt.Errorf("error setting egress proxy status: %w", err) - } - } - return nil -} - -// addrsHaveChanged returns true if the provided netmap update contains tailnet address change for this proxy node. -// Netmap must not be nil. -func (ep *egressProxy) addrsHaveChanged(n ipn.Notify) bool { - return !reflect.DeepEqual(ep.tailnetAddrs, n.NetMap.SelfNode.Addresses()) -} - -// syncEgressConfigs adds and deletes firewall rules to match the desired -// configuration. It uses the provided status to determine what is currently -// applied and updates the status after a successful sync. -func (ep *egressProxy) syncEgressConfigs(cfgs *egressservices.Configs, status *egressservices.Status, n ipn.Notify) (*egressservices.Status, error) { - if !(wantsServicesConfigured(cfgs) || hasServicesConfigured(status)) { - return nil, nil - } - - // Delete unnecessary services. - if err := ep.deleteUnnecessaryServices(cfgs, status); err != nil { - return nil, fmt.Errorf("error deleting services: %w", err) - - } - newStatus := &egressservices.Status{} - if !wantsServicesConfigured(cfgs) { - return newStatus, nil + return fmt.Errorf("error getting prefs: %w", err) } - - // Add new services, update rules for any that have changed. - rulesPerSvcToAdd := make(map[string][]rule, 0) - rulesPerSvcToDelete := make(map[string][]rule, 0) - for svcName, cfg := range *cfgs { - tailnetTargetIPs, err := ep.tailnetTargetIPsForSvc(cfg, n) - if err != nil { - return nil, fmt.Errorf("error determining tailnet target IPs: %w", err) - } - rulesToAdd, rulesToDelete, err := updatesForCfg(svcName, cfg, status, tailnetTargetIPs) - if err != nil { - return nil, fmt.Errorf("error validating service changes: %v", err) - } - log.Printf("syncegressservices: looking at svc %s rulesToAdd %d rulesToDelete %d", svcName, len(rulesToAdd), len(rulesToDelete)) - if len(rulesToAdd) != 0 { - mak.Set(&rulesPerSvcToAdd, svcName, rulesToAdd) - } - if len(rulesToDelete) != 0 { - mak.Set(&rulesPerSvcToDelete, svcName, rulesToDelete) - } - if len(rulesToAdd) != 0 || ep.addrsHaveChanged(n) { - // For each tailnet target, set up SNAT from the local tailnet device address of the matching - // family. - for _, t := range tailnetTargetIPs { - var local netip.Addr - for _, pfx := range n.NetMap.SelfNode.Addresses().All() { - if !pfx.IsSingleIP() { - continue - } - if pfx.Addr().Is4() != t.Is4() { - continue - } - local = pfx.Addr() - break - } - if !local.IsValid() { - return nil, fmt.Errorf("no valid local IP: %v", local) - } - if err := ep.nfr.EnsureSNATForDst(local, t); err != nil { - return nil, fmt.Errorf("error setting up SNAT rule: %w", err) - } - } - } - // Update the status. Status will be written back to the state Secret by the caller. - mak.Set(&newStatus.Services, svcName, &egressservices.ServiceStatus{TailnetTargetIPs: tailnetTargetIPs, TailnetTarget: cfg.TailnetTarget, Ports: cfg.Ports}) - } - - // Actually apply the firewall rules. - if err := ensureRulesAdded(rulesPerSvcToAdd, ep.nfr); err != nil { - return nil, fmt.Errorf("error adding rules: %w", err) - } - if err := ensureRulesDeleted(rulesPerSvcToDelete, ep.nfr); err != nil { - return nil, fmt.Errorf("error deleting rules: %w", err) - } - - return newStatus, nil -} - -// updatesForCfg calculates any rules that need to be added or deleted for an individucal egress service config. -func updatesForCfg(svcName string, cfg egressservices.Config, status *egressservices.Status, tailnetTargetIPs []netip.Addr) ([]rule, []rule, error) { - rulesToAdd := make([]rule, 0) - rulesToDelete := make([]rule, 0) - currentConfig, ok := lookupCurrentConfig(svcName, status) - - // If no rules for service are present yet, add them all. - if !ok { - for _, t := range tailnetTargetIPs { - for ports := range cfg.Ports { - log.Printf("syncegressservices: svc %s adding port %v", svcName, ports) - rulesToAdd = append(rulesToAdd, rule{tailnetPort: ports.TargetPort, containerPort: ports.MatchPort, protocol: ports.Protocol, tailnetIP: t}) - } - } - return rulesToAdd, rulesToDelete, nil - } - - // If there are no backend targets available, delete any currently configured rules. - if len(tailnetTargetIPs) == 0 { - log.Printf("tailnet target for egress service %s does not have any backend addresses, deleting all rules", svcName) - for _, ip := range currentConfig.TailnetTargetIPs { - for ports := range currentConfig.Ports { - rulesToDelete = append(rulesToAdd, rule{tailnetPort: ports.TargetPort, containerPort: ports.MatchPort, protocol: ports.Protocol, tailnetIP: ip}) - } - } - return rulesToAdd, rulesToDelete, nil - } - - // If there are rules present for backend targets that no longer match, delete them. - for _, ip := range currentConfig.TailnetTargetIPs { - var found bool - for _, wantsIP := range tailnetTargetIPs { - if reflect.DeepEqual(ip, wantsIP) { - found = true - break - } - } - if !found { - for ports := range currentConfig.Ports { - rulesToDelete = append(rulesToDelete, rule{tailnetPort: ports.TargetPort, containerPort: ports.MatchPort, protocol: ports.Protocol, tailnetIP: ip}) - } - } - } - - // Sync rules for the currently wanted backend targets. - for _, ip := range tailnetTargetIPs { - - // If the backend target is not yet present in status, add all rules. - var found bool - for _, gotIP := range currentConfig.TailnetTargetIPs { - if reflect.DeepEqual(ip, gotIP) { - found = true - break - } - } - if !found { - for ports := range cfg.Ports { - rulesToAdd = append(rulesToAdd, rule{tailnetPort: ports.TargetPort, containerPort: ports.MatchPort, protocol: ports.Protocol, tailnetIP: ip}) - } - continue - } - - // If the backend target is present in status, check that the - // currently applied rules are up to date. - - // Delete any current portmappings that are no longer present in config. - for port := range currentConfig.Ports { - if _, ok := cfg.Ports[port]; ok { - continue - } - rulesToDelete = append(rulesToDelete, rule{tailnetPort: port.TargetPort, containerPort: port.MatchPort, protocol: port.Protocol, tailnetIP: ip}) - } - - // Add any new portmappings. - for port := range cfg.Ports { - if _, ok := currentConfig.Ports[port]; ok { - continue - } - rulesToAdd = append(rulesToAdd, rule{tailnetPort: port.TargetPort, containerPort: port.MatchPort, protocol: port.Protocol, tailnetIP: ip}) - } + if len(prefs.AdvertiseServices) == 0 { + return nil } - return rulesToAdd, rulesToDelete, nil -} -// deleteUnneccessaryServices ensure that any services found on status, but not -// present in config are deleted. -func (ep *egressProxy) deleteUnnecessaryServices(cfgs *egressservices.Configs, status *egressservices.Status) error { - if !hasServicesConfigured(status) { + log.Printf("unadvertising services: %v", prefs.AdvertiseServices) + if _, err := lc.EditPrefs(ctx, &ipn.MaskedPrefs{ + AdvertiseServicesSet: true, + Prefs: ipn.Prefs{ + AdvertiseServices: nil, + }, + }); err != nil { + // EditPrefs only returns an error if it fails _set_ its local prefs. + // If it fails to _persist_ the prefs in state, we don't get an error + // and we continue waiting below, as control will failover as usual. + return fmt.Errorf("error setting prefs AdvertiseServices: %w", err) + } + + // Services use the same (failover XOR regional routing) mechanism that + // HA subnet routers use. Unfortunately we don't yet get a reliable signal + // from control that it's responded to our unadvertisement, so the best we + // can do is wait for 20 seconds, where 15s is the approximate maximum time + // it should take for control to choose a new primary, and 5s is for buffer. + // + // Note: There is no guarantee that clients have been _informed_ of the new + // primary no matter how long we wait. We would need a mechanism to await + // netmap updates for peers to know for sure. + // + // See https://tailscale.com/kb/1115/high-availability for more details. + // TODO(tomhjp): Wait for a netmap update instead of sleeping when control + // supports that. + select { + case <-ctx.Done(): return nil - } - if !wantsServicesConfigured(cfgs) { - for svcName, svc := range status.Services { - log.Printf("service %s is no longer required, deleting", svcName) - if err := ensureServiceDeleted(svcName, svc, ep.nfr); err != nil { - return fmt.Errorf("error deleting service %s: %w", svcName, err) - } - } + case <-time.After(20 * time.Second): return nil } - - for svcName, svc := range status.Services { - if _, ok := (*cfgs)[svcName]; !ok { - log.Printf("service %s is no longer required, deleting", svcName) - if err := ensureServiceDeleted(svcName, svc, ep.nfr); err != nil { - return fmt.Errorf("error deleting service %s: %w", svcName, err) - } - // TODO (irbekrm): also delete the SNAT rule here - } - } - return nil -} - -// getConfigs gets the mounted egress service configuration. -func (ep *egressProxy) getConfigs() (*egressservices.Configs, error) { - j, err := os.ReadFile(ep.cfgPath) - if os.IsNotExist(err) { - return nil, nil - } - if err != nil { - return nil, err - } - if len(j) == 0 || string(j) == "" { - return nil, nil - } - cfg := &egressservices.Configs{} - if err := json.Unmarshal(j, &cfg); err != nil { - return nil, err - } - return cfg, nil -} - -// getStatus gets the current status of the configured firewall. The current -// status is stored in state Secret. Returns nil status if no status that -// applies to the current proxy Pod was found. Uses the Pod IP to determine if a -// status found in the state Secret applies to this proxy Pod. -func (ep *egressProxy) getStatus(ctx context.Context) (*egressservices.Status, error) { - secret, err := ep.kc.GetSecret(ctx, ep.stateSecret) - if err != nil { - return nil, fmt.Errorf("error retrieving state secret: %w", err) - } - status := &egressservices.Status{} - raw, ok := secret.Data[egressservices.KeyEgressServices] - if !ok { - return nil, nil - } - if err := json.Unmarshal([]byte(raw), status); err != nil { - return nil, fmt.Errorf("error unmarshalling previous config: %w", err) - } - if reflect.DeepEqual(status.PodIPv4, ep.podIPv4) { - return status, nil - } - return nil, nil -} - -// setStatus writes egress proxy's currently configured firewall to the state -// Secret and updates proxy's tailnet addresses. -func (ep *egressProxy) setStatus(ctx context.Context, status *egressservices.Status, n ipn.Notify) error { - // Pod IP is used to determine if a stored status applies to THIS proxy Pod. - if status == nil { - status = &egressservices.Status{} - } - status.PodIPv4 = ep.podIPv4 - secret, err := ep.kc.GetSecret(ctx, ep.stateSecret) - if err != nil { - return fmt.Errorf("error retrieving state Secret: %w", err) - } - bs, err := json.Marshal(status) - if err != nil { - return fmt.Errorf("error marshalling service config: %w", err) - } - secret.Data[egressservices.KeyEgressServices] = bs - patch := kubeclient.JSONPatch{ - Op: "replace", - Path: fmt.Sprintf("/data/%s", egressservices.KeyEgressServices), - Value: bs, - } - if err := ep.kc.JSONPatchSecret(ctx, ep.stateSecret, []kubeclient.JSONPatch{patch}); err != nil { - return fmt.Errorf("error patching state Secret: %w", err) - } - ep.tailnetAddrs = n.NetMap.SelfNode.Addresses().AsSlice() - return nil -} - -// tailnetTargetIPsForSvc returns the tailnet IPs to which traffic for this -// egress service should be proxied. The egress service can be configured by IP -// or by FQDN. If it's configured by IP, just return that. If it's configured by -// FQDN, resolve the FQDN and return the resolved IPs. It checks if the -// netfilter runner supports IPv6 NAT and skips any IPv6 addresses if it -// doesn't. -func (ep *egressProxy) tailnetTargetIPsForSvc(svc egressservices.Config, n ipn.Notify) (addrs []netip.Addr, err error) { - if svc.TailnetTarget.IP != "" { - addr, err := netip.ParseAddr(svc.TailnetTarget.IP) - if err != nil { - return nil, fmt.Errorf("error parsing tailnet target IP: %w", err) - } - if addr.Is6() && !ep.nfr.HasIPV6NAT() { - log.Printf("tailnet target is an IPv6 address, but this host does not support IPv6 in the chosen firewall mode. This will probably not work.") - return addrs, nil - } - return []netip.Addr{addr}, nil - } - - if svc.TailnetTarget.FQDN == "" { - return nil, errors.New("unexpected egress service config- neither tailnet target IP nor FQDN is set") - } - if n.NetMap == nil { - log.Printf("netmap is not available, unable to determine backend addresses for %s", svc.TailnetTarget.FQDN) - return addrs, nil - } - var ( - node tailcfg.NodeView - nodeFound bool - ) - for _, nn := range n.NetMap.Peers { - if equalFQDNs(nn.Name(), svc.TailnetTarget.FQDN) { - node = nn - nodeFound = true - break - } - } - if nodeFound { - for _, addr := range node.Addresses().AsSlice() { - if addr.Addr().Is6() && !ep.nfr.HasIPV6NAT() { - log.Printf("tailnet target %v is an IPv6 address, but this host does not support IPv6 in the chosen firewall mode, skipping.", addr.Addr().String()) - continue - } - addrs = append(addrs, addr.Addr()) - } - // Egress target endpoints configured via FQDN are stored, so - // that we can determine if a netmap update should trigger a - // resync. - mak.Set(&ep.targetFQDNs, svc.TailnetTarget.FQDN, node.Addresses().AsSlice()) - } - return addrs, nil -} - -// shouldResync parses netmap update and returns true if the update contains -// changes for which the egress proxy's firewall should be reconfigured. -func (ep *egressProxy) shouldResync(n ipn.Notify) bool { - if n.NetMap == nil { - return false - } - - // If proxy's tailnet addresses have changed, resync. - if !reflect.DeepEqual(n.NetMap.SelfNode.Addresses().AsSlice(), ep.tailnetAddrs) { - log.Printf("node addresses have changed, trigger egress config resync") - ep.tailnetAddrs = n.NetMap.SelfNode.Addresses().AsSlice() - return true - } - - // If the IPs for any of the egress services configured via FQDN have - // changed, resync. - for fqdn, ips := range ep.targetFQDNs { - for _, nn := range n.NetMap.Peers { - if equalFQDNs(nn.Name(), fqdn) { - if !reflect.DeepEqual(ips, nn.Addresses().AsSlice()) { - log.Printf("backend addresses for egress target %q have changed old IPs %v, new IPs %v trigger egress config resync", nn.Name(), ips, nn.Addresses().AsSlice()) - } - return true - } - } - } - return false -} - -// ensureServiceDeleted ensures that any rules for an egress service are removed -// from the firewall configuration. -func ensureServiceDeleted(svcName string, svc *egressservices.ServiceStatus, nfr linuxfw.NetfilterRunner) error { - - // Note that the portmap is needed for iptables based firewall only. - // Nftables group rules for a service in a chain, so there is no need to - // specify individual portmapping based rules. - pms := make([]linuxfw.PortMap, 0) - for pm := range svc.Ports { - pms = append(pms, linuxfw.PortMap{MatchPort: pm.MatchPort, TargetPort: pm.TargetPort, Protocol: pm.Protocol}) - } - - if err := nfr.DeleteSvc(svcName, tailscaleTunInterface, svc.TailnetTargetIPs, pms); err != nil { - return fmt.Errorf("error deleting service %s: %w", svcName, err) - } - return nil -} - -// ensureRulesAdded ensures that all portmapping rules are added to the firewall -// configuration. For any rules that already exist, calling this function is a -// no-op. In case of nftables, a service consists of one or two (one per IP -// family) chains that conain the portmapping rules for the service and the -// chains as needed when this function is called. -func ensureRulesAdded(rulesPerSvc map[string][]rule, nfr linuxfw.NetfilterRunner) error { - for svc, rules := range rulesPerSvc { - for _, rule := range rules { - log.Printf("ensureRulesAdded svc %s tailnetTarget %s container port %d tailnet port %d protocol %s", svc, rule.tailnetIP, rule.containerPort, rule.tailnetPort, rule.protocol) - if err := nfr.EnsurePortMapRuleForSvc(svc, tailscaleTunInterface, rule.tailnetIP, linuxfw.PortMap{MatchPort: rule.containerPort, TargetPort: rule.tailnetPort, Protocol: rule.protocol}); err != nil { - return fmt.Errorf("error ensuring rule: %w", err) - } - } - } - return nil -} - -// ensureRulesDeleted ensures that the given rules are deleted from the firewall -// configuration. For any rules that do not exist, calling this funcion is a -// no-op. -func ensureRulesDeleted(rulesPerSvc map[string][]rule, nfr linuxfw.NetfilterRunner) error { - for svc, rules := range rulesPerSvc { - for _, rule := range rules { - log.Printf("ensureRulesDeleted svc %s tailnetTarget %s container port %d tailnet port %d protocol %s", svc, rule.tailnetIP, rule.containerPort, rule.tailnetPort, rule.protocol) - if err := nfr.DeletePortMapRuleForSvc(svc, tailscaleTunInterface, rule.tailnetIP, linuxfw.PortMap{MatchPort: rule.containerPort, TargetPort: rule.tailnetPort, Protocol: rule.protocol}); err != nil { - return fmt.Errorf("error deleting rule: %w", err) - } - } - } - return nil -} - -func lookupCurrentConfig(svcName string, status *egressservices.Status) (*egressservices.ServiceStatus, bool) { - if status == nil || len(status.Services) == 0 { - return nil, false - } - c, ok := status.Services[svcName] - return c, ok -} - -func equalFQDNs(s, s1 string) bool { - s, _ = strings.CutSuffix(s, ".") - s1, _ = strings.CutSuffix(s1, ".") - return strings.EqualFold(s, s1) -} - -// rule contains configuration for an egress proxy firewall rule. -type rule struct { - containerPort uint16 // port to match incoming traffic - tailnetPort uint16 // tailnet service port - tailnetIP netip.Addr // tailnet service IP - protocol string -} - -func wantsServicesConfigured(cfgs *egressservices.Configs) bool { - return cfgs != nil && len(*cfgs) != 0 -} - -func hasServicesConfigured(status *egressservices.Status) bool { - return status != nil && len(status.Services) != 0 -} - -func servicesStatusIsEqual(st, st1 *egressservices.Status) bool { - if st == nil && st1 == nil { - return true - } - if st == nil || st1 == nil { - return false - } - st.PodIPv4 = "" - st1.PodIPv4 = "" - return reflect.DeepEqual(*st, *st1) } diff --git a/cmd/containerboot/settings.go b/cmd/containerboot/settings.go index 742713e7700de..5a8be9036b3ca 100644 --- a/cmd/containerboot/settings.go +++ b/cmd/containerboot/settings.go @@ -64,11 +64,23 @@ type settings struct { // when setting up rules to proxy cluster traffic to cluster ingress // target. // Deprecated: use PodIPv4, PodIPv6 instead to support dual stack clusters - PodIP string - PodIPv4 string - PodIPv6 string - HealthCheckAddrPort string - EgressSvcsCfgPath string + PodIP string + PodIPv4 string + PodIPv6 string + PodUID string + HealthCheckAddrPort string + LocalAddrPort string + MetricsEnabled bool + HealthCheckEnabled bool + DebugAddrPort string + EgressProxiesCfgPath string + IngressProxiesCfgPath string + // CertShareMode is set for Kubernetes Pods running cert share mode. + // Possible values are empty (containerboot doesn't run any certs + // logic), 'ro' (for Pods that shold never attempt to issue/renew + // certs) and 'rw' for Pods that should manage the TLS certs shared + // amongst the replicas. + CertShareMode string } func configFromEnv() (*settings, error) { @@ -98,7 +110,13 @@ func configFromEnv() (*settings, error) { PodIP: defaultEnv("POD_IP", ""), EnableForwardingOptimizations: defaultBool("TS_EXPERIMENTAL_ENABLE_FORWARDING_OPTIMIZATIONS", false), HealthCheckAddrPort: defaultEnv("TS_HEALTHCHECK_ADDR_PORT", ""), - EgressSvcsCfgPath: defaultEnv("TS_EGRESS_SERVICES_CONFIG_PATH", ""), + LocalAddrPort: defaultEnv("TS_LOCAL_ADDR_PORT", "[::]:9002"), + MetricsEnabled: defaultBool("TS_ENABLE_METRICS", false), + HealthCheckEnabled: defaultBool("TS_ENABLE_HEALTH_CHECK", false), + DebugAddrPort: defaultEnv("TS_DEBUG_ADDR_PORT", ""), + EgressProxiesCfgPath: defaultEnv("TS_EGRESS_PROXIES_CONFIG_PATH", ""), + IngressProxiesCfgPath: defaultEnv("TS_INGRESS_PROXIES_CONFIG_PATH", ""), + PodUID: defaultEnv("POD_UID", ""), } podIPs, ok := os.LookupEnv("POD_IPS") if ok { @@ -118,12 +136,80 @@ func configFromEnv() (*settings, error) { cfg.PodIPv6 = parsed.String() } } + // If cert share is enabled, set the replica as read or write. Only 0th + // replica should be able to write. + isInCertShareMode := defaultBool("TS_EXPERIMENTAL_CERT_SHARE", false) + if isInCertShareMode { + cfg.CertShareMode = "ro" + podName := os.Getenv("POD_NAME") + if strings.HasSuffix(podName, "-0") { + cfg.CertShareMode = "rw" + } + } + + // See https://github.com/tailscale/tailscale/issues/16108 for context- we + // do this to preserve the previous behaviour where --accept-dns could be + // set either via TS_ACCEPT_DNS or TS_EXTRA_ARGS. + acceptDNS := cfg.AcceptDNS != nil && *cfg.AcceptDNS + tsExtraArgs, acceptDNSNew := parseAcceptDNS(cfg.ExtraArgs, acceptDNS) + cfg.ExtraArgs = tsExtraArgs + if acceptDNS != acceptDNSNew { + cfg.AcceptDNS = &acceptDNSNew + } + if err := cfg.validate(); err != nil { return nil, fmt.Errorf("invalid configuration: %v", err) } return cfg, nil } +// parseAcceptDNS parses any values for Tailscale --accept-dns flag set via +// TS_ACCEPT_DNS and TS_EXTRA_ARGS env vars. If TS_EXTRA_ARGS contains +// --accept-dns flag, override the acceptDNS value with the one from +// TS_EXTRA_ARGS. +// The value of extraArgs can be empty string or one or more whitespace-separate +// key value pairs for 'tailscale up' command. The value for boolean flags can +// be omitted (default to true). +func parseAcceptDNS(extraArgs string, acceptDNS bool) (string, bool) { + if !strings.Contains(extraArgs, "--accept-dns") { + return extraArgs, acceptDNS + } + // TODO(irbekrm): we should validate that TS_EXTRA_ARGS contains legit + // 'tailscale up' flag values separated by whitespace. + argsArr := strings.Fields(extraArgs) + i := -1 + for key, val := range argsArr { + if strings.HasPrefix(val, "--accept-dns") { + i = key + break + } + } + if i == -1 { + return extraArgs, acceptDNS + } + a := strings.TrimSpace(argsArr[i]) + var acceptDNSFromExtraArgsS string + keyval := strings.Split(a, "=") + if len(keyval) == 2 { + acceptDNSFromExtraArgsS = keyval[1] + } else if len(keyval) == 1 && keyval[0] == "--accept-dns" { + // If the arg is just --accept-dns, we assume it means true. + acceptDNSFromExtraArgsS = "true" + } else { + log.Printf("TS_EXTRA_ARGS contains --accept-dns, but it is not in the expected format --accept-dns=, ignoring it") + return extraArgs, acceptDNS + } + acceptDNSFromExtraArgs, err := strconv.ParseBool(acceptDNSFromExtraArgsS) + if err != nil { + log.Printf("TS_EXTRA_ARGS contains --accept-dns=%q, which is not a valid boolean value, ignoring it", acceptDNSFromExtraArgsS) + return extraArgs, acceptDNS + } + if acceptDNSFromExtraArgs != acceptDNS { + log.Printf("TS_EXTRA_ARGS contains --accept-dns=%v, which overrides TS_ACCEPT_DNS=%v", acceptDNSFromExtraArgs, acceptDNS) + } + return strings.Join(append(argsArr[:i], argsArr[i+1:]...), " "), acceptDNSFromExtraArgs +} + func (s *settings) validate() error { if s.TailscaledConfigFilePath != "" { dir, file := path.Split(s.TailscaledConfigFilePath) @@ -171,17 +257,37 @@ func (s *settings) validate() error { return errors.New("TS_EXPERIMENTAL_ENABLE_FORWARDING_OPTIMIZATIONS is not supported in userspace mode") } if s.HealthCheckAddrPort != "" { + log.Printf("[warning] TS_HEALTHCHECK_ADDR_PORT is deprecated and will be removed in 1.82.0. Please use TS_ENABLE_HEALTH_CHECK and optionally TS_LOCAL_ADDR_PORT instead.") if _, err := netip.ParseAddrPort(s.HealthCheckAddrPort); err != nil { - return fmt.Errorf("error parsing TS_HEALTH_CHECK_ADDR_PORT value %q: %w", s.HealthCheckAddrPort, err) + return fmt.Errorf("error parsing TS_HEALTHCHECK_ADDR_PORT value %q: %w", s.HealthCheckAddrPort, err) + } + } + if s.localMetricsEnabled() || s.localHealthEnabled() || s.EgressProxiesCfgPath != "" { + if _, err := netip.ParseAddrPort(s.LocalAddrPort); err != nil { + return fmt.Errorf("error parsing TS_LOCAL_ADDR_PORT value %q: %w", s.LocalAddrPort, err) } } + if s.DebugAddrPort != "" { + if _, err := netip.ParseAddrPort(s.DebugAddrPort); err != nil { + return fmt.Errorf("error parsing TS_DEBUG_ADDR_PORT value %q: %w", s.DebugAddrPort, err) + } + } + if s.HealthCheckEnabled && s.HealthCheckAddrPort != "" { + return errors.New("TS_HEALTHCHECK_ADDR_PORT is deprecated and will be removed in 1.82.0, use TS_ENABLE_HEALTH_CHECK and optionally TS_LOCAL_ADDR_PORT") + } + if s.EgressProxiesCfgPath != "" && !(s.InKubernetes && s.KubeSecret != "") { + return errors.New("TS_EGRESS_PROXIES_CONFIG_PATH is only supported for Tailscale running on Kubernetes") + } + if s.IngressProxiesCfgPath != "" && !(s.InKubernetes && s.KubeSecret != "") { + return errors.New("TS_INGRESS_PROXIES_CONFIG_PATH is only supported for Tailscale running on Kubernetes") + } return nil } // setupKube is responsible for doing any necessary configuration and checks to // ensure that tailscale state storage and authentication mechanism will work on // Kubernetes. -func (cfg *settings) setupKube(ctx context.Context) error { +func (cfg *settings) setupKube(ctx context.Context, kc *kubeClient) error { if cfg.KubeSecret == "" { return nil } @@ -190,6 +296,7 @@ func (cfg *settings) setupKube(ctx context.Context) error { return fmt.Errorf("some Kubernetes permissions are missing, please check your RBAC configuration: %v", err) } cfg.KubernetesCanPatch = canPatch + kc.canPatch = canPatch s, err := kc.GetSecret(ctx, cfg.KubeSecret) if err != nil { @@ -263,7 +370,7 @@ func isOneStepConfig(cfg *settings) bool { // as an L3 proxy, proxying to an endpoint provided via one of the config env // vars. func isL3Proxy(cfg *settings) bool { - return cfg.ProxyTargetIP != "" || cfg.ProxyTargetDNSName != "" || cfg.TailnetTargetIP != "" || cfg.TailnetTargetFQDN != "" || cfg.AllowProxyingClusterTrafficViaIngress || cfg.EgressSvcsCfgPath != "" + return cfg.ProxyTargetIP != "" || cfg.ProxyTargetDNSName != "" || cfg.TailnetTargetIP != "" || cfg.TailnetTargetFQDN != "" || cfg.AllowProxyingClusterTrafficViaIngress || cfg.EgressProxiesCfgPath != "" || cfg.IngressProxiesCfgPath != "" } // hasKubeStateStore returns true if the state must be stored in a Kubernetes @@ -272,6 +379,18 @@ func hasKubeStateStore(cfg *settings) bool { return cfg.InKubernetes && cfg.KubernetesCanPatch && cfg.KubeSecret != "" } +func (cfg *settings) localMetricsEnabled() bool { + return cfg.LocalAddrPort != "" && cfg.MetricsEnabled +} + +func (cfg *settings) localHealthEnabled() bool { + return cfg.LocalAddrPort != "" && cfg.HealthCheckEnabled +} + +func (cfg *settings) egressSvcsTerminateEPEnabled() bool { + return cfg.LocalAddrPort != "" && cfg.EgressProxiesCfgPath != "" +} + // defaultEnv returns the value of the given envvar name, or defVal if // unset. func defaultEnv(name, defVal string) string { diff --git a/cmd/containerboot/settings_test.go b/cmd/containerboot/settings_test.go new file mode 100644 index 0000000000000..dbec066c9ab0d --- /dev/null +++ b/cmd/containerboot/settings_test.go @@ -0,0 +1,108 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package main + +import "testing" + +func Test_parseAcceptDNS(t *testing.T) { + tests := []struct { + name string + extraArgs string + acceptDNS bool + wantExtraArgs string + wantAcceptDNS bool + }{ + { + name: "false_extra_args_unset", + extraArgs: "", + wantExtraArgs: "", + wantAcceptDNS: false, + }, + { + name: "false_unrelated_args_set", + extraArgs: "--accept-routes=true --advertise-routes=10.0.0.1/32", + wantExtraArgs: "--accept-routes=true --advertise-routes=10.0.0.1/32", + wantAcceptDNS: false, + }, + { + name: "true_extra_args_unset", + extraArgs: "", + acceptDNS: true, + wantExtraArgs: "", + wantAcceptDNS: true, + }, + { + name: "true_unrelated_args_set", + acceptDNS: true, + extraArgs: "--accept-routes=true --advertise-routes=10.0.0.1/32", + wantExtraArgs: "--accept-routes=true --advertise-routes=10.0.0.1/32", + wantAcceptDNS: true, + }, + { + name: "false_extra_args_set_to_false", + extraArgs: "--accept-dns=false", + wantExtraArgs: "", + wantAcceptDNS: false, + }, + { + name: "false_extra_args_set_to_true", + extraArgs: "--accept-dns=true", + wantExtraArgs: "", + wantAcceptDNS: true, + }, + { + name: "true_extra_args_set_to_false", + extraArgs: "--accept-dns=false", + acceptDNS: true, + wantExtraArgs: "", + wantAcceptDNS: false, + }, + { + name: "true_extra_args_set_to_true", + extraArgs: "--accept-dns=true", + acceptDNS: true, + wantExtraArgs: "", + wantAcceptDNS: true, + }, + { + name: "false_extra_args_set_to_true_implicitly", + extraArgs: "--accept-dns", + wantExtraArgs: "", + wantAcceptDNS: true, + }, + { + name: "false_extra_args_set_to_true_implicitly_with_unrelated_args", + extraArgs: "--accept-dns --accept-routes --advertise-routes=10.0.0.1/32", + wantExtraArgs: "--accept-routes --advertise-routes=10.0.0.1/32", + wantAcceptDNS: true, + }, + { + name: "false_extra_args_set_to_true_implicitly_surrounded_with_unrelated_args", + extraArgs: "--accept-routes --accept-dns --advertise-routes=10.0.0.1/32", + wantExtraArgs: "--accept-routes --advertise-routes=10.0.0.1/32", + wantAcceptDNS: true, + }, + { + name: "true_extra_args_set_to_false_with_unrelated_args", + extraArgs: "--accept-routes --accept-dns=false", + acceptDNS: true, + wantExtraArgs: "--accept-routes", + wantAcceptDNS: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotExtraArgs, gotAcceptDNS := parseAcceptDNS(tt.extraArgs, tt.acceptDNS) + if gotExtraArgs != tt.wantExtraArgs { + t.Errorf("parseAcceptDNS() gotExtraArgs = %v, want %v", gotExtraArgs, tt.wantExtraArgs) + } + if gotAcceptDNS != tt.wantAcceptDNS { + t.Errorf("parseAcceptDNS() gotAcceptDNS = %v, want %v", gotAcceptDNS, tt.wantAcceptDNS) + } + }) + } +} diff --git a/cmd/containerboot/tailscaled.go b/cmd/containerboot/tailscaled.go index 53fb7e703be45..f828c52573089 100644 --- a/cmd/containerboot/tailscaled.go +++ b/cmd/containerboot/tailscaled.go @@ -13,14 +13,17 @@ import ( "log" "os" "os/exec" + "path/filepath" + "reflect" "strings" "syscall" "time" - "tailscale.com/client/tailscale" + "github.com/fsnotify/fsnotify" + "tailscale.com/client/local" ) -func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient, *os.Process, error) { +func startTailscaled(ctx context.Context, cfg *settings) (*local.Client, *os.Process, error) { args := tailscaledArgs(cfg) // tailscaled runs without context, since it needs to persist // beyond the startup timeout in ctx. @@ -30,28 +33,31 @@ func startTailscaled(ctx context.Context, cfg *settings) (*tailscale.LocalClient cmd.SysProcAttr = &syscall.SysProcAttr{ Setpgid: true, } + if cfg.CertShareMode != "" { + cmd.Env = append(os.Environ(), "TS_CERT_SHARE_MODE="+cfg.CertShareMode) + } log.Printf("Starting tailscaled") if err := cmd.Start(); err != nil { - return nil, nil, fmt.Errorf("starting tailscaled failed: %v", err) + return nil, nil, fmt.Errorf("starting tailscaled failed: %w", err) } // Wait for the socket file to appear, otherwise API ops will racily fail. - log.Printf("Waiting for tailscaled socket") + log.Printf("Waiting for tailscaled socket at %s", cfg.Socket) for { if ctx.Err() != nil { - log.Fatalf("Timed out waiting for tailscaled socket") + return nil, nil, errors.New("timed out waiting for tailscaled socket") } _, err := os.Stat(cfg.Socket) if errors.Is(err, fs.ErrNotExist) { time.Sleep(100 * time.Millisecond) continue } else if err != nil { - log.Fatalf("Waiting for tailscaled socket: %v", err) + return nil, nil, fmt.Errorf("error waiting for tailscaled socket: %w", err) } break } - tsClient := &tailscale.LocalClient{ + tsClient := &local.Client{ Socket: cfg.Socket, UseSocketOnly: true, } @@ -90,6 +96,12 @@ func tailscaledArgs(cfg *settings) []string { if cfg.TailscaledConfigFilePath != "" { args = append(args, "--config="+cfg.TailscaledConfigFilePath) } + // Once enough proxy versions have been released for all the supported + // versions to understand this cfg setting, the operator can stop + // setting TS_TAILSCALED_EXTRA_ARGS for the debug flag. + if cfg.DebugAddrPort != "" && !strings.Contains(cfg.DaemonExtraArgs, cfg.DebugAddrPort) { + args = append(args, "--debug="+cfg.DebugAddrPort) + } if cfg.DaemonExtraArgs != "" { args = append(args, strings.Fields(cfg.DaemonExtraArgs)...) } @@ -160,3 +172,75 @@ func tailscaleSet(ctx context.Context, cfg *settings) error { } return nil } + +func watchTailscaledConfigChanges(ctx context.Context, path string, lc *local.Client, errCh chan<- error) { + var ( + tickChan <-chan time.Time + eventChan <-chan fsnotify.Event + errChan <-chan error + tailscaledCfgDir = filepath.Dir(path) + prevTailscaledCfg []byte + ) + if w, err := fsnotify.NewWatcher(); err != nil { + // Creating a new fsnotify watcher would fail for example if inotify was not able to create a new file descriptor. + // See https://github.com/tailscale/tailscale/issues/15081 + log.Printf("tailscaled config watch: failed to create fsnotify watcher, timer-only mode: %v", err) + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + tickChan = ticker.C + } else { + defer w.Close() + if err := w.Add(tailscaledCfgDir); err != nil { + errCh <- fmt.Errorf("failed to add fsnotify watch: %w", err) + return + } + eventChan = w.Events + errChan = w.Errors + } + b, err := os.ReadFile(path) + if err != nil { + errCh <- fmt.Errorf("error reading configfile: %w", err) + return + } + prevTailscaledCfg = b + // kubelet mounts Secrets to Pods using a series of symlinks, one of + // which is /..data that Kubernetes recommends consumers to + // use if they need to monitor changes + // https://github.com/kubernetes/kubernetes/blob/v1.28.1/pkg/volume/util/atomic_writer.go#L39-L61 + const kubeletMountedCfg = "..data" + toWatch := filepath.Join(tailscaledCfgDir, kubeletMountedCfg) + for { + select { + case <-ctx.Done(): + return + case err := <-errChan: + errCh <- fmt.Errorf("watcher error: %w", err) + return + case <-tickChan: + case event := <-eventChan: + if event.Name != toWatch { + continue + } + } + b, err := os.ReadFile(path) + if err != nil { + errCh <- fmt.Errorf("error reading configfile: %w", err) + return + } + // For some proxy types the mounted volume also contains tailscaled state and other files. We + // don't want to reload config unnecessarily on unrelated changes to these files. + if reflect.DeepEqual(b, prevTailscaledCfg) { + continue + } + prevTailscaledCfg = b + log.Printf("tailscaled config watch: ensuring that config is up to date") + ok, err := lc.ReloadConfig(ctx) + if err != nil { + errCh <- fmt.Errorf("error reloading tailscaled config: %w", err) + return + } + if ok { + log.Printf("tailscaled config watch: config was reloaded") + } + } +} diff --git a/cmd/derper/bootstrap_dns_test.go b/cmd/derper/bootstrap_dns_test.go index d151bc2b05fdf..9b99103abfe33 100644 --- a/cmd/derper/bootstrap_dns_test.go +++ b/cmd/derper/bootstrap_dns_test.go @@ -20,10 +20,10 @@ import ( ) func BenchmarkHandleBootstrapDNS(b *testing.B) { - tstest.Replace(b, bootstrapDNS, "log.tailscale.io,login.tailscale.com,controlplane.tailscale.com,login.us.tailscale.com") + tstest.Replace(b, bootstrapDNS, "log.tailscale.com,login.tailscale.com,controlplane.tailscale.com,login.us.tailscale.com") refreshBootstrapDNS() w := new(bitbucketResponseWriter) - req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape("log.tailscale.io"), nil) + req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape("log.tailscale.com"), nil) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(b *testing.PB) { @@ -63,7 +63,7 @@ func TestUnpublishedDNS(t *testing.T) { nettest.SkipIfNoNetwork(t) const published = "login.tailscale.com" - const unpublished = "log.tailscale.io" + const unpublished = "log.tailscale.com" prev1, prev2 := *bootstrapDNS, *unpublishedDNS *bootstrapDNS = published @@ -119,18 +119,18 @@ func TestUnpublishedDNSEmptyList(t *testing.T) { unpublishedDNSCache.Store(&dnsEntryMap{ IPs: map[string][]net.IP{ - "log.tailscale.io": {}, + "log.tailscale.com": {}, "controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)}, }, Percent: map[string]float64{ - "log.tailscale.io": 1.0, + "log.tailscale.com": 1.0, "controlplane.tailscale.com": 1.0, }, }) t.Run("CacheMiss", func(t *testing.T) { // One domain in map but empty, one not in map at all - for _, q := range []string{"log.tailscale.io", "login.tailscale.com"} { + for _, q := range []string{"log.tailscale.com", "login.tailscale.com"} { resetMetrics() ips := getBootstrapDNS(t, q) diff --git a/cmd/derper/cert.go b/cmd/derper/cert.go index db84aa515d257..b95755c64d2a7 100644 --- a/cmd/derper/cert.go +++ b/cmd/derper/cert.go @@ -4,15 +4,28 @@ package main import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" "crypto/tls" "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "encoding/pem" "errors" "fmt" + "log" + "math/big" + "net" "net/http" + "os" "path/filepath" "regexp" + "time" "golang.org/x/crypto/acme/autocert" + "tailscale.com/tailcfg" ) var unsafeHostnameCharacters = regexp.MustCompile(`[^a-zA-Z0-9-\.]`) @@ -53,8 +66,9 @@ func certProviderByCertMode(mode, dir, hostname string) (certProvider, error) { } type manualCertManager struct { - cert *tls.Certificate - hostname string + cert *tls.Certificate + hostname string // hostname or IP address of server + noHostname bool // whether hostname is an IP address } // NewManualCertManager returns a cert provider which read certificate by given hostname on create. @@ -63,8 +77,18 @@ func NewManualCertManager(certdir, hostname string) (certProvider, error) { crtPath := filepath.Join(certdir, keyname+".crt") keyPath := filepath.Join(certdir, keyname+".key") cert, err := tls.LoadX509KeyPair(crtPath, keyPath) + hostnameIP := net.ParseIP(hostname) // or nil if hostname isn't an IP address if err != nil { - return nil, fmt.Errorf("can not load x509 key pair for hostname %q: %w", keyname, err) + // If the hostname is an IP address, automatically create a + // self-signed certificate for it. + var certp *tls.Certificate + if os.IsNotExist(err) && hostnameIP != nil { + certp, err = createSelfSignedIPCert(crtPath, keyPath, hostname) + } + if err != nil { + return nil, fmt.Errorf("can not load x509 key pair for hostname %q: %w", keyname, err) + } + cert = *certp } // ensure hostname matches with the certificate x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) @@ -74,7 +98,23 @@ func NewManualCertManager(certdir, hostname string) (certProvider, error) { if err := x509Cert.VerifyHostname(hostname); err != nil { return nil, fmt.Errorf("cert invalid for hostname %q: %w", hostname, err) } - return &manualCertManager{cert: &cert, hostname: hostname}, nil + if hostnameIP != nil { + // If the hostname is an IP address, print out information on how to + // confgure this in the derpmap. + dn := &tailcfg.DERPNode{ + Name: "custom", + RegionID: 900, + HostName: hostname, + CertName: fmt.Sprintf("sha256-raw:%-02x", sha256.Sum256(x509Cert.Raw)), + } + dnJSON, _ := json.Marshal(dn) + log.Printf("Using self-signed certificate for IP address %q. Configure it in DERPMap using: (https://tailscale.com/s/custom-derp)\n %s", hostname, dnJSON) + } + return &manualCertManager{ + cert: &cert, + hostname: hostname, + noHostname: net.ParseIP(hostname) != nil, + }, nil } func (m *manualCertManager) TLSConfig() *tls.Config { @@ -88,7 +128,7 @@ func (m *manualCertManager) TLSConfig() *tls.Config { } func (m *manualCertManager) getCertificate(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { - if hi.ServerName != m.hostname { + if hi.ServerName != m.hostname && !m.noHostname { return nil, fmt.Errorf("cert mismatch with hostname: %q", hi.ServerName) } @@ -103,3 +143,69 @@ func (m *manualCertManager) getCertificate(hi *tls.ClientHelloInfo) (*tls.Certif func (m *manualCertManager) HTTPHandler(fallback http.Handler) http.Handler { return fallback } + +func createSelfSignedIPCert(crtPath, keyPath, ipStr string) (*tls.Certificate, error) { + ip := net.ParseIP(ipStr) + if ip == nil { + return nil, fmt.Errorf("invalid IP address: %s", ipStr) + } + + priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, fmt.Errorf("failed to generate EC private key: %v", err) + } + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, fmt.Errorf("failed to generate serial number: %v", err) + } + + now := time.Now() + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + CommonName: ipStr, + }, + NotBefore: now, + NotAfter: now.AddDate(1, 0, 0), // expires in 1 year; a bit over that is rejected by macOS etc + + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + // Set the IP as a SAN. + template.IPAddresses = []net.IP{ip} + + // Create the self-signed certificate. + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return nil, fmt.Errorf("failed to create certificate: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + + keyBytes, err := x509.MarshalECPrivateKey(priv) + if err != nil { + return nil, fmt.Errorf("unable to marshal EC private key: %v", err) + } + + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyBytes}) + + if err := os.MkdirAll(filepath.Dir(crtPath), 0700); err != nil { + return nil, fmt.Errorf("failed to create directory for certificate: %v", err) + } + if err := os.WriteFile(crtPath, certPEM, 0644); err != nil { + return nil, fmt.Errorf("failed to write certificate to %s: %v", crtPath, err) + } + if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil { + return nil, fmt.Errorf("failed to write key to %s: %v", keyPath, err) + } + + tlsCert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + return nil, fmt.Errorf("failed to create tls.Certificate: %v", err) + } + return &tlsCert, nil +} diff --git a/cmd/derper/cert_test.go b/cmd/derper/cert_test.go new file mode 100644 index 0000000000000..31fd4ea446949 --- /dev/null +++ b/cmd/derper/cert_test.go @@ -0,0 +1,171 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "tailscale.com/derp" + "tailscale.com/derp/derphttp" + "tailscale.com/net/netmon" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// Verify that in --certmode=manual mode, we can use a bare IP address +// as the --hostname and that GetCertificate will return it. +func TestCertIP(t *testing.T) { + dir := t.TempDir() + const hostname = "1.2.3.4" + + priv, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + if err != nil { + t.Fatal(err) + } + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + t.Fatal(err) + } + ip := net.ParseIP(hostname) + if ip == nil { + t.Fatalf("invalid IP address %q", hostname) + } + template := &x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Tailscale Test Corp"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(30 * 24 * time.Hour), + + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IPAddresses: []net.IP{ip}, + } + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv) + if err != nil { + t.Fatal(err) + } + certOut, err := os.Create(filepath.Join(dir, hostname+".crt")) + if err != nil { + t.Fatal(err) + } + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + t.Fatalf("Failed to write data to cert.pem: %v", err) + } + if err := certOut.Close(); err != nil { + t.Fatalf("Error closing cert.pem: %v", err) + } + + keyOut, err := os.OpenFile(filepath.Join(dir, hostname+".key"), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) + if err != nil { + t.Fatal(err) + } + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + t.Fatalf("Unable to marshal private key: %v", err) + } + if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { + t.Fatalf("Failed to write data to key.pem: %v", err) + } + if err := keyOut.Close(); err != nil { + t.Fatalf("Error closing key.pem: %v", err) + } + + cp, err := certProviderByCertMode("manual", dir, hostname) + if err != nil { + t.Fatal(err) + } + back, err := cp.TLSConfig().GetCertificate(&tls.ClientHelloInfo{ + ServerName: "", // no SNI + }) + if err != nil { + t.Fatalf("GetCertificate: %v", err) + } + if back == nil { + t.Fatalf("GetCertificate returned nil") + } +} + +// Test that we can dial a raw IP without using a hostname and without a WebPKI +// cert, validating the cert against the signature of the cert in the DERP map's +// DERPNode. +// +// See https://github.com/tailscale/tailscale/issues/11776. +func TestPinnedCertRawIP(t *testing.T) { + td := t.TempDir() + cp, err := NewManualCertManager(td, "127.0.0.1") + if err != nil { + t.Fatalf("NewManualCertManager: %v", err) + } + + cert, err := cp.TLSConfig().GetCertificate(&tls.ClientHelloInfo{ + ServerName: "127.0.0.1", + }) + if err != nil { + t.Fatalf("GetCertificate: %v", err) + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer ln.Close() + + ds := derp.NewServer(key.NewNode(), t.Logf) + + derpHandler := derphttp.Handler(ds) + mux := http.NewServeMux() + mux.Handle("/derp", derpHandler) + + var hs http.Server + hs.Handler = mux + hs.TLSConfig = cp.TLSConfig() + ds.ModifyTLSConfigToAddMetaCert(hs.TLSConfig) + go hs.ServeTLS(ln, "", "") + + lnPort := ln.Addr().(*net.TCPAddr).Port + + reg := &tailcfg.DERPRegion{ + RegionID: 900, + Nodes: []*tailcfg.DERPNode{ + { + RegionID: 900, + HostName: "127.0.0.1", + CertName: fmt.Sprintf("sha256-raw:%-02x", sha256.Sum256(cert.Leaf.Raw)), + DERPPort: lnPort, + }, + }, + } + + netMon := netmon.NewStatic() + dc := derphttp.NewRegionClient(key.NewNode(), t.Logf, netMon, func() *tailcfg.DERPRegion { + return reg + }) + defer dc.Close() + + _, connClose, _, err := dc.DialRegionTLS(context.Background(), reg) + if err != nil { + t.Fatalf("DialRegionTLS: %v", err) + } + defer connClose.Close() +} diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index 417dbcfb0deb7..ca772353079dc 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -27,9 +27,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa L github.com/google/nftables/expr from github.com/google/nftables+ L github.com/google/nftables/internal/parseexprfunc from github.com/google/nftables+ L github.com/google/nftables/xt from github.com/google/nftables/expr+ - github.com/google/uuid from tailscale.com/util/fastuuid github.com/hdevalence/ed25519consensus from tailscale.com/tka - L github.com/josharian/native from github.com/mdlayher/netlink+ L đŸ’Ŗ github.com/jsimonetti/rtnetlink from tailscale.com/net/netmon L github.com/jsimonetti/rtnetlink/internal/unix from github.com/jsimonetti/rtnetlink L đŸ’Ŗ github.com/mdlayher/netlink from github.com/google/nftables+ @@ -37,11 +35,11 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa L github.com/mdlayher/netlink/nltest from github.com/google/nftables L đŸ’Ŗ github.com/mdlayher/socket from github.com/mdlayher/netlink đŸ’Ŗ github.com/mitchellh/go-ps from tailscale.com/safesocket + github.com/munnerz/goautoneg from github.com/prometheus/common/expfmt đŸ’Ŗ github.com/prometheus/client_golang/prometheus from tailscale.com/tsweb/promvarz github.com/prometheus/client_golang/prometheus/internal from github.com/prometheus/client_golang/prometheus github.com/prometheus/client_model/go from github.com/prometheus/client_golang/prometheus+ github.com/prometheus/common/expfmt from github.com/prometheus/client_golang/prometheus+ - github.com/prometheus/common/internal/bitbucket.org/ww/goautoneg from github.com/prometheus/common/expfmt github.com/prometheus/common/model from github.com/prometheus/client_golang/prometheus+ LD github.com/prometheus/procfs from github.com/prometheus/client_golang/prometheus LD github.com/prometheus/procfs/internal/fs from github.com/prometheus/procfs @@ -53,9 +51,11 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa W github.com/tailscale/go-winio/pkg/guid from github.com/tailscale/go-winio+ L đŸ’Ŗ github.com/tailscale/netlink from tailscale.com/util/linuxfw L đŸ’Ŗ github.com/tailscale/netlink/nl from github.com/tailscale/netlink + github.com/tailscale/setec/client/setec from tailscale.com/cmd/derper + github.com/tailscale/setec/types/api from github.com/tailscale/setec/client/setec L github.com/vishvananda/netns from github.com/tailscale/netlink+ github.com/x448/float16 from github.com/fxamacker/cbor/v2 - đŸ’Ŗ go4.org/mem from tailscale.com/client/tailscale+ + đŸ’Ŗ go4.org/mem from tailscale.com/client/local+ go4.org/netipx from tailscale.com/net/tsaddr W đŸ’Ŗ golang.zx2c4.com/wireguard/windows/tunnel/winipcfg from tailscale.com/net/netmon+ google.golang.org/protobuf/encoding/protodelim from github.com/prometheus/common/expfmt @@ -87,51 +87,57 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa google.golang.org/protobuf/runtime/protoimpl from github.com/prometheus/client_model/go+ google.golang.org/protobuf/types/known/timestamppb from github.com/prometheus/client_golang/prometheus+ tailscale.com from tailscale.com/version - tailscale.com/atomicfile from tailscale.com/cmd/derper+ + đŸ’Ŗ tailscale.com/atomicfile from tailscale.com/cmd/derper+ + tailscale.com/client/local from tailscale.com/client/tailscale+ tailscale.com/client/tailscale from tailscale.com/derp - tailscale.com/client/tailscale/apitype from tailscale.com/client/tailscale + tailscale.com/client/tailscale/apitype from tailscale.com/client/tailscale+ tailscale.com/derp from tailscale.com/cmd/derper+ + tailscale.com/derp/derpconst from tailscale.com/derp+ tailscale.com/derp/derphttp from tailscale.com/cmd/derper tailscale.com/disco from tailscale.com/derp - tailscale.com/drive from tailscale.com/client/tailscale+ - tailscale.com/envknob from tailscale.com/client/tailscale+ + tailscale.com/drive from tailscale.com/client/local+ + tailscale.com/envknob from tailscale.com/client/local+ + tailscale.com/feature from tailscale.com/tsweb tailscale.com/health from tailscale.com/net/tlsdial+ tailscale.com/hostinfo from tailscale.com/net/netmon+ - tailscale.com/ipn from tailscale.com/client/tailscale - tailscale.com/ipn/ipnstate from tailscale.com/client/tailscale+ + tailscale.com/ipn from tailscale.com/client/local + tailscale.com/ipn/ipnstate from tailscale.com/client/local+ tailscale.com/kube/kubetypes from tailscale.com/envknob tailscale.com/metrics from tailscale.com/cmd/derper+ + tailscale.com/net/bakedroots from tailscale.com/net/tlsdial tailscale.com/net/dnscache from tailscale.com/derp/derphttp tailscale.com/net/ktimeout from tailscale.com/cmd/derper tailscale.com/net/netaddr from tailscale.com/ipn+ tailscale.com/net/netknob from tailscale.com/net/netns đŸ’Ŗ tailscale.com/net/netmon from tailscale.com/derp/derphttp+ đŸ’Ŗ tailscale.com/net/netns from tailscale.com/derp/derphttp - tailscale.com/net/netutil from tailscale.com/client/tailscale + tailscale.com/net/netutil from tailscale.com/client/local + tailscale.com/net/netx from tailscale.com/net/dnscache+ tailscale.com/net/sockstats from tailscale.com/derp/derphttp tailscale.com/net/stun from tailscale.com/net/stunserver tailscale.com/net/stunserver from tailscale.com/cmd/derper L tailscale.com/net/tcpinfo from tailscale.com/derp tailscale.com/net/tlsdial from tailscale.com/derp/derphttp + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial tailscale.com/net/tsaddr from tailscale.com/ipn+ đŸ’Ŗ tailscale.com/net/tshttpproxy from tailscale.com/derp/derphttp+ - tailscale.com/net/wsconn from tailscale.com/cmd/derper+ - tailscale.com/paths from tailscale.com/client/tailscale - đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/tailscale + tailscale.com/net/wsconn from tailscale.com/cmd/derper + tailscale.com/paths from tailscale.com/client/local + đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/local tailscale.com/syncs from tailscale.com/cmd/derper+ - tailscale.com/tailcfg from tailscale.com/client/tailscale+ - tailscale.com/tka from tailscale.com/client/tailscale+ + tailscale.com/tailcfg from tailscale.com/client/local+ + tailscale.com/tka from tailscale.com/client/local+ W tailscale.com/tsconst from tailscale.com/net/netmon+ tailscale.com/tstime from tailscale.com/derp+ tailscale.com/tstime/mono from tailscale.com/tstime/rate tailscale.com/tstime/rate from tailscale.com/derp - tailscale.com/tsweb from tailscale.com/cmd/derper - tailscale.com/tsweb/promvarz from tailscale.com/tsweb + tailscale.com/tsweb from tailscale.com/cmd/derper+ + tailscale.com/tsweb/promvarz from tailscale.com/cmd/derper tailscale.com/tsweb/varz from tailscale.com/tsweb+ tailscale.com/types/dnstype from tailscale.com/tailcfg+ tailscale.com/types/empty from tailscale.com/ipn tailscale.com/types/ipproto from tailscale.com/tailcfg+ - tailscale.com/types/key from tailscale.com/client/tailscale+ + tailscale.com/types/key from tailscale.com/client/local+ tailscale.com/types/lazy from tailscale.com/version+ tailscale.com/types/logger from tailscale.com/cmd/derper+ tailscale.com/types/netmap from tailscale.com/ipn @@ -139,8 +145,9 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa tailscale.com/types/persist from tailscale.com/ipn tailscale.com/types/preftype from tailscale.com/ipn tailscale.com/types/ptr from tailscale.com/hostinfo+ + tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/ipn+ - tailscale.com/types/tkatype from tailscale.com/client/tailscale+ + tailscale.com/types/tkatype from tailscale.com/client/local+ tailscale.com/types/views from tailscale.com/ipn+ tailscale.com/util/cibuild from tailscale.com/health tailscale.com/util/clientmetric from tailscale.com/net/netmon+ @@ -150,24 +157,30 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa đŸ’Ŗ tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting L đŸ’Ŗ tailscale.com/util/dirwalk from tailscale.com/metrics tailscale.com/util/dnsname from tailscale.com/hostinfo+ - tailscale.com/util/fastuuid from tailscale.com/tsweb + tailscale.com/util/eventbus from tailscale.com/net/netmon đŸ’Ŗ tailscale.com/util/hashx from tailscale.com/util/deephash tailscale.com/util/httpm from tailscale.com/client/tailscale - tailscale.com/util/lineread from tailscale.com/hostinfo+ + tailscale.com/util/lineiter from tailscale.com/hostinfo+ L tailscale.com/util/linuxfw from tailscale.com/net/netns tailscale.com/util/mak from tailscale.com/health+ tailscale.com/util/multierr from tailscale.com/health+ tailscale.com/util/nocasemaps from tailscale.com/types/ipproto + tailscale.com/util/rands from tailscale.com/tsweb tailscale.com/util/set from tailscale.com/derp+ tailscale.com/util/singleflight from tailscale.com/net/dnscache tailscale.com/util/slicesx from tailscale.com/cmd/derper+ tailscale.com/util/syspolicy from tailscale.com/ipn tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+ - tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy - tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ + tailscale.com/util/testenv from tailscale.com/util/syspolicy+ tailscale.com/util/usermetric from tailscale.com/health tailscale.com/util/vizerror from tailscale.com/tailcfg+ W đŸ’Ŗ tailscale.com/util/winutil from tailscale.com/hostinfo+ + W đŸ’Ŗ tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source W đŸ’Ŗ tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ tailscale.com/version from tailscale.com/derp+ tailscale.com/version/distro from tailscale.com/envknob+ @@ -182,23 +195,25 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa golang.org/x/crypto/cryptobyte from crypto/ecdsa+ golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ golang.org/x/crypto/curve25519 from golang.org/x/crypto/nacl/box+ - golang.org/x/crypto/hkdf from crypto/tls+ + golang.org/x/crypto/internal/alias from golang.org/x/crypto/chacha20+ + golang.org/x/crypto/internal/poly1305 from golang.org/x/crypto/chacha20poly1305+ golang.org/x/crypto/nacl/box from tailscale.com/types/key golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - golang.org/x/crypto/sha3 from crypto/internal/mlkem768+ - W golang.org/x/exp/constraints from tailscale.com/util/winutil - golang.org/x/exp/maps from tailscale.com/util/syspolicy/setting + golang.org/x/exp/constraints from tailscale.com/util/winutil+ + golang.org/x/exp/maps from tailscale.com/util/syspolicy/setting+ L golang.org/x/net/bpf from github.com/mdlayher/netlink+ golang.org/x/net/dns/dnsmessage from net+ golang.org/x/net/http/httpguts from net/http golang.org/x/net/http/httpproxy from net/http+ golang.org/x/net/http2/hpack from net/http golang.org/x/net/idna from golang.org/x/crypto/acme/autocert+ + golang.org/x/net/internal/socks from golang.org/x/net/proxy golang.org/x/net/proxy from tailscale.com/net/netns D golang.org/x/net/route from net+ golang.org/x/sync/errgroup from github.com/mdlayher/socket+ - golang.org/x/sys/cpu from github.com/josharian/native+ + golang.org/x/sync/singleflight from github.com/tailscale/setec/client/setec + golang.org/x/sys/cpu from golang.org/x/crypto/argon2+ LD golang.org/x/sys/unix from github.com/google/nftables+ W golang.org/x/sys/windows from github.com/dblohm7/wingoes+ W golang.org/x/sys/windows/registry from github.com/dblohm7/wingoes+ @@ -217,7 +232,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa container/list from crypto/tls+ context from crypto/tls+ crypto from crypto/ecdh+ - crypto/aes from crypto/ecdsa+ + crypto/aes from crypto/internal/hpke+ crypto/cipher from crypto/aes+ crypto/des from crypto/tls+ crypto/dsa from crypto/x509 @@ -226,19 +241,58 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa crypto/ed25519 from crypto/tls+ crypto/elliptic from crypto/ecdsa+ crypto/hmac from crypto/tls+ + crypto/internal/boring from crypto/aes+ + crypto/internal/boring/bbig from crypto/ecdsa+ + crypto/internal/boring/sig from crypto/internal/boring + crypto/internal/entropy from crypto/internal/fips140/drbg + crypto/internal/fips140 from crypto/internal/fips140/aes+ + crypto/internal/fips140/aes from crypto/aes+ + crypto/internal/fips140/aes/gcm from crypto/cipher+ + crypto/internal/fips140/alias from crypto/cipher+ + crypto/internal/fips140/bigmod from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/check from crypto/internal/fips140/aes+ + crypto/internal/fips140/drbg from crypto/internal/fips140/aes/gcm+ + crypto/internal/fips140/ecdh from crypto/ecdh + crypto/internal/fips140/ecdsa from crypto/ecdsa + crypto/internal/fips140/ed25519 from crypto/ed25519 + crypto/internal/fips140/edwards25519 from crypto/internal/fips140/ed25519 + crypto/internal/fips140/edwards25519/field from crypto/ecdh+ + crypto/internal/fips140/hkdf from crypto/internal/fips140/tls13+ + crypto/internal/fips140/hmac from crypto/hmac+ + crypto/internal/fips140/mlkem from crypto/tls + crypto/internal/fips140/nistec from crypto/elliptic+ + crypto/internal/fips140/nistec/fiat from crypto/internal/fips140/nistec + crypto/internal/fips140/rsa from crypto/rsa + crypto/internal/fips140/sha256 from crypto/internal/fips140/check+ + crypto/internal/fips140/sha3 from crypto/internal/fips140/hmac+ + crypto/internal/fips140/sha512 from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/subtle from crypto/internal/fips140/aes+ + crypto/internal/fips140/tls12 from crypto/tls + crypto/internal/fips140/tls13 from crypto/tls + crypto/internal/fips140deps/byteorder from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/cpu from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/godebug from crypto/internal/fips140+ + crypto/internal/fips140hash from crypto/ecdsa+ + crypto/internal/fips140only from crypto/cipher+ + crypto/internal/hpke from crypto/tls + crypto/internal/impl from crypto/internal/fips140/aes+ + crypto/internal/randutil from crypto/dsa+ + crypto/internal/sysrand from crypto/internal/entropy+ crypto/md5 from crypto/tls+ crypto/rand from crypto/ed25519+ crypto/rc4 from crypto/tls crypto/rsa from crypto/tls+ crypto/sha1 from crypto/tls+ crypto/sha256 from crypto/tls+ + crypto/sha3 from crypto/internal/fips140hash crypto/sha512 from crypto/ecdsa+ - crypto/subtle from crypto/aes+ + crypto/subtle from crypto/cipher+ crypto/tls from golang.org/x/crypto/acme+ + crypto/tls/internal/fips140tls from crypto/tls crypto/x509 from crypto/tls+ + D crypto/x509/internal/macos from crypto/x509 crypto/x509/pkix from crypto/x509+ - database/sql/driver from github.com/google/uuid - embed from crypto/internal/nistec+ + embed from google.golang.org/protobuf/internal/editiondefaults+ encoding from encoding/json+ encoding/asn1 from crypto/x509+ encoding/base32 from github.com/fxamacker/cbor/v2+ @@ -249,7 +303,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa encoding/pem from crypto/tls+ errors from bufio+ expvar from github.com/prometheus/client_golang/prometheus+ - flag from tailscale.com/cmd/derper + flag from tailscale.com/cmd/derper+ fmt from compress/flate+ go/token from google.golang.org/protobuf/internal/strs hash from crypto+ @@ -257,9 +311,50 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa hash/fnv from google.golang.org/protobuf/internal/detrand hash/maphash from go4.org/mem html from net/http/pprof+ + html/template from tailscale.com/cmd/derper+ + internal/abi from crypto/x509/internal/macos+ + internal/asan from internal/runtime/maps+ + internal/bisect from internal/godebug + internal/bytealg from bytes+ + internal/byteorder from crypto/cipher+ + internal/chacha8rand from math/rand/v2+ + internal/coverage/rtcov from runtime + internal/cpu from crypto/internal/fips140deps/cpu+ + internal/filepathlite from os+ + internal/fmtsort from fmt+ + internal/goarch from crypto/internal/fips140deps/cpu+ + internal/godebug from crypto/internal/fips140deps/godebug+ + internal/godebugs from internal/godebug+ + internal/goexperiment from hash/maphash+ + internal/goos from crypto/x509+ + internal/itoa from internal/poll+ + internal/msan from internal/runtime/maps+ + internal/nettrace from net+ + internal/oserror from io/fs+ + internal/poll from net+ + internal/profile from net/http/pprof + internal/profilerecord from runtime+ + internal/race from internal/poll+ + internal/reflectlite from context+ + internal/runtime/atomic from internal/runtime/exithook+ + internal/runtime/exithook from runtime + internal/runtime/maps from reflect+ + internal/runtime/math from internal/runtime/maps+ + internal/runtime/sys from crypto/subtle+ + L internal/runtime/syscall from runtime+ + internal/singleflight from net + internal/stringslite from embed+ + internal/sync from sync+ + internal/syscall/execenv from os+ + LD internal/syscall/unix from crypto/internal/sysrand+ + W internal/syscall/windows from crypto/internal/sysrand+ + W internal/syscall/windows/registry from mime+ + W internal/syscall/windows/sysdll from internal/syscall/windows+ + internal/testlog from os + internal/unsafeheader from internal/reflectlite+ io from bufio+ io/fs from crypto/x509+ - io/ioutil from github.com/mitchellh/go-ps+ + L io/ioutil from github.com/mitchellh/go-ps+ iter from maps+ log from expvar+ log/internal from log @@ -268,7 +363,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa math/big from crypto/dsa+ math/bits from compress/flate+ math/rand from github.com/mdlayher/netlink+ - math/rand/v2 from tailscale.com/util/fastuuid+ + math/rand/v2 from crypto/ecdsa+ mime from github.com/prometheus/common/expfmt+ mime/multipart from net/http mime/quotedprintable from mime/multipart @@ -276,19 +371,21 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa net/http from expvar+ net/http/httptrace from net/http+ net/http/internal from net/http + net/http/internal/ascii from net/http net/http/pprof from tailscale.com/tsweb net/netip from go4.org/netipx+ net/textproto from golang.org/x/net/http/httpguts+ net/url from crypto/x509+ - os from crypto/rand+ + os from crypto/internal/sysrand+ os/exec from github.com/coreos/go-iptables/iptables+ os/signal from tailscale.com/cmd/derper - W os/user from tailscale.com/util/winutil + W os/user from tailscale.com/util/winutil+ path from github.com/prometheus/client_golang/prometheus/internal+ path/filepath from crypto/x509+ reflect from crypto/x509+ regexp from github.com/coreos/go-iptables/iptables+ regexp/syntax from regexp + runtime from crypto/internal/fips140+ runtime/debug from github.com/prometheus/client_golang/prometheus+ runtime/metrics from github.com/prometheus/client_golang/prometheus+ runtime/pprof from net/http/pprof @@ -299,10 +396,14 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa strings from bufio+ sync from compress/flate+ sync/atomic from context+ - syscall from crypto/rand+ + syscall from crypto/internal/sysrand+ text/tabwriter from runtime/pprof + text/template from html/template + text/template/parse from html/template+ time from compress/gzip+ unicode from bytes+ unicode/utf16 from crypto/x509+ unicode/utf8 from bufio+ unique from net/netip + unsafe from bytes+ + weak from unique diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index 80c9dc44f138f..3c6fda68c4d59 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -19,6 +19,7 @@ import ( "expvar" "flag" "fmt" + "html/template" "io" "log" "math" @@ -26,6 +27,7 @@ import ( "net/http" "os" "os/signal" + "path" "path/filepath" "regexp" "runtime" @@ -35,6 +37,7 @@ import ( "syscall" "time" + "github.com/tailscale/setec/client/setec" "golang.org/x/time/rate" "tailscale.com/atomicfile" "tailscale.com/derp" @@ -46,6 +49,9 @@ import ( "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/version" + + // Support for prometheus varz in tsweb + _ "tailscale.com/tsweb/promvarz" ) var ( @@ -57,18 +63,25 @@ var ( configPath = flag.String("c", "", "config file path") certMode = flag.String("certmode", "letsencrypt", "mode for getting a cert. possible options: manual, letsencrypt") certDir = flag.String("certdir", tsweb.DefaultCertDir("derper-certs"), "directory to store LetsEncrypt certs, if addr's port is :443") - hostname = flag.String("hostname", "derp.tailscale.com", "LetsEncrypt host name, if addr's port is :443") + hostname = flag.String("hostname", "derp.tailscale.com", "LetsEncrypt host name, if addr's port is :443. When --certmode=manual, this can be an IP address to avoid SNI checks") runSTUN = flag.Bool("stun", true, "whether to run a STUN server. It will bind to the same IP (if any) as the --addr flag value.") runDERP = flag.Bool("derp", true, "whether to run a DERP server. The only reason to set this false is if you're decommissioning a server but want to keep its bootstrap DNS functionality still running.") + flagHome = flag.String("home", "", "what to serve at the root path. It may be left empty (the default, for a default homepage), \"blank\" for a blank page, or a URL to redirect to") meshPSKFile = flag.String("mesh-psk-file", defaultMeshPSKFile(), "if non-empty, path to file containing the mesh pre-shared key file. It should contain some hex string; whitespace is trimmed.") - meshWith = flag.String("mesh-with", "", "optional comma-separated list of hostnames to mesh with; the server's own hostname can be in the list") + meshWith = flag.String("mesh-with", "", "optional comma-separated list of hostnames to mesh with; the server's own hostname can be in the list. If an entry contains a slash, the second part names a hostname to be used when dialing the target.") + secretsURL = flag.String("secrets-url", "", "SETEC server URL for secrets retrieval of mesh key") + secretPrefix = flag.String("secrets-path-prefix", "prod/derp", "setec path prefix for \""+setecMeshKeyName+"\" secret for DERP mesh key") + secretsCacheDir = flag.String("secrets-cache-dir", defaultSetecCacheDir(), "directory to cache setec secrets in (required if --secrets-url is set)") bootstrapDNS = flag.String("bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns") unpublishedDNS = flag.String("unpublished-bootstrap-dns-names", "", "optional comma-separated list of hostnames to make available at /bootstrap-dns and not publish in the list. If an entry contains a slash, the second part names a DNS record to poll for its TXT record with a `0` to `100` value for rollout percentage.") + verifyClients = flag.Bool("verify-clients", false, "verify clients to this DERP server through a local tailscaled instance.") verifyClientURL = flag.String("verify-client-url", "", "if non-empty, an admission controller URL for permitting client connections; see tailcfg.DERPAdmitClientRequest") verifyFailOpen = flag.Bool("verify-client-url-fail-open", true, "whether we fail open if --verify-client-url is unreachable") + socket = flag.String("socket", "", "optional alternate path to tailscaled socket (only relevant when using --verify-clients)") + acceptConnLimit = flag.Float64("accept-connection-limit", math.Inf(+1), "rate limit for accepting new connection") acceptConnBurst = flag.Int("accept-connection-burst", math.MaxInt, "burst limit for accepting new connection") @@ -76,13 +89,21 @@ var ( tcpKeepAlive = flag.Duration("tcp-keepalive-time", 10*time.Minute, "TCP keepalive time") // tcpUserTimeout is intentionally short, so that hung connections are cleaned up promptly. DERPs should be nearby users. tcpUserTimeout = flag.Duration("tcp-user-timeout", 15*time.Second, "TCP user timeout") + // tcpWriteTimeout is the timeout for writing to client TCP connections. It does not apply to mesh connections. + tcpWriteTimeout = flag.Duration("tcp-write-timeout", derp.DefaultTCPWiteTimeout, "TCP write timeout; 0 results in no timeout being set on writes") ) var ( tlsRequestVersion = &metrics.LabelMap{Label: "version"} tlsActiveVersion = &metrics.LabelMap{Label: "version"} + + // Exactly 64 hexadecimal lowercase digits. + validMeshKey = regexp.MustCompile(`^[0-9a-f]{64}$`) ) +const setecMeshKeyName = "meshkey" +const meshKeyEnvVar = "TAILSCALE_DERPER_MESH_KEY" + func init() { expvar.Publish("derper_tls_request_version", tlsRequestVersion) expvar.Publish("gauge_derper_tls_active_version", tlsActiveVersion) @@ -138,6 +159,14 @@ func writeNewConfig() config { return cfg } +func checkMeshKey(key string) (string, error) { + key = strings.TrimSpace(key) + if !validMeshKey.MatchString(key) { + return "", errors.New("key must contain exactly 64 hex digits") + } + return key, nil +} + func main() { flag.Parse() if *versionFlag { @@ -170,26 +199,70 @@ func main() { s := derp.NewServer(cfg.PrivateKey, log.Printf) s.SetVerifyClient(*verifyClients) + s.SetTailscaledSocketPath(*socket) s.SetVerifyClientURL(*verifyClientURL) s.SetVerifyClientURLFailOpen(*verifyFailOpen) + s.SetTCPWriteTimeout(*tcpWriteTimeout) - if *meshPSKFile != "" { - b, err := os.ReadFile(*meshPSKFile) + var meshKey string + if *dev { + meshKey = os.Getenv(meshKeyEnvVar) + if meshKey == "" { + log.Printf("No mesh key specified for dev via %s\n", meshKeyEnvVar) + } else { + log.Printf("Set mesh key from %s\n", meshKeyEnvVar) + } + } else if *secretsURL != "" { + meshKeySecret := path.Join(*secretPrefix, setecMeshKeyName) + fc, err := setec.NewFileCache(*secretsCacheDir) if err != nil { - log.Fatal(err) + log.Fatalf("NewFileCache: %v", err) + } + log.Printf("Setting up setec store from %q", *secretsURL) + st, err := setec.NewStore(ctx, + setec.StoreConfig{ + Client: setec.Client{Server: *secretsURL}, + Secrets: []string{ + meshKeySecret, + }, + Cache: fc, + }) + if err != nil { + log.Fatalf("NewStore: %v", err) } - key := strings.TrimSpace(string(b)) - if matched, _ := regexp.MatchString(`(?i)^[0-9a-f]{64,}$`, key); !matched { - log.Fatalf("key in %s must contain 64+ hex digits", *meshPSKFile) + meshKey = st.Secret(meshKeySecret).GetString() + log.Println("Got mesh key from setec store") + st.Close() + } else if *meshPSKFile != "" { + b, err := setec.StaticFile(*meshPSKFile) + if err != nil { + log.Fatalf("StaticFile failed to get key: %v", err) } + log.Println("Got mesh key from static file") + meshKey = b.GetString() + } + + if meshKey == "" && *dev { + log.Printf("No mesh key configured for --dev mode") + } else if meshKey == "" { + log.Printf("No mesh key configured") + } else if key, err := checkMeshKey(meshKey); err != nil { + log.Fatalf("invalid mesh key: %v", err) + } else { s.SetMeshKey(key) - log.Printf("DERP mesh key configured") + log.Println("DERP mesh key configured") } + if err := startMesh(s); err != nil { log.Fatalf("startMesh: %v", err) } expvar.Publish("derp", s.ExpVar()) + handleHome, ok := getHomeHandler(*flagHome) + if !ok { + log.Fatalf("unknown --home value %q", *flagHome) + } + mux := http.NewServeMux() if *runDERP { derpHandler := derphttp.Handler(s) @@ -210,28 +283,7 @@ func main() { mux.HandleFunc("/bootstrap-dns", tsweb.BrowserHeaderHandlerFunc(handleBootstrapDNS)) mux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tsweb.AddBrowserHeaders(w) - w.Header().Set("Content-Type", "text/html; charset=utf-8") - w.WriteHeader(200) - io.WriteString(w, ` -

DERP

-

- This is a Tailscale DERP server. -

-

- Documentation: -

- -`) - if !*runDERP { - io.WriteString(w, `

Status: disabled

`) - } - if tsweb.AllowDebugAccess(r) { - io.WriteString(w, "

Debug info at /debug/.

\n") - } + handleHome.ServeHTTP(w, r) })) mux.Handle("/robots.txt", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tsweb.AddBrowserHeaders(w) @@ -273,6 +325,9 @@ func main() { Control: ktimeout.UserTimeout(*tcpUserTimeout), KeepAlive: *tcpKeepAlive, } + // As of 2025-02-19, MPTCP does not support TCP_USER_TIMEOUT socket option + // set in ktimeout.UserTimeout above. + lc.SetMultipathTCP(false) quietLogger := log.New(logger.HTTPServerLogFilter{Inner: log.Printf}, "", 0) httpsrv := &http.Server{ @@ -387,6 +442,10 @@ func prodAutocertHostPolicy(_ context.Context, host string) error { return errors.New("invalid hostname") } +func defaultSetecCacheDir() string { + return filepath.Join(os.Getenv("HOME"), ".cache", "derper-secrets") +} + func defaultMeshPSKFile() string { try := []string{ "/home/derp/keys/derp-mesh.key", @@ -468,3 +527,84 @@ func init() { return 0 })) } + +type templateData struct { + ShowAbuseInfo bool + Disabled bool + AllowDebug bool +} + +// homePageTemplate renders the home page using [templateData]. +var homePageTemplate = template.Must(template.New("home").Parse(` +

DERP

+

+ This is a Tailscale DERP server. +

+ +

+ It provides STUN, interactive connectivity establishment, and relaying of end-to-end encrypted traffic + for Tailscale clients. +

+ +{{if .ShowAbuseInfo }} +

+ If you suspect abuse, please contact security@tailscale.com. +

+{{end}} + +

+ Documentation: +

+ + + +{{if .Disabled}} +

Status: disabled

+{{end}} + +{{if .AllowDebug}} +

Debug info at /debug/.

+{{end}} + + +`)) + +// getHomeHandler returns a handler for the home page based on a flag string +// as documented on the --home flag. +func getHomeHandler(val string) (_ http.Handler, ok bool) { + if val == "" { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(200) + err := homePageTemplate.Execute(w, templateData{ + ShowAbuseInfo: validProdHostname.MatchString(*hostname), + Disabled: !*runDERP, + AllowDebug: tsweb.AllowDebugAccess(r), + }) + if err != nil { + if r.Context().Err() == nil { + log.Printf("homePageTemplate.Execute: %v", err) + } + return + } + }), true + } + if val == "blank" { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(200) + }), true + } + if strings.HasPrefix(val, "http://") || strings.HasPrefix(val, "https://") { + return http.RedirectHandler(val, http.StatusFound), true + } + return nil, false +} diff --git a/cmd/derper/derper_test.go b/cmd/derper/derper_test.go index 553a78f9f6426..12686ce4eb5f3 100644 --- a/cmd/derper/derper_test.go +++ b/cmd/derper/derper_test.go @@ -4,6 +4,7 @@ package main import ( + "bytes" "context" "net/http" "net/http/httptest" @@ -107,6 +108,76 @@ func TestDeps(t *testing.T) { "gvisor.dev/gvisor/pkg/tcpip/header": "https://github.com/tailscale/tailscale/issues/9756", "tailscale.com/net/packet": "not needed in derper", "github.com/gaissmai/bart": "not needed in derper", + "database/sql/driver": "not needed in derper", // previously came in via github.com/google/uuid }, }.Check(t) } + +func TestTemplate(t *testing.T) { + buf := &bytes.Buffer{} + err := homePageTemplate.Execute(buf, templateData{ + ShowAbuseInfo: true, + Disabled: true, + AllowDebug: true, + }) + if err != nil { + t.Fatal(err) + } + + str := buf.String() + if !strings.Contains(str, "If you suspect abuse") { + t.Error("Output is missing abuse mailto") + } + if !strings.Contains(str, "Tailscale Security Policies") { + t.Error("Output is missing Tailscale Security Policies link") + } + if !strings.Contains(str, "Status:") { + t.Error("Output is missing disabled status") + } + if !strings.Contains(str, "Debug info") { + t.Error("Output is missing debug info") + } +} + +func TestCheckMeshKey(t *testing.T) { + testCases := []struct { + name string + input string + want string + wantErr bool + }{ + { + name: "KeyOkay", + input: "f1ffafffffffffffffffffffffffffffffffffffffffffffffffff2ffffcfff6", + want: "f1ffafffffffffffffffffffffffffffffffffffffffffffffffff2ffffcfff6", + wantErr: false, + }, + { + name: "TrimKeyOkay", + input: " f1ffafffffffffffffffffffffffffffffffffffffffffffffffff2ffffcfff6 ", + want: "f1ffafffffffffffffffffffffffffffffffffffffffffffffffff2ffffcfff6", + wantErr: false, + }, + { + name: "NotAKey", + input: "zzthisisnotakey", + want: "", + wantErr: true, + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + k, err := checkMeshKey(tt.input) + if err != nil && !tt.wantErr { + t.Errorf("unexpected error: %v", err) + } + if k != tt.want && err == nil { + t.Errorf("want: %s doesn't match expected: %s", tt.want, k) + } + + }) + } + +} diff --git a/cmd/derper/mesh.go b/cmd/derper/mesh.go index ee1807f001202..1d8e3ef93c8b3 100644 --- a/cmd/derper/mesh.go +++ b/cmd/derper/mesh.go @@ -10,7 +10,6 @@ import ( "log" "net" "strings" - "time" "tailscale.com/derp" "tailscale.com/derp/derphttp" @@ -25,15 +24,28 @@ func startMesh(s *derp.Server) error { if !s.HasMeshKey() { return errors.New("--mesh-with requires --mesh-psk-file") } - for _, host := range strings.Split(*meshWith, ",") { - if err := startMeshWithHost(s, host); err != nil { + for _, hostTuple := range strings.Split(*meshWith, ",") { + if err := startMeshWithHost(s, hostTuple); err != nil { return err } } return nil } -func startMeshWithHost(s *derp.Server, host string) error { +func startMeshWithHost(s *derp.Server, hostTuple string) error { + var host string + var dialHost string + hostParts := strings.Split(hostTuple, "/") + if len(hostParts) > 2 { + return fmt.Errorf("too many components in host tuple %q", hostTuple) + } + host = hostParts[0] + if len(hostParts) == 2 { + dialHost = hostParts[1] + } else { + dialHost = hostParts[0] + } + logf := logger.WithPrefix(log.Printf, fmt.Sprintf("mesh(%q): ", host)) netMon := netmon.NewStatic() // good enough for cmd/derper; no need for netns fanciness c, err := derphttp.NewClient(s.PrivateKey(), "https://"+host+"/derp", logf, netMon) @@ -43,31 +55,20 @@ func startMeshWithHost(s *derp.Server, host string) error { c.MeshKey = s.MeshKey() c.WatchConnectionChanges = true - // For meshed peers within a region, connect via VPC addresses. - c.SetURLDialer(func(ctx context.Context, network, addr string) (net.Conn, error) { - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } + logf("will dial %q for %q", dialHost, host) + if dialHost != host { var d net.Dialer - var r net.Resolver - if base, ok := strings.CutSuffix(host, ".tailscale.com"); ok && port == "443" { - subCtx, cancel := context.WithTimeout(ctx, 2*time.Second) - defer cancel() - vpcHost := base + "-vpc.tailscale.com" - ips, _ := r.LookupIP(subCtx, "ip", vpcHost) - if len(ips) > 0 { - vpcAddr := net.JoinHostPort(ips[0].String(), port) - c, err := d.DialContext(subCtx, network, vpcAddr) - if err == nil { - log.Printf("connected to %v (%v) instead of %v", vpcHost, ips[0], base) - return c, nil - } - log.Printf("failed to connect to %v (%v): %v; trying non-VPC route", vpcHost, ips[0], err) + c.SetURLDialer(func(ctx context.Context, network, addr string) (net.Conn, error) { + _, port, err := net.SplitHostPort(addr) + if err != nil { + logf("failed to split %q: %v", addr, err) + return nil, err } - } - return d.DialContext(ctx, network, addr) - }) + dialAddr := net.JoinHostPort(dialHost, port) + logf("dialing %q instead of %q", dialAddr, addr) + return d.DialContext(ctx, network, dialAddr) + }) + } add := func(m derp.PeerPresentMessage) { s.AddPacketForwarder(m.Key, c) } remove := func(m derp.PeerGoneMessage) { s.RemovePacketForwarder(m.Peer, c) } diff --git a/cmd/derpprobe/derpprobe.go b/cmd/derpprobe/derpprobe.go index 5b7b77091de7f..2723a31aea471 100644 --- a/cmd/derpprobe/derpprobe.go +++ b/cmd/derpprobe/derpprobe.go @@ -9,26 +9,34 @@ import ( "fmt" "log" "net/http" + "os" "sort" "time" "tailscale.com/prober" "tailscale.com/tsweb" "tailscale.com/version" + + // Support for prometheus varz in tsweb + _ "tailscale.com/tsweb/promvarz" ) var ( - derpMapURL = flag.String("derp-map", "https://login.tailscale.com/derpmap/default", "URL to DERP map (https:// or file://) or 'local' to use the local tailscaled's DERP map") - versionFlag = flag.Bool("version", false, "print version and exit") - listen = flag.String("listen", ":8030", "HTTP listen address") - probeOnce = flag.Bool("once", false, "probe once and print results, then exit; ignores the listen flag") - spread = flag.Bool("spread", true, "whether to spread probing over time") - interval = flag.Duration("interval", 15*time.Second, "probe interval") - meshInterval = flag.Duration("mesh-interval", 15*time.Second, "mesh probe interval") - stunInterval = flag.Duration("stun-interval", 15*time.Second, "STUN probe interval") - tlsInterval = flag.Duration("tls-interval", 15*time.Second, "TLS probe interval") - bwInterval = flag.Duration("bw-interval", 0, "bandwidth probe interval (0 = no bandwidth probing)") - bwSize = flag.Int64("bw-probe-size-bytes", 1_000_000, "bandwidth probe size") + derpMapURL = flag.String("derp-map", "https://login.tailscale.com/derpmap/default", "URL to DERP map (https:// or file://) or 'local' to use the local tailscaled's DERP map") + versionFlag = flag.Bool("version", false, "print version and exit") + listen = flag.String("listen", ":8030", "HTTP listen address") + probeOnce = flag.Bool("once", false, "probe once and print results, then exit; ignores the listen flag") + spread = flag.Bool("spread", true, "whether to spread probing over time") + interval = flag.Duration("interval", 15*time.Second, "probe interval") + meshInterval = flag.Duration("mesh-interval", 15*time.Second, "mesh probe interval") + stunInterval = flag.Duration("stun-interval", 15*time.Second, "STUN probe interval") + tlsInterval = flag.Duration("tls-interval", 15*time.Second, "TLS probe interval") + bwInterval = flag.Duration("bw-interval", 0, "bandwidth probe interval (0 = no bandwidth probing)") + bwSize = flag.Int64("bw-probe-size-bytes", 1_000_000, "bandwidth probe size") + bwTUNIPv4Address = flag.String("bw-tun-ipv4-addr", "", "if specified, bandwidth probes will be performed over a TUN device at this address in order to exercise TCP-in-TCP in similar fashion to TCP over Tailscale via DERP; we will use a /30 subnet including this IP address") + qdPacketsPerSecond = flag.Int("qd-packets-per-second", 0, "if greater than 0, queuing delay will be measured continuously using 260 byte packets (approximate size of a CallMeMaybe packet) sent at this rate per second") + qdPacketTimeout = flag.Duration("qd-packet-timeout", 5*time.Second, "queuing delay packets arriving after this period of time from being sent are treated like dropped packets and don't count toward queuing delay timings") + regionCodeOrID = flag.String("region-code", "", "probe only this region (e.g. 'lax' or '17'); if left blank, all regions will be probed") ) func main() { @@ -43,9 +51,13 @@ func main() { prober.WithMeshProbing(*meshInterval), prober.WithSTUNProbing(*stunInterval), prober.WithTLSProbing(*tlsInterval), + prober.WithQueuingDelayProbing(*qdPacketsPerSecond, *qdPacketTimeout), } if *bwInterval > 0 { - opts = append(opts, prober.WithBandwidthProbing(*bwInterval, *bwSize)) + opts = append(opts, prober.WithBandwidthProbing(*bwInterval, *bwSize, *bwTUNIPv4Address)) + } + if *regionCodeOrID != "" { + opts = append(opts, prober.WithRegionCodeOrID(*regionCodeOrID)) } dp, err := prober.DERP(p, *derpMapURL, opts...) if err != nil { @@ -64,6 +76,9 @@ func main() { for _, s := range st.bad { log.Printf("bad: %s", s) } + if len(st.bad) > 0 { + os.Exit(1) + } return } @@ -102,7 +117,7 @@ func getOverallStatus(p *prober.Prober) (o overallStatus) { // Do not show probes that have not finished yet. continue } - if i.Result { + if i.Status == prober.ProbeStatusSucceeded { o.addGoodf("%s: %s", p, i.Latency) } else { o.addBadf("%s: %s", p, i.Error) diff --git a/cmd/dist/dist.go b/cmd/dist/dist.go index 05f5bbfb231a8..038ced708e0f0 100644 --- a/cmd/dist/dist.go +++ b/cmd/dist/dist.go @@ -5,11 +5,13 @@ package main import ( + "cmp" "context" "errors" "flag" "log" "os" + "slices" "tailscale.com/release/dist" "tailscale.com/release/dist/cli" @@ -19,9 +21,12 @@ import ( ) var ( - synologyPackageCenter bool - qnapPrivateKeyPath string - qnapCertificatePath string + synologyPackageCenter bool + gcloudCredentialsBase64 string + gcloudProject string + gcloudKeyring string + qnapKeyName string + qnapCertificateBase64 string ) func getTargets() ([]dist.Target, error) { @@ -42,10 +47,11 @@ func getTargets() ([]dist.Target, error) { // To build for package center, run // ./tool/go run ./cmd/dist build --synology-package-center synology ret = append(ret, synology.Targets(synologyPackageCenter, nil)...) - if (qnapPrivateKeyPath == "") != (qnapCertificatePath == "") { - return nil, errors.New("both --qnap-private-key-path and --qnap-certificate-path must be set") + qnapSigningArgs := []string{gcloudCredentialsBase64, gcloudProject, gcloudKeyring, qnapKeyName, qnapCertificateBase64} + if cmp.Or(qnapSigningArgs...) != "" && slices.Contains(qnapSigningArgs, "") { + return nil, errors.New("all of --gcloud-credentials, --gcloud-project, --gcloud-keyring, --qnap-key-name and --qnap-certificate must be set") } - ret = append(ret, qnap.Targets(qnapPrivateKeyPath, qnapCertificatePath)...) + ret = append(ret, qnap.Targets(gcloudCredentialsBase64, gcloudProject, gcloudKeyring, qnapKeyName, qnapCertificateBase64)...) return ret, nil } @@ -54,8 +60,11 @@ func main() { for _, subcmd := range cmd.Subcommands { if subcmd.Name == "build" { subcmd.FlagSet.BoolVar(&synologyPackageCenter, "synology-package-center", false, "build synology packages with extra metadata for the official package center") - subcmd.FlagSet.StringVar(&qnapPrivateKeyPath, "qnap-private-key-path", "", "sign qnap packages with given key (must also provide --qnap-certificate-path)") - subcmd.FlagSet.StringVar(&qnapCertificatePath, "qnap-certificate-path", "", "sign qnap packages with given certificate (must also provide --qnap-private-key-path)") + subcmd.FlagSet.StringVar(&gcloudCredentialsBase64, "gcloud-credentials", "", "base64 encoded GCP credentials (used when signing QNAP builds)") + subcmd.FlagSet.StringVar(&gcloudProject, "gcloud-project", "", "name of project in GCP KMS (used when signing QNAP builds)") + subcmd.FlagSet.StringVar(&gcloudKeyring, "gcloud-keyring", "", "path to keyring in GCP KMS (used when signing QNAP builds)") + subcmd.FlagSet.StringVar(&qnapKeyName, "qnap-key-name", "", "name of GCP key to use when signing QNAP builds") + subcmd.FlagSet.StringVar(&qnapCertificateBase64, "qnap-certificate", "", "base64 encoded certificate to use when signing QNAP builds") } } diff --git a/cmd/get-authkey/main.go b/cmd/get-authkey/main.go index 777258d64b21b..ec7ab5d2c6158 100644 --- a/cmd/get-authkey/main.go +++ b/cmd/get-authkey/main.go @@ -16,14 +16,10 @@ import ( "strings" "golang.org/x/oauth2/clientcredentials" - "tailscale.com/client/tailscale" + "tailscale.com/internal/client/tailscale" ) func main() { - // Required to use our client API. We're fine with the instability since the - // client lives in the same repo as this code. - tailscale.I_Acknowledge_This_API_Is_Unstable = true - reusable := flag.Bool("reusable", false, "allocate a reusable authkey") ephemeral := flag.Bool("ephemeral", false, "allocate an ephemeral authkey") preauth := flag.Bool("preauth", true, "set the authkey as pre-authorized") @@ -46,7 +42,6 @@ func main() { ClientID: clientID, ClientSecret: clientSecret, TokenURL: baseURL + "/api/v2/oauth/token", - Scopes: []string{"device"}, } ctx := context.Background() diff --git a/cmd/gitops-pusher/gitops-pusher.go b/cmd/gitops-pusher/gitops-pusher.go index c33937ef24959..690ca287056d3 100644 --- a/cmd/gitops-pusher/gitops-pusher.go +++ b/cmd/gitops-pusher/gitops-pusher.go @@ -13,6 +13,7 @@ import ( "encoding/json" "flag" "fmt" + "io" "log" "net/http" "os" @@ -58,8 +59,8 @@ func apply(cache *Cache, client *http.Client, tailnet, apiKey string) func(conte } if cache.PrevETag == "" { - log.Println("no previous etag found, assuming local file is correct and recording that") - cache.PrevETag = localEtag + log.Println("no previous etag found, assuming the latest control etag") + cache.PrevETag = controlEtag } log.Printf("control: %s", controlEtag) @@ -105,8 +106,8 @@ func test(cache *Cache, client *http.Client, tailnet, apiKey string) func(contex } if cache.PrevETag == "" { - log.Println("no previous etag found, assuming local file is correct and recording that") - cache.PrevETag = localEtag + log.Println("no previous etag found, assuming the latest control etag") + cache.PrevETag = controlEtag } log.Printf("control: %s", controlEtag) @@ -148,8 +149,8 @@ func getChecksums(cache *Cache, client *http.Client, tailnet, apiKey string) fun } if cache.PrevETag == "" { - log.Println("no previous etag found, assuming local file is correct and recording that") - cache.PrevETag = Shuck(localEtag) + log.Println("no previous etag found, assuming control etag") + cache.PrevETag = Shuck(controlEtag) } log.Printf("control: %s", controlEtag) @@ -405,7 +406,8 @@ func getACLETag(ctx context.Context, client *http.Client, tailnet, apiKey string got := resp.StatusCode want := http.StatusOK if got != want { - return "", fmt.Errorf("wanted HTTP status code %d but got %d", want, got) + errorDetails, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("wanted HTTP status code %d but got %d: %#q", want, got, string(errorDetails)) } return Shuck(resp.Header.Get("ETag")), nil diff --git a/cmd/hello/hello.go b/cmd/hello/hello.go index e4b0ca8278095..fa116b28b15ab 100644 --- a/cmd/hello/hello.go +++ b/cmd/hello/hello.go @@ -18,8 +18,9 @@ import ( "strings" "time" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/client/tailscale/apitype" + "tailscale.com/tailcfg" ) var ( @@ -31,7 +32,7 @@ var ( //go:embed hello.tmpl.html var embeddedTemplate string -var localClient tailscale.LocalClient +var localClient local.Client func main() { flag.Parse() @@ -134,6 +135,10 @@ func tailscaleIP(who *apitype.WhoIsResponse) string { if who == nil { return "" } + vals, err := tailcfg.UnmarshalNodeCapJSON[string](who.Node.CapMap, tailcfg.NodeAttrNativeIPV4) + if err == nil && len(vals) > 0 { + return vals[0] + } for _, nodeIP := range who.Node.Addresses { if nodeIP.Addr().Is4() && nodeIP.IsSingleIP() { return nodeIP.Addr().String() diff --git a/cmd/k8s-nameserver/main.go b/cmd/k8s-nameserver/main.go index ca4b449358083..76888ebc7f576 100644 --- a/cmd/k8s-nameserver/main.go +++ b/cmd/k8s-nameserver/main.go @@ -11,6 +11,7 @@ package main import ( "context" "encoding/json" + "flag" "fmt" "log" "net" @@ -19,16 +20,38 @@ import ( "path/filepath" "sync" "syscall" + "time" "github.com/fsnotify/fsnotify" "github.com/miekg/dns" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/fields" + "k8s.io/client-go/informers" + "k8s.io/client-go/kubernetes" + listersv1 "k8s.io/client-go/listers/core/v1" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/cache" operatorutils "tailscale.com/k8s-operator" "tailscale.com/util/dnsname" ) +var ( + // domain is the DNS domain that this nameserver has registered a handler for. + domain = flag.String("domain", "ts.net", "the DNS domain to serve records for") + updateMode = flag.String( + "update-mode", + mountAccessUpdateMode, + fmt.Sprintf( + "how to detect changes to the configMap which contains the DNS entries.\n"+ + "%q watches the mounted configMap for changes.\n"+ + "%q watches the configMap directly via the Kubernetes API.", + mountAccessUpdateMode, + directAccessUpdateMode, + ), + ) +) + const ( - // tsNetDomain is the domain that this DNS nameserver has registered a handler for. - tsNetDomain = "ts.net" // addr is the the address that the UDP and TCP listeners will listen on. addr = ":1053" @@ -37,10 +60,31 @@ const ( // /config is the only supported way for configuring this nameserver. defaultDNSConfigDir = "/config" kubeletMountedConfigLn = "..data" + + // configMapName is the name of the configMap which needs to be watched + // for changes when using the non-mount update mode. + configMapName = "dnsrecords" + // configMapKey is the configMap key which contains the DNS data + configMapKey = "records.json" + + // the update modes define how changes to the configMap are detected. + // Either by watching the mounted file (might be slower due to the time + // needed for syncing) or by watching the configMap directly (needs + // more permissions for the service account the k8s-namesever runs + // with). + directAccessUpdateMode = "direct-access" + mountAccessUpdateMode = "mount" + + // configMapDefaultNamespace sets the default namespace for reading the + // configMap if the env variable POD_NAMESPACE is not set. Otherwise + // the content of the POD_NAMESPACE env variable determines where to + // read the configMap from. This only matters when using direct access + // mode for updates. + configMapDefaultNamespace = "tailscale" ) // nameserver is a simple nameserver that responds to DNS queries for A records -// for ts.net domain names over UDP or TCP. It serves DNS responses from +// for the names of the given domain over UDP or TCP. It serves DNS responses from // in-memory IPv4 host records. It is intended to be deployed on Kubernetes with // a ConfigMap mounted at /config that should contain the host records. It // dynamically reconfigures its in-memory mappings as the contents of the @@ -49,7 +93,7 @@ type nameserver struct { // configReader returns the latest desired configuration (host records) // for the nameserver. By default it gets set to a reader that reads // from a Kubernetes ConfigMap mounted at /config, but this can be - // overridden in tests. + // overridden. configReader configReaderFunc // configWatcher is a watcher that returns an event when the desired // configuration has changed and the nameserver should update the @@ -63,26 +107,31 @@ type nameserver struct { } func main() { + flag.Parse() ctx, cancel := context.WithCancel(context.Background()) - // Ensure that we watch the kube Configmap mounted at /config for - // nameserver configuration updates and send events when updates happen. - c := ensureWatcherForKubeConfigMap(ctx) + if !validUpdateMode(*updateMode) { + log.Fatalf("non valid update mode: %q", *updateMode) + } + reader, watcher, err := configMapReaderAndWatcher(ctx, *updateMode) + if err != nil { + log.Fatalf("can not setup configMap reader: %v", err) + } ns := &nameserver{ - configReader: configMapConfigReader, - configWatcher: c, + configReader: reader, + configWatcher: watcher, } // Ensure that in-memory records get set up to date now and will get // reset when the configuration changes. ns.runRecordsReconciler(ctx) - // Register a DNS server handle for ts.net domain names. Not having a + // Register a DNS server handle for names of the domain. Not having a // handle registered for any other domain names is how we enforce that - // this nameserver can only be used for ts.net domains - querying any + // this nameserver can only be used for the given domain - querying any // other domain names returns Rcode Refused. - dns.HandleFunc(tsNetDomain, ns.handleFunc()) + dns.HandleFunc(*domain, ns.handleFunc()) // Listen for DNS queries over UDP and TCP. udpSig := make(chan os.Signal) @@ -112,7 +161,7 @@ func (n *nameserver) handleFunc() func(w dns.ResponseWriter, r *dns.Msg) { h := func(w dns.ResponseWriter, r *dns.Msg) { m := new(dns.Msg) defer func() { - w.WriteMsg(m) + _ = w.WriteMsg(m) }() if len(r.Question) < 1 { log.Print("[unexpected] nameserver received a request with no questions") @@ -135,7 +184,7 @@ func (n *nameserver) handleFunc() func(w dns.ResponseWriter, r *dns.Msg) { m.RecursionAvailable = false ips := n.lookupIP4(fqdn) - if ips == nil || len(ips) == 0 { + if len(ips) == 0 { // As we are the authoritative nameserver for MagicDNS // names, if we do not have a record for this MagicDNS // name, it does not exist. @@ -231,7 +280,7 @@ func (n *nameserver) resetRecords() error { log.Printf("error reading nameserver's configuration: %v", err) return err } - if dnsCfgBytes == nil || len(dnsCfgBytes) < 1 { + if len(dnsCfgBytes) < 1 { log.Print("nameserver's configuration is empty, any in-memory records will be unset") n.mu.Lock() n.ip4 = make(map[dnsname.FQDN][]net.IP) @@ -284,7 +333,7 @@ func listenAndServe(net, addr string, shutdown chan os.Signal) { go func() { <-shutdown log.Printf("shutting down server for %s", net) - s.Shutdown() + _ = s.Shutdown() }() log.Printf("listening for %s queries on %s", net, addr) if err := s.ListenAndServe(); err != nil { @@ -292,10 +341,86 @@ func listenAndServe(net, addr string, shutdown chan os.Signal) { } } -// ensureWatcherForKubeConfigMap sets up a new file watcher for the ConfigMap +func getClientset() (*kubernetes.Clientset, error) { + config, err := rest.InClusterConfig() + if err != nil { + return nil, fmt.Errorf("failed to load in-cluster config: %w", err) + } + clientset, err := kubernetes.NewForConfig(config) + if err != nil { + return nil, fmt.Errorf("failed to create clientset: %w", err) + } + return clientset, nil +} + +func getConfigMapNamespace() string { + namespace := configMapDefaultNamespace + if ns := os.Getenv("POD_NAMESPACE"); ns != "" { + namespace = ns + } + return namespace +} + +func configMapReaderAndWatcher(ctx context.Context, updateMode string) (configReaderFunc, chan string, error) { + switch updateMode { + case mountAccessUpdateMode: + return configMapMountedReader, watchMountedConfigMap(ctx), nil + case directAccessUpdateMode: + cs, err := getClientset() + if err != nil { + return nil, nil, err + } + watcherChannel, cacheReader, err := watchConfigMap(ctx, cs, configMapName, getConfigMapNamespace()) + if err != nil { + return nil, nil, err + } + return configMapCacheReader(cacheReader), watcherChannel, nil + default: + return nil, nil, fmt.Errorf("no implementation for update mode %q", updateMode) + } +} + +// watchConfigMap watches the configMap identified by the given name and +// namespace. It emits a message in the returned channel whenever the configMap +// updated. It also returns a configMapLister which allows to access the cached objects +// retrieved by the API server. +func watchConfigMap(ctx context.Context, cs kubernetes.Interface, configMapName, configMapNamespace string) (chan string, listersv1.ConfigMapLister, error) { + ch := make(chan string) + + fieldSelector := fields.OneTermEqualSelector("metadata.name", configMapName).String() + factory := informers.NewSharedInformerFactoryWithOptions( + cs, + // we resync every hour to account for missed watches + time.Hour, + informers.WithNamespace(configMapNamespace), + informers.WithTweakListOptions(func(options *metav1.ListOptions) { + options.FieldSelector = fieldSelector + }), + ) + cmFactory := factory.Core().V1().ConfigMaps() + _, _ = cmFactory.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{ + AddFunc: func(obj any) { + ch <- fmt.Sprintf("ConfigMap %s added or synced", configMapName) + }, + UpdateFunc: func(oldObj, newObj any) { + ch <- fmt.Sprintf("ConfigMap %s updated", configMapName) + }, + }) + factory.Start(ctx.Done()) + // Wait for cache sync + log.Println("waiting for configMap cache to sync") + if !cache.WaitForCacheSync(ctx.Done(), cmFactory.Informer().HasSynced) { + return nil, nil, fmt.Errorf("configMap cache did not sync successful") + } + log.Println("configMap cache successfully synced") + + return ch, cmFactory.Lister(), nil +} + +// watchMountedConfigMap sets up a new file watcher for the ConfigMap // that's expected to be mounted at /config. Returns a channel that receives an // event every time the contents get updated. -func ensureWatcherForKubeConfigMap(ctx context.Context) chan string { +func watchMountedConfigMap(ctx context.Context) chan string { c := make(chan string) watcher, err := fsnotify.NewWatcher() if err != nil { @@ -307,7 +432,9 @@ func ensureWatcherForKubeConfigMap(ctx context.Context) chan string { // https://github.com/kubernetes/kubernetes/blob/v1.28.1/pkg/volume/util/atomic_writer.go#L39-L61 toWatch := filepath.Join(defaultDNSConfigDir, kubeletMountedConfigLn) go func() { - defer watcher.Close() + defer func() { + _ = watcher.Close() + }() log.Printf("starting file watch for %s", defaultDNSConfigDir) for { select { @@ -354,9 +481,24 @@ func ensureWatcherForKubeConfigMap(ctx context.Context) chan string { // configReaderFunc is a function that returns the desired nameserver configuration. type configReaderFunc func() ([]byte, error) -// configMapConfigReader reads the desired nameserver configuration from a +func configMapCacheReader(lister listersv1.ConfigMapLister) configReaderFunc { + return func() ([]byte, error) { + cm, err := lister.ConfigMaps(getConfigMapNamespace()).Get(configMapName) + if err != nil { + return nil, fmt.Errorf("can not read configMap: %w", err) + } + if data, exists := cm.Data[configMapKey]; exists { + return []byte(data), nil + } + // if the configMap is empty we need to return `nil` which will + // be handled by the caller specifically + return nil, nil + } +} + +// configMapMountedReader reads the desired nameserver configuration from a // records.json file in a ConfigMap mounted at /config. -var configMapConfigReader configReaderFunc = func() ([]byte, error) { +var configMapMountedReader configReaderFunc = func() ([]byte, error) { if contents, err := os.ReadFile(filepath.Join(defaultDNSConfigDir, operatorutils.DNSRecordsCMKey)); err == nil { return contents, nil } else if os.IsNotExist(err) { @@ -377,3 +519,12 @@ func (n *nameserver) lookupIP4(fqdn dnsname.FQDN) []net.IP { f := n.ip4[fqdn] return f } + +func validUpdateMode(m string) bool { + switch m { + case directAccessUpdateMode, mountAccessUpdateMode: + return true + default: + return false + } +} diff --git a/cmd/k8s-nameserver/main_test.go b/cmd/k8s-nameserver/main_test.go index d9a33c4faffe5..e6ab931996425 100644 --- a/cmd/k8s-nameserver/main_test.go +++ b/cmd/k8s-nameserver/main_test.go @@ -6,11 +6,15 @@ package main import ( + "context" "net" "testing" "github.com/google/go-cmp/cmp" "github.com/miekg/dns" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/kubernetes/fake" "tailscale.com/util/dnsname" ) @@ -195,6 +199,33 @@ func TestResetRecords(t *testing.T) { t.Fatalf("unexpected nameserver.ip4 contents (-got +want): \n%s", diff) } }) + t.Run(tt.name+" (direct access)", func(t *testing.T) { + client := fake.NewSimpleClientset(&corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: configMapName, + Namespace: configMapDefaultNamespace, + }, + Data: map[string]string{ + configMapKey: string(tt.config), + }, + }) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + _, reader, err := watchConfigMap(ctx, client, configMapName, configMapDefaultNamespace) + if err != nil { + t.Fatal(err) + } + ns := &nameserver{ + ip4: tt.hasIp4, + configReader: configMapCacheReader(reader), + } + if err := ns.resetRecords(); err == nil == tt.wantsErr { + t.Errorf("resetRecords() returned err: %v, wantsErr: %v", err, tt.wantsErr) + } + if diff := cmp.Diff(ns.ip4, tt.wantsIp4); diff != "" { + t.Fatalf("unexpected nameserver.ip4 contents (-got +want): \n%s", diff) + } + }) } } diff --git a/cmd/k8s-operator/connector.go b/cmd/k8s-operator/connector.go index 016166b4cda29..c243036cbabd9 100644 --- a/cmd/k8s-operator/connector.go +++ b/cmd/k8s-operator/connector.go @@ -10,10 +10,12 @@ import ( "fmt" "net/netip" "slices" + "strings" "sync" "time" - "github.com/pkg/errors" + "errors" + "go.uber.org/zap" xslices "golang.org/x/exp/slices" corev1 "k8s.io/api/core/v1" @@ -34,6 +36,7 @@ import ( const ( reasonConnectorCreationFailed = "ConnectorCreationFailed" + reasonConnectorCreating = "ConnectorCreating" reasonConnectorCreated = "ConnectorCreated" reasonConnectorInvalid = "ConnectorInvalid" @@ -58,6 +61,7 @@ type ConnectorReconciler struct { subnetRouters set.Slice[types.UID] // for subnet routers gauge exitNodes set.Slice[types.UID] // for exit nodes gauge + appConnectors set.Slice[types.UID] // for app connectors gauge } var ( @@ -67,6 +71,8 @@ var ( gaugeConnectorSubnetRouterResources = clientmetric.NewGauge(kubetypes.MetricConnectorWithSubnetRouterCount) // gaugeConnectorExitNodeResources tracks the number of Connectors currently managed by this operator instance that are exit nodes. gaugeConnectorExitNodeResources = clientmetric.NewGauge(kubetypes.MetricConnectorWithExitNodeCount) + // gaugeConnectorAppConnectorResources tracks the number of Connectors currently managed by this operator instance that are app connectors. + gaugeConnectorAppConnectorResources = clientmetric.NewGauge(kubetypes.MetricConnectorWithAppConnectorCount) ) func (a *ConnectorReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { @@ -108,13 +114,12 @@ func (a *ConnectorReconciler) Reconcile(ctx context.Context, req reconcile.Reque oldCnStatus := cn.Status.DeepCopy() setStatus := func(cn *tsapi.Connector, _ tsapi.ConditionType, status metav1.ConditionStatus, reason, message string) (reconcile.Result, error) { tsoperator.SetConnectorCondition(cn, tsapi.ConnectorReady, status, reason, message, cn.Generation, a.clock, logger) - if !apiequality.Semantic.DeepEqual(oldCnStatus, cn.Status) { + var updateErr error + if !apiequality.Semantic.DeepEqual(oldCnStatus, &cn.Status) { // An error encountered here should get returned by the Reconcile function. - if updateErr := a.Client.Status().Update(ctx, cn); updateErr != nil { - err = errors.Wrap(err, updateErr.Error()) - } + updateErr = a.Client.Status().Update(ctx, cn) } - return res, err + return res, errors.Join(err, updateErr) } if !slices.Contains(cn.Finalizers, FinalizerName) { @@ -131,17 +136,24 @@ func (a *ConnectorReconciler) Reconcile(ctx context.Context, req reconcile.Reque } if err := a.validate(cn); err != nil { - logger.Errorf("error validating Connector spec: %w", err) message := fmt.Sprintf(messageConnectorInvalid, err) a.recorder.Eventf(cn, corev1.EventTypeWarning, reasonConnectorInvalid, message) return setStatus(cn, tsapi.ConnectorReady, metav1.ConditionFalse, reasonConnectorInvalid, message) } if err = a.maybeProvisionConnector(ctx, logger, cn); err != nil { - logger.Errorf("error creating Connector resources: %w", err) + reason := reasonConnectorCreationFailed message := fmt.Sprintf(messageConnectorCreationFailed, err) - a.recorder.Eventf(cn, corev1.EventTypeWarning, reasonConnectorCreationFailed, message) - return setStatus(cn, tsapi.ConnectorReady, metav1.ConditionFalse, reasonConnectorCreationFailed, message) + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + reason = reasonConnectorCreating + message = fmt.Sprintf("optimistic lock error, retrying: %s", err) + err = nil + logger.Info(message) + } else { + a.recorder.Eventf(cn, corev1.EventTypeWarning, reason, message) + } + + return setStatus(cn, tsapi.ConnectorReady, metav1.ConditionFalse, reason, message) } logger.Info("Connector resources synced") @@ -150,6 +162,9 @@ func (a *ConnectorReconciler) Reconcile(ctx context.Context, req reconcile.Reque cn.Status.SubnetRoutes = cn.Spec.SubnetRouter.AdvertiseRoutes.Stringify() return setStatus(cn, tsapi.ConnectorReady, metav1.ConditionTrue, reasonConnectorCreated, reasonConnectorCreated) } + if cn.Spec.AppConnector != nil { + cn.Status.IsAppConnector = true + } cn.Status.SubnetRoutes = "" return setStatus(cn, tsapi.ConnectorReady, metav1.ConditionTrue, reasonConnectorCreated, reasonConnectorCreated) } @@ -183,29 +198,44 @@ func (a *ConnectorReconciler) maybeProvisionConnector(ctx context.Context, logge isExitNode: cn.Spec.ExitNode, }, ProxyClassName: proxyClass, + proxyType: proxyTypeConnector, } if cn.Spec.SubnetRouter != nil && len(cn.Spec.SubnetRouter.AdvertiseRoutes) > 0 { sts.Connector.routes = cn.Spec.SubnetRouter.AdvertiseRoutes.Stringify() } + if cn.Spec.AppConnector != nil { + sts.Connector.isAppConnector = true + if len(cn.Spec.AppConnector.Routes) != 0 { + sts.Connector.routes = cn.Spec.AppConnector.Routes.Stringify() + } + } + a.mu.Lock() - if sts.Connector.isExitNode { + if cn.Spec.ExitNode { a.exitNodes.Add(cn.UID) } else { a.exitNodes.Remove(cn.UID) } - if sts.Connector.routes != "" { + if cn.Spec.SubnetRouter != nil { a.subnetRouters.Add(cn.GetUID()) } else { a.subnetRouters.Remove(cn.GetUID()) } + if cn.Spec.AppConnector != nil { + a.appConnectors.Add(cn.GetUID()) + } else { + a.appConnectors.Remove(cn.GetUID()) + } a.mu.Unlock() gaugeConnectorSubnetRouterResources.Set(int64(a.subnetRouters.Len())) gaugeConnectorExitNodeResources.Set(int64(a.exitNodes.Len())) + gaugeConnectorAppConnectorResources.Set(int64(a.appConnectors.Len())) var connectors set.Slice[types.UID] connectors.AddSlice(a.exitNodes.Slice()) connectors.AddSlice(a.subnetRouters.Slice()) + connectors.AddSlice(a.appConnectors.Slice()) gaugeConnectorResources.Set(int64(connectors.Len())) _, err := a.ssr.Provision(ctx, logger, sts) @@ -213,27 +243,27 @@ func (a *ConnectorReconciler) maybeProvisionConnector(ctx context.Context, logge return err } - _, tsHost, ips, err := a.ssr.DeviceInfo(ctx, crl) + dev, err := a.ssr.DeviceInfo(ctx, crl, logger) if err != nil { return err } - if tsHost == "" { - logger.Debugf("no Tailscale hostname known yet, waiting for connector pod to finish auth") + if dev == nil || dev.hostname == "" { + logger.Debugf("no Tailscale hostname known yet, waiting for Connector Pod to finish auth") // No hostname yet. Wait for the connector pod to auth. cn.Status.TailnetIPs = nil cn.Status.Hostname = "" return nil } - cn.Status.TailnetIPs = ips - cn.Status.Hostname = tsHost + cn.Status.TailnetIPs = dev.ips + cn.Status.Hostname = dev.hostname return nil } func (a *ConnectorReconciler) maybeCleanupConnector(ctx context.Context, logger *zap.SugaredLogger, cn *tsapi.Connector) (bool, error) { - if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(cn.Name, a.tsnamespace, "connector")); err != nil { + if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(cn.Name, a.tsnamespace, "connector"), proxyTypeConnector); err != nil { return false, fmt.Errorf("failed to cleanup Connector resources: %w", err) } else if !done { logger.Debugf("Connector cleanup not done yet, waiting for next reconcile") @@ -248,12 +278,15 @@ func (a *ConnectorReconciler) maybeCleanupConnector(ctx context.Context, logger a.mu.Lock() a.subnetRouters.Remove(cn.UID) a.exitNodes.Remove(cn.UID) + a.appConnectors.Remove(cn.UID) a.mu.Unlock() gaugeConnectorExitNodeResources.Set(int64(a.exitNodes.Len())) gaugeConnectorSubnetRouterResources.Set(int64(a.subnetRouters.Len())) + gaugeConnectorAppConnectorResources.Set(int64(a.appConnectors.Len())) var connectors set.Slice[types.UID] connectors.AddSlice(a.exitNodes.Slice()) connectors.AddSlice(a.subnetRouters.Slice()) + connectors.AddSlice(a.appConnectors.Slice()) gaugeConnectorResources.Set(int64(connectors.Len())) return true, nil } @@ -262,8 +295,14 @@ func (a *ConnectorReconciler) validate(cn *tsapi.Connector) error { // Connector fields are already validated at apply time with CEL validation // on custom resource fields. The checks here are a backup in case the // CEL validation breaks without us noticing. - if !(cn.Spec.SubnetRouter != nil || cn.Spec.ExitNode) { - return errors.New("invalid spec: a Connector must expose subnet routes or act as an exit node (or both)") + if cn.Spec.SubnetRouter == nil && !cn.Spec.ExitNode && cn.Spec.AppConnector == nil { + return errors.New("invalid spec: a Connector must be configured as at least one of subnet router, exit node or app connector") + } + if (cn.Spec.SubnetRouter != nil || cn.Spec.ExitNode) && cn.Spec.AppConnector != nil { + return errors.New("invalid spec: a Connector that is configured as an app connector must not be also configured as a subnet router or exit node") + } + if cn.Spec.AppConnector != nil { + return validateAppConnector(cn.Spec.AppConnector) } if cn.Spec.SubnetRouter == nil { return nil @@ -272,19 +311,27 @@ func (a *ConnectorReconciler) validate(cn *tsapi.Connector) error { } func validateSubnetRouter(sb *tsapi.SubnetRouter) error { - if len(sb.AdvertiseRoutes) < 1 { + if len(sb.AdvertiseRoutes) == 0 { return errors.New("invalid subnet router spec: no routes defined") } - var err error - for _, route := range sb.AdvertiseRoutes { + return validateRoutes(sb.AdvertiseRoutes) +} + +func validateAppConnector(ac *tsapi.AppConnector) error { + return validateRoutes(ac.Routes) +} + +func validateRoutes(routes tsapi.Routes) error { + var errs []error + for _, route := range routes { pfx, e := netip.ParsePrefix(string(route)) if e != nil { - err = errors.Wrap(err, fmt.Sprintf("route %s is invalid: %v", route, err)) + errs = append(errs, fmt.Errorf("route %v is invalid: %v", route, e)) continue } if pfx.Masked() != pfx { - err = errors.Wrap(err, fmt.Sprintf("route %s has non-address bits set; expected %s", pfx, pfx.Masked())) + errs = append(errs, fmt.Errorf("route %s has non-address bits set; expected %s", pfx, pfx.Masked())) } } - return err + return errors.Join(errs...) } diff --git a/cmd/k8s-operator/connector_test.go b/cmd/k8s-operator/connector_test.go index a4ba90d3d6683..f32fe3282020c 100644 --- a/cmd/k8s-operator/connector_test.go +++ b/cmd/k8s-operator/connector_test.go @@ -8,12 +8,14 @@ package main import ( "context" "testing" + "time" "go.uber.org/zap" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client/fake" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" @@ -77,8 +79,8 @@ func TestConnector(t *testing.T) { subnetRoutes: "10.40.0.0/14", app: kubetypes.AppConnector, } - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) // Connector status should get updated with the IP/hostname info when available. const hostname = "foo.tailnetxyz.ts.net" @@ -104,7 +106,7 @@ func TestConnector(t *testing.T) { opts.subnetRoutes = "10.40.0.0/14,10.44.0.0/20" expectReconciled(t, cr, "", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) // Remove a route. mustUpdate[tsapi.Connector](t, fc, "", "test", func(conn *tsapi.Connector) { @@ -112,7 +114,7 @@ func TestConnector(t *testing.T) { }) opts.subnetRoutes = "10.44.0.0/20" expectReconciled(t, cr, "", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) // Remove the subnet router. mustUpdate[tsapi.Connector](t, fc, "", "test", func(conn *tsapi.Connector) { @@ -120,7 +122,7 @@ func TestConnector(t *testing.T) { }) opts.subnetRoutes = "" expectReconciled(t, cr, "", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) // Re-add the subnet router. mustUpdate[tsapi.Connector](t, fc, "", "test", func(conn *tsapi.Connector) { @@ -130,7 +132,7 @@ func TestConnector(t *testing.T) { }) opts.subnetRoutes = "10.44.0.0/20" expectReconciled(t, cr, "", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) // Delete the Connector. if err = fc.Delete(context.Background(), cn); err != nil { @@ -173,8 +175,8 @@ func TestConnector(t *testing.T) { hostname: "test-connector", app: kubetypes.AppConnector, } - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) // Add an exit node. mustUpdate[tsapi.Connector](t, fc, "", "test", func(conn *tsapi.Connector) { @@ -182,7 +184,7 @@ func TestConnector(t *testing.T) { }) opts.isExitNode = true expectReconciled(t, cr, "", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) // Delete the Connector. if err = fc.Delete(context.Background(), cn); err != nil { @@ -201,7 +203,7 @@ func TestConnectorWithProxyClass(t *testing.T) { pc := &tsapi.ProxyClass{ ObjectMeta: metav1.ObjectMeta{Name: "custom-metadata"}, Spec: tsapi.ProxyClassSpec{StatefulSet: &tsapi.StatefulSet{ - Labels: map[string]string{"foo": "bar"}, + Labels: tsapi.Labels{"foo": "bar"}, Annotations: map[string]string{"bar.io/foo": "some-val"}, Pod: &tsapi.Pod{Annotations: map[string]string{"foo.io/bar": "some-val"}}}}, } @@ -259,8 +261,8 @@ func TestConnectorWithProxyClass(t *testing.T) { subnetRoutes: "10.40.0.0/14", app: kubetypes.AppConnector, } - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) // 2. Update Connector to specify a ProxyClass. ProxyClass is not yet // ready, so its configuration is NOT applied to the Connector @@ -269,7 +271,7 @@ func TestConnectorWithProxyClass(t *testing.T) { conn.Spec.ProxyClass = "custom-metadata" }) expectReconciled(t, cr, "", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) // 3. ProxyClass is set to Ready by proxy-class reconciler. Connector // get reconciled and configuration from the ProxyClass is applied to @@ -284,7 +286,7 @@ func TestConnectorWithProxyClass(t *testing.T) { }) opts.proxyClass = pc.Name expectReconciled(t, cr, "", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) // 4. Connector.spec.proxyClass field is unset, Connector gets // reconciled and configuration from the ProxyClass is removed from the @@ -294,5 +296,102 @@ func TestConnectorWithProxyClass(t *testing.T) { }) opts.proxyClass = "" expectReconciled(t, cr, "", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) +} + +func TestConnectorWithAppConnector(t *testing.T) { + // Setup + cn := &tsapi.Connector{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + UID: types.UID("1234-UID"), + }, + TypeMeta: metav1.TypeMeta{ + Kind: tsapi.ConnectorKind, + APIVersion: "tailscale.io/v1alpha1", + }, + Spec: tsapi.ConnectorSpec{ + AppConnector: &tsapi.AppConnector{}, + }, + } + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(cn). + WithStatusSubresource(cn). + Build() + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + cl := tstest.NewClock(tstest.ClockOpts{}) + fr := record.NewFakeRecorder(1) + cr := &ConnectorReconciler{ + Client: fc, + clock: cl, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + recorder: fr, + } + + // 1. Connector with app connnector is created and becomes ready + expectReconciled(t, cr, "", "test") + fullName, shortName := findGenName(t, fc, "", "test", "connector") + opts := configOpts{ + stsName: shortName, + secretName: fullName, + parentType: "connector", + hostname: "test-connector", + app: kubetypes.AppConnector, + isAppConnector: true, + } + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) + // Connector's ready condition should be set to true + + cn.ObjectMeta.Finalizers = append(cn.ObjectMeta.Finalizers, "tailscale.com/finalizer") + cn.Status.IsAppConnector = true + cn.Status.Conditions = []metav1.Condition{{ + Type: string(tsapi.ConnectorReady), + Status: metav1.ConditionTrue, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + Reason: reasonConnectorCreated, + Message: reasonConnectorCreated, + }} + expectEqual(t, fc, cn) + + // 2. Connector with invalid app connector routes has status set to invalid + mustUpdate[tsapi.Connector](t, fc, "", "test", func(conn *tsapi.Connector) { + conn.Spec.AppConnector.Routes = tsapi.Routes{tsapi.Route("1.2.3.4/5")} + }) + cn.Spec.AppConnector.Routes = tsapi.Routes{tsapi.Route("1.2.3.4/5")} + expectReconciled(t, cr, "", "test") + cn.Status.Conditions = []metav1.Condition{{ + Type: string(tsapi.ConnectorReady), + Status: metav1.ConditionFalse, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + Reason: reasonConnectorInvalid, + Message: "Connector is invalid: route 1.2.3.4/5 has non-address bits set; expected 0.0.0.0/5", + }} + expectEqual(t, fc, cn) + + // 3. Connector with valid app connnector routes becomes ready + mustUpdate[tsapi.Connector](t, fc, "", "test", func(conn *tsapi.Connector) { + conn.Spec.AppConnector.Routes = tsapi.Routes{tsapi.Route("10.88.2.21/32")} + }) + cn.Spec.AppConnector.Routes = tsapi.Routes{tsapi.Route("10.88.2.21/32")} + cn.Status.Conditions = []metav1.Condition{{ + Type: string(tsapi.ConnectorReady), + Status: metav1.ConditionTrue, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + Reason: reasonConnectorCreated, + Message: reasonConnectorCreated, + }} + expectReconciled(t, cr, "", "test") } diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index b77ea22ef5297..12fb5cf2e5a65 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -9,7 +9,6 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ L github.com/aws/aws-sdk-go-v2/aws/arn from tailscale.com/ipn/store/awsstore L github.com/aws/aws-sdk-go-v2/aws/defaults from github.com/aws/aws-sdk-go-v2/service/ssm+ L github.com/aws/aws-sdk-go-v2/aws/middleware from github.com/aws/aws-sdk-go-v2/aws/retry+ - L github.com/aws/aws-sdk-go-v2/aws/middleware/private/metrics from github.com/aws/aws-sdk-go-v2/aws/retry+ L github.com/aws/aws-sdk-go-v2/aws/protocol/query from github.com/aws/aws-sdk-go-v2/service/sts L github.com/aws/aws-sdk-go-v2/aws/protocol/restjson from github.com/aws/aws-sdk-go-v2/service/ssm+ L github.com/aws/aws-sdk-go-v2/aws/protocol/xml from github.com/aws/aws-sdk-go-v2/service/sts @@ -31,10 +30,12 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ L github.com/aws/aws-sdk-go-v2/internal/auth from github.com/aws/aws-sdk-go-v2/aws/signer/v4+ L github.com/aws/aws-sdk-go-v2/internal/auth/smithy from github.com/aws/aws-sdk-go-v2/service/ssm+ L github.com/aws/aws-sdk-go-v2/internal/configsources from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/aws-sdk-go-v2/internal/context from github.com/aws/aws-sdk-go-v2/aws/retry+ L github.com/aws/aws-sdk-go-v2/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/ssm+ L github.com/aws/aws-sdk-go-v2/internal/endpoints/awsrulesfn from github.com/aws/aws-sdk-go-v2/service/ssm+ L github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 from github.com/aws/aws-sdk-go-v2/service/ssm/internal/endpoints+ L github.com/aws/aws-sdk-go-v2/internal/ini from github.com/aws/aws-sdk-go-v2/config + L github.com/aws/aws-sdk-go-v2/internal/middleware from github.com/aws/aws-sdk-go-v2/service/sso+ L github.com/aws/aws-sdk-go-v2/internal/rand from github.com/aws/aws-sdk-go-v2/aws+ L github.com/aws/aws-sdk-go-v2/internal/sdk from github.com/aws/aws-sdk-go-v2/aws+ L github.com/aws/aws-sdk-go-v2/internal/sdkio from github.com/aws/aws-sdk-go-v2/credentials/processcreds @@ -69,18 +70,19 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ L github.com/aws/smithy-go/internal/sync/singleflight from github.com/aws/smithy-go/auth/bearer L github.com/aws/smithy-go/io from github.com/aws/aws-sdk-go-v2/feature/ec2/imds+ L github.com/aws/smithy-go/logging from github.com/aws/aws-sdk-go-v2/aws+ + L github.com/aws/smithy-go/metrics from github.com/aws/aws-sdk-go-v2/aws/retry+ L github.com/aws/smithy-go/middleware from github.com/aws/aws-sdk-go-v2/aws+ L github.com/aws/smithy-go/private/requestcompression from github.com/aws/aws-sdk-go-v2/config L github.com/aws/smithy-go/ptr from github.com/aws/aws-sdk-go-v2/aws+ L github.com/aws/smithy-go/rand from github.com/aws/aws-sdk-go-v2/aws/middleware+ L github.com/aws/smithy-go/time from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/smithy-go/tracing from github.com/aws/aws-sdk-go-v2/aws/middleware+ L github.com/aws/smithy-go/transport/http from github.com/aws/aws-sdk-go-v2/aws/middleware+ L github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http L github.com/aws/smithy-go/waiter from github.com/aws/aws-sdk-go-v2/service/ssm github.com/beorn7/perks/quantile from github.com/prometheus/client_golang/prometheus - github.com/bits-and-blooms/bitset from github.com/gaissmai/bart đŸ’Ŗ github.com/cespare/xxhash/v2 from github.com/prometheus/client_golang/prometheus - github.com/coder/websocket from tailscale.com/control/controlhttp+ + github.com/coder/websocket from tailscale.com/util/eventbus github.com/coder/websocket/internal/errd from github.com/coder/websocket github.com/coder/websocket/internal/util from github.com/coder/websocket github.com/coder/websocket/internal/xsync from github.com/coder/websocket @@ -98,8 +100,10 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/evanphx/json-patch/v5 from sigs.k8s.io/controller-runtime/pkg/client github.com/evanphx/json-patch/v5/internal/json from github.com/evanphx/json-patch/v5 đŸ’Ŗ github.com/fsnotify/fsnotify from sigs.k8s.io/controller-runtime/pkg/certwatcher - github.com/fxamacker/cbor/v2 from tailscale.com/tka + github.com/fxamacker/cbor/v2 from tailscale.com/tka+ github.com/gaissmai/bart from tailscale.com/net/ipset+ + github.com/gaissmai/bart/internal/bitset from github.com/gaissmai/bart+ + github.com/gaissmai/bart/internal/sparse from github.com/gaissmai/bart github.com/go-json-experiment/json from tailscale.com/types/opt+ github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json/internal/jsonflags+ github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json/internal/jsonopts+ @@ -114,11 +118,11 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/go-openapi/jsonpointer from github.com/go-openapi/jsonreference github.com/go-openapi/jsonreference from k8s.io/kube-openapi/pkg/internal+ github.com/go-openapi/jsonreference/internal from github.com/go-openapi/jsonreference - github.com/go-openapi/swag from github.com/go-openapi/jsonpointer+ + đŸ’Ŗ github.com/go-openapi/swag from github.com/go-openapi/jsonpointer+ L đŸ’Ŗ github.com/godbus/dbus/v5 from tailscale.com/net/dns đŸ’Ŗ github.com/gogo/protobuf/proto from k8s.io/api/admission/v1+ github.com/gogo/protobuf/sortkeys from k8s.io/api/admission/v1+ - github.com/golang/groupcache/lru from k8s.io/client-go/tools/record+ + github.com/golang/groupcache/lru from tailscale.com/net/dnscache github.com/golang/protobuf/proto from k8s.io/client-go/discovery+ github.com/google/btree from gvisor.dev/gvisor/pkg/tcpip/header+ github.com/google/gnostic-models/compiler from github.com/google/gnostic-models/openapiv2+ @@ -143,15 +147,10 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/gorilla/csrf from tailscale.com/client/web github.com/gorilla/securecookie from github.com/gorilla/csrf github.com/hdevalence/ed25519consensus from tailscale.com/clientupdate/distsign+ - L đŸ’Ŗ github.com/illarion/gonotify/v2 from tailscale.com/net/dns - github.com/imdario/mergo from k8s.io/client-go/tools/clientcmd - L github.com/insomniacslk/dhcp/dhcpv4 from tailscale.com/net/tstun - L github.com/insomniacslk/dhcp/iana from github.com/insomniacslk/dhcp/dhcpv4 - L github.com/insomniacslk/dhcp/interfaces from github.com/insomniacslk/dhcp/dhcpv4 - L github.com/insomniacslk/dhcp/rfc1035label from github.com/insomniacslk/dhcp/dhcpv4 + L đŸ’Ŗ github.com/illarion/gonotify/v3 from tailscale.com/net/dns + L github.com/illarion/gonotify/v3/syscallf from github.com/illarion/gonotify/v3 L github.com/jmespath/go-jmespath from github.com/aws/aws-sdk-go-v2/service/ssm github.com/josharian/intern from github.com/mailru/easyjson/jlexer - L github.com/josharian/native from github.com/mdlayher/netlink+ L đŸ’Ŗ github.com/jsimonetti/rtnetlink from tailscale.com/net/netmon L github.com/jsimonetti/rtnetlink/internal/unix from github.com/jsimonetti/rtnetlink đŸ’Ŗ github.com/json-iterator/go from sigs.k8s.io/structured-merge-diff/v4/fieldpath+ @@ -162,7 +161,6 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/klauspost/compress/internal/snapref from github.com/klauspost/compress/zstd github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe github.com/klauspost/compress/zstd/internal/xxhash from github.com/klauspost/compress/zstd - github.com/kortschak/wol from tailscale.com/ipn/ipnlocal github.com/mailru/easyjson/buffer from github.com/mailru/easyjson/jwriter đŸ’Ŗ github.com/mailru/easyjson/jlexer from github.com/go-openapi/swag github.com/mailru/easyjson/jwriter from github.com/go-openapi/swag @@ -176,13 +174,8 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ đŸ’Ŗ github.com/mitchellh/go-ps from tailscale.com/safesocket github.com/modern-go/concurrent from github.com/json-iterator/go đŸ’Ŗ github.com/modern-go/reflect2 from github.com/json-iterator/go - github.com/munnerz/goautoneg from k8s.io/kube-openapi/pkg/handler3 + github.com/munnerz/goautoneg from k8s.io/kube-openapi/pkg/handler3+ github.com/opencontainers/go-digest from github.com/distribution/reference - L github.com/pierrec/lz4/v4 from github.com/u-root/uio/uio - L github.com/pierrec/lz4/v4/internal/lz4block from github.com/pierrec/lz4/v4+ - L github.com/pierrec/lz4/v4/internal/lz4errors from github.com/pierrec/lz4/v4+ - L github.com/pierrec/lz4/v4/internal/lz4stream from github.com/pierrec/lz4/v4 - L github.com/pierrec/lz4/v4/internal/xxh32 from github.com/pierrec/lz4/v4/internal/lz4stream github.com/pkg/errors from github.com/evanphx/json-patch/v5+ D github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack đŸ’Ŗ github.com/prometheus/client_golang/prometheus from github.com/prometheus/client_golang/prometheus/collectors+ @@ -191,7 +184,6 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/prometheus/client_golang/prometheus/promhttp from sigs.k8s.io/controller-runtime/pkg/metrics/server+ github.com/prometheus/client_model/go from github.com/prometheus/client_golang/prometheus+ github.com/prometheus/common/expfmt from github.com/prometheus/client_golang/prometheus+ - github.com/prometheus/common/internal/bitbucket.org/ww/goautoneg from github.com/prometheus/common/expfmt github.com/prometheus/common/model from github.com/prometheus/client_golang/prometheus+ LD github.com/prometheus/procfs from github.com/prometheus/client_golang/prometheus LD github.com/prometheus/procfs/internal/fs from github.com/prometheus/procfs @@ -204,10 +196,6 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ W đŸ’Ŗ github.com/tailscale/go-winio/internal/socket from github.com/tailscale/go-winio W github.com/tailscale/go-winio/internal/stringbuffer from github.com/tailscale/go-winio/internal/fs W github.com/tailscale/go-winio/pkg/guid from github.com/tailscale/go-winio+ - github.com/tailscale/golang-x-crypto/acme from tailscale.com/ipn/ipnlocal - LD github.com/tailscale/golang-x-crypto/internal/poly1305 from github.com/tailscale/golang-x-crypto/ssh - LD github.com/tailscale/golang-x-crypto/ssh from tailscale.com/ipn/ipnlocal - LD github.com/tailscale/golang-x-crypto/ssh/internal/bcrypt_pbkdf from github.com/tailscale/golang-x-crypto/ssh github.com/tailscale/goupnp from github.com/tailscale/goupnp/dcps/internetgateway2+ github.com/tailscale/goupnp/dcps/internetgateway2 from tailscale.com/net/portmapper github.com/tailscale/goupnp/httpu from github.com/tailscale/goupnp+ @@ -229,9 +217,6 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/tailscale/wireguard-go/rwcancel from github.com/tailscale/wireguard-go/device+ github.com/tailscale/wireguard-go/tai64n from github.com/tailscale/wireguard-go/device đŸ’Ŗ github.com/tailscale/wireguard-go/tun from github.com/tailscale/wireguard-go/device+ - github.com/tcnksm/go-httpstat from tailscale.com/net/netcheck - L github.com/u-root/uio/rand from github.com/insomniacslk/dhcp/dhcpv4 - L github.com/u-root/uio/uio from github.com/insomniacslk/dhcp/dhcpv4+ L github.com/vishvananda/netns from github.com/tailscale/netlink+ github.com/x448/float16 from github.com/fxamacker/cbor/v2 go.uber.org/multierr from go.uber.org/zap+ @@ -244,7 +229,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ go.uber.org/zap/internal/pool from go.uber.org/zap+ go.uber.org/zap/internal/stacktrace from go.uber.org/zap go.uber.org/zap/zapcore from github.com/go-logr/zapr+ - đŸ’Ŗ go4.org/mem from tailscale.com/client/tailscale+ + đŸ’Ŗ go4.org/mem from tailscale.com/client/local+ go4.org/netipx from tailscale.com/ipn/ipnlocal+ W đŸ’Ŗ golang.zx2c4.com/wintun from github.com/tailscale/wireguard-go/tun W đŸ’Ŗ golang.zx2c4.com/wireguard/windows/tunnel/winipcfg from tailscale.com/net/dns+ @@ -256,6 +241,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ google.golang.org/protobuf/internal/descopts from google.golang.org/protobuf/internal/filedesc+ google.golang.org/protobuf/internal/detrand from google.golang.org/protobuf/internal/descfmt+ google.golang.org/protobuf/internal/editiondefaults from google.golang.org/protobuf/internal/filedesc+ + google.golang.org/protobuf/internal/editionssupport from google.golang.org/protobuf/reflect/protodesc google.golang.org/protobuf/internal/encoding/defval from google.golang.org/protobuf/internal/encoding/tag+ google.golang.org/protobuf/internal/encoding/messageset from google.golang.org/protobuf/encoding/prototext+ google.golang.org/protobuf/internal/encoding/tag from google.golang.org/protobuf/internal/impl @@ -281,8 +267,8 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ google.golang.org/protobuf/types/gofeaturespb from google.golang.org/protobuf/reflect/protodesc google.golang.org/protobuf/types/known/anypb from github.com/google/gnostic-models/compiler+ google.golang.org/protobuf/types/known/timestamppb from github.com/prometheus/client_golang/prometheus+ + gopkg.in/evanphx/json-patch.v4 from k8s.io/client-go/testing gopkg.in/inf.v0 from k8s.io/apimachinery/pkg/api/resource - gopkg.in/yaml.v2 from k8s.io/kube-openapi/pkg/util/proto+ gopkg.in/yaml.v3 from github.com/go-openapi/swag+ gvisor.dev/gvisor/pkg/atomicbitops from gvisor.dev/gvisor/pkg/buffer+ gvisor.dev/gvisor/pkg/bits from gvisor.dev/gvisor/pkg/buffer @@ -304,12 +290,12 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ gvisor.dev/gvisor/pkg/tcpip/hash/jenkins from gvisor.dev/gvisor/pkg/tcpip/stack+ gvisor.dev/gvisor/pkg/tcpip/header from gvisor.dev/gvisor/pkg/tcpip/header/parse+ gvisor.dev/gvisor/pkg/tcpip/header/parse from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ - gvisor.dev/gvisor/pkg/tcpip/internal/tcp from gvisor.dev/gvisor/pkg/tcpip/stack+ + gvisor.dev/gvisor/pkg/tcpip/internal/tcp from gvisor.dev/gvisor/pkg/tcpip/transport/tcp gvisor.dev/gvisor/pkg/tcpip/network/hash from gvisor.dev/gvisor/pkg/tcpip/network/ipv4 gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ gvisor.dev/gvisor/pkg/tcpip/network/internal/ip from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ gvisor.dev/gvisor/pkg/tcpip/network/internal/multicast from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ - gvisor.dev/gvisor/pkg/tcpip/network/ipv4 from tailscale.com/net/tstun+ + gvisor.dev/gvisor/pkg/tcpip/network/ipv4 from tailscale.com/wgengine/netstack gvisor.dev/gvisor/pkg/tcpip/network/ipv6 from tailscale.com/wgengine/netstack gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+ gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+ @@ -351,6 +337,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/api/certificates/v1alpha1 from k8s.io/client-go/applyconfigurations/certificates/v1alpha1+ k8s.io/api/certificates/v1beta1 from k8s.io/client-go/applyconfigurations/certificates/v1beta1+ k8s.io/api/coordination/v1 from k8s.io/client-go/applyconfigurations/coordination/v1+ + k8s.io/api/coordination/v1alpha2 from k8s.io/client-go/applyconfigurations/coordination/v1alpha2+ k8s.io/api/coordination/v1beta1 from k8s.io/client-go/applyconfigurations/coordination/v1beta1+ k8s.io/api/core/v1 from k8s.io/api/apps/v1+ k8s.io/api/discovery/v1 from k8s.io/client-go/applyconfigurations/discovery/v1+ @@ -373,7 +360,8 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/api/rbac/v1 from k8s.io/client-go/applyconfigurations/rbac/v1+ k8s.io/api/rbac/v1alpha1 from k8s.io/client-go/applyconfigurations/rbac/v1alpha1+ k8s.io/api/rbac/v1beta1 from k8s.io/client-go/applyconfigurations/rbac/v1beta1+ - k8s.io/api/resource/v1alpha2 from k8s.io/client-go/applyconfigurations/resource/v1alpha2+ + k8s.io/api/resource/v1alpha3 from k8s.io/client-go/applyconfigurations/resource/v1alpha3+ + k8s.io/api/resource/v1beta1 from k8s.io/client-go/applyconfigurations/resource/v1beta1+ k8s.io/api/scheduling/v1 from k8s.io/client-go/applyconfigurations/scheduling/v1+ k8s.io/api/scheduling/v1alpha1 from k8s.io/client-go/applyconfigurations/scheduling/v1alpha1+ k8s.io/api/scheduling/v1beta1 from k8s.io/client-go/applyconfigurations/scheduling/v1beta1+ @@ -382,14 +370,16 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/api/storage/v1beta1 from k8s.io/client-go/applyconfigurations/storage/v1beta1+ k8s.io/api/storagemigration/v1alpha1 from k8s.io/client-go/applyconfigurations/storagemigration/v1alpha1+ k8s.io/apiextensions-apiserver/pkg/apis/apiextensions from k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1 - đŸ’Ŗ k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1 from sigs.k8s.io/controller-runtime/pkg/webhook/conversion + đŸ’Ŗ k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1 from sigs.k8s.io/controller-runtime/pkg/webhook/conversion+ k8s.io/apimachinery/pkg/api/equality from k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1+ k8s.io/apimachinery/pkg/api/errors from k8s.io/apimachinery/pkg/util/managedfields/internal+ k8s.io/apimachinery/pkg/api/meta from k8s.io/apimachinery/pkg/api/validation+ + k8s.io/apimachinery/pkg/api/meta/testrestmapper from k8s.io/client-go/testing k8s.io/apimachinery/pkg/api/resource from k8s.io/api/autoscaling/v1+ k8s.io/apimachinery/pkg/api/validation from k8s.io/apimachinery/pkg/util/managedfields/internal+ đŸ’Ŗ k8s.io/apimachinery/pkg/apis/meta/internalversion from k8s.io/apimachinery/pkg/apis/meta/internalversion/scheme+ k8s.io/apimachinery/pkg/apis/meta/internalversion/scheme from k8s.io/client-go/metadata + k8s.io/apimachinery/pkg/apis/meta/internalversion/validation from k8s.io/client-go/util/watchlist đŸ’Ŗ k8s.io/apimachinery/pkg/apis/meta/v1 from k8s.io/api/admission/v1+ k8s.io/apimachinery/pkg/apis/meta/v1/unstructured from k8s.io/apimachinery/pkg/runtime/serializer/versioning+ k8s.io/apimachinery/pkg/apis/meta/v1/validation from k8s.io/apimachinery/pkg/api/validation+ @@ -401,6 +391,9 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/apimachinery/pkg/runtime from k8s.io/api/admission/v1+ k8s.io/apimachinery/pkg/runtime/schema from k8s.io/api/admission/v1+ k8s.io/apimachinery/pkg/runtime/serializer from k8s.io/apimachinery/pkg/apis/meta/internalversion/scheme+ + k8s.io/apimachinery/pkg/runtime/serializer/cbor from k8s.io/client-go/dynamic+ + k8s.io/apimachinery/pkg/runtime/serializer/cbor/direct from k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1+ + k8s.io/apimachinery/pkg/runtime/serializer/cbor/internal/modes from k8s.io/apimachinery/pkg/runtime/serializer/cbor+ k8s.io/apimachinery/pkg/runtime/serializer/json from k8s.io/apimachinery/pkg/runtime/serializer+ k8s.io/apimachinery/pkg/runtime/serializer/protobuf from k8s.io/apimachinery/pkg/runtime/serializer k8s.io/apimachinery/pkg/runtime/serializer/recognizer from k8s.io/apimachinery/pkg/runtime/serializer+ @@ -452,6 +445,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/client-go/applyconfigurations/certificates/v1alpha1 from k8s.io/client-go/kubernetes/typed/certificates/v1alpha1 k8s.io/client-go/applyconfigurations/certificates/v1beta1 from k8s.io/client-go/kubernetes/typed/certificates/v1beta1 k8s.io/client-go/applyconfigurations/coordination/v1 from k8s.io/client-go/kubernetes/typed/coordination/v1 + k8s.io/client-go/applyconfigurations/coordination/v1alpha2 from k8s.io/client-go/kubernetes/typed/coordination/v1alpha2 k8s.io/client-go/applyconfigurations/coordination/v1beta1 from k8s.io/client-go/kubernetes/typed/coordination/v1beta1 k8s.io/client-go/applyconfigurations/core/v1 from k8s.io/client-go/applyconfigurations/apps/v1+ k8s.io/client-go/applyconfigurations/discovery/v1 from k8s.io/client-go/kubernetes/typed/discovery/v1 @@ -476,7 +470,8 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/client-go/applyconfigurations/rbac/v1 from k8s.io/client-go/kubernetes/typed/rbac/v1 k8s.io/client-go/applyconfigurations/rbac/v1alpha1 from k8s.io/client-go/kubernetes/typed/rbac/v1alpha1 k8s.io/client-go/applyconfigurations/rbac/v1beta1 from k8s.io/client-go/kubernetes/typed/rbac/v1beta1 - k8s.io/client-go/applyconfigurations/resource/v1alpha2 from k8s.io/client-go/kubernetes/typed/resource/v1alpha2 + k8s.io/client-go/applyconfigurations/resource/v1alpha3 from k8s.io/client-go/kubernetes/typed/resource/v1alpha3 + k8s.io/client-go/applyconfigurations/resource/v1beta1 from k8s.io/client-go/kubernetes/typed/resource/v1beta1 k8s.io/client-go/applyconfigurations/scheduling/v1 from k8s.io/client-go/kubernetes/typed/scheduling/v1 k8s.io/client-go/applyconfigurations/scheduling/v1alpha1 from k8s.io/client-go/kubernetes/typed/scheduling/v1alpha1 k8s.io/client-go/applyconfigurations/scheduling/v1beta1 from k8s.io/client-go/kubernetes/typed/scheduling/v1beta1 @@ -486,8 +481,80 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/client-go/applyconfigurations/storagemigration/v1alpha1 from k8s.io/client-go/kubernetes/typed/storagemigration/v1alpha1 k8s.io/client-go/discovery from k8s.io/client-go/applyconfigurations/meta/v1+ k8s.io/client-go/dynamic from sigs.k8s.io/controller-runtime/pkg/cache/internal+ - k8s.io/client-go/features from k8s.io/client-go/tools/cache - k8s.io/client-go/kubernetes from k8s.io/client-go/tools/leaderelection/resourcelock + k8s.io/client-go/features from k8s.io/client-go/tools/cache+ + k8s.io/client-go/gentype from k8s.io/client-go/kubernetes/typed/admissionregistration/v1+ + k8s.io/client-go/informers from k8s.io/client-go/tools/leaderelection + k8s.io/client-go/informers/admissionregistration from k8s.io/client-go/informers + k8s.io/client-go/informers/admissionregistration/v1 from k8s.io/client-go/informers/admissionregistration + k8s.io/client-go/informers/admissionregistration/v1alpha1 from k8s.io/client-go/informers/admissionregistration + k8s.io/client-go/informers/admissionregistration/v1beta1 from k8s.io/client-go/informers/admissionregistration + k8s.io/client-go/informers/apiserverinternal from k8s.io/client-go/informers + k8s.io/client-go/informers/apiserverinternal/v1alpha1 from k8s.io/client-go/informers/apiserverinternal + k8s.io/client-go/informers/apps from k8s.io/client-go/informers + k8s.io/client-go/informers/apps/v1 from k8s.io/client-go/informers/apps + k8s.io/client-go/informers/apps/v1beta1 from k8s.io/client-go/informers/apps + k8s.io/client-go/informers/apps/v1beta2 from k8s.io/client-go/informers/apps + k8s.io/client-go/informers/autoscaling from k8s.io/client-go/informers + k8s.io/client-go/informers/autoscaling/v1 from k8s.io/client-go/informers/autoscaling + k8s.io/client-go/informers/autoscaling/v2 from k8s.io/client-go/informers/autoscaling + k8s.io/client-go/informers/autoscaling/v2beta1 from k8s.io/client-go/informers/autoscaling + k8s.io/client-go/informers/autoscaling/v2beta2 from k8s.io/client-go/informers/autoscaling + k8s.io/client-go/informers/batch from k8s.io/client-go/informers + k8s.io/client-go/informers/batch/v1 from k8s.io/client-go/informers/batch + k8s.io/client-go/informers/batch/v1beta1 from k8s.io/client-go/informers/batch + k8s.io/client-go/informers/certificates from k8s.io/client-go/informers + k8s.io/client-go/informers/certificates/v1 from k8s.io/client-go/informers/certificates + k8s.io/client-go/informers/certificates/v1alpha1 from k8s.io/client-go/informers/certificates + k8s.io/client-go/informers/certificates/v1beta1 from k8s.io/client-go/informers/certificates + k8s.io/client-go/informers/coordination from k8s.io/client-go/informers + k8s.io/client-go/informers/coordination/v1 from k8s.io/client-go/informers/coordination + k8s.io/client-go/informers/coordination/v1alpha2 from k8s.io/client-go/informers/coordination + k8s.io/client-go/informers/coordination/v1beta1 from k8s.io/client-go/informers/coordination + k8s.io/client-go/informers/core from k8s.io/client-go/informers + k8s.io/client-go/informers/core/v1 from k8s.io/client-go/informers/core + k8s.io/client-go/informers/discovery from k8s.io/client-go/informers + k8s.io/client-go/informers/discovery/v1 from k8s.io/client-go/informers/discovery + k8s.io/client-go/informers/discovery/v1beta1 from k8s.io/client-go/informers/discovery + k8s.io/client-go/informers/events from k8s.io/client-go/informers + k8s.io/client-go/informers/events/v1 from k8s.io/client-go/informers/events + k8s.io/client-go/informers/events/v1beta1 from k8s.io/client-go/informers/events + k8s.io/client-go/informers/extensions from k8s.io/client-go/informers + k8s.io/client-go/informers/extensions/v1beta1 from k8s.io/client-go/informers/extensions + k8s.io/client-go/informers/flowcontrol from k8s.io/client-go/informers + k8s.io/client-go/informers/flowcontrol/v1 from k8s.io/client-go/informers/flowcontrol + k8s.io/client-go/informers/flowcontrol/v1beta1 from k8s.io/client-go/informers/flowcontrol + k8s.io/client-go/informers/flowcontrol/v1beta2 from k8s.io/client-go/informers/flowcontrol + k8s.io/client-go/informers/flowcontrol/v1beta3 from k8s.io/client-go/informers/flowcontrol + k8s.io/client-go/informers/internalinterfaces from k8s.io/client-go/informers+ + k8s.io/client-go/informers/networking from k8s.io/client-go/informers + k8s.io/client-go/informers/networking/v1 from k8s.io/client-go/informers/networking + k8s.io/client-go/informers/networking/v1alpha1 from k8s.io/client-go/informers/networking + k8s.io/client-go/informers/networking/v1beta1 from k8s.io/client-go/informers/networking + k8s.io/client-go/informers/node from k8s.io/client-go/informers + k8s.io/client-go/informers/node/v1 from k8s.io/client-go/informers/node + k8s.io/client-go/informers/node/v1alpha1 from k8s.io/client-go/informers/node + k8s.io/client-go/informers/node/v1beta1 from k8s.io/client-go/informers/node + k8s.io/client-go/informers/policy from k8s.io/client-go/informers + k8s.io/client-go/informers/policy/v1 from k8s.io/client-go/informers/policy + k8s.io/client-go/informers/policy/v1beta1 from k8s.io/client-go/informers/policy + k8s.io/client-go/informers/rbac from k8s.io/client-go/informers + k8s.io/client-go/informers/rbac/v1 from k8s.io/client-go/informers/rbac + k8s.io/client-go/informers/rbac/v1alpha1 from k8s.io/client-go/informers/rbac + k8s.io/client-go/informers/rbac/v1beta1 from k8s.io/client-go/informers/rbac + k8s.io/client-go/informers/resource from k8s.io/client-go/informers + k8s.io/client-go/informers/resource/v1alpha3 from k8s.io/client-go/informers/resource + k8s.io/client-go/informers/resource/v1beta1 from k8s.io/client-go/informers/resource + k8s.io/client-go/informers/scheduling from k8s.io/client-go/informers + k8s.io/client-go/informers/scheduling/v1 from k8s.io/client-go/informers/scheduling + k8s.io/client-go/informers/scheduling/v1alpha1 from k8s.io/client-go/informers/scheduling + k8s.io/client-go/informers/scheduling/v1beta1 from k8s.io/client-go/informers/scheduling + k8s.io/client-go/informers/storage from k8s.io/client-go/informers + k8s.io/client-go/informers/storage/v1 from k8s.io/client-go/informers/storage + k8s.io/client-go/informers/storage/v1alpha1 from k8s.io/client-go/informers/storage + k8s.io/client-go/informers/storage/v1beta1 from k8s.io/client-go/informers/storage + k8s.io/client-go/informers/storagemigration from k8s.io/client-go/informers + k8s.io/client-go/informers/storagemigration/v1alpha1 from k8s.io/client-go/informers/storagemigration + k8s.io/client-go/kubernetes from k8s.io/client-go/tools/leaderelection/resourcelock+ k8s.io/client-go/kubernetes/scheme from k8s.io/client-go/discovery+ k8s.io/client-go/kubernetes/typed/admissionregistration/v1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/admissionregistration/v1alpha1 from k8s.io/client-go/kubernetes @@ -511,6 +578,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/client-go/kubernetes/typed/certificates/v1alpha1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/certificates/v1beta1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/coordination/v1 from k8s.io/client-go/kubernetes+ + k8s.io/client-go/kubernetes/typed/coordination/v1alpha2 from k8s.io/client-go/kubernetes+ k8s.io/client-go/kubernetes/typed/coordination/v1beta1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/core/v1 from k8s.io/client-go/kubernetes+ k8s.io/client-go/kubernetes/typed/discovery/v1 from k8s.io/client-go/kubernetes @@ -533,7 +601,8 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/client-go/kubernetes/typed/rbac/v1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/rbac/v1alpha1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/rbac/v1beta1 from k8s.io/client-go/kubernetes - k8s.io/client-go/kubernetes/typed/resource/v1alpha2 from k8s.io/client-go/kubernetes + k8s.io/client-go/kubernetes/typed/resource/v1alpha3 from k8s.io/client-go/kubernetes + k8s.io/client-go/kubernetes/typed/resource/v1beta1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/scheduling/v1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/scheduling/v1alpha1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/scheduling/v1beta1 from k8s.io/client-go/kubernetes @@ -541,6 +610,56 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/client-go/kubernetes/typed/storage/v1alpha1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/storage/v1beta1 from k8s.io/client-go/kubernetes k8s.io/client-go/kubernetes/typed/storagemigration/v1alpha1 from k8s.io/client-go/kubernetes + k8s.io/client-go/listers from k8s.io/client-go/listers/admissionregistration/v1+ + k8s.io/client-go/listers/admissionregistration/v1 from k8s.io/client-go/informers/admissionregistration/v1 + k8s.io/client-go/listers/admissionregistration/v1alpha1 from k8s.io/client-go/informers/admissionregistration/v1alpha1 + k8s.io/client-go/listers/admissionregistration/v1beta1 from k8s.io/client-go/informers/admissionregistration/v1beta1 + k8s.io/client-go/listers/apiserverinternal/v1alpha1 from k8s.io/client-go/informers/apiserverinternal/v1alpha1 + k8s.io/client-go/listers/apps/v1 from k8s.io/client-go/informers/apps/v1 + k8s.io/client-go/listers/apps/v1beta1 from k8s.io/client-go/informers/apps/v1beta1 + k8s.io/client-go/listers/apps/v1beta2 from k8s.io/client-go/informers/apps/v1beta2 + k8s.io/client-go/listers/autoscaling/v1 from k8s.io/client-go/informers/autoscaling/v1 + k8s.io/client-go/listers/autoscaling/v2 from k8s.io/client-go/informers/autoscaling/v2 + k8s.io/client-go/listers/autoscaling/v2beta1 from k8s.io/client-go/informers/autoscaling/v2beta1 + k8s.io/client-go/listers/autoscaling/v2beta2 from k8s.io/client-go/informers/autoscaling/v2beta2 + k8s.io/client-go/listers/batch/v1 from k8s.io/client-go/informers/batch/v1 + k8s.io/client-go/listers/batch/v1beta1 from k8s.io/client-go/informers/batch/v1beta1 + k8s.io/client-go/listers/certificates/v1 from k8s.io/client-go/informers/certificates/v1 + k8s.io/client-go/listers/certificates/v1alpha1 from k8s.io/client-go/informers/certificates/v1alpha1 + k8s.io/client-go/listers/certificates/v1beta1 from k8s.io/client-go/informers/certificates/v1beta1 + k8s.io/client-go/listers/coordination/v1 from k8s.io/client-go/informers/coordination/v1 + k8s.io/client-go/listers/coordination/v1alpha2 from k8s.io/client-go/informers/coordination/v1alpha2 + k8s.io/client-go/listers/coordination/v1beta1 from k8s.io/client-go/informers/coordination/v1beta1 + k8s.io/client-go/listers/core/v1 from k8s.io/client-go/informers/core/v1 + k8s.io/client-go/listers/discovery/v1 from k8s.io/client-go/informers/discovery/v1 + k8s.io/client-go/listers/discovery/v1beta1 from k8s.io/client-go/informers/discovery/v1beta1 + k8s.io/client-go/listers/events/v1 from k8s.io/client-go/informers/events/v1 + k8s.io/client-go/listers/events/v1beta1 from k8s.io/client-go/informers/events/v1beta1 + k8s.io/client-go/listers/extensions/v1beta1 from k8s.io/client-go/informers/extensions/v1beta1 + k8s.io/client-go/listers/flowcontrol/v1 from k8s.io/client-go/informers/flowcontrol/v1 + k8s.io/client-go/listers/flowcontrol/v1beta1 from k8s.io/client-go/informers/flowcontrol/v1beta1 + k8s.io/client-go/listers/flowcontrol/v1beta2 from k8s.io/client-go/informers/flowcontrol/v1beta2 + k8s.io/client-go/listers/flowcontrol/v1beta3 from k8s.io/client-go/informers/flowcontrol/v1beta3 + k8s.io/client-go/listers/networking/v1 from k8s.io/client-go/informers/networking/v1 + k8s.io/client-go/listers/networking/v1alpha1 from k8s.io/client-go/informers/networking/v1alpha1 + k8s.io/client-go/listers/networking/v1beta1 from k8s.io/client-go/informers/networking/v1beta1 + k8s.io/client-go/listers/node/v1 from k8s.io/client-go/informers/node/v1 + k8s.io/client-go/listers/node/v1alpha1 from k8s.io/client-go/informers/node/v1alpha1 + k8s.io/client-go/listers/node/v1beta1 from k8s.io/client-go/informers/node/v1beta1 + k8s.io/client-go/listers/policy/v1 from k8s.io/client-go/informers/policy/v1 + k8s.io/client-go/listers/policy/v1beta1 from k8s.io/client-go/informers/policy/v1beta1 + k8s.io/client-go/listers/rbac/v1 from k8s.io/client-go/informers/rbac/v1 + k8s.io/client-go/listers/rbac/v1alpha1 from k8s.io/client-go/informers/rbac/v1alpha1 + k8s.io/client-go/listers/rbac/v1beta1 from k8s.io/client-go/informers/rbac/v1beta1 + k8s.io/client-go/listers/resource/v1alpha3 from k8s.io/client-go/informers/resource/v1alpha3 + k8s.io/client-go/listers/resource/v1beta1 from k8s.io/client-go/informers/resource/v1beta1 + k8s.io/client-go/listers/scheduling/v1 from k8s.io/client-go/informers/scheduling/v1 + k8s.io/client-go/listers/scheduling/v1alpha1 from k8s.io/client-go/informers/scheduling/v1alpha1 + k8s.io/client-go/listers/scheduling/v1beta1 from k8s.io/client-go/informers/scheduling/v1beta1 + k8s.io/client-go/listers/storage/v1 from k8s.io/client-go/informers/storage/v1 + k8s.io/client-go/listers/storage/v1alpha1 from k8s.io/client-go/informers/storage/v1alpha1 + k8s.io/client-go/listers/storage/v1beta1 from k8s.io/client-go/informers/storage/v1beta1 + k8s.io/client-go/listers/storagemigration/v1alpha1 from k8s.io/client-go/informers/storagemigration/v1alpha1 k8s.io/client-go/metadata from sigs.k8s.io/controller-runtime/pkg/cache/internal+ k8s.io/client-go/openapi from k8s.io/client-go/discovery k8s.io/client-go/pkg/apis/clientauthentication from k8s.io/client-go/pkg/apis/clientauthentication/install+ @@ -552,6 +671,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/client-go/rest from k8s.io/client-go/discovery+ k8s.io/client-go/rest/watch from k8s.io/client-go/rest k8s.io/client-go/restmapper from sigs.k8s.io/controller-runtime/pkg/client/apiutil + k8s.io/client-go/testing from k8s.io/client-go/gentype k8s.io/client-go/tools/auth from k8s.io/client-go/tools/clientcmd k8s.io/client-go/tools/cache from sigs.k8s.io/controller-runtime/pkg/cache+ k8s.io/client-go/tools/cache/synctrack from k8s.io/client-go/tools/cache @@ -568,11 +688,14 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/client-go/tools/record/util from k8s.io/client-go/tools/record k8s.io/client-go/tools/reference from k8s.io/client-go/kubernetes/typed/core/v1+ k8s.io/client-go/transport from k8s.io/client-go/plugin/pkg/client/auth/exec+ + k8s.io/client-go/util/apply from k8s.io/client-go/dynamic+ k8s.io/client-go/util/cert from k8s.io/client-go/rest+ k8s.io/client-go/util/connrotation from k8s.io/client-go/plugin/pkg/client/auth/exec+ + k8s.io/client-go/util/consistencydetector from k8s.io/client-go/dynamic+ k8s.io/client-go/util/flowcontrol from k8s.io/client-go/kubernetes+ k8s.io/client-go/util/homedir from k8s.io/client-go/tools/clientcmd k8s.io/client-go/util/keyutil from k8s.io/client-go/util/cert + k8s.io/client-go/util/watchlist from k8s.io/client-go/dynamic+ k8s.io/client-go/util/workqueue from k8s.io/client-go/transport+ k8s.io/klog/v2 from k8s.io/apimachinery/pkg/api/meta+ k8s.io/klog/v2/internal/buffer from k8s.io/klog/v2 @@ -593,11 +716,12 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ k8s.io/utils/buffer from k8s.io/client-go/tools/cache k8s.io/utils/clock from k8s.io/apimachinery/pkg/util/cache+ k8s.io/utils/clock/testing from k8s.io/client-go/util/flowcontrol + k8s.io/utils/internal/third_party/forked/golang/golang-lru from k8s.io/utils/lru k8s.io/utils/internal/third_party/forked/golang/net from k8s.io/utils/net + k8s.io/utils/lru from k8s.io/client-go/tools/record k8s.io/utils/net from k8s.io/apimachinery/pkg/util/net+ k8s.io/utils/pointer from k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1+ k8s.io/utils/ptr from k8s.io/client-go/tools/cache+ - k8s.io/utils/strings/slices from k8s.io/apimachinery/pkg/labels k8s.io/utils/trace from k8s.io/client-go/tools/cache sigs.k8s.io/controller-runtime/pkg/builder from tailscale.com/cmd/k8s-operator sigs.k8s.io/controller-runtime/pkg/cache from sigs.k8s.io/controller-runtime/pkg/cluster+ @@ -630,12 +754,12 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ sigs.k8s.io/controller-runtime/pkg/metrics from sigs.k8s.io/controller-runtime/pkg/certwatcher/metrics+ sigs.k8s.io/controller-runtime/pkg/metrics/server from sigs.k8s.io/controller-runtime/pkg/manager sigs.k8s.io/controller-runtime/pkg/predicate from sigs.k8s.io/controller-runtime/pkg/builder+ - sigs.k8s.io/controller-runtime/pkg/ratelimiter from sigs.k8s.io/controller-runtime/pkg/controller+ sigs.k8s.io/controller-runtime/pkg/reconcile from sigs.k8s.io/controller-runtime/pkg/builder+ sigs.k8s.io/controller-runtime/pkg/recorder from sigs.k8s.io/controller-runtime/pkg/leaderelection+ sigs.k8s.io/controller-runtime/pkg/source from sigs.k8s.io/controller-runtime/pkg/builder+ sigs.k8s.io/controller-runtime/pkg/webhook from sigs.k8s.io/controller-runtime/pkg/manager sigs.k8s.io/controller-runtime/pkg/webhook/admission from sigs.k8s.io/controller-runtime/pkg/builder+ + sigs.k8s.io/controller-runtime/pkg/webhook/admission/metrics from sigs.k8s.io/controller-runtime/pkg/webhook/admission sigs.k8s.io/controller-runtime/pkg/webhook/conversion from sigs.k8s.io/controller-runtime/pkg/builder sigs.k8s.io/controller-runtime/pkg/webhook/internal/metrics from sigs.k8s.io/controller-runtime/pkg/webhook+ sigs.k8s.io/json from k8s.io/apimachinery/pkg/runtime/serializer/json+ @@ -646,11 +770,12 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ sigs.k8s.io/structured-merge-diff/v4/typed from k8s.io/apimachinery/pkg/util/managedfields+ sigs.k8s.io/structured-merge-diff/v4/value from k8s.io/apimachinery/pkg/runtime+ sigs.k8s.io/yaml from k8s.io/apimachinery/pkg/runtime/serializer/json+ - sigs.k8s.io/yaml/goyaml.v2 from sigs.k8s.io/yaml + sigs.k8s.io/yaml/goyaml.v2 from sigs.k8s.io/yaml+ tailscale.com from tailscale.com/version tailscale.com/appc from tailscale.com/ipn/ipnlocal - tailscale.com/atomicfile from tailscale.com/ipn+ - tailscale.com/client/tailscale from tailscale.com/client/web+ + đŸ’Ŗ tailscale.com/atomicfile from tailscale.com/ipn+ + tailscale.com/client/local from tailscale.com/client/tailscale+ + tailscale.com/client/tailscale from tailscale.com/cmd/k8s-operator+ tailscale.com/client/tailscale/apitype from tailscale.com/client/tailscale+ tailscale.com/client/web from tailscale.com/ipn/ipnlocal tailscale.com/clientupdate from tailscale.com/client/web+ @@ -658,25 +783,31 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ tailscale.com/control/controlclient from tailscale.com/ipn/ipnlocal+ tailscale.com/control/controlhttp from tailscale.com/control/controlclient + tailscale.com/control/controlhttp/controlhttpcommon from tailscale.com/control/controlhttp tailscale.com/control/controlknobs from tailscale.com/control/controlclient+ tailscale.com/derp from tailscale.com/derp/derphttp+ + tailscale.com/derp/derpconst from tailscale.com/derp+ tailscale.com/derp/derphttp from tailscale.com/ipn/localapi+ tailscale.com/disco from tailscale.com/derp+ tailscale.com/doctor from tailscale.com/ipn/ipnlocal tailscale.com/doctor/ethtool from tailscale.com/ipn/ipnlocal đŸ’Ŗ tailscale.com/doctor/permissions from tailscale.com/ipn/ipnlocal tailscale.com/doctor/routetable from tailscale.com/ipn/ipnlocal - tailscale.com/drive from tailscale.com/client/tailscale+ - tailscale.com/envknob from tailscale.com/client/tailscale+ + tailscale.com/drive from tailscale.com/client/local+ + tailscale.com/envknob from tailscale.com/client/local+ + tailscale.com/envknob/featureknob from tailscale.com/client/web+ + tailscale.com/feature from tailscale.com/ipn/ipnext+ tailscale.com/health from tailscale.com/control/controlclient+ tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal tailscale.com/hostinfo from tailscale.com/client/web+ + tailscale.com/internal/client/tailscale from tailscale.com/cmd/k8s-operator tailscale.com/internal/noiseconn from tailscale.com/control/controlclient - tailscale.com/ipn from tailscale.com/client/tailscale+ + tailscale.com/ipn from tailscale.com/client/local+ tailscale.com/ipn/conffile from tailscale.com/ipn/ipnlocal+ đŸ’Ŗ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnlocal+ + tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal tailscale.com/ipn/ipnlocal from tailscale.com/ipn/localapi+ - tailscale.com/ipn/ipnstate from tailscale.com/client/tailscale+ + tailscale.com/ipn/ipnstate from tailscale.com/client/local+ tailscale.com/ipn/localapi from tailscale.com/tsnet tailscale.com/ipn/policy from tailscale.com/ipn/ipnlocal tailscale.com/ipn/store from tailscale.com/ipn/ipnlocal+ @@ -684,13 +815,15 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/ipn/store/kubestore from tailscale.com/cmd/k8s-operator+ tailscale.com/ipn/store/mem from tailscale.com/ipn/ipnlocal+ tailscale.com/k8s-operator from tailscale.com/cmd/k8s-operator + tailscale.com/k8s-operator/api-proxy from tailscale.com/cmd/k8s-operator tailscale.com/k8s-operator/apis from tailscale.com/k8s-operator/apis/v1alpha1 tailscale.com/k8s-operator/apis/v1alpha1 from tailscale.com/cmd/k8s-operator+ - tailscale.com/k8s-operator/sessionrecording from tailscale.com/cmd/k8s-operator + tailscale.com/k8s-operator/sessionrecording from tailscale.com/k8s-operator/api-proxy tailscale.com/k8s-operator/sessionrecording/spdy from tailscale.com/k8s-operator/sessionrecording tailscale.com/k8s-operator/sessionrecording/tsrecorder from tailscale.com/k8s-operator/sessionrecording+ tailscale.com/k8s-operator/sessionrecording/ws from tailscale.com/k8s-operator/sessionrecording tailscale.com/kube/egressservices from tailscale.com/cmd/k8s-operator + tailscale.com/kube/ingressservices from tailscale.com/cmd/k8s-operator tailscale.com/kube/kubeapi from tailscale.com/ipn/store/kubestore+ tailscale.com/kube/kubeclient from tailscale.com/ipn/store/kubestore tailscale.com/kube/kubetypes from tailscale.com/cmd/k8s-operator+ @@ -702,13 +835,14 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/logtail/backoff from tailscale.com/control/controlclient+ tailscale.com/logtail/filch from tailscale.com/log/sockstatlog+ tailscale.com/metrics from tailscale.com/derp+ + tailscale.com/net/bakedroots from tailscale.com/net/tlsdial+ tailscale.com/net/captivedetection from tailscale.com/ipn/ipnlocal+ tailscale.com/net/connstats from tailscale.com/net/tstun+ tailscale.com/net/dns from tailscale.com/ipn/ipnlocal+ tailscale.com/net/dns/publicdns from tailscale.com/net/dns+ tailscale.com/net/dns/recursive from tailscale.com/net/dnsfallback tailscale.com/net/dns/resolvconffile from tailscale.com/cmd/k8s-operator+ - tailscale.com/net/dns/resolver from tailscale.com/net/dns + tailscale.com/net/dns/resolver from tailscale.com/net/dns+ tailscale.com/net/dnscache from tailscale.com/control/controlclient+ tailscale.com/net/dnsfallback from tailscale.com/control/controlclient+ tailscale.com/net/flowtrack from tailscale.com/net/packet+ @@ -722,7 +856,8 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ đŸ’Ŗ tailscale.com/net/netmon from tailscale.com/control/controlclient+ đŸ’Ŗ tailscale.com/net/netns from tailscale.com/derp/derphttp+ W đŸ’Ŗ tailscale.com/net/netstat from tailscale.com/portlist - tailscale.com/net/netutil from tailscale.com/client/tailscale+ + tailscale.com/net/netutil from tailscale.com/client/local+ + tailscale.com/net/netx from tailscale.com/control/controlclient+ tailscale.com/net/packet from tailscale.com/net/connstats+ tailscale.com/net/packet/checksum from tailscale.com/net/tstun tailscale.com/net/ping from tailscale.com/net/netcheck+ @@ -734,38 +869,43 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/net/stun from tailscale.com/ipn/localapi+ L tailscale.com/net/tcpinfo from tailscale.com/derp tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial tailscale.com/net/tsaddr from tailscale.com/client/web+ tailscale.com/net/tsdial from tailscale.com/control/controlclient+ đŸ’Ŗ tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ tailscale.com/net/tstun from tailscale.com/tsd+ - tailscale.com/net/wsconn from tailscale.com/control/controlhttp+ + tailscale.com/net/udprelay/endpoint from tailscale.com/wgengine/magicsock tailscale.com/omit from tailscale.com/ipn/conffile - tailscale.com/paths from tailscale.com/client/tailscale+ + tailscale.com/paths from tailscale.com/client/local+ đŸ’Ŗ tailscale.com/portlist from tailscale.com/ipn/ipnlocal tailscale.com/posture from tailscale.com/ipn/ipnlocal tailscale.com/proxymap from tailscale.com/tsd+ - đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/tailscale+ + đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/local+ tailscale.com/sessionrecording from tailscale.com/k8s-operator/sessionrecording+ tailscale.com/syncs from tailscale.com/control/controlknobs+ - tailscale.com/tailcfg from tailscale.com/client/tailscale+ - tailscale.com/taildrop from tailscale.com/ipn/ipnlocal+ + tailscale.com/tailcfg from tailscale.com/client/local+ + tailscale.com/tempfork/acme from tailscale.com/ipn/ipnlocal tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock - tailscale.com/tka from tailscale.com/client/tailscale+ + tailscale.com/tempfork/httprec from tailscale.com/control/controlclient + tailscale.com/tka from tailscale.com/client/local+ tailscale.com/tsconst from tailscale.com/net/netmon+ tailscale.com/tsd from tailscale.com/ipn/ipnlocal+ tailscale.com/tsnet from tailscale.com/cmd/k8s-operator+ tailscale.com/tstime from tailscale.com/cmd/k8s-operator+ tailscale.com/tstime/mono from tailscale.com/net/tstun+ tailscale.com/tstime/rate from tailscale.com/derp+ - tailscale.com/tsweb/varz from tailscale.com/util/usermetric + tailscale.com/tsweb from tailscale.com/util/eventbus + tailscale.com/tsweb/varz from tailscale.com/util/usermetric+ tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal + tailscale.com/types/bools from tailscale.com/tsnet tailscale.com/types/dnstype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/empty from tailscale.com/ipn+ tailscale.com/types/ipproto from tailscale.com/net/flowtrack+ - tailscale.com/types/key from tailscale.com/client/tailscale+ + tailscale.com/types/key from tailscale.com/client/local+ tailscale.com/types/lazy from tailscale.com/ipn/ipnlocal+ tailscale.com/types/logger from tailscale.com/appc+ tailscale.com/types/logid from tailscale.com/ipn/ipnlocal+ + tailscale.com/types/mapx from tailscale.com/ipn/ipnext tailscale.com/types/netlogtype from tailscale.com/net/connstats+ tailscale.com/types/netmap from tailscale.com/control/controlclient+ tailscale.com/types/nettype from tailscale.com/ipn/localapi+ @@ -773,24 +913,25 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/types/persist from tailscale.com/control/controlclient+ tailscale.com/types/preftype from tailscale.com/ipn+ tailscale.com/types/ptr from tailscale.com/cmd/k8s-operator+ + tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/control/controlclient+ - tailscale.com/types/tkatype from tailscale.com/client/tailscale+ + tailscale.com/types/tkatype from tailscale.com/client/local+ tailscale.com/types/views from tailscale.com/appc+ tailscale.com/util/cibuild from tailscale.com/health tailscale.com/util/clientmetric from tailscale.com/cmd/k8s-operator+ tailscale.com/util/cloudenv from tailscale.com/hostinfo+ tailscale.com/util/cmpver from tailscale.com/clientupdate+ - tailscale.com/util/ctxkey from tailscale.com/cmd/k8s-operator+ + tailscale.com/util/ctxkey from tailscale.com/client/tailscale/apitype+ đŸ’Ŗ tailscale.com/util/deephash from tailscale.com/ipn/ipnlocal+ L đŸ’Ŗ tailscale.com/util/dirwalk from tailscale.com/metrics+ tailscale.com/util/dnsname from tailscale.com/appc+ + tailscale.com/util/eventbus from tailscale.com/tsd+ tailscale.com/util/execqueue from tailscale.com/appc+ tailscale.com/util/goroutines from tailscale.com/ipn/ipnlocal tailscale.com/util/groupmember from tailscale.com/client/web+ đŸ’Ŗ tailscale.com/util/hashx from tailscale.com/util/deephash - tailscale.com/util/httphdr from tailscale.com/ipn/ipnlocal+ tailscale.com/util/httpm from tailscale.com/client/tailscale+ - tailscale.com/util/lineread from tailscale.com/hostinfo+ + tailscale.com/util/lineiter from tailscale.com/hostinfo+ L tailscale.com/util/linuxfw from tailscale.com/net/netns+ tailscale.com/util/mak from tailscale.com/appc+ tailscale.com/util/multierr from tailscale.com/control/controlclient+ @@ -798,9 +939,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/util/nocasemaps from tailscale.com/types/ipproto đŸ’Ŗ tailscale.com/util/osdiag from tailscale.com/ipn/localapi W đŸ’Ŗ tailscale.com/util/osdiag/internal/wsc from tailscale.com/util/osdiag - tailscale.com/util/osshare from tailscale.com/ipn/ipnlocal tailscale.com/util/osuser from tailscale.com/ipn/ipnlocal - tailscale.com/util/progresstracking from tailscale.com/ipn/localapi tailscale.com/util/race from tailscale.com/net/dns/resolver tailscale.com/util/racebuild from tailscale.com/logpolicy tailscale.com/util/rands from tailscale.com/ipn/ipnlocal+ @@ -810,25 +949,26 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/util/slicesx from tailscale.com/appc+ tailscale.com/util/syspolicy from tailscale.com/control/controlclient+ tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+ - tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy - tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock tailscale.com/util/systemd from tailscale.com/control/controlclient+ tailscale.com/util/testenv from tailscale.com/control/controlclient+ tailscale.com/util/truncate from tailscale.com/logtail - tailscale.com/util/uniq from tailscale.com/ipn/ipnlocal+ tailscale.com/util/usermetric from tailscale.com/health+ tailscale.com/util/vizerror from tailscale.com/tailcfg+ đŸ’Ŗ tailscale.com/util/winutil from tailscale.com/clientupdate+ W đŸ’Ŗ tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+ - W đŸ’Ŗ tailscale.com/util/winutil/gp from tailscale.com/net/dns + W đŸ’Ŗ tailscale.com/util/winutil/gp from tailscale.com/net/dns+ W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal W đŸ’Ŗ tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ tailscale.com/util/zstdframe from tailscale.com/control/controlclient+ tailscale.com/version from tailscale.com/client/web+ tailscale.com/version/distro from tailscale.com/client/web+ tailscale.com/wgengine from tailscale.com/ipn/ipnlocal+ - tailscale.com/wgengine/capture from tailscale.com/ipn/ipnlocal+ tailscale.com/wgengine/filter from tailscale.com/control/controlclient+ tailscale.com/wgengine/filter/filtertype from tailscale.com/types/netmap+ đŸ’Ŗ tailscale.com/wgengine/magicsock from tailscale.com/ipn/ipnlocal+ @@ -844,18 +984,21 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ golang.org/x/crypto/argon2 from tailscale.com/tka golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ golang.org/x/crypto/blake2s from github.com/tailscale/wireguard-go/device+ - LD golang.org/x/crypto/blowfish from github.com/tailscale/golang-x-crypto/ssh/internal/bcrypt_pbkdf - golang.org/x/crypto/chacha20 from github.com/tailscale/golang-x-crypto/ssh+ + LD golang.org/x/crypto/blowfish from golang.org/x/crypto/ssh/internal/bcrypt_pbkdf + golang.org/x/crypto/chacha20 from golang.org/x/crypto/ssh+ golang.org/x/crypto/chacha20poly1305 from crypto/tls+ golang.org/x/crypto/cryptobyte from crypto/ecdsa+ golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ - golang.org/x/crypto/curve25519 from github.com/tailscale/golang-x-crypto/ssh+ - golang.org/x/crypto/hkdf from crypto/tls+ + golang.org/x/crypto/curve25519 from golang.org/x/crypto/ssh+ + golang.org/x/crypto/hkdf from tailscale.com/control/controlbase + golang.org/x/crypto/internal/alias from golang.org/x/crypto/chacha20+ + golang.org/x/crypto/internal/poly1305 from golang.org/x/crypto/chacha20poly1305+ golang.org/x/crypto/nacl/box from tailscale.com/types/key golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/poly1305 from github.com/tailscale/wireguard-go/device golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - golang.org/x/crypto/sha3 from crypto/internal/mlkem768+ + LD golang.org/x/crypto/ssh from tailscale.com/ipn/ipnlocal + LD golang.org/x/crypto/ssh/internal/bcrypt_pbkdf from golang.org/x/crypto/ssh golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+ golang.org/x/exp/maps from sigs.k8s.io/controller-runtime/pkg/cache+ golang.org/x/exp/slices from tailscale.com/cmd/k8s-operator+ @@ -868,6 +1011,10 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ golang.org/x/net/http2/hpack from golang.org/x/net/http2+ golang.org/x/net/icmp from github.com/prometheus-community/pro-bing+ golang.org/x/net/idna from golang.org/x/net/http/httpguts+ + golang.org/x/net/internal/httpcommon from golang.org/x/net/http2 + golang.org/x/net/internal/iana from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socks from golang.org/x/net/proxy golang.org/x/net/ipv4 from github.com/miekg/dns+ golang.org/x/net/ipv6 from github.com/miekg/dns+ golang.org/x/net/proxy from tailscale.com/net/netns @@ -877,7 +1024,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ golang.org/x/oauth2/clientcredentials from tailscale.com/cmd/k8s-operator golang.org/x/oauth2/internal from golang.org/x/oauth2+ golang.org/x/sync/errgroup from github.com/mdlayher/socket+ - golang.org/x/sys/cpu from github.com/josharian/native+ + golang.org/x/sys/cpu from github.com/tailscale/certstore+ LD golang.org/x/sys/unix from github.com/fsnotify/fsnotify+ W golang.org/x/sys/windows from github.com/dblohm7/wingoes+ W golang.org/x/sys/windows/registry from github.com/dblohm7/wingoes+ @@ -900,7 +1047,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ container/list from crypto/tls+ context from crypto/tls+ crypto from crypto/ecdh+ - crypto/aes from crypto/ecdsa+ + crypto/aes from crypto/internal/hpke+ crypto/cipher from crypto/aes+ crypto/des from crypto/tls+ crypto/dsa from crypto/x509+ @@ -909,22 +1056,62 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ crypto/ed25519 from crypto/tls+ crypto/elliptic from crypto/ecdsa+ crypto/hmac from crypto/tls+ + crypto/internal/boring from crypto/aes+ + crypto/internal/boring/bbig from crypto/ecdsa+ + crypto/internal/boring/sig from crypto/internal/boring + crypto/internal/entropy from crypto/internal/fips140/drbg + crypto/internal/fips140 from crypto/internal/fips140/aes+ + crypto/internal/fips140/aes from crypto/aes+ + crypto/internal/fips140/aes/gcm from crypto/cipher+ + crypto/internal/fips140/alias from crypto/cipher+ + crypto/internal/fips140/bigmod from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/check from crypto/internal/fips140/aes+ + crypto/internal/fips140/drbg from crypto/internal/fips140/aes/gcm+ + crypto/internal/fips140/ecdh from crypto/ecdh + crypto/internal/fips140/ecdsa from crypto/ecdsa + crypto/internal/fips140/ed25519 from crypto/ed25519 + crypto/internal/fips140/edwards25519 from crypto/internal/fips140/ed25519 + crypto/internal/fips140/edwards25519/field from crypto/ecdh+ + crypto/internal/fips140/hkdf from crypto/internal/fips140/tls13+ + crypto/internal/fips140/hmac from crypto/hmac+ + crypto/internal/fips140/mlkem from crypto/tls + crypto/internal/fips140/nistec from crypto/elliptic+ + crypto/internal/fips140/nistec/fiat from crypto/internal/fips140/nistec + crypto/internal/fips140/rsa from crypto/rsa + crypto/internal/fips140/sha256 from crypto/internal/fips140/check+ + crypto/internal/fips140/sha3 from crypto/internal/fips140/hmac+ + crypto/internal/fips140/sha512 from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/subtle from crypto/internal/fips140/aes+ + crypto/internal/fips140/tls12 from crypto/tls + crypto/internal/fips140/tls13 from crypto/tls + crypto/internal/fips140deps/byteorder from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/cpu from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/godebug from crypto/internal/fips140+ + crypto/internal/fips140hash from crypto/ecdsa+ + crypto/internal/fips140only from crypto/cipher+ + crypto/internal/hpke from crypto/tls + crypto/internal/impl from crypto/internal/fips140/aes+ + crypto/internal/randutil from crypto/dsa+ + crypto/internal/sysrand from crypto/internal/entropy+ crypto/md5 from crypto/tls+ crypto/rand from crypto/ed25519+ crypto/rc4 from crypto/tls+ crypto/rsa from crypto/tls+ crypto/sha1 from crypto/tls+ crypto/sha256 from crypto/tls+ + crypto/sha3 from crypto/internal/fips140hash crypto/sha512 from crypto/ecdsa+ - crypto/subtle from crypto/aes+ + crypto/subtle from crypto/cipher+ crypto/tls from github.com/aws/aws-sdk-go-v2/aws/transport/http+ + crypto/tls/internal/fips140tls from crypto/tls crypto/x509 from crypto/tls+ + D crypto/x509/internal/macos from crypto/x509 crypto/x509/pkix from crypto/x509+ database/sql from github.com/prometheus/client_golang/prometheus/collectors database/sql/driver from database/sql+ W debug/dwarf from debug/pe W debug/pe from github.com/dblohm7/wingoes/pe - embed from crypto/internal/nistec+ + embed from github.com/tailscale/web-client-prebuilt+ encoding from encoding/gob+ encoding/asn1 from crypto/x509+ encoding/base32 from github.com/fxamacker/cbor/v2+ @@ -948,12 +1135,54 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ go/scanner from go/ast+ go/token from go/ast+ hash from compress/zlib+ - hash/adler32 from compress/zlib+ + hash/adler32 from compress/zlib hash/crc32 from compress/gzip+ hash/fnv from google.golang.org/protobuf/internal/detrand hash/maphash from go4.org/mem html from html/template+ - html/template from github.com/gorilla/csrf + html/template from github.com/gorilla/csrf+ + internal/abi from crypto/x509/internal/macos+ + internal/asan from internal/runtime/maps+ + internal/bisect from internal/godebug + internal/bytealg from bytes+ + internal/byteorder from crypto/cipher+ + internal/chacha8rand from math/rand/v2+ + internal/coverage/rtcov from runtime + internal/cpu from crypto/internal/fips140deps/cpu+ + internal/filepathlite from os+ + internal/fmtsort from fmt+ + internal/goarch from crypto/internal/fips140deps/cpu+ + internal/godebug from archive/tar+ + internal/godebugs from internal/godebug+ + internal/goexperiment from hash/maphash+ + internal/goos from crypto/x509+ + internal/itoa from internal/poll+ + internal/lazyregexp from go/doc + internal/msan from internal/runtime/maps+ + internal/nettrace from net+ + internal/oserror from io/fs+ + internal/poll from net+ + internal/profile from net/http/pprof + internal/profilerecord from runtime+ + internal/race from internal/poll+ + internal/reflectlite from context+ + internal/runtime/atomic from internal/runtime/exithook+ + internal/runtime/exithook from runtime + internal/runtime/maps from reflect+ + internal/runtime/math from internal/runtime/maps+ + internal/runtime/sys from crypto/subtle+ + L internal/runtime/syscall from runtime+ + internal/saferio from debug/pe+ + internal/singleflight from net + internal/stringslite from embed+ + internal/sync from sync+ + internal/syscall/execenv from os+ + LD internal/syscall/unix from crypto/internal/sysrand+ + W internal/syscall/windows from crypto/internal/sysrand+ + W internal/syscall/windows/registry from mime+ + W internal/syscall/windows/sysdll from internal/syscall/windows+ + internal/testlog from os + internal/unsafeheader from internal/reflectlite+ io from archive/tar+ io/fs from archive/tar+ io/ioutil from github.com/aws/aws-sdk-go-v2/aws/protocol/query+ @@ -962,6 +1191,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ log/internal from log+ log/slog from github.com/go-logr/logr+ log/slog/internal from log/slog + log/slog/internal/buffer from log/slog maps from sigs.k8s.io/controller-runtime/pkg/predicate+ math from archive/tar+ math/big from crypto/dsa+ @@ -973,15 +1203,15 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ mime/quotedprintable from mime/multipart net from crypto/tls+ net/http from expvar+ - net/http/httptest from tailscale.com/control/controlclient net/http/httptrace from github.com/prometheus-community/pro-bing+ net/http/httputil from github.com/aws/smithy-go/transport/http+ net/http/internal from net/http+ + net/http/internal/ascii from net/http+ net/http/pprof from sigs.k8s.io/controller-runtime/pkg/manager+ net/netip from github.com/gaissmai/bart+ net/textproto from github.com/aws/aws-sdk-go-v2/aws/signer/v4+ net/url from crypto/x509+ - os from crypto/rand+ + os from crypto/internal/sysrand+ os/exec from github.com/aws/aws-sdk-go-v2/credentials/processcreds+ os/signal from sigs.k8s.io/controller-runtime/pkg/manager/signals os/user from archive/tar+ @@ -990,6 +1220,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ reflect from archive/tar+ regexp from github.com/aws/aws-sdk-go-v2/internal/endpoints+ regexp/syntax from regexp + runtime from archive/tar+ runtime/debug from github.com/aws/aws-sdk-go-v2/internal/sync/singleflight+ runtime/metrics from github.com/prometheus/client_golang/prometheus+ runtime/pprof from net/http/pprof+ @@ -1009,3 +1240,5 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ unicode/utf16 from crypto/x509+ unicode/utf8 from bufio+ unique from net/netip + unsafe from bytes+ + weak from unique diff --git a/cmd/k8s-operator/deploy/chart/templates/deployment.yaml b/cmd/k8s-operator/deploy/chart/templates/deployment.yaml index c428d5d1e751e..1b9b97186b6ca 100644 --- a/cmd/k8s-operator/deploy/chart/templates/deployment.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/deployment.yaml @@ -35,9 +35,13 @@ spec: {{- toYaml . | nindent 8 }} {{- end }} volumes: - - name: oauth - secret: - secretName: operator-oauth + - name: oauth + {{- with .Values.oauthSecretVolume }} + {{- toYaml . | nindent 10 }} + {{- else }} + secret: + secretName: operator-oauth + {{- end }} containers: - name: operator {{- with .Values.operatorConfig.securityContext }} @@ -81,6 +85,14 @@ spec: - name: PROXY_DEFAULT_CLASS value: {{ .Values.proxyConfig.defaultProxyClass }} {{- end }} + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid {{- with .Values.operatorConfig.extraEnv }} {{- toYaml . | nindent 12 }} {{- end }} diff --git a/cmd/k8s-operator/deploy/chart/templates/ingressclass.yaml b/cmd/k8s-operator/deploy/chart/templates/ingressclass.yaml index 2a1fa81b42793..208d58ee10f08 100644 --- a/cmd/k8s-operator/deploy/chart/templates/ingressclass.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/ingressclass.yaml @@ -1,3 +1,4 @@ +{{- if .Values.ingressClass.enabled }} apiVersion: networking.k8s.io/v1 kind: IngressClass metadata: @@ -6,3 +7,4 @@ metadata: spec: controller: tailscale.com/ts-ingress # controller name currently can not be changed # parameters: {} # currently no parameters are supported +{{- end }} diff --git a/cmd/k8s-operator/deploy/chart/templates/operator-rbac.yaml b/cmd/k8s-operator/deploy/chart/templates/operator-rbac.yaml index ede61070b4399..00d8318acdce4 100644 --- a/cmd/k8s-operator/deploy/chart/templates/operator-rbac.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/operator-rbac.yaml @@ -6,6 +6,10 @@ kind: ServiceAccount metadata: name: operator namespace: {{ .Release.Namespace }} + {{- with .Values.operatorConfig.serviceAccountAnnotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRole @@ -21,6 +25,9 @@ rules: - apiGroups: ["networking.k8s.io"] resources: ["ingressclasses"] verbs: ["get", "list", "watch"] +- apiGroups: ["discovery.k8s.io"] + resources: ["endpointslices"] + verbs: ["get", "list", "watch"] - apiGroups: ["tailscale.com"] resources: ["connectors", "connectors/status", "proxyclasses", "proxyclasses/status", "proxygroups", "proxygroups/status"] verbs: ["get", "list", "watch", "update"] @@ -30,6 +37,10 @@ rules: - apiGroups: ["tailscale.com"] resources: ["recorders", "recorders/status"] verbs: ["get", "list", "watch", "update"] +- apiGroups: ["apiextensions.k8s.io"] + resources: ["customresourcedefinitions"] + verbs: ["get", "list", "watch"] + resourceNames: ["servicemonitors.monitoring.coreos.com"] --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding @@ -55,7 +66,10 @@ rules: verbs: ["create","delete","deletecollection","get","list","patch","update","watch"] - apiGroups: [""] resources: ["pods"] - verbs: ["get","list","watch"] + verbs: ["get","list","watch", "update"] +- apiGroups: [""] + resources: ["pods/status"] + verbs: ["update"] - apiGroups: ["apps"] resources: ["statefulsets", "deployments"] verbs: ["create","delete","deletecollection","get","list","patch","update","watch"] @@ -64,7 +78,10 @@ rules: verbs: ["get", "list", "watch", "create", "update", "deletecollection"] - apiGroups: ["rbac.authorization.k8s.io"] resources: ["roles", "rolebindings"] - verbs: ["get", "create", "patch", "update", "list", "watch"] + verbs: ["get", "create", "patch", "update", "list", "watch", "deletecollection"] +- apiGroups: ["monitoring.coreos.com"] + resources: ["servicemonitors"] + verbs: ["get", "list", "update", "create", "delete"] --- apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding diff --git a/cmd/k8s-operator/deploy/chart/templates/proxy-rbac.yaml b/cmd/k8s-operator/deploy/chart/templates/proxy-rbac.yaml index 1c15c9119f971..fa552a7c7e39a 100644 --- a/cmd/k8s-operator/deploy/chart/templates/proxy-rbac.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/proxy-rbac.yaml @@ -16,6 +16,9 @@ rules: - apiGroups: [""] resources: ["secrets"] verbs: ["create","delete","deletecollection","get","list","patch","update","watch"] +- apiGroups: [""] + resources: ["events"] + verbs: ["create", "patch", "get"] --- apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding diff --git a/cmd/k8s-operator/deploy/chart/values.yaml b/cmd/k8s-operator/deploy/chart/values.yaml index e6f4cada44de7..2d1effc255dc5 100644 --- a/cmd/k8s-operator/deploy/chart/values.yaml +++ b/cmd/k8s-operator/deploy/chart/values.yaml @@ -3,11 +3,26 @@ # Operator oauth credentials. If set a Kubernetes Secret with the provided # values will be created in the operator namespace. If unset a Secret named -# operator-oauth must be precreated. +# operator-oauth must be precreated or oauthSecretVolume needs to be adjusted. +# This block will be overridden by oauthSecretVolume, if set. oauth: {} # clientId: "" # clientSecret: "" +# Secret volume. +# If set it defines the volume the oauth secrets will be mounted from. +# The volume needs to contain two files named `client_id` and `client_secret`. +# If unset the volume will reference the Secret named operator-oauth. +# This block will override the oauth block. +oauthSecretVolume: {} + # csi: + # driver: secrets-store.csi.k8s.io + # readOnly: true + # volumeAttributes: + # secretProviderClass: tailscale-oauth + # + ## NAME is pre-defined! + # installCRDs determines whether tailscale.com CRDs should be installed as part # of chart installation. We do not use Helm's CRD installation mechanism as that # does not allow for upgrading CRDs. @@ -40,6 +55,9 @@ operatorConfig: podAnnotations: {} podLabels: {} + serviceAccountAnnotations: {} + # eks.amazonaws.com/role-arn: arn:aws:iam::123456789012:role/tailscale-operator-role + tolerations: [] affinity: {} @@ -54,6 +72,9 @@ operatorConfig: # - name: EXTRA_VAR2 # value: "value2" +# In the case that you already have a tailscale ingressclass in your cluster (or vcluster), you can disable the creation here +ingressClass: + enabled: true # proxyConfig contains configuraton that will be applied to any ingress/egress # proxies created by the operator. diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_connectors.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_connectors.yaml index 9614f74e6b162..d645e39228062 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_connectors.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_connectors.yaml @@ -2,7 +2,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: connectors.tailscale.com spec: group: tailscale.com @@ -24,10 +24,17 @@ spec: jsonPath: .status.isExitNode name: IsExitNode type: string + - description: Whether this Connector instance is an app connector. + jsonPath: .status.isAppConnector + name: IsAppConnector + type: string - description: Status of the deployed Connector resources. jsonPath: .status.conditions[?(@.type == "ConnectorReady")].reason name: Status type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: @@ -66,10 +73,40 @@ spec: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#spec-and-status type: object properties: + appConnector: + description: |- + AppConnector defines whether the Connector device should act as a Tailscale app connector. A Connector that is + configured as an app connector cannot be a subnet router or an exit node. If this field is unset, the + Connector does not act as an app connector. + Note that you will need to manually configure the permissions and the domains for the app connector via the + Admin panel. + Note also that the main tested and supported use case of this config option is to deploy an app connector on + Kubernetes to access SaaS applications available on the public internet. Using the app connector to expose + cluster workloads or other internal workloads to tailnet might work, but this is not a use case that we have + tested or optimised for. + If you are using the app connector to access SaaS applications because you need a predictable egress IP that + can be whitelisted, it is also your responsibility to ensure that cluster traffic from the connector flows + via that predictable IP, for example by enforcing that cluster egress traffic is routed via an egress NAT + device with a static IP address. + https://tailscale.com/kb/1281/app-connectors + type: object + properties: + routes: + description: |- + Routes are optional preconfigured routes for the domains routed via the app connector. + If not set, routes for the domains will be discovered dynamically. + If set, the app connector will immediately be able to route traffic using the preconfigured routes, but may + also dynamically discover other routes. + https://tailscale.com/kb/1332/apps-best-practices#preconfiguration + type: array + minItems: 1 + items: + type: string + format: cidr exitNode: description: |- - ExitNode defines whether the Connector node should act as a - Tailscale exit node. Defaults to false. + ExitNode defines whether the Connector device should act as a Tailscale exit node. Defaults to false. + This field is mutually exclusive with the appConnector field. https://tailscale.com/kb/1103/exit-nodes type: boolean hostname: @@ -90,9 +127,11 @@ spec: type: string subnetRouter: description: |- - SubnetRouter defines subnet routes that the Connector node should - expose to tailnet. If unset, none are exposed. + SubnetRouter defines subnet routes that the Connector device should + expose to tailnet as a Tailscale subnet router. https://tailscale.com/kb/1019/subnets/ + If this field is unset, the device does not get configured as a Tailscale subnet router. + This field is mutually exclusive with the appConnector field. type: object required: - advertiseRoutes @@ -125,8 +164,10 @@ spec: type: string pattern: ^tag:[a-zA-Z][a-zA-Z0-9-]*$ x-kubernetes-validations: - - rule: has(self.subnetRouter) || self.exitNode == true - message: A Connector needs to be either an exit node or a subnet router, or both. + - rule: has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true) || has(self.appConnector) + message: A Connector needs to have at least one of exit node, subnet router or app connector configured. + - rule: '!((has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true)) && has(self.appConnector))' + message: The appConnector field is mutually exclusive with exitNode and subnetRouter fields. status: description: |- ConnectorStatus describes the status of the Connector. This is set @@ -200,6 +241,9 @@ spec: If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the node. type: string + isAppConnector: + description: IsAppConnector is set to true if the Connector acts as an app connector. + type: boolean isExitNode: description: IsExitNode is set to true if the Connector acts as an exit node. type: boolean diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_dnsconfigs.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_dnsconfigs.yaml index 13aee9b9e9ebf..a1e5c51d64ca2 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_dnsconfigs.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_dnsconfigs.yaml @@ -2,7 +2,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: dnsconfigs.tailscale.com spec: group: tailscale.com @@ -20,6 +20,9 @@ spec: jsonPath: .status.nameserver.ip name: NameserverIP type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: @@ -88,6 +91,122 @@ spec: when a DNSConfig is applied. type: object properties: + cmd: + description: Cmd can be used to overwrite the command used when running the nameserver image. + type: array + items: + type: string + env: + description: |- + Env can be used to pass environment variables to the nameserver + container. + type: array + items: + description: EnvVar represents an environment variable present in a Container. + type: object + required: + - name + properties: + name: + description: Name of the environment variable. Must be a C_IDENTIFIER. + type: string + value: + description: |- + Variable references $(VAR_NAME) are expanded + using the previously defined environment variables in the container and + any service environment variables. If a variable cannot be resolved, + the reference in the input string will be unchanged. Double $$ are reduced + to a single $, which allows for escaping the $(VAR_NAME) syntax: i.e. + "$$(VAR_NAME)" will produce the string literal "$(VAR_NAME)". + Escaped references will never be expanded, regardless of whether the variable + exists or not. + Defaults to "". + type: string + valueFrom: + description: Source for the environment variable's value. Cannot be used if value is not empty. + type: object + properties: + configMapKeyRef: + description: Selects a key of a ConfigMap. + type: object + required: + - key + properties: + key: + description: The key to select. + type: string + name: + description: |- + Name of the referent. + This field is effectively required, but due to backwards compatibility is + allowed to be empty. Instances of this type with an empty value here are + almost certainly wrong. + More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names + type: string + default: "" + optional: + description: Specify whether the ConfigMap or its key must be defined + type: boolean + x-kubernetes-map-type: atomic + fieldRef: + description: |- + Selects a field of the pod: supports metadata.name, metadata.namespace, `metadata.labels['']`, `metadata.annotations['']`, + spec.nodeName, spec.serviceAccountName, status.hostIP, status.podIP, status.podIPs. + type: object + required: + - fieldPath + properties: + apiVersion: + description: Version of the schema the FieldPath is written in terms of, defaults to "v1". + type: string + fieldPath: + description: Path of the field to select in the specified API version. + type: string + x-kubernetes-map-type: atomic + resourceFieldRef: + description: |- + Selects a resource of the container: only resources limits and requests + (limits.cpu, limits.memory, limits.ephemeral-storage, requests.cpu, requests.memory and requests.ephemeral-storage) are currently supported. + type: object + required: + - resource + properties: + containerName: + description: 'Container name: required for volumes, optional for env vars' + type: string + divisor: + description: Specifies the output format of the exposed resources, defaults to "1" + pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ + anyOf: + - type: integer + - type: string + x-kubernetes-int-or-string: true + resource: + description: 'Required: resource to select' + type: string + x-kubernetes-map-type: atomic + secretKeyRef: + description: Selects a key of a secret in the pod's namespace + type: object + required: + - key + properties: + key: + description: The key of the secret to select from. Must be a valid secret key. + type: string + name: + description: |- + Name of the referent. + This field is effectively required, but due to backwards compatibility is + allowed to be empty. Instances of this type with an empty value here are + almost certainly wrong. + More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names + type: string + default: "" + optional: + description: Specify whether the Secret or its key must be defined + type: boolean + x-kubernetes-map-type: atomic image: description: Nameserver image. Defaults to tailscale/k8s-nameserver:unstable. type: object @@ -98,6 +217,13 @@ spec: tag: description: Tag defaults to unstable. type: string + podLabels: + description: |- + PodLabels are the labels which will be attached to the nameserver + pod. They can be used to define network policies. + type: object + additionalProperties: + type: string status: description: |- Status describes the status of the DNSConfig. This is set diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml index 0fff30516a132..1541234755029 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_proxyclasses.yaml @@ -2,7 +2,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: proxyclasses.tailscale.com spec: group: tailscale.com @@ -18,6 +18,9 @@ spec: jsonPath: .status.conditions[?(@.type == "ProxyClassReady")].reason name: Status type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: @@ -73,9 +76,45 @@ spec: enable: description: |- Setting enable to true will make the proxy serve Tailscale metrics - at :9001/debug/metrics. + at :9002/metrics. + A metrics Service named -metrics will also be created in the operator's namespace and will + serve the metrics at :9002/metrics. + + In 1.78.x and 1.80.x, this field also serves as the default value for + .spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both + fields will independently default to false. + Defaults to false. type: boolean + serviceMonitor: + description: |- + Enable to create a Prometheus ServiceMonitor for scraping the proxy's Tailscale metrics. + The ServiceMonitor will select the metrics Service that gets created when metrics are enabled. + The ingested metrics for each Service monitor will have labels to identify the proxy: + ts_proxy_type: ingress_service|ingress_resource|connector|proxygroup + ts_proxy_parent_name: name of the parent resource (i.e name of the Connector, Tailscale Ingress, Tailscale Service or ProxyGroup) + ts_proxy_parent_namespace: namespace of the parent resource (if the parent resource is not cluster scoped) + job: ts__[]_ + type: object + required: + - enable + properties: + enable: + description: If Enable is set to true, a Prometheus ServiceMonitor will be created. Enable can only be set to true if metrics are enabled. + type: boolean + labels: + description: |- + Labels to add to the ServiceMonitor. + Labels must be valid Kubernetes labels. + https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set + type: object + additionalProperties: + type: string + maxLength: 63 + pattern: ^(([a-zA-Z0-9][-._a-zA-Z0-9]*)?[a-zA-Z0-9])?$ + x-kubernetes-validations: + - rule: '!(has(self.serviceMonitor) && self.serviceMonitor.enable && !self.enable)' + message: ServiceMonitor can only be enabled if metrics are enabled statefulSet: description: |- Configuration parameters for the proxy's StatefulSet. Tailscale @@ -107,6 +146,8 @@ spec: type: object additionalProperties: type: string + maxLength: 63 + pattern: ^(([a-zA-Z0-9][-._a-zA-Z0-9]*)?[a-zA-Z0-9])?$ pod: description: Configuration for the proxy Pod. type: object @@ -390,7 +431,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -405,7 +446,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -562,7 +603,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -577,7 +618,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -735,7 +776,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -750,7 +791,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -907,7 +948,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -922,7 +963,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -1036,6 +1077,8 @@ spec: type: object additionalProperties: type: string + maxLength: 63 + pattern: ^(([a-zA-Z0-9][-._a-zA-Z0-9]*)?[a-zA-Z0-9])?$ nodeName: description: |- Proxy Pod's node name. @@ -1134,6 +1177,32 @@ spec: Note that this field cannot be set when spec.os.name is windows. type: integer format: int64 + seLinuxChangePolicy: + description: |- + seLinuxChangePolicy defines how the container's SELinux label is applied to all volumes used by the Pod. + It has no effect on nodes that do not support SELinux or to volumes does not support SELinux. + Valid values are "MountOption" and "Recursive". + + "Recursive" means relabeling of all files on all Pod volumes by the container runtime. + This may be slow for large volumes, but allows mixing privileged and unprivileged Pods sharing the same volume on the same node. + + "MountOption" mounts all eligible Pod volumes with `-o context` mount option. + This requires all Pods that share the same volume to use the same SELinux label. + It is not possible to share the same volume among privileged and unprivileged Pods. + Eligible volumes are in-tree FibreChannel and iSCSI volumes, and all CSI volumes + whose CSI driver announces SELinux support by setting spec.seLinuxMount: true in their + CSIDriver instance. Other volumes are always re-labelled recursively. + "MountOption" value is allowed only when SELinuxMount feature gate is enabled. + + If not specified and SELinuxMount feature gate is enabled, "MountOption" is used. + If not specified and SELinuxMount feature gate is disabled, "MountOption" is used for ReadWriteOncePod volumes + and "Recursive" for all other volumes. + + This field affects only Pods that have SELinux label set, either in PodSecurityContext or in SecurityContext of all containers. + + All Pods that use the same volume should use the same seLinuxChangePolicy, otherwise some pods can get stuck in ContainerCreating state. + Note that this field cannot be set when spec.os.name is windows. + type: string seLinuxOptions: description: |- The SELinux context to be applied to all containers. @@ -1182,18 +1251,28 @@ spec: type: string supplementalGroups: description: |- - A list of groups applied to the first process run in each container, in addition - to the container's primary GID, the fsGroup (if specified), and group memberships - defined in the container image for the uid of the container process. If unspecified, - no additional groups are added to any container. Note that group memberships - defined in the container image for the uid of the container process are still effective, - even if they are not included in this list. + A list of groups applied to the first process run in each container, in + addition to the container's primary GID and fsGroup (if specified). If + the SupplementalGroupsPolicy feature is enabled, the + supplementalGroupsPolicy field determines whether these are in addition + to or instead of any group memberships defined in the container image. + If unspecified, no additional groups are added, though group memberships + defined in the container image may still be used, depending on the + supplementalGroupsPolicy field. Note that this field cannot be set when spec.os.name is windows. type: array items: type: integer format: int64 x-kubernetes-list-type: atomic + supplementalGroupsPolicy: + description: |- + Defines how supplemental groups of the first container processes are calculated. + Valid values are "Merge" and "Strict". If not specified, "Merge" is used. + (Alpha) Using the field requires the SupplementalGroupsPolicy feature gate to be enabled + and the container runtime must implement support for this feature. + Note that this field cannot be set when spec.os.name is windows. + type: string sysctls: description: |- Sysctls hold a list of namespaced sysctls used for the pod. Pods with unsupported @@ -1249,6 +1328,25 @@ spec: description: Configuration for the proxy container running tailscale. type: object properties: + debug: + description: |- + Configuration for enabling extra debug information in the container. + Not recommended for production use. + type: object + properties: + enable: + description: |- + Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/ + and internal debug metrics endpoint at :9001/debug/metrics, where + 9001 is a container port named "debug". The endpoints and their responses + may change in backwards incompatible ways in the future, and should not + be considered stable. + + In 1.78.x and 1.80.x, this setting will default to the value of + .spec.metrics.enable, and requests to the "metrics" port matching the + mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x, + this setting will default to false, and no requests will be proxied. + type: boolean env: description: |- List of environment variables to set in the container. @@ -1330,6 +1428,12 @@ spec: the Pod where this field is used. It makes that resource available inside a container. type: string + request: + description: |- + Request is the name chosen for a request in the referenced claim. + If empty, everything from the claim is made available, otherwise + only the result of this request. + type: string x-kubernetes-list-map-keys: - name x-kubernetes-list-type: map @@ -1360,11 +1464,12 @@ spec: securityContext: description: |- Container security context. - Security context specified here will override the security context by the operator. - By default the operator: - - sets 'privileged: true' for the init container - - set NET_ADMIN capability for tailscale container for proxies that - are created for Services or Connector. + Security context specified here will override the security context set by the operator. + By default the operator sets the Tailscale container and the Tailscale init container to privileged + for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup. + You can reduce the permissions of the Tailscale container to cap NET_ADMIN by + installing device plugin in your cluster and configuring the proxies tun device to be created + by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752 https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context type: object properties: @@ -1433,7 +1538,7 @@ spec: procMount: description: |- procMount denotes the type of proc mount to use for the containers. - The default is DefaultProcMount which uses the container runtime defaults for + The default value is Default which uses the container runtime defaults for readonly paths and masked paths. This requires the ProcMountType feature flag to be enabled. Note that this field cannot be set when spec.os.name is windows. @@ -1553,6 +1658,25 @@ spec: description: Configuration for the proxy init container that enables forwarding. type: object properties: + debug: + description: |- + Configuration for enabling extra debug information in the container. + Not recommended for production use. + type: object + properties: + enable: + description: |- + Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/ + and internal debug metrics endpoint at :9001/debug/metrics, where + 9001 is a container port named "debug". The endpoints and their responses + may change in backwards incompatible ways in the future, and should not + be considered stable. + + In 1.78.x and 1.80.x, this setting will default to the value of + .spec.metrics.enable, and requests to the "metrics" port matching the + mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x, + this setting will default to false, and no requests will be proxied. + type: boolean env: description: |- List of environment variables to set in the container. @@ -1634,6 +1758,12 @@ spec: the Pod where this field is used. It makes that resource available inside a container. type: string + request: + description: |- + Request is the name chosen for a request in the referenced claim. + If empty, everything from the claim is made available, otherwise + only the result of this request. + type: string x-kubernetes-list-map-keys: - name x-kubernetes-list-type: map @@ -1664,11 +1794,12 @@ spec: securityContext: description: |- Container security context. - Security context specified here will override the security context by the operator. - By default the operator: - - sets 'privileged: true' for the init container - - set NET_ADMIN capability for tailscale container for proxies that - are created for Services or Connector. + Security context specified here will override the security context set by the operator. + By default the operator sets the Tailscale container and the Tailscale init container to privileged + for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup. + You can reduce the permissions of the Tailscale container to cap NET_ADMIN by + installing device plugin in your cluster and configuring the proxies tun device to be created + by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752 https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context type: object properties: @@ -1737,7 +1868,7 @@ spec: procMount: description: |- procMount denotes the type of proc mount to use for the containers. - The default is DefaultProcMount which uses the container runtime defaults for + The default value is Default which uses the container runtime defaults for readonly paths and masked paths. This requires the ProcMountType feature flag to be enabled. Note that this field cannot be set when spec.os.name is windows. @@ -1896,6 +2027,182 @@ spec: Value is the taint value the toleration matches to. If the operator is Exists, the value should be empty, otherwise just a regular string. type: string + topologySpreadConstraints: + description: |- + Proxy Pod's topology spread constraints. + By default Tailscale Kubernetes operator does not apply any topology spread constraints. + https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/ + type: array + items: + description: TopologySpreadConstraint specifies how to spread matching pods among the given topology. + type: object + required: + - maxSkew + - topologyKey + - whenUnsatisfiable + properties: + labelSelector: + description: |- + LabelSelector is used to find matching pods. + Pods that match this label selector are counted to determine the number of pods + in their corresponding topology domain. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select the pods over which + spreading will be calculated. The keys are used to lookup values from the + incoming pod labels, those key-value labels are ANDed with labelSelector + to select the group of existing pods over which spreading will be calculated + for the incoming pod. The same key is forbidden to exist in both MatchLabelKeys and LabelSelector. + MatchLabelKeys cannot be set when LabelSelector isn't set. + Keys that don't exist in the incoming pod labels will + be ignored. A null or empty list means only match against labelSelector. + + This is a beta field and requires the MatchLabelKeysInPodTopologySpread feature gate to be enabled (enabled by default). + type: array + items: + type: string + x-kubernetes-list-type: atomic + maxSkew: + description: |- + MaxSkew describes the degree to which pods may be unevenly distributed. + When `whenUnsatisfiable=DoNotSchedule`, it is the maximum permitted difference + between the number of matching pods in the target topology and the global minimum. + The global minimum is the minimum number of matching pods in an eligible domain + or zero if the number of eligible domains is less than MinDomains. + For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same + labelSelector spread as 2/2/1: + In this case, the global minimum is 1. + | zone1 | zone2 | zone3 | + | P P | P P | P | + - if MaxSkew is 1, incoming pod can only be scheduled to zone3 to become 2/2/2; + scheduling it onto zone1(zone2) would make the ActualSkew(3-1) on zone1(zone2) + violate MaxSkew(1). + - if MaxSkew is 2, incoming pod can be scheduled onto any zone. + When `whenUnsatisfiable=ScheduleAnyway`, it is used to give higher precedence + to topologies that satisfy it. + It's a required field. Default value is 1 and 0 is not allowed. + type: integer + format: int32 + minDomains: + description: |- + MinDomains indicates a minimum number of eligible domains. + When the number of eligible domains with matching topology keys is less than minDomains, + Pod Topology Spread treats "global minimum" as 0, and then the calculation of Skew is performed. + And when the number of eligible domains with matching topology keys equals or greater than minDomains, + this value has no effect on scheduling. + As a result, when the number of eligible domains is less than minDomains, + scheduler won't schedule more than maxSkew Pods to those domains. + If value is nil, the constraint behaves as if MinDomains is equal to 1. + Valid values are integers greater than 0. + When value is not nil, WhenUnsatisfiable must be DoNotSchedule. + + For example, in a 3-zone cluster, MaxSkew is set to 2, MinDomains is set to 5 and pods with the same + labelSelector spread as 2/2/2: + | zone1 | zone2 | zone3 | + | P P | P P | P P | + The number of domains is less than 5(MinDomains), so "global minimum" is treated as 0. + In this situation, new pod with the same labelSelector cannot be scheduled, + because computed skew will be 3(3 - 0) if new Pod is scheduled to any of the three zones, + it will violate MaxSkew. + type: integer + format: int32 + nodeAffinityPolicy: + description: |- + NodeAffinityPolicy indicates how we will treat Pod's nodeAffinity/nodeSelector + when calculating pod topology spread skew. Options are: + - Honor: only nodes matching nodeAffinity/nodeSelector are included in the calculations. + - Ignore: nodeAffinity/nodeSelector are ignored. All nodes are included in the calculations. + + If this value is nil, the behavior is equivalent to the Honor policy. + This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag. + type: string + nodeTaintsPolicy: + description: |- + NodeTaintsPolicy indicates how we will treat node taints when calculating + pod topology spread skew. Options are: + - Honor: nodes without taints, along with tainted nodes for which the incoming pod + has a toleration, are included. + - Ignore: node taints are ignored. All nodes are included. + + If this value is nil, the behavior is equivalent to the Ignore policy. + This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag. + type: string + topologyKey: + description: |- + TopologyKey is the key of node labels. Nodes that have a label with this key + and identical values are considered to be in the same topology. + We consider each as a "bucket", and try to put balanced number + of pods into each bucket. + We define a domain as a particular instance of a topology. + Also, we define an eligible domain as a domain whose nodes meet the requirements of + nodeAffinityPolicy and nodeTaintsPolicy. + e.g. If TopologyKey is "kubernetes.io/hostname", each Node is a domain of that topology. + And, if TopologyKey is "topology.kubernetes.io/zone", each zone is a domain of that topology. + It's a required field. + type: string + whenUnsatisfiable: + description: |- + WhenUnsatisfiable indicates how to deal with a pod if it doesn't satisfy + the spread constraint. + - DoNotSchedule (default) tells the scheduler not to schedule it. + - ScheduleAnyway tells the scheduler to schedule the pod in any location, + but giving higher precedence to topologies that would help reduce the + skew. + A constraint is considered "Unsatisfiable" for an incoming pod + if and only if every possible node assignment for that pod would violate + "MaxSkew" on some topology. + For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same + labelSelector spread as 3/1/1: + | zone1 | zone2 | zone3 | + | P P P | P | P | + If WhenUnsatisfiable is set to DoNotSchedule, incoming pod can only be scheduled + to zone2(zone3) to become 3/2/1(3/1/2) as ActualSkew(2-1) on zone2(zone3) satisfies + MaxSkew(1). In other words, the cluster can still be imbalanced, but scheduler + won't make it *more* imbalanced. + It's a required field. + type: string tailscale: description: |- TailscaleConfig contains options to configure the tailscale-specific @@ -1911,6 +2218,22 @@ spec: https://tailscale.com/kb/1019/subnets#use-your-subnet-routes-from-other-devices Defaults to false. type: boolean + useLetsEncryptStagingEnvironment: + description: |- + Set UseLetsEncryptStagingEnvironment to true to issue TLS + certificates for any HTTPS endpoints exposed to the tailnet from + LetsEncrypt's staging environment. + https://letsencrypt.org/docs/staging-environment/ + This setting only affects Tailscale Ingress resources. + By default Ingress TLS certificates are issued from LetsEncrypt's + production environment. + Changing this setting true -> false, will result in any + existing certs being re-issued from the production environment. + Changing this setting false (default) -> true, when certs have already + been provisioned from production environment will NOT result in certs + being re-issued from the staging environment before they need to be + renewed. + type: boolean status: description: |- Status of the ProxyClass. This is set and managed automatically. diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_proxygroups.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_proxygroups.yaml index 66701bdf4afbd..4b9149e23e55b 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_proxygroups.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_proxygroups.yaml @@ -2,7 +2,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: proxygroups.tailscale.com spec: group: tailscale.com @@ -20,9 +20,27 @@ spec: jsonPath: .status.conditions[?(@.type == "ProxyGroupReady")].reason name: Status type: string + - description: ProxyGroup type. + jsonPath: .spec.type + name: Type + type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: + description: |- + ProxyGroup defines a set of Tailscale devices that will act as proxies. + Currently only egress ProxyGroups are supported. + + Use the tailscale.com/proxy-group annotation on a Service to specify that + the egress proxy should be implemented by a ProxyGroup instead of a single + dedicated proxy. In addition to running a highly available set of proxies, + ProxyGroup also allows for serving many annotated Services from a single + set of proxies to minimise resource consumption. + + More info: https://tailscale.com/kb/1438/kubernetes-operator-cluster-egress type: object required: - spec @@ -73,6 +91,7 @@ spec: Defaults to 2. type: integer format: int32 + minimum: 0 tags: description: |- Tags that the Tailscale devices will be tagged with. Defaults to [tag:k8s]. @@ -86,10 +105,16 @@ spec: type: string pattern: ^tag:[a-zA-Z][a-zA-Z0-9-]*$ type: - description: Type of the ProxyGroup proxies. Currently the only supported type is egress. + description: |- + Type of the ProxyGroup proxies. Supported types are egress and ingress. + Type is immutable once a ProxyGroup is created. type: string enum: - egress + - ingress + x-kubernetes-validations: + - rule: self == oldSelf + message: ProxyGroup type is immutable status: description: |- ProxyGroupStatus describes the status of the ProxyGroup resources. This is diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_recorders.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_recorders.yaml index fda8bcebdbe53..0f3dcfcca52c8 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_recorders.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_recorders.yaml @@ -2,7 +2,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: recorders.tailscale.com spec: group: tailscale.com @@ -24,9 +24,18 @@ spec: jsonPath: .status.devices[?(@.url != "")].url name: URL type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: + description: |- + Recorder defines a tsrecorder device for recording SSH sessions. By default, + it will store recordings in a local ephemeral volume. If you want to persist + recordings, you can configure an S3-compatible API for storage. + + More info: https://tailscale.com/kb/1484/kubernetes-operator-deploying-tsrecorder type: object required: - spec @@ -366,7 +375,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -381,7 +390,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -538,7 +547,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -553,7 +562,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -711,7 +720,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -726,7 +735,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -883,7 +892,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -898,7 +907,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). type: array items: type: string @@ -1060,6 +1069,12 @@ spec: the Pod where this field is used. It makes that resource available inside a container. type: string + request: + description: |- + Request is the name chosen for a request in the referenced claim. + If empty, everything from the claim is made available, otherwise + only the result of this request. + type: string x-kubernetes-list-map-keys: - name x-kubernetes-list-type: map @@ -1159,7 +1174,7 @@ spec: procMount: description: |- procMount denotes the type of proc mount to use for the containers. - The default is DefaultProcMount which uses the container runtime defaults for + The default value is Default which uses the container runtime defaults for readonly paths and masked paths. This requires the ProcMountType feature flag to be enabled. Note that this field cannot be set when spec.os.name is windows. @@ -1395,6 +1410,32 @@ spec: Note that this field cannot be set when spec.os.name is windows. type: integer format: int64 + seLinuxChangePolicy: + description: |- + seLinuxChangePolicy defines how the container's SELinux label is applied to all volumes used by the Pod. + It has no effect on nodes that do not support SELinux or to volumes does not support SELinux. + Valid values are "MountOption" and "Recursive". + + "Recursive" means relabeling of all files on all Pod volumes by the container runtime. + This may be slow for large volumes, but allows mixing privileged and unprivileged Pods sharing the same volume on the same node. + + "MountOption" mounts all eligible Pod volumes with `-o context` mount option. + This requires all Pods that share the same volume to use the same SELinux label. + It is not possible to share the same volume among privileged and unprivileged Pods. + Eligible volumes are in-tree FibreChannel and iSCSI volumes, and all CSI volumes + whose CSI driver announces SELinux support by setting spec.seLinuxMount: true in their + CSIDriver instance. Other volumes are always re-labelled recursively. + "MountOption" value is allowed only when SELinuxMount feature gate is enabled. + + If not specified and SELinuxMount feature gate is enabled, "MountOption" is used. + If not specified and SELinuxMount feature gate is disabled, "MountOption" is used for ReadWriteOncePod volumes + and "Recursive" for all other volumes. + + This field affects only Pods that have SELinux label set, either in PodSecurityContext or in SecurityContext of all containers. + + All Pods that use the same volume should use the same seLinuxChangePolicy, otherwise some pods can get stuck in ContainerCreating state. + Note that this field cannot be set when spec.os.name is windows. + type: string seLinuxOptions: description: |- The SELinux context to be applied to all containers. @@ -1443,18 +1484,28 @@ spec: type: string supplementalGroups: description: |- - A list of groups applied to the first process run in each container, in addition - to the container's primary GID, the fsGroup (if specified), and group memberships - defined in the container image for the uid of the container process. If unspecified, - no additional groups are added to any container. Note that group memberships - defined in the container image for the uid of the container process are still effective, - even if they are not included in this list. + A list of groups applied to the first process run in each container, in + addition to the container's primary GID and fsGroup (if specified). If + the SupplementalGroupsPolicy feature is enabled, the + supplementalGroupsPolicy field determines whether these are in addition + to or instead of any group memberships defined in the container image. + If unspecified, no additional groups are added, though group memberships + defined in the container image may still be used, depending on the + supplementalGroupsPolicy field. Note that this field cannot be set when spec.os.name is windows. type: array items: type: integer format: int64 x-kubernetes-list-type: atomic + supplementalGroupsPolicy: + description: |- + Defines how supplemental groups of the first container processes are calculated. + Valid values are "Merge" and "Strict". If not specified, "Merge" is used. + (Alpha) Using the field requires the SupplementalGroupsPolicy feature gate to be enabled + and the container runtime must implement support for this feature. + Note that this field cannot be set when spec.os.name is windows. + type: string sysctls: description: |- Sysctls hold a list of namespaced sysctls used for the pod. Pods with unsupported @@ -1506,6 +1557,36 @@ spec: May also be set in PodSecurityContext. If set in both SecurityContext and PodSecurityContext, the value specified in SecurityContext takes precedence. type: string + serviceAccount: + description: |- + Config for the ServiceAccount to create for the Recorder's StatefulSet. + By default, the operator will create a ServiceAccount with the same + name as the Recorder resource. + https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#service-account + type: object + properties: + annotations: + description: |- + Annotations to add to the ServiceAccount. + https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/#syntax-and-character-set + + You can use this to add IAM roles to the ServiceAccount (IRSA) instead of + providing static S3 credentials in a Secret. + https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html + + For example: + eks.amazonaws.com/role-arn: arn:aws:iam:::role/ + type: object + additionalProperties: + type: string + name: + description: |- + Name of the ServiceAccount to create. Defaults to the name of the + Recorder resource. + https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#service-account + type: string + maxLength: 253 + pattern: ^[a-z0-9]([a-z0-9-.]{0,61}[a-z0-9])?$ tolerations: description: |- Tolerations for Recorder Pods. By default, the operator does not apply diff --git a/cmd/k8s-operator/deploy/manifests/nameserver/deploy.yaml b/cmd/k8s-operator/deploy/manifests/nameserver/deploy.yaml index c3a16e03e9a42..57ad783755453 100644 --- a/cmd/k8s-operator/deploy/manifests/nameserver/deploy.yaml +++ b/cmd/k8s-operator/deploy/manifests/nameserver/deploy.yaml @@ -18,6 +18,11 @@ spec: containers: - imagePullPolicy: IfNotPresent name: nameserver + env: + - name: POD_NAMESPACE + valueFrom: + fieldRef: + fieldPath: metadata.namespace ports: - name: tcp protocol: TCP diff --git a/cmd/k8s-operator/deploy/manifests/nameserver/role.yaml b/cmd/k8s-operator/deploy/manifests/nameserver/role.yaml new file mode 100644 index 0000000000000..4771ab3115c17 --- /dev/null +++ b/cmd/k8s-operator/deploy/manifests/nameserver/role.yaml @@ -0,0 +1,9 @@ +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: dnsrecords-watcher +rules: + - apiGroups: [""] + resources: ["configmaps"] + resourceNames: ["dnsrecords"] + verbs: ["get", "list", "watch"] diff --git a/cmd/k8s-operator/deploy/manifests/nameserver/rolebinding.yaml b/cmd/k8s-operator/deploy/manifests/nameserver/rolebinding.yaml new file mode 100644 index 0000000000000..54498462f8b64 --- /dev/null +++ b/cmd/k8s-operator/deploy/manifests/nameserver/rolebinding.yaml @@ -0,0 +1,12 @@ +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: dnsrecords-watcher-binding +subjects: + - kind: ServiceAccount + name: nameserver + namespace: tailscale +roleRef: + kind: Role + name: dnsrecords-watcher + apiGroup: rbac.authorization.k8s.io diff --git a/cmd/k8s-operator/deploy/manifests/operator.yaml b/cmd/k8s-operator/deploy/manifests/operator.yaml index 1a812b7362757..02415400a2bb3 100644 --- a/cmd/k8s-operator/deploy/manifests/operator.yaml +++ b/cmd/k8s-operator/deploy/manifests/operator.yaml @@ -31,7 +31,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: connectors.tailscale.com spec: group: tailscale.com @@ -53,10 +53,17 @@ spec: jsonPath: .status.isExitNode name: IsExitNode type: string + - description: Whether this Connector instance is an app connector. + jsonPath: .status.isAppConnector + name: IsAppConnector + type: string - description: Status of the deployed Connector resources. jsonPath: .status.conditions[?(@.type == "ConnectorReady")].reason name: Status type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: @@ -91,10 +98,40 @@ spec: More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#spec-and-status properties: + appConnector: + description: |- + AppConnector defines whether the Connector device should act as a Tailscale app connector. A Connector that is + configured as an app connector cannot be a subnet router or an exit node. If this field is unset, the + Connector does not act as an app connector. + Note that you will need to manually configure the permissions and the domains for the app connector via the + Admin panel. + Note also that the main tested and supported use case of this config option is to deploy an app connector on + Kubernetes to access SaaS applications available on the public internet. Using the app connector to expose + cluster workloads or other internal workloads to tailnet might work, but this is not a use case that we have + tested or optimised for. + If you are using the app connector to access SaaS applications because you need a predictable egress IP that + can be whitelisted, it is also your responsibility to ensure that cluster traffic from the connector flows + via that predictable IP, for example by enforcing that cluster egress traffic is routed via an egress NAT + device with a static IP address. + https://tailscale.com/kb/1281/app-connectors + properties: + routes: + description: |- + Routes are optional preconfigured routes for the domains routed via the app connector. + If not set, routes for the domains will be discovered dynamically. + If set, the app connector will immediately be able to route traffic using the preconfigured routes, but may + also dynamically discover other routes. + https://tailscale.com/kb/1332/apps-best-practices#preconfiguration + items: + format: cidr + type: string + minItems: 1 + type: array + type: object exitNode: description: |- - ExitNode defines whether the Connector node should act as a - Tailscale exit node. Defaults to false. + ExitNode defines whether the Connector device should act as a Tailscale exit node. Defaults to false. + This field is mutually exclusive with the appConnector field. https://tailscale.com/kb/1103/exit-nodes type: boolean hostname: @@ -115,9 +152,11 @@ spec: type: string subnetRouter: description: |- - SubnetRouter defines subnet routes that the Connector node should - expose to tailnet. If unset, none are exposed. + SubnetRouter defines subnet routes that the Connector device should + expose to tailnet as a Tailscale subnet router. https://tailscale.com/kb/1019/subnets/ + If this field is unset, the device does not get configured as a Tailscale subnet router. + This field is mutually exclusive with the appConnector field. properties: advertiseRoutes: description: |- @@ -151,8 +190,10 @@ spec: type: array type: object x-kubernetes-validations: - - message: A Connector needs to be either an exit node or a subnet router, or both. - rule: has(self.subnetRouter) || self.exitNode == true + - message: A Connector needs to have at least one of exit node, subnet router or app connector configured. + rule: has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true) || has(self.appConnector) + - message: The appConnector field is mutually exclusive with exitNode and subnetRouter fields. + rule: '!((has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true)) && has(self.appConnector))' status: description: |- ConnectorStatus describes the status of the Connector. This is set @@ -225,6 +266,9 @@ spec: If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the node. type: string + isAppConnector: + description: IsAppConnector is set to true if the Connector acts as an app connector. + type: boolean isExitNode: description: IsExitNode is set to true if the Connector acts as an exit node. type: boolean @@ -253,7 +297,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: dnsconfigs.tailscale.com spec: group: tailscale.com @@ -271,6 +315,9 @@ spec: jsonPath: .status.nameserver.ip name: NameserverIP type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: @@ -332,6 +379,122 @@ spec: Tailscale Ingresses. The operator will always deploy this nameserver when a DNSConfig is applied. properties: + cmd: + description: Cmd can be used to overwrite the command used when running the nameserver image. + items: + type: string + type: array + env: + description: |- + Env can be used to pass environment variables to the nameserver + container. + items: + description: EnvVar represents an environment variable present in a Container. + properties: + name: + description: Name of the environment variable. Must be a C_IDENTIFIER. + type: string + value: + description: |- + Variable references $(VAR_NAME) are expanded + using the previously defined environment variables in the container and + any service environment variables. If a variable cannot be resolved, + the reference in the input string will be unchanged. Double $$ are reduced + to a single $, which allows for escaping the $(VAR_NAME) syntax: i.e. + "$$(VAR_NAME)" will produce the string literal "$(VAR_NAME)". + Escaped references will never be expanded, regardless of whether the variable + exists or not. + Defaults to "". + type: string + valueFrom: + description: Source for the environment variable's value. Cannot be used if value is not empty. + properties: + configMapKeyRef: + description: Selects a key of a ConfigMap. + properties: + key: + description: The key to select. + type: string + name: + default: "" + description: |- + Name of the referent. + This field is effectively required, but due to backwards compatibility is + allowed to be empty. Instances of this type with an empty value here are + almost certainly wrong. + More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names + type: string + optional: + description: Specify whether the ConfigMap or its key must be defined + type: boolean + required: + - key + type: object + x-kubernetes-map-type: atomic + fieldRef: + description: |- + Selects a field of the pod: supports metadata.name, metadata.namespace, `metadata.labels['']`, `metadata.annotations['']`, + spec.nodeName, spec.serviceAccountName, status.hostIP, status.podIP, status.podIPs. + properties: + apiVersion: + description: Version of the schema the FieldPath is written in terms of, defaults to "v1". + type: string + fieldPath: + description: Path of the field to select in the specified API version. + type: string + required: + - fieldPath + type: object + x-kubernetes-map-type: atomic + resourceFieldRef: + description: |- + Selects a resource of the container: only resources limits and requests + (limits.cpu, limits.memory, limits.ephemeral-storage, requests.cpu, requests.memory and requests.ephemeral-storage) are currently supported. + properties: + containerName: + description: 'Container name: required for volumes, optional for env vars' + type: string + divisor: + anyOf: + - type: integer + - type: string + description: Specifies the output format of the exposed resources, defaults to "1" + pattern: ^(\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))(([KMGTPE]i)|[numkMGTPE]|([eE](\+|-)?(([0-9]+(\.[0-9]*)?)|(\.[0-9]+))))?$ + x-kubernetes-int-or-string: true + resource: + description: 'Required: resource to select' + type: string + required: + - resource + type: object + x-kubernetes-map-type: atomic + secretKeyRef: + description: Selects a key of a secret in the pod's namespace + properties: + key: + description: The key of the secret to select from. Must be a valid secret key. + type: string + name: + default: "" + description: |- + Name of the referent. + This field is effectively required, but due to backwards compatibility is + allowed to be empty. Instances of this type with an empty value here are + almost certainly wrong. + More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names + type: string + optional: + description: Specify whether the Secret or its key must be defined + type: boolean + required: + - key + type: object + x-kubernetes-map-type: atomic + type: object + required: + - name + type: object + type: array image: description: Nameserver image. Defaults to tailscale/k8s-nameserver:unstable. properties: @@ -342,6 +505,13 @@ spec: description: Tag defaults to unstable. type: string type: object + podLabels: + additionalProperties: + type: string + description: |- + PodLabels are the labels which will be attached to the nameserver + pod. They can be used to define network policies. + type: object type: object required: - nameserver @@ -435,7 +605,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: proxyclasses.tailscale.com spec: group: tailscale.com @@ -451,6 +621,9 @@ spec: jsonPath: .status.conditions[?(@.type == "ProxyClassReady")].reason name: Status type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: @@ -499,12 +672,48 @@ spec: enable: description: |- Setting enable to true will make the proxy serve Tailscale metrics - at :9001/debug/metrics. + at :9002/metrics. + A metrics Service named -metrics will also be created in the operator's namespace and will + serve the metrics at :9002/metrics. + + In 1.78.x and 1.80.x, this field also serves as the default value for + .spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both + fields will independently default to false. + Defaults to false. type: boolean + serviceMonitor: + description: |- + Enable to create a Prometheus ServiceMonitor for scraping the proxy's Tailscale metrics. + The ServiceMonitor will select the metrics Service that gets created when metrics are enabled. + The ingested metrics for each Service monitor will have labels to identify the proxy: + ts_proxy_type: ingress_service|ingress_resource|connector|proxygroup + ts_proxy_parent_name: name of the parent resource (i.e name of the Connector, Tailscale Ingress, Tailscale Service or ProxyGroup) + ts_proxy_parent_namespace: namespace of the parent resource (if the parent resource is not cluster scoped) + job: ts__[]_ + properties: + enable: + description: If Enable is set to true, a Prometheus ServiceMonitor will be created. Enable can only be set to true if metrics are enabled. + type: boolean + labels: + additionalProperties: + maxLength: 63 + pattern: ^(([a-zA-Z0-9][-._a-zA-Z0-9]*)?[a-zA-Z0-9])?$ + type: string + description: |- + Labels to add to the ServiceMonitor. + Labels must be valid Kubernetes labels. + https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set + type: object + required: + - enable + type: object required: - enable type: object + x-kubernetes-validations: + - message: ServiceMonitor can only be enabled if metrics are enabled + rule: '!(has(self.serviceMonitor) && self.serviceMonitor.enable && !self.enable)' statefulSet: description: |- Configuration parameters for the proxy's StatefulSet. Tailscale @@ -525,6 +734,8 @@ spec: type: object labels: additionalProperties: + maxLength: 63 + pattern: ^(([a-zA-Z0-9][-._a-zA-Z0-9]*)?[a-zA-Z0-9])?$ type: string description: |- Labels that will be added to the StatefulSet created for the proxy. @@ -807,7 +1018,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -822,7 +1033,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -983,7 +1194,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -998,7 +1209,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -1152,7 +1363,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -1167,7 +1378,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -1328,7 +1539,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -1343,7 +1554,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -1455,6 +1666,8 @@ spec: type: array labels: additionalProperties: + maxLength: 63 + pattern: ^(([a-zA-Z0-9][-._a-zA-Z0-9]*)?[a-zA-Z0-9])?$ type: string description: |- Labels that will be added to the proxy Pod. @@ -1560,6 +1773,32 @@ spec: Note that this field cannot be set when spec.os.name is windows. format: int64 type: integer + seLinuxChangePolicy: + description: |- + seLinuxChangePolicy defines how the container's SELinux label is applied to all volumes used by the Pod. + It has no effect on nodes that do not support SELinux or to volumes does not support SELinux. + Valid values are "MountOption" and "Recursive". + + "Recursive" means relabeling of all files on all Pod volumes by the container runtime. + This may be slow for large volumes, but allows mixing privileged and unprivileged Pods sharing the same volume on the same node. + + "MountOption" mounts all eligible Pod volumes with `-o context` mount option. + This requires all Pods that share the same volume to use the same SELinux label. + It is not possible to share the same volume among privileged and unprivileged Pods. + Eligible volumes are in-tree FibreChannel and iSCSI volumes, and all CSI volumes + whose CSI driver announces SELinux support by setting spec.seLinuxMount: true in their + CSIDriver instance. Other volumes are always re-labelled recursively. + "MountOption" value is allowed only when SELinuxMount feature gate is enabled. + + If not specified and SELinuxMount feature gate is enabled, "MountOption" is used. + If not specified and SELinuxMount feature gate is disabled, "MountOption" is used for ReadWriteOncePod volumes + and "Recursive" for all other volumes. + + This field affects only Pods that have SELinux label set, either in PodSecurityContext or in SecurityContext of all containers. + + All Pods that use the same volume should use the same seLinuxChangePolicy, otherwise some pods can get stuck in ContainerCreating state. + Note that this field cannot be set when spec.os.name is windows. + type: string seLinuxOptions: description: |- The SELinux context to be applied to all containers. @@ -1608,18 +1847,28 @@ spec: type: object supplementalGroups: description: |- - A list of groups applied to the first process run in each container, in addition - to the container's primary GID, the fsGroup (if specified), and group memberships - defined in the container image for the uid of the container process. If unspecified, - no additional groups are added to any container. Note that group memberships - defined in the container image for the uid of the container process are still effective, - even if they are not included in this list. + A list of groups applied to the first process run in each container, in + addition to the container's primary GID and fsGroup (if specified). If + the SupplementalGroupsPolicy feature is enabled, the + supplementalGroupsPolicy field determines whether these are in addition + to or instead of any group memberships defined in the container image. + If unspecified, no additional groups are added, though group memberships + defined in the container image may still be used, depending on the + supplementalGroupsPolicy field. Note that this field cannot be set when spec.os.name is windows. items: format: int64 type: integer type: array x-kubernetes-list-type: atomic + supplementalGroupsPolicy: + description: |- + Defines how supplemental groups of the first container processes are calculated. + Valid values are "Merge" and "Strict". If not specified, "Merge" is used. + (Alpha) Using the field requires the SupplementalGroupsPolicy feature gate to be enabled + and the container runtime must implement support for this feature. + Note that this field cannot be set when spec.os.name is windows. + type: string sysctls: description: |- Sysctls hold a list of namespaced sysctls used for the pod. Pods with unsupported @@ -1675,6 +1924,25 @@ spec: tailscaleContainer: description: Configuration for the proxy container running tailscale. properties: + debug: + description: |- + Configuration for enabling extra debug information in the container. + Not recommended for production use. + properties: + enable: + description: |- + Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/ + and internal debug metrics endpoint at :9001/debug/metrics, where + 9001 is a container port named "debug". The endpoints and their responses + may change in backwards incompatible ways in the future, and should not + be considered stable. + + In 1.78.x and 1.80.x, this setting will default to the value of + .spec.metrics.enable, and requests to the "metrics" port matching the + mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x, + this setting will default to false, and no requests will be proxied. + type: boolean + type: object env: description: |- List of environment variables to set in the container. @@ -1751,6 +2019,12 @@ spec: the Pod where this field is used. It makes that resource available inside a container. type: string + request: + description: |- + Request is the name chosen for a request in the referenced claim. + If empty, everything from the claim is made available, otherwise + only the result of this request. + type: string required: - name type: object @@ -1786,11 +2060,12 @@ spec: securityContext: description: |- Container security context. - Security context specified here will override the security context by the operator. - By default the operator: - - sets 'privileged: true' for the init container - - set NET_ADMIN capability for tailscale container for proxies that - are created for Services or Connector. + Security context specified here will override the security context set by the operator. + By default the operator sets the Tailscale container and the Tailscale init container to privileged + for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup. + You can reduce the permissions of the Tailscale container to cap NET_ADMIN by + installing device plugin in your cluster and configuring the proxies tun device to be created + by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752 https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context properties: allowPrivilegeEscalation: @@ -1858,7 +2133,7 @@ spec: procMount: description: |- procMount denotes the type of proc mount to use for the containers. - The default is DefaultProcMount which uses the container runtime defaults for + The default value is Default which uses the container runtime defaults for readonly paths and masked paths. This requires the ProcMountType feature flag to be enabled. Note that this field cannot be set when spec.os.name is windows. @@ -1979,6 +2254,25 @@ spec: tailscaleInitContainer: description: Configuration for the proxy init container that enables forwarding. properties: + debug: + description: |- + Configuration for enabling extra debug information in the container. + Not recommended for production use. + properties: + enable: + description: |- + Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/ + and internal debug metrics endpoint at :9001/debug/metrics, where + 9001 is a container port named "debug". The endpoints and their responses + may change in backwards incompatible ways in the future, and should not + be considered stable. + + In 1.78.x and 1.80.x, this setting will default to the value of + .spec.metrics.enable, and requests to the "metrics" port matching the + mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x, + this setting will default to false, and no requests will be proxied. + type: boolean + type: object env: description: |- List of environment variables to set in the container. @@ -2055,6 +2349,12 @@ spec: the Pod where this field is used. It makes that resource available inside a container. type: string + request: + description: |- + Request is the name chosen for a request in the referenced claim. + If empty, everything from the claim is made available, otherwise + only the result of this request. + type: string required: - name type: object @@ -2090,11 +2390,12 @@ spec: securityContext: description: |- Container security context. - Security context specified here will override the security context by the operator. - By default the operator: - - sets 'privileged: true' for the init container - - set NET_ADMIN capability for tailscale container for proxies that - are created for Services or Connector. + Security context specified here will override the security context set by the operator. + By default the operator sets the Tailscale container and the Tailscale init container to privileged + for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup. + You can reduce the permissions of the Tailscale container to cap NET_ADMIN by + installing device plugin in your cluster and configuring the proxies tun device to be created + by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752 https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context properties: allowPrivilegeEscalation: @@ -2162,7 +2463,7 @@ spec: procMount: description: |- procMount denotes the type of proc mount to use for the containers. - The default is DefaultProcMount which uses the container runtime defaults for + The default value is Default which uses the container runtime defaults for readonly paths and masked paths. This requires the ProcMountType feature flag to be enabled. Note that this field cannot be set when spec.os.name is windows. @@ -2323,6 +2624,182 @@ spec: type: string type: object type: array + topologySpreadConstraints: + description: |- + Proxy Pod's topology spread constraints. + By default Tailscale Kubernetes operator does not apply any topology spread constraints. + https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/ + items: + description: TopologySpreadConstraint specifies how to spread matching pods among the given topology. + properties: + labelSelector: + description: |- + LabelSelector is used to find matching pods. + Pods that match this label selector are counted to determine the number of pods + in their corresponding topology domain. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select the pods over which + spreading will be calculated. The keys are used to lookup values from the + incoming pod labels, those key-value labels are ANDed with labelSelector + to select the group of existing pods over which spreading will be calculated + for the incoming pod. The same key is forbidden to exist in both MatchLabelKeys and LabelSelector. + MatchLabelKeys cannot be set when LabelSelector isn't set. + Keys that don't exist in the incoming pod labels will + be ignored. A null or empty list means only match against labelSelector. + + This is a beta field and requires the MatchLabelKeysInPodTopologySpread feature gate to be enabled (enabled by default). + items: + type: string + type: array + x-kubernetes-list-type: atomic + maxSkew: + description: |- + MaxSkew describes the degree to which pods may be unevenly distributed. + When `whenUnsatisfiable=DoNotSchedule`, it is the maximum permitted difference + between the number of matching pods in the target topology and the global minimum. + The global minimum is the minimum number of matching pods in an eligible domain + or zero if the number of eligible domains is less than MinDomains. + For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same + labelSelector spread as 2/2/1: + In this case, the global minimum is 1. + | zone1 | zone2 | zone3 | + | P P | P P | P | + - if MaxSkew is 1, incoming pod can only be scheduled to zone3 to become 2/2/2; + scheduling it onto zone1(zone2) would make the ActualSkew(3-1) on zone1(zone2) + violate MaxSkew(1). + - if MaxSkew is 2, incoming pod can be scheduled onto any zone. + When `whenUnsatisfiable=ScheduleAnyway`, it is used to give higher precedence + to topologies that satisfy it. + It's a required field. Default value is 1 and 0 is not allowed. + format: int32 + type: integer + minDomains: + description: |- + MinDomains indicates a minimum number of eligible domains. + When the number of eligible domains with matching topology keys is less than minDomains, + Pod Topology Spread treats "global minimum" as 0, and then the calculation of Skew is performed. + And when the number of eligible domains with matching topology keys equals or greater than minDomains, + this value has no effect on scheduling. + As a result, when the number of eligible domains is less than minDomains, + scheduler won't schedule more than maxSkew Pods to those domains. + If value is nil, the constraint behaves as if MinDomains is equal to 1. + Valid values are integers greater than 0. + When value is not nil, WhenUnsatisfiable must be DoNotSchedule. + + For example, in a 3-zone cluster, MaxSkew is set to 2, MinDomains is set to 5 and pods with the same + labelSelector spread as 2/2/2: + | zone1 | zone2 | zone3 | + | P P | P P | P P | + The number of domains is less than 5(MinDomains), so "global minimum" is treated as 0. + In this situation, new pod with the same labelSelector cannot be scheduled, + because computed skew will be 3(3 - 0) if new Pod is scheduled to any of the three zones, + it will violate MaxSkew. + format: int32 + type: integer + nodeAffinityPolicy: + description: |- + NodeAffinityPolicy indicates how we will treat Pod's nodeAffinity/nodeSelector + when calculating pod topology spread skew. Options are: + - Honor: only nodes matching nodeAffinity/nodeSelector are included in the calculations. + - Ignore: nodeAffinity/nodeSelector are ignored. All nodes are included in the calculations. + + If this value is nil, the behavior is equivalent to the Honor policy. + This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag. + type: string + nodeTaintsPolicy: + description: |- + NodeTaintsPolicy indicates how we will treat node taints when calculating + pod topology spread skew. Options are: + - Honor: nodes without taints, along with tainted nodes for which the incoming pod + has a toleration, are included. + - Ignore: node taints are ignored. All nodes are included. + + If this value is nil, the behavior is equivalent to the Ignore policy. + This is a beta-level feature default enabled by the NodeInclusionPolicyInPodTopologySpread feature flag. + type: string + topologyKey: + description: |- + TopologyKey is the key of node labels. Nodes that have a label with this key + and identical values are considered to be in the same topology. + We consider each as a "bucket", and try to put balanced number + of pods into each bucket. + We define a domain as a particular instance of a topology. + Also, we define an eligible domain as a domain whose nodes meet the requirements of + nodeAffinityPolicy and nodeTaintsPolicy. + e.g. If TopologyKey is "kubernetes.io/hostname", each Node is a domain of that topology. + And, if TopologyKey is "topology.kubernetes.io/zone", each zone is a domain of that topology. + It's a required field. + type: string + whenUnsatisfiable: + description: |- + WhenUnsatisfiable indicates how to deal with a pod if it doesn't satisfy + the spread constraint. + - DoNotSchedule (default) tells the scheduler not to schedule it. + - ScheduleAnyway tells the scheduler to schedule the pod in any location, + but giving higher precedence to topologies that would help reduce the + skew. + A constraint is considered "Unsatisfiable" for an incoming pod + if and only if every possible node assignment for that pod would violate + "MaxSkew" on some topology. + For example, in a 3-zone cluster, MaxSkew is set to 1, and pods with the same + labelSelector spread as 3/1/1: + | zone1 | zone2 | zone3 | + | P P P | P | P | + If WhenUnsatisfiable is set to DoNotSchedule, incoming pod can only be scheduled + to zone2(zone3) to become 3/2/1(3/1/2) as ActualSkew(2-1) on zone2(zone3) satisfies + MaxSkew(1). In other words, the cluster can still be imbalanced, but scheduler + won't make it *more* imbalanced. + It's a required field. + type: string + required: + - maxSkew + - topologyKey + - whenUnsatisfiable + type: object + type: array type: object type: object tailscale: @@ -2340,6 +2817,22 @@ spec: Defaults to false. type: boolean type: object + useLetsEncryptStagingEnvironment: + description: |- + Set UseLetsEncryptStagingEnvironment to true to issue TLS + certificates for any HTTPS endpoints exposed to the tailnet from + LetsEncrypt's staging environment. + https://letsencrypt.org/docs/staging-environment/ + This setting only affects Tailscale Ingress resources. + By default Ingress TLS certificates are issued from LetsEncrypt's + production environment. + Changing this setting true -> false, will result in any + existing certs being re-issued from the production environment. + Changing this setting false (default) -> true, when certs have already + been provisioned from production environment will NOT result in certs + being re-issued from the staging environment before they need to be + renewed. + type: boolean type: object status: description: |- @@ -2420,7 +2913,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: proxygroups.tailscale.com spec: group: tailscale.com @@ -2438,9 +2931,27 @@ spec: jsonPath: .status.conditions[?(@.type == "ProxyGroupReady")].reason name: Status type: string + - description: ProxyGroup type. + jsonPath: .spec.type + name: Type + type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: + description: |- + ProxyGroup defines a set of Tailscale devices that will act as proxies. + Currently only egress ProxyGroups are supported. + + Use the tailscale.com/proxy-group annotation on a Service to specify that + the egress proxy should be implemented by a ProxyGroup instead of a single + dedicated proxy. In addition to running a highly available set of proxies, + ProxyGroup also allows for serving many annotated Services from a single + set of proxies to minimise resource consumption. + + More info: https://tailscale.com/kb/1438/kubernetes-operator-cluster-egress properties: apiVersion: description: |- @@ -2484,6 +2995,7 @@ spec: Replicas specifies how many replicas to create the StatefulSet with. Defaults to 2. format: int32 + minimum: 0 type: integer tags: description: |- @@ -2498,10 +3010,16 @@ spec: type: string type: array type: - description: Type of the ProxyGroup proxies. Currently the only supported type is egress. + description: |- + Type of the ProxyGroup proxies. Supported types are egress and ingress. + Type is immutable once a ProxyGroup is created. enum: - egress + - ingress type: string + x-kubernetes-validations: + - message: ProxyGroup type is immutable + rule: self == oldSelf required: - type type: object @@ -2608,7 +3126,7 @@ apiVersion: apiextensions.k8s.io/v1 kind: CustomResourceDefinition metadata: annotations: - controller-gen.kubebuilder.io/version: v0.15.1-0.20240618033008-7824932b0cab + controller-gen.kubebuilder.io/version: v0.17.0 name: recorders.tailscale.com spec: group: tailscale.com @@ -2630,9 +3148,18 @@ spec: jsonPath: .status.devices[?(@.url != "")].url name: URL type: string + - jsonPath: .metadata.creationTimestamp + name: Age + type: date name: v1alpha1 schema: openAPIV3Schema: + description: |- + Recorder defines a tsrecorder device for recording SSH sessions. By default, + it will store recordings in a local ephemeral volume. If you want to persist + recordings, you can configure an S3-compatible API for storage. + + More info: https://tailscale.com/kb/1484/kubernetes-operator-deploying-tsrecorder properties: apiVersion: description: |- @@ -2956,7 +3483,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -2971,7 +3498,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -3132,7 +3659,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -3147,7 +3674,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -3301,7 +3828,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -3316,7 +3843,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -3477,7 +4004,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both matchLabelKeys and labelSelector. Also, matchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -3492,7 +4019,7 @@ spec: pod labels will be ignored. The default value is empty. The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. Also, mismatchLabelKeys cannot be set when labelSelector isn't set. - This is an alpha field and requires enabling MatchLabelKeysInPodAffinity feature gate. + This is a beta field and requires enabling MatchLabelKeysInPodAffinity feature gate (enabled by default). items: type: string type: array @@ -3654,6 +4181,12 @@ spec: the Pod where this field is used. It makes that resource available inside a container. type: string + request: + description: |- + Request is the name chosen for a request in the referenced claim. + If empty, everything from the claim is made available, otherwise + only the result of this request. + type: string required: - name type: object @@ -3757,7 +4290,7 @@ spec: procMount: description: |- procMount denotes the type of proc mount to use for the containers. - The default is DefaultProcMount which uses the container runtime defaults for + The default value is Default which uses the container runtime defaults for readonly paths and masked paths. This requires the ProcMountType feature flag to be enabled. Note that this field cannot be set when spec.os.name is windows. @@ -3994,6 +4527,32 @@ spec: Note that this field cannot be set when spec.os.name is windows. format: int64 type: integer + seLinuxChangePolicy: + description: |- + seLinuxChangePolicy defines how the container's SELinux label is applied to all volumes used by the Pod. + It has no effect on nodes that do not support SELinux or to volumes does not support SELinux. + Valid values are "MountOption" and "Recursive". + + "Recursive" means relabeling of all files on all Pod volumes by the container runtime. + This may be slow for large volumes, but allows mixing privileged and unprivileged Pods sharing the same volume on the same node. + + "MountOption" mounts all eligible Pod volumes with `-o context` mount option. + This requires all Pods that share the same volume to use the same SELinux label. + It is not possible to share the same volume among privileged and unprivileged Pods. + Eligible volumes are in-tree FibreChannel and iSCSI volumes, and all CSI volumes + whose CSI driver announces SELinux support by setting spec.seLinuxMount: true in their + CSIDriver instance. Other volumes are always re-labelled recursively. + "MountOption" value is allowed only when SELinuxMount feature gate is enabled. + + If not specified and SELinuxMount feature gate is enabled, "MountOption" is used. + If not specified and SELinuxMount feature gate is disabled, "MountOption" is used for ReadWriteOncePod volumes + and "Recursive" for all other volumes. + + This field affects only Pods that have SELinux label set, either in PodSecurityContext or in SecurityContext of all containers. + + All Pods that use the same volume should use the same seLinuxChangePolicy, otherwise some pods can get stuck in ContainerCreating state. + Note that this field cannot be set when spec.os.name is windows. + type: string seLinuxOptions: description: |- The SELinux context to be applied to all containers. @@ -4042,18 +4601,28 @@ spec: type: object supplementalGroups: description: |- - A list of groups applied to the first process run in each container, in addition - to the container's primary GID, the fsGroup (if specified), and group memberships - defined in the container image for the uid of the container process. If unspecified, - no additional groups are added to any container. Note that group memberships - defined in the container image for the uid of the container process are still effective, - even if they are not included in this list. + A list of groups applied to the first process run in each container, in + addition to the container's primary GID and fsGroup (if specified). If + the SupplementalGroupsPolicy feature is enabled, the + supplementalGroupsPolicy field determines whether these are in addition + to or instead of any group memberships defined in the container image. + If unspecified, no additional groups are added, though group memberships + defined in the container image may still be used, depending on the + supplementalGroupsPolicy field. Note that this field cannot be set when spec.os.name is windows. items: format: int64 type: integer type: array x-kubernetes-list-type: atomic + supplementalGroupsPolicy: + description: |- + Defines how supplemental groups of the first container processes are calculated. + Valid values are "Merge" and "Strict". If not specified, "Merge" is used. + (Alpha) Using the field requires the SupplementalGroupsPolicy feature gate to be enabled + and the container runtime must implement support for this feature. + Note that this field cannot be set when spec.os.name is windows. + type: string sysctls: description: |- Sysctls hold a list of namespaced sysctls used for the pod. Pods with unsupported @@ -4106,6 +4675,36 @@ spec: type: string type: object type: object + serviceAccount: + description: |- + Config for the ServiceAccount to create for the Recorder's StatefulSet. + By default, the operator will create a ServiceAccount with the same + name as the Recorder resource. + https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#service-account + properties: + annotations: + additionalProperties: + type: string + description: |- + Annotations to add to the ServiceAccount. + https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/#syntax-and-character-set + + You can use this to add IAM roles to the ServiceAccount (IRSA) instead of + providing static S3 credentials in a Secret. + https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html + + For example: + eks.amazonaws.com/role-arn: arn:aws:iam:::role/ + type: object + name: + description: |- + Name of the ServiceAccount to create. Defaults to the name of the + Recorder resource. + https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#service-account + maxLength: 253 + pattern: ^[a-z0-9]([a-z0-9-.]{0,61}[a-z0-9])?$ + type: string + type: object tolerations: description: |- Tolerations for Recorder Pods. By default, the operator does not apply @@ -4352,6 +4951,14 @@ rules: - get - list - watch + - apiGroups: + - discovery.k8s.io + resources: + - endpointslices + verbs: + - get + - list + - watch - apiGroups: - tailscale.com resources: @@ -4386,6 +4993,16 @@ rules: - list - watch - update + - apiGroups: + - apiextensions.k8s.io + resourceNames: + - servicemonitors.monitoring.coreos.com + resources: + - customresourcedefinitions + verbs: + - get + - list + - watch --- apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding @@ -4429,6 +5046,13 @@ rules: - get - list - watch + - update + - apiGroups: + - "" + resources: + - pods/status + verbs: + - update - apiGroups: - apps resources: @@ -4466,6 +5090,17 @@ rules: - update - list - watch + - deletecollection + - apiGroups: + - monitoring.coreos.com + resources: + - servicemonitors + verbs: + - get + - list + - update + - create + - delete --- apiVersion: rbac.authorization.k8s.io/v1 kind: Role @@ -4486,6 +5121,14 @@ rules: - patch - update - watch + - apiGroups: + - "" + resources: + - events + verbs: + - create + - patch + - get --- apiVersion: rbac.authorization.k8s.io/v1 kind: RoleBinding @@ -4558,6 +5201,14 @@ spec: value: "false" - name: PROXY_FIREWALL_MODE value: auto + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid image: tailscale/k8s-operator:unstable imagePullPolicy: Always name: operator diff --git a/cmd/k8s-operator/deploy/manifests/proxy.yaml b/cmd/k8s-operator/deploy/manifests/proxy.yaml index a79d48d73ce0f..3c9a3eaa36c56 100644 --- a/cmd/k8s-operator/deploy/manifests/proxy.yaml +++ b/cmd/k8s-operator/deploy/manifests/proxy.yaml @@ -30,7 +30,13 @@ spec: valueFrom: fieldRef: fieldPath: status.podIP + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid securityContext: - capabilities: - add: - - NET_ADMIN + privileged: true diff --git a/cmd/k8s-operator/deploy/manifests/userspace-proxy.yaml b/cmd/k8s-operator/deploy/manifests/userspace-proxy.yaml index 46b49a57b1909..6617f6d4b52fe 100644 --- a/cmd/k8s-operator/deploy/manifests/userspace-proxy.yaml +++ b/cmd/k8s-operator/deploy/manifests/userspace-proxy.yaml @@ -24,3 +24,11 @@ spec: valueFrom: fieldRef: fieldPath: status.podIP + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid diff --git a/cmd/k8s-operator/dnsrecords.go b/cmd/k8s-operator/dnsrecords.go index bba87bf255910..f91dd49ec255e 100644 --- a/cmd/k8s-operator/dnsrecords.go +++ b/cmd/k8s-operator/dnsrecords.go @@ -10,6 +10,7 @@ import ( "encoding/json" "fmt" "slices" + "strings" "go.uber.org/zap" corev1 "k8s.io/api/core/v1" @@ -98,7 +99,15 @@ func (dnsRR *dnsRecordsReconciler) Reconcile(ctx context.Context, req reconcile. return reconcile.Result{}, nil } - return reconcile.Result{}, dnsRR.maybeProvision(ctx, headlessSvc, logger) + if err := dnsRR.maybeProvision(ctx, headlessSvc, logger); err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + } else { + return reconcile.Result{}, err + } + } + + return reconcile.Result{}, nil } // maybeProvision ensures that dnsrecords ConfigMap contains a record for the diff --git a/cmd/k8s-operator/dnsrecords_test.go b/cmd/k8s-operator/dnsrecords_test.go index 389461b85f340..4e73e6c9e33ba 100644 --- a/cmd/k8s-operator/dnsrecords_test.go +++ b/cmd/k8s-operator/dnsrecords_test.go @@ -22,6 +22,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client/fake" operatorutils "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" "tailscale.com/tstest" "tailscale.com/types/ptr" ) @@ -163,10 +164,10 @@ func headlessSvcForParent(o client.Object, typ string) *corev1.Service { Name: o.GetName(), Namespace: "tailscale", Labels: map[string]string{ - LabelManaged: "true", - LabelParentName: o.GetName(), - LabelParentNamespace: o.GetNamespace(), - LabelParentType: typ, + kubetypes.LabelManaged: "true", + LabelParentName: o.GetName(), + LabelParentNamespace: o.GetNamespace(), + LabelParentType: typ, }, }, Spec: corev1.ServiceSpec{ diff --git a/cmd/k8s-operator/e2e/ingress_test.go b/cmd/k8s-operator/e2e/ingress_test.go new file mode 100644 index 0000000000000..373dd2c7dc88f --- /dev/null +++ b/cmd/k8s-operator/e2e/ingress_test.go @@ -0,0 +1,108 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package e2e + +import ( + "context" + "fmt" + "net/http" + "testing" + "time" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/wait" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/config" + kube "tailscale.com/k8s-operator" + "tailscale.com/tstest" +) + +// See [TestMain] for test requirements. +func TestIngress(t *testing.T) { + if tsClient == nil { + t.Skip("TestIngress requires credentials for a tailscale client") + } + + ctx := context.Background() + cfg := config.GetConfigOrDie() + cl, err := client.New(cfg, client.Options{}) + if err != nil { + t.Fatal(err) + } + // Apply nginx + createAndCleanup(t, ctx, cl, &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "nginx", + Namespace: "default", + Labels: map[string]string{ + "app.kubernetes.io/name": "nginx", + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "nginx", + Image: "nginx", + }, + }, + }, + }) + // Apply service to expose it as ingress + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + Namespace: "default", + Annotations: map[string]string{ + "tailscale.com/expose": "true", + }, + }, + Spec: corev1.ServiceSpec{ + Selector: map[string]string{ + "app.kubernetes.io/name": "nginx", + }, + Ports: []corev1.ServicePort{ + { + Name: "http", + Protocol: "TCP", + Port: 80, + }, + }, + }, + } + createAndCleanup(t, ctx, cl, svc) + + // TODO: instead of timing out only when test times out, cancel context after 60s or so. + if err := wait.PollUntilContextCancel(ctx, time.Millisecond*100, true, func(ctx context.Context) (done bool, err error) { + maybeReadySvc := &corev1.Service{ObjectMeta: objectMeta("default", "test-ingress")} + if err := get(ctx, cl, maybeReadySvc); err != nil { + return false, err + } + isReady := kube.SvcIsReady(maybeReadySvc) + if isReady { + t.Log("Service is ready") + } + return isReady, nil + }); err != nil { + t.Fatalf("error waiting for the Service to become Ready: %v", err) + } + + var resp *http.Response + if err := tstest.WaitFor(time.Second*60, func() error { + // TODO(tomhjp): Get the tailnet DNS name from the associated secret instead. + // If we are not the first tailnet node with the requested name, we'll get + // a -N suffix. + resp, err = tsClient.HTTPClient.Get(fmt.Sprintf("http://%s-%s:80", svc.Namespace, svc.Name)) + if err != nil { + return err + } + return nil + }); err != nil { + t.Fatalf("error trying to reach service: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("unexpected status: %v; response body s", resp.StatusCode) + } +} diff --git a/cmd/k8s-operator/e2e/main_test.go b/cmd/k8s-operator/e2e/main_test.go new file mode 100644 index 0000000000000..5a1364e09d0d7 --- /dev/null +++ b/cmd/k8s-operator/e2e/main_test.go @@ -0,0 +1,193 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package e2e + +import ( + "context" + "errors" + "fmt" + "log" + "os" + "slices" + "strings" + "testing" + + "github.com/go-logr/zapr" + "github.com/tailscale/hujson" + "go.uber.org/zap/zapcore" + "golang.org/x/oauth2/clientcredentials" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + logf "sigs.k8s.io/controller-runtime/pkg/log" + kzap "sigs.k8s.io/controller-runtime/pkg/log/zap" + "tailscale.com/internal/client/tailscale" +) + +const ( + e2eManagedComment = "// This is managed by the k8s-operator e2e tests" +) + +var ( + tsClient *tailscale.Client + testGrants = map[string]string{ + "test-proxy": `{ + "src": ["tag:e2e-test-proxy"], + "dst": ["tag:k8s-operator"], + "app": { + "tailscale.com/cap/kubernetes": [{ + "impersonate": { + "groups": ["ts:e2e-test-proxy"], + }, + }], + }, + }`, + } +) + +// This test suite is currently not run in CI. +// It requires some setup not handled by this code: +// - Kubernetes cluster with tailscale operator installed +// - Current kubeconfig context set to connect to that cluster (directly, no operator proxy) +// - Operator installed with --set apiServerProxyConfig.mode="true" +// - ACLs that define tag:e2e-test-proxy tag. TODO(tomhjp): Can maybe replace this prereq onwards with an API key +// - OAuth client ID and secret in TS_API_CLIENT_ID and TS_API_CLIENT_SECRET env +// - OAuth client must have auth_keys and policy_file write for tag:e2e-test-proxy tag +func TestMain(m *testing.M) { + code, err := runTests(m) + if err != nil { + log.Fatal(err) + } + os.Exit(code) +} + +func runTests(m *testing.M) (int, error) { + zlog := kzap.NewRaw([]kzap.Opts{kzap.UseDevMode(true), kzap.Level(zapcore.DebugLevel)}...).Sugar() + logf.SetLogger(zapr.NewLogger(zlog.Desugar())) + + if clientID := os.Getenv("TS_API_CLIENT_ID"); clientID != "" { + cleanup, err := setupClientAndACLs() + if err != nil { + return 0, err + } + defer func() { + err = errors.Join(err, cleanup()) + }() + } + + return m.Run(), nil +} + +func setupClientAndACLs() (cleanup func() error, _ error) { + ctx := context.Background() + credentials := clientcredentials.Config{ + ClientID: os.Getenv("TS_API_CLIENT_ID"), + ClientSecret: os.Getenv("TS_API_CLIENT_SECRET"), + TokenURL: "https://login.tailscale.com/api/v2/oauth/token", + Scopes: []string{"auth_keys", "policy_file"}, + } + tsClient = tailscale.NewClient("-", nil) + tsClient.HTTPClient = credentials.Client(ctx) + + if err := patchACLs(ctx, tsClient, func(acls *hujson.Value) { + for test, grant := range testGrants { + deleteTestGrants(test, acls) + addTestGrant(test, grant, acls) + } + }); err != nil { + return nil, err + } + + return func() error { + return patchACLs(ctx, tsClient, func(acls *hujson.Value) { + for test := range testGrants { + deleteTestGrants(test, acls) + } + }) + }, nil +} + +func patchACLs(ctx context.Context, tsClient *tailscale.Client, patchFn func(*hujson.Value)) error { + acls, err := tsClient.ACLHuJSON(ctx) + if err != nil { + return err + } + hj, err := hujson.Parse([]byte(acls.ACL)) + if err != nil { + return err + } + + patchFn(&hj) + + hj.Format() + acls.ACL = hj.String() + if _, err := tsClient.SetACLHuJSON(ctx, *acls, true); err != nil { + return err + } + + return nil +} + +func addTestGrant(test, grant string, acls *hujson.Value) error { + v, err := hujson.Parse([]byte(grant)) + if err != nil { + return err + } + + // Add the managed comment to the first line of the grant object contents. + v.Value.(*hujson.Object).Members[0].Name.BeforeExtra = hujson.Extra(fmt.Sprintf("%s: %s\n", e2eManagedComment, test)) + + if err := acls.Patch([]byte(fmt.Sprintf(`[{"op": "add", "path": "/grants/-", "value": %s}]`, v.String()))); err != nil { + return err + } + + return nil +} + +func deleteTestGrants(test string, acls *hujson.Value) error { + grants := acls.Find("/grants") + + var patches []string + for i, g := range grants.Value.(*hujson.Array).Elements { + members := g.Value.(*hujson.Object).Members + if len(members) == 0 { + continue + } + comment := strings.TrimSpace(string(members[0].Name.BeforeExtra)) + if name, found := strings.CutPrefix(comment, e2eManagedComment+": "); found && name == test { + patches = append(patches, fmt.Sprintf(`{"op": "remove", "path": "/grants/%d"}`, i)) + } + } + + // Remove in reverse order so we don't affect the found indices as we mutate. + slices.Reverse(patches) + + if err := acls.Patch([]byte(fmt.Sprintf("[%s]", strings.Join(patches, ",")))); err != nil { + return err + } + + return nil +} + +func objectMeta(namespace, name string) metav1.ObjectMeta { + return metav1.ObjectMeta{ + Namespace: namespace, + Name: name, + } +} + +func createAndCleanup(t *testing.T, ctx context.Context, cl client.Client, obj client.Object) { + t.Helper() + if err := cl.Create(ctx, obj); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := cl.Delete(ctx, obj); err != nil { + t.Errorf("error cleaning up %s %s/%s: %s", obj.GetObjectKind().GroupVersionKind(), obj.GetNamespace(), obj.GetName(), err) + } + }) +} + +func get(ctx context.Context, cl client.Client, obj client.Object) error { + return cl.Get(ctx, client.ObjectKeyFromObject(obj), obj) +} diff --git a/cmd/k8s-operator/e2e/proxy_test.go b/cmd/k8s-operator/e2e/proxy_test.go new file mode 100644 index 0000000000000..eac983e88d613 --- /dev/null +++ b/cmd/k8s-operator/e2e/proxy_test.go @@ -0,0 +1,156 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package e2e + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "testing" + "time" + + corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/client-go/rest" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/config" + "tailscale.com/client/tailscale" + "tailscale.com/tsnet" + "tailscale.com/tstest" +) + +// See [TestMain] for test requirements. +func TestProxy(t *testing.T) { + if tsClient == nil { + t.Skip("TestProxy requires credentials for a tailscale client") + } + + ctx := context.Background() + cfg := config.GetConfigOrDie() + cl, err := client.New(cfg, client.Options{}) + if err != nil { + t.Fatal(err) + } + + // Create role and role binding to allow a group we'll impersonate to do stuff. + createAndCleanup(t, ctx, cl, &rbacv1.Role{ + ObjectMeta: objectMeta("tailscale", "read-secrets"), + Rules: []rbacv1.PolicyRule{{ + APIGroups: []string{""}, + Verbs: []string{"get"}, + Resources: []string{"secrets"}, + }}, + }) + createAndCleanup(t, ctx, cl, &rbacv1.RoleBinding{ + ObjectMeta: objectMeta("tailscale", "read-secrets"), + Subjects: []rbacv1.Subject{{ + Kind: "Group", + Name: "ts:e2e-test-proxy", + }}, + RoleRef: rbacv1.RoleRef{ + Kind: "Role", + Name: "read-secrets", + }, + }) + + // Get operator host name from kube secret. + operatorSecret := corev1.Secret{ + ObjectMeta: objectMeta("tailscale", "operator"), + } + if err := get(ctx, cl, &operatorSecret); err != nil { + t.Fatal(err) + } + + // Connect to tailnet with test-specific tag so we can use the + // [testGrants] ACLs when connecting to the API server proxy + ts := tsnetServerWithTag(t, ctx, "tag:e2e-test-proxy") + proxyCfg := &rest.Config{ + Host: fmt.Sprintf("https://%s:443", hostNameFromOperatorSecret(t, operatorSecret)), + Dial: ts.Dial, + } + proxyCl, err := client.New(proxyCfg, client.Options{}) + if err != nil { + t.Fatal(err) + } + + // Expect success. + allowedSecret := corev1.Secret{ + ObjectMeta: objectMeta("tailscale", "operator"), + } + // Wait for up to a minute the first time we use the proxy, to give it time + // to provision the TLS certs. + if err := tstest.WaitFor(time.Second*60, func() error { + return get(ctx, proxyCl, &allowedSecret) + }); err != nil { + t.Fatal(err) + } + + // Expect forbidden. + forbiddenSecret := corev1.Secret{ + ObjectMeta: objectMeta("default", "operator"), + } + if err := get(ctx, proxyCl, &forbiddenSecret); err == nil || !apierrors.IsForbidden(err) { + t.Fatalf("expected forbidden error fetching secret from default namespace: %s", err) + } +} + +func tsnetServerWithTag(t *testing.T, ctx context.Context, tag string) *tsnet.Server { + caps := tailscale.KeyCapabilities{ + Devices: tailscale.KeyDeviceCapabilities{ + Create: tailscale.KeyDeviceCreateCapabilities{ + Reusable: false, + Preauthorized: true, + Ephemeral: true, + Tags: []string{tag}, + }, + }, + } + + authKey, authKeyMeta, err := tsClient.CreateKey(ctx, caps) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := tsClient.DeleteKey(ctx, authKeyMeta.ID); err != nil { + t.Errorf("error deleting auth key: %s", err) + } + }) + + ts := &tsnet.Server{ + Hostname: "test-proxy", + Ephemeral: true, + Dir: t.TempDir(), + AuthKey: authKey, + } + _, err = ts.Up(ctx) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := ts.Close(); err != nil { + t.Errorf("error shutting down tsnet.Server: %s", err) + } + }) + + return ts +} + +func hostNameFromOperatorSecret(t *testing.T, s corev1.Secret) string { + profiles := map[string]any{} + if err := json.Unmarshal(s.Data["_profiles"], &profiles); err != nil { + t.Fatal(err) + } + key, ok := strings.CutPrefix(string(s.Data["_current-profile"]), "profile-") + if !ok { + t.Fatal(string(s.Data["_current-profile"])) + } + profile, ok := profiles[key] + if !ok { + t.Fatal(profiles) + } + + return ((profile.(map[string]any))["Name"]).(string) +} diff --git a/cmd/k8s-operator/egress-eps.go b/cmd/k8s-operator/egress-eps.go index 85992abed9e37..3441e12ba93ec 100644 --- a/cmd/k8s-operator/egress-eps.go +++ b/cmd/k8s-operator/egress-eps.go @@ -20,7 +20,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" - tsoperator "tailscale.com/k8s-operator" "tailscale.com/kube/egressservices" "tailscale.com/types/ptr" ) @@ -71,25 +70,27 @@ func (er *egressEpsReconciler) Reconcile(ctx context.Context, req reconcile.Requ if err != nil { return res, fmt.Errorf("error retrieving ExternalName Service: %w", err) } - if !tsoperator.EgressServiceIsValidAndConfigured(svc) { - l.Infof("Cluster resources for ExternalName Service %s/%s are not yet configured", svc.Namespace, svc.Name) - return res, nil - } // TODO(irbekrm): currently this reconcile loop runs all the checks every time it's triggered, which is // wasteful. Once we have a Ready condition for ExternalName Services for ProxyGroup, use the condition to // determine if a reconcile is needed. oldEps := eps.DeepCopy() - proxyGroupName := eps.Labels[labelProxyGroup] tailnetSvc := tailnetSvcName(svc) l = l.With("tailnet-service-name", tailnetSvc) // Retrieve the desired tailnet service configuration from the ConfigMap. + proxyGroupName := eps.Labels[labelProxyGroup] _, cfgs, err := egressSvcsConfigs(ctx, er.Client, proxyGroupName, er.tsNamespace) if err != nil { return res, fmt.Errorf("error retrieving tailnet services configuration: %w", err) } + if cfgs == nil { + // TODO(irbekrm): this path would be hit if egress service was once exposed on a ProxyGroup that later + // got deleted. Probably the EndpointSlices then need to be deleted too- need to rethink this flow. + l.Debugf("No egress config found, likely because ProxyGroup has not been created") + return res, nil + } cfg, ok := (*cfgs)[tailnetSvc] if !ok { l.Infof("[unexpected] configuration for tailnet service %s not found", tailnetSvc) diff --git a/cmd/k8s-operator/egress-eps_test.go b/cmd/k8s-operator/egress-eps_test.go index a64f3e4e1bb50..bd81071cb5e4f 100644 --- a/cmd/k8s-operator/egress-eps_test.go +++ b/cmd/k8s-operator/egress-eps_test.go @@ -112,7 +112,7 @@ func TestTailscaleEgressEndpointSlices(t *testing.T) { Terminating: pointer.ToBool(false), }, }) - expectEqual(t, fc, eps, nil) + expectEqual(t, fc, eps) }) t.Run("status_does_not_match_pod_ip", func(t *testing.T) { _, stateS := podAndSecretForProxyGroup("foo") // replica Pod has IP 10.0.0.1 @@ -122,7 +122,7 @@ func TestTailscaleEgressEndpointSlices(t *testing.T) { }) expectReconciled(t, er, "operator-ns", "foo") eps.Endpoints = []discoveryv1.Endpoint{} - expectEqual(t, fc, eps, nil) + expectEqual(t, fc, eps) }) } diff --git a/cmd/k8s-operator/egress-pod-readiness.go b/cmd/k8s-operator/egress-pod-readiness.go new file mode 100644 index 0000000000000..05cf1aa1abfed --- /dev/null +++ b/cmd/k8s-operator/egress-pod-readiness.go @@ -0,0 +1,274 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "errors" + "fmt" + "net/http" + "slices" + "strings" + "sync/atomic" + "time" + + "go.uber.org/zap" + xslices "golang.org/x/exp/slices" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" + "tailscale.com/logtail/backoff" + "tailscale.com/tstime" + "tailscale.com/util/httpm" +) + +const tsEgressReadinessGate = "tailscale.com/egress-services" + +// egressPodsReconciler is responsible for setting tailscale.com/egress-services condition on egress ProxyGroup Pods. +// The condition is used as a readiness gate for the Pod, meaning that kubelet will not mark the Pod as ready before the +// condition is set. The ProxyGroup StatefulSet updates are rolled out in such a way that no Pod is restarted, before +// the previous Pod is marked as ready, so ensuring that the Pod does not get marked as ready when it is not yet able to +// route traffic for egress service prevents downtime during restarts caused by no available endpoints left because +// every Pod has been recreated and is not yet added to endpoints. +// https://kubernetes.io/docs/concepts/workloads/pods/pod-lifecycle/#pod-readiness-gate +type egressPodsReconciler struct { + client.Client + logger *zap.SugaredLogger + tsNamespace string + clock tstime.Clock + httpClient doer // http client that can be set to a mock client in tests + maxBackoff time.Duration // max backoff period between health check calls +} + +// Reconcile reconciles an egress ProxyGroup Pods on changes to those Pods and ProxyGroup EndpointSlices. It ensures +// that for each Pod who is ready to route traffic to all egress services for the ProxyGroup, the Pod has a +// tailscale.com/egress-services condition to set, so that kubelet will mark the Pod as ready. +// +// For the Pod to be ready +// to route traffic to the egress service, the kube proxy needs to have set up the Pod's IP as an endpoint for the +// ClusterIP Service corresponding to the egress service. +// +// Note that the endpoints for the ClusterIP Service are configured by the operator itself using custom +// EndpointSlices(egress-eps-reconciler), so the routing is not blocked on Pod's readiness. +// +// Each egress service has a corresponding ClusterIP Service, that exposes all user configured +// tailnet ports, as well as a health check port for the proxy. +// +// The reconciler calls the health check endpoint of each Service up to N number of times, where N is the number of +// replicas for the ProxyGroup x 3, and checks if the received response is healthy response from the Pod being reconciled. +// +// The health check response contains a header with the +// Pod's IP address- this is used to determine whether the response is received from this Pod. +// +// If the Pod does not appear to be serving the health check endpoint (pre-v1.80 proxies), the reconciler just sets the +// readiness condition for backwards compatibility reasons. +func (er *egressPodsReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { + l := er.logger.With("Pod", req.NamespacedName) + l.Debugf("starting reconcile") + defer l.Debugf("reconcile finished") + + pod := new(corev1.Pod) + err = er.Get(ctx, req.NamespacedName, pod) + if apierrors.IsNotFound(err) { + return reconcile.Result{}, nil + } + if err != nil { + return reconcile.Result{}, fmt.Errorf("failed to get Pod: %w", err) + } + if !pod.DeletionTimestamp.IsZero() { + l.Debugf("Pod is being deleted, do nothing") + return res, nil + } + if pod.Labels[LabelParentType] != proxyTypeProxyGroup { + l.Infof("[unexpected] reconciler called for a Pod that is not a ProxyGroup Pod") + return res, nil + } + + // If the Pod does not have the readiness gate set, there is no need to add the readiness condition. In practice + // this will happen if the user has configured custom TS_LOCAL_ADDR_PORT, thus disabling the graceful failover. + if !slices.ContainsFunc(pod.Spec.ReadinessGates, func(r corev1.PodReadinessGate) bool { + return r.ConditionType == tsEgressReadinessGate + }) { + l.Debug("Pod does not have egress readiness gate set, skipping") + return res, nil + } + + proxyGroupName := pod.Labels[LabelParentName] + pg := new(tsapi.ProxyGroup) + if err := er.Get(ctx, types.NamespacedName{Name: proxyGroupName}, pg); err != nil { + return res, fmt.Errorf("error getting ProxyGroup %q: %w", proxyGroupName, err) + } + if pg.Spec.Type != typeEgress { + l.Infof("[unexpected] reconciler called for %q ProxyGroup Pod", pg.Spec.Type) + return res, nil + } + // Get all ClusterIP Services for all egress targets exposed to cluster via this ProxyGroup. + lbls := map[string]string{ + kubetypes.LabelManaged: "true", + labelProxyGroup: proxyGroupName, + labelSvcType: typeEgress, + } + svcs := &corev1.ServiceList{} + if err := er.List(ctx, svcs, client.InNamespace(er.tsNamespace), client.MatchingLabels(lbls)); err != nil { + return res, fmt.Errorf("error listing ClusterIP Services") + } + + idx := xslices.IndexFunc(pod.Status.Conditions, func(c corev1.PodCondition) bool { + return c.Type == tsEgressReadinessGate + }) + if idx != -1 { + l.Debugf("Pod is already ready, do nothing") + return res, nil + } + + var routesMissing atomic.Bool + errChan := make(chan error, len(svcs.Items)) + for _, svc := range svcs.Items { + s := svc + go func() { + ll := l.With("service_name", s.Name) + d := retrieveClusterDomain(er.tsNamespace, ll) + healthCheckAddr := healthCheckForSvc(&s, d) + if healthCheckAddr == "" { + ll.Debugf("ClusterIP Service does not expose a health check endpoint, unable to verify if routing is set up") + errChan <- nil + return + } + + var routesSetup bool + bo := backoff.NewBackoff(s.Name, ll.Infof, er.maxBackoff) + for range numCalls(pgReplicas(pg)) { + if ctx.Err() != nil { + errChan <- nil + return + } + state, err := er.lookupPodRouteViaSvc(ctx, pod, healthCheckAddr, ll) + if err != nil { + errChan <- fmt.Errorf("error validating if routing has been set up for Pod: %w", err) + return + } + if state == healthy || state == cannotVerify { + routesSetup = true + break + } + if state == unreachable || state == unhealthy || state == podNotReady { + bo.BackOff(ctx, errors.New("backoff")) + } + } + if !routesSetup { + ll.Debugf("Pod is not yet configured as Service endpoint") + routesMissing.Store(true) + } + errChan <- nil + }() + } + for range len(svcs.Items) { + e := <-errChan + err = errors.Join(err, e) + } + if err != nil { + return res, fmt.Errorf("error verifying conectivity: %w", err) + } + if rm := routesMissing.Load(); rm { + l.Info("Pod is not yet added as an endpoint for all egress targets, waiting...") + return reconcile.Result{RequeueAfter: shortRequeue}, nil + } + if err := er.setPodReady(ctx, pod, l); err != nil { + return res, fmt.Errorf("error setting Pod as ready: %w", err) + } + return res, nil +} + +func (er *egressPodsReconciler) setPodReady(ctx context.Context, pod *corev1.Pod, l *zap.SugaredLogger) error { + if slices.ContainsFunc(pod.Status.Conditions, func(c corev1.PodCondition) bool { + return c.Type == tsEgressReadinessGate + }) { + return nil + } + l.Infof("Pod is ready to route traffic to all egress targets") + pod.Status.Conditions = append(pod.Status.Conditions, corev1.PodCondition{ + Type: tsEgressReadinessGate, + Status: corev1.ConditionTrue, + LastTransitionTime: metav1.Time{Time: er.clock.Now()}, + }) + return er.Status().Update(ctx, pod) +} + +// healthCheckState is the result of a single request to an egress Service health check endpoint with a goal to hit a +// specific backend Pod. +type healthCheckState int8 + +const ( + cannotVerify healthCheckState = iota // not verifiable for this setup (i.e earlier proxy version) + unreachable // no backends or another network error + notFound // hit another backend + unhealthy // not 200 + podNotReady // Pod is not ready, i.e does not have an IP address yet + healthy // 200 +) + +// lookupPodRouteViaSvc attempts to reach a Pod using a health check endpoint served by a Service and returns the state of the health check. +func (er *egressPodsReconciler) lookupPodRouteViaSvc(ctx context.Context, pod *corev1.Pod, healthCheckAddr string, l *zap.SugaredLogger) (healthCheckState, error) { + if !slices.ContainsFunc(pod.Spec.Containers[0].Env, func(e corev1.EnvVar) bool { + return e.Name == "TS_ENABLE_HEALTH_CHECK" && e.Value == "true" + }) { + l.Debugf("Pod does not have health check enabled, unable to verify if it is currently routable via Service") + return cannotVerify, nil + } + wantsIP, err := podIPv4(pod) + if err != nil { + return -1, fmt.Errorf("error determining Pod's IP address: %w", err) + } + if wantsIP == "" { + return podNotReady, nil + } + + ctx, cancel := context.WithTimeout(ctx, time.Second*3) + defer cancel() + req, err := http.NewRequestWithContext(ctx, httpm.GET, healthCheckAddr, nil) + if err != nil { + return -1, fmt.Errorf("error creating new HTTP request: %w", err) + } + // Do not re-use the same connection for the next request so to maximize the chance of hitting all backends equally. + req.Close = true + resp, err := er.httpClient.Do(req) + if err != nil { + // This is most likely because this is the first Pod and is not yet added to Service endoints. Other + // error types are possible, but checking for those would likely make the system too fragile. + return unreachable, nil + } + defer resp.Body.Close() + gotIP := resp.Header.Get(kubetypes.PodIPv4Header) + if gotIP == "" { + l.Debugf("Health check does not return Pod's IP header, unable to verify if Pod is currently routable via Service") + return cannotVerify, nil + } + if !strings.EqualFold(wantsIP, gotIP) { + return notFound, nil + } + if resp.StatusCode != http.StatusOK { + return unhealthy, nil + } + return healthy, nil +} + +// numCalls return the number of times an endpoint on a ProxyGroup Service should be called till it can be safely +// assumed that, if none of the responses came back from a specific Pod then traffic for the Service is currently not +// being routed to that Pod. This assumes that traffic for the Service is routed via round robin, so +// InternalTrafficPolicy must be 'Cluster' and session affinity must be None. +func numCalls(replicas int32) int32 { + return replicas * 3 +} + +// doer is an interface for HTTP client that can be set to a mock client in tests. +type doer interface { + Do(*http.Request) (*http.Response, error) +} diff --git a/cmd/k8s-operator/egress-pod-readiness_test.go b/cmd/k8s-operator/egress-pod-readiness_test.go new file mode 100644 index 0000000000000..3c35d9043ebe6 --- /dev/null +++ b/cmd/k8s-operator/egress-pod-readiness_test.go @@ -0,0 +1,525 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "bytes" + "errors" + "fmt" + "io" + "log" + "net/http" + "sync" + "testing" + "time" + + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/intstr" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" + "tailscale.com/tstest" + "tailscale.com/types/ptr" +) + +func TestEgressPodReadiness(t *testing.T) { + // We need to pass a Pod object to WithStatusSubresource because of some quirks in how the fake client + // works. Without this code we would not be able to update Pod's status further down. + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithStatusSubresource(&corev1.Pod{}). + Build() + zl, _ := zap.NewDevelopment() + cl := tstest.NewClock(tstest.ClockOpts{}) + rec := &egressPodsReconciler{ + tsNamespace: "operator-ns", + Client: fc, + logger: zl.Sugar(), + clock: cl, + } + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "dev", + }, + Spec: tsapi.ProxyGroupSpec{ + Type: "egress", + Replicas: ptr.To(int32(3)), + }, + } + mustCreate(t, fc, pg) + podIP := "10.0.0.2" + podTemplate := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "operator-ns", + Name: "pod", + Labels: map[string]string{ + LabelParentType: "proxygroup", + LabelParentName: "dev", + }, + }, + Spec: corev1.PodSpec{ + ReadinessGates: []corev1.PodReadinessGate{{ + ConditionType: tsEgressReadinessGate, + }}, + Containers: []corev1.Container{{ + Name: "tailscale", + Env: []corev1.EnvVar{{ + Name: "TS_ENABLE_HEALTH_CHECK", + Value: "true", + }}, + }}, + }, + Status: corev1.PodStatus{ + PodIPs: []corev1.PodIP{{IP: podIP}}, + }, + } + + t.Run("no_egress_services", func(t *testing.T) { + pod := podTemplate.DeepCopy() + mustCreate(t, fc, pod) + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod) + }) + t.Run("one_svc_already_routed_to", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + mustCreateAll(t, fc, svc, pod) + resp := readyResps(podIP, 1) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{hep: resp}, + } + rec.httpClient = &httpCl + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + + // A subsequent reconcile should not change the Pod. + expectReconciled(t, rec, "operator-ns", pod.Name) + expectEqual(t, fc, pod) + + mustDeleteAll(t, fc, pod, svc) + }) + t.Run("one_svc_many_backends_eventually_routed_to", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + mustCreateAll(t, fc, svc, pod) + // For a 3 replica ProxyGroup the healthcheck endpoint should be called 9 times, make the 9th time only + // return with the right Pod IP. + resps := append(readyResps("10.0.0.3", 4), append(readyResps("10.0.0.4", 4), readyResps(podIP, 1)...)...) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{hep: resps}, + } + rec.httpClient = &httpCl + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc) + }) + t.Run("one_svc_one_backend_eventually_healthy", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + mustCreateAll(t, fc, svc, pod) + // For a 3 replica ProxyGroup the healthcheck endpoint should be called 9 times, make the 9th time only + // return with 200 status code. + resps := append(unreadyResps(podIP, 8), readyResps(podIP, 1)...) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{hep: resps}, + } + rec.httpClient = &httpCl + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc) + }) + t.Run("one_svc_one_backend_never_routable", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + mustCreateAll(t, fc, svc, pod) + // For a 3 replica ProxyGroup the healthcheck endpoint should be called 9 times and Pod should be + // requeued if neither of those succeed. + resps := readyResps("10.0.0.3", 9) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{hep: resps}, + } + rec.httpClient = &httpCl + expectRequeue(t, rec, "operator-ns", pod.Name) + + // Pod should not have readiness gate condition set. + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc) + }) + t.Run("one_svc_many_backends_already_routable", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + svc2, hep2 := newSvc("svc-2", 9002) + svc3, hep3 := newSvc("svc-3", 9002) + mustCreateAll(t, fc, svc, svc2, svc3, pod) + resps := readyResps(podIP, 1) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{ + hep: resps, + hep2: resps, + hep3: resps, + }, + } + rec.httpClient = &httpCl + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should not have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc, svc2, svc3) + }) + t.Run("one_svc_many_backends_eventually_routable_and_healthy", func(t *testing.T) { + pod := podTemplate.DeepCopy() + svc, hep := newSvc("svc", 9002) + svc2, hep2 := newSvc("svc-2", 9002) + svc3, hep3 := newSvc("svc-3", 9002) + mustCreateAll(t, fc, svc, svc2, svc3, pod) + resps := append(readyResps("10.0.0.3", 7), readyResps(podIP, 1)...) + resps2 := append(readyResps("10.0.0.3", 5), readyResps(podIP, 1)...) + resps3 := append(unreadyResps(podIP, 4), readyResps(podIP, 1)...) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{ + hep: resps, + hep2: resps2, + hep3: resps3, + }, + } + rec.httpClient = &httpCl + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc, svc2, svc3) + }) + t.Run("one_svc_many_backends_never_routable_and_healthy", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + svc2, hep2 := newSvc("svc-2", 9002) + svc3, hep3 := newSvc("svc-3", 9002) + mustCreateAll(t, fc, svc, svc2, svc3, pod) + // For a ProxyGroup with 3 replicas, each Service's health endpoint will be tried 9 times and the Pod + // will be requeued if neither succeeds. + resps := readyResps("10.0.0.3", 9) + resps2 := append(readyResps("10.0.0.3", 5), readyResps("10.0.0.4", 4)...) + resps3 := unreadyResps(podIP, 9) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{ + hep: resps, + hep2: resps2, + hep3: resps3, + }, + } + rec.httpClient = &httpCl + expectRequeue(t, rec, "operator-ns", pod.Name) + + // Pod should not have readiness gate condition set. + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc, svc2, svc3) + }) + t.Run("one_svc_many_backends_one_never_routable", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + svc2, hep2 := newSvc("svc-2", 9002) + svc3, hep3 := newSvc("svc-3", 9002) + mustCreateAll(t, fc, svc, svc2, svc3, pod) + // For a ProxyGroup with 3 replicas, each Service's health endpoint will be tried 9 times and the Pod + // will be requeued if any one never succeeds. + resps := readyResps(podIP, 9) + resps2 := readyResps(podIP, 9) + resps3 := append(readyResps("10.0.0.3", 5), readyResps("10.0.0.4", 4)...) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{ + hep: resps, + hep2: resps2, + hep3: resps3, + }, + } + rec.httpClient = &httpCl + expectRequeue(t, rec, "operator-ns", pod.Name) + + // Pod should not have readiness gate condition set. + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc, svc2, svc3) + }) + t.Run("one_svc_many_backends_one_never_healthy", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + svc2, hep2 := newSvc("svc-2", 9002) + svc3, hep3 := newSvc("svc-3", 9002) + mustCreateAll(t, fc, svc, svc2, svc3, pod) + // For a ProxyGroup with 3 replicas, each Service's health endpoint will be tried 9 times and the Pod + // will be requeued if any one never succeeds. + resps := readyResps(podIP, 9) + resps2 := unreadyResps(podIP, 9) + resps3 := readyResps(podIP, 9) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{ + hep: resps, + hep2: resps2, + hep3: resps3, + }, + } + rec.httpClient = &httpCl + expectRequeue(t, rec, "operator-ns", pod.Name) + + // Pod should not have readiness gate condition set. + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc, svc2, svc3) + }) + t.Run("one_svc_many_backends_different_ports_eventually_healthy_and_routable", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9003) + svc2, hep2 := newSvc("svc-2", 9004) + svc3, hep3 := newSvc("svc-3", 9010) + mustCreateAll(t, fc, svc, svc2, svc3, pod) + // For a ProxyGroup with 3 replicas, each Service's health endpoint will be tried up to 9 times and + // marked as success as soon as one try succeeds. + resps := append(readyResps("10.0.0.3", 7), readyResps(podIP, 1)...) + resps2 := append(readyResps("10.0.0.3", 5), readyResps(podIP, 1)...) + resps3 := append(unreadyResps(podIP, 4), readyResps(podIP, 1)...) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{ + hep: resps, + hep2: resps2, + hep3: resps3, + }, + } + rec.httpClient = &httpCl + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc, svc2, svc3) + }) + // Proxies of 1.78 and earlier did not set the Pod IP header. + t.Run("pod_does_not_return_ip_header", func(t *testing.T) { + pod := podTemplate.DeepCopy() + pod.Name = "foo-bar" + + svc, hep := newSvc("foo-bar", 9002) + mustCreateAll(t, fc, svc, pod) + // If a response does not contain Pod IP header, we assume that this is an earlier proxy version, + // readiness cannot be verified so the readiness gate is just set to true. + resps := unreadyResps("", 1) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{ + hep: resps, + }, + } + rec.httpClient = &httpCl + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc) + }) + t.Run("one_svc_one_backend_eventually_healthy_and_routable", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + svc, hep := newSvc("svc", 9002) + mustCreateAll(t, fc, svc, pod) + // If a response errors, it is probably because the Pod is not yet properly running, so retry. + resps := append(erroredResps(8), readyResps(podIP, 1)...) + httpCl := fakeHTTPClient{ + t: t, + state: map[string][]fakeResponse{ + hep: resps, + }, + } + rec.httpClient = &httpCl + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc) + }) + t.Run("one_svc_one_backend_svc_does_not_have_health_port", func(t *testing.T) { + pod := podTemplate.DeepCopy() + + // If a Service does not have health port set, we assume that it is not possible to determine Pod's + // readiness and set it to ready. + svc, _ := newSvc("svc", -1) + mustCreateAll(t, fc, svc, pod) + rec.httpClient = nil + expectReconciled(t, rec, "operator-ns", pod.Name) + + // Pod should have readiness gate condition set. + podSetReady(pod, cl) + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc) + }) + t.Run("error_setting_up_healthcheck", func(t *testing.T) { + pod := podTemplate.DeepCopy() + // This is not a realistic reason for error, but we are just testing the behaviour of a healthcheck + // lookup failing. + pod.Status.PodIPs = []corev1.PodIP{{IP: "not-an-ip"}} + + svc, _ := newSvc("svc", 9002) + svc2, _ := newSvc("svc-2", 9002) + svc3, _ := newSvc("svc-3", 9002) + mustCreateAll(t, fc, svc, svc2, svc3, pod) + rec.httpClient = nil + expectError(t, rec, "operator-ns", pod.Name) + + // Pod should not have readiness gate condition set. + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc, svc2, svc3) + }) + t.Run("pod_does_not_have_an_ip_address", func(t *testing.T) { + pod := podTemplate.DeepCopy() + pod.Status.PodIPs = nil + + svc, _ := newSvc("svc", 9002) + svc2, _ := newSvc("svc-2", 9002) + svc3, _ := newSvc("svc-3", 9002) + mustCreateAll(t, fc, svc, svc2, svc3, pod) + rec.httpClient = nil + expectRequeue(t, rec, "operator-ns", pod.Name) + + // Pod should not have readiness gate condition set. + expectEqual(t, fc, pod) + mustDeleteAll(t, fc, pod, svc, svc2, svc3) + }) +} + +func readyResps(ip string, num int) (resps []fakeResponse) { + for range num { + resps = append(resps, fakeResponse{statusCode: 200, podIP: ip}) + } + return resps +} + +func unreadyResps(ip string, num int) (resps []fakeResponse) { + for range num { + resps = append(resps, fakeResponse{statusCode: 503, podIP: ip}) + } + return resps +} + +func erroredResps(num int) (resps []fakeResponse) { + for range num { + resps = append(resps, fakeResponse{err: errors.New("timeout")}) + } + return resps +} + +func newSvc(name string, port int32) (*corev1.Service, string) { + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: "operator-ns", + Name: name, + Labels: map[string]string{ + kubetypes.LabelManaged: "true", + labelProxyGroup: "dev", + labelSvcType: typeEgress, + }, + }, + Spec: corev1.ServiceSpec{}, + } + if port != -1 { + svc.Spec.Ports = []corev1.ServicePort{ + { + Name: tsHealthCheckPortName, + Port: port, + TargetPort: intstr.FromInt(9002), + Protocol: "TCP", + }, + } + } + return svc, fmt.Sprintf("http://%s.operator-ns.svc.cluster.local:%d/healthz", name, port) +} + +func podSetReady(pod *corev1.Pod, cl *tstest.Clock) { + pod.Status.Conditions = append(pod.Status.Conditions, corev1.PodCondition{ + Type: tsEgressReadinessGate, + Status: corev1.ConditionTrue, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + }) +} + +// fakeHTTPClient is a mock HTTP client with a preset map of request URLs to list of responses. When it receives a +// request for a specific URL, it returns the preset response for that URL. It errors if an unexpected request is +// received. +type fakeHTTPClient struct { + t *testing.T + mu sync.Mutex // protects following + state map[string][]fakeResponse +} + +func (f *fakeHTTPClient) Do(req *http.Request) (*http.Response, error) { + f.mu.Lock() + resps := f.state[req.URL.String()] + if len(resps) == 0 { + f.mu.Unlock() + log.Printf("\n\n\nURL %q\n\n\n", req.URL) + f.t.Fatalf("fakeHTTPClient received an unexpected request for %q", req.URL) + } + defer func() { + if len(resps) == 1 { + delete(f.state, req.URL.String()) + f.mu.Unlock() + return + } + f.state[req.URL.String()] = f.state[req.URL.String()][1:] + f.mu.Unlock() + }() + + resp := resps[0] + if resp.err != nil { + return nil, resp.err + } + r := http.Response{ + StatusCode: resp.statusCode, + Header: make(http.Header), + Body: io.NopCloser(bytes.NewReader([]byte{})), + } + r.Header.Add(kubetypes.PodIPv4Header, resp.podIP) + return &r, nil +} + +type fakeResponse struct { + err error + statusCode int + podIP string // for the Pod IP header +} diff --git a/cmd/k8s-operator/egress-services-readiness.go b/cmd/k8s-operator/egress-services-readiness.go index f6991145f88fc..5e95a52790395 100644 --- a/cmd/k8s-operator/egress-services-readiness.go +++ b/cmd/k8s-operator/egress-services-readiness.go @@ -48,11 +48,12 @@ type egressSvcsReadinessReconciler struct { // service to determine how many replicas are currently able to route traffic. func (esrr *egressSvcsReadinessReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { l := esrr.logger.With("Service", req.NamespacedName) - defer l.Info("reconcile finished") + l.Debugf("starting reconcile") + defer l.Debugf("reconcile finished") svc := new(corev1.Service) if err = esrr.Get(ctx, req.NamespacedName, svc); apierrors.IsNotFound(err) { - l.Info("Service not found") + l.Debugf("Service not found") return res, nil } else if err != nil { return res, fmt.Errorf("failed to get Service: %w", err) @@ -64,7 +65,7 @@ func (esrr *egressSvcsReadinessReconciler) Reconcile(ctx context.Context, req re oldStatus := svc.Status.DeepCopy() defer func() { tsoperator.SetServiceCondition(svc, tsapi.EgressSvcReady, st, reason, msg, esrr.clock, l) - if !apiequality.Semantic.DeepEqual(oldStatus, svc.Status) { + if !apiequality.Semantic.DeepEqual(oldStatus, &svc.Status) { err = errors.Join(err, esrr.Status().Update(ctx, svc)) } }() @@ -127,16 +128,16 @@ func (esrr *egressSvcsReadinessReconciler) Reconcile(ctx context.Context, req re return res, err } if pod == nil { - l.Infof("[unexpected] ProxyGroup is ready, but replica %d was not found", i) + l.Warnf("[unexpected] ProxyGroup is ready, but replica %d was not found", i) reason, msg = reasonClusterResourcesNotReady, reasonClusterResourcesNotReady return res, nil } - l.Infof("looking at Pod with IPs %v", pod.Status.PodIPs) + l.Debugf("looking at Pod with IPs %v", pod.Status.PodIPs) ready := false for _, ep := range eps.Endpoints { - l.Infof("looking at endpoint with addresses %v", ep.Addresses) + l.Debugf("looking at endpoint with addresses %v", ep.Addresses) if endpointReadyForPod(&ep, pod, l) { - l.Infof("endpoint is ready for Pod") + l.Debugf("endpoint is ready for Pod") ready = true break } @@ -165,7 +166,7 @@ func (esrr *egressSvcsReadinessReconciler) Reconcile(ctx context.Context, req re func endpointReadyForPod(ep *discoveryv1.Endpoint, pod *corev1.Pod, l *zap.SugaredLogger) bool { podIP, err := podIPv4(pod) if err != nil { - l.Infof("[unexpected] error retrieving Pod's IPv4 address: %v", err) + l.Warnf("[unexpected] error retrieving Pod's IPv4 address: %v", err) return false } // Currently we only ever set a single address on and Endpoint and nothing else is meant to modify this. diff --git a/cmd/k8s-operator/egress-services-readiness_test.go b/cmd/k8s-operator/egress-services-readiness_test.go index 052eb1a493801..ce947329ddfb8 100644 --- a/cmd/k8s-operator/egress-services-readiness_test.go +++ b/cmd/k8s-operator/egress-services-readiness_test.go @@ -67,24 +67,24 @@ func TestEgressServiceReadiness(t *testing.T) { setClusterNotReady(egressSvc, cl, zl.Sugar()) t.Run("endpointslice_does_not_exist", func(t *testing.T) { expectReconciled(t, rec, "dev", "my-app") - expectEqual(t, fc, egressSvc, nil) // not ready + expectEqual(t, fc, egressSvc) // not ready }) t.Run("proxy_group_does_not_exist", func(t *testing.T) { mustCreate(t, fc, eps) expectReconciled(t, rec, "dev", "my-app") - expectEqual(t, fc, egressSvc, nil) // still not ready + expectEqual(t, fc, egressSvc) // still not ready }) t.Run("proxy_group_not_ready", func(t *testing.T) { mustCreate(t, fc, pg) expectReconciled(t, rec, "dev", "my-app") - expectEqual(t, fc, egressSvc, nil) // still not ready + expectEqual(t, fc, egressSvc) // still not ready }) t.Run("no_ready_replicas", func(t *testing.T) { setPGReady(pg, cl, zl.Sugar()) mustUpdateStatus(t, fc, pg.Namespace, pg.Name, func(p *tsapi.ProxyGroup) { p.Status = pg.Status }) - expectEqual(t, fc, pg, nil) + expectEqual(t, fc, pg) for i := range pgReplicas(pg) { p := pod(pg, i) mustCreate(t, fc, p) @@ -94,7 +94,7 @@ func TestEgressServiceReadiness(t *testing.T) { } expectReconciled(t, rec, "dev", "my-app") setNotReady(egressSvc, cl, zl.Sugar(), pgReplicas(pg)) - expectEqual(t, fc, egressSvc, nil) // still not ready + expectEqual(t, fc, egressSvc) // still not ready }) t.Run("one_ready_replica", func(t *testing.T) { setEndpointForReplica(pg, 0, eps) @@ -103,7 +103,7 @@ func TestEgressServiceReadiness(t *testing.T) { }) setReady(egressSvc, cl, zl.Sugar(), pgReplicas(pg), 1) expectReconciled(t, rec, "dev", "my-app") - expectEqual(t, fc, egressSvc, nil) // partially ready + expectEqual(t, fc, egressSvc) // partially ready }) t.Run("all_replicas_ready", func(t *testing.T) { for i := range pgReplicas(pg) { @@ -114,7 +114,7 @@ func TestEgressServiceReadiness(t *testing.T) { }) setReady(egressSvc, cl, zl.Sugar(), pgReplicas(pg), pgReplicas(pg)) expectReconciled(t, rec, "dev", "my-app") - expectEqual(t, fc, egressSvc, nil) // ready + expectEqual(t, fc, egressSvc) // ready }) } diff --git a/cmd/k8s-operator/egress-services.go b/cmd/k8s-operator/egress-services.go index 98ed943669cd0..b0fb1110f9c54 100644 --- a/cmd/k8s-operator/egress-services.go +++ b/cmd/k8s-operator/egress-services.go @@ -51,14 +51,16 @@ const ( labelSvcType = "tailscale.com/svc-type" // ingress or egress typeEgress = "egress" // maxPorts is the maximum number of ports that can be exposed on a - // container. In practice this will be ports in range [3000 - 4000). The + // container. In practice this will be ports in range [10000 - 11000). The // high range should make it easier to distinguish container ports from // the tailnet target ports for debugging purposes (i.e when reading - // netfilter rules). The limit of 10000 is somewhat arbitrary, the + // netfilter rules). The limit of 1000 is somewhat arbitrary, the // assumption is that this would not be hit in practice. - maxPorts = 10000 + maxPorts = 1000 indexEgressProxyGroup = ".metadata.annotations.egress-proxy-group" + + tsHealthCheckPortName = "tailscale-health-check" ) var gaugeEgressServices = clientmetric.NewGauge(kubetypes.MetricEgressServiceCount) @@ -68,10 +70,11 @@ var gaugeEgressServices = clientmetric.NewGauge(kubetypes.MetricEgressServiceCou // on whose proxies it should be exposed. type egressSvcsReconciler struct { client.Client - logger *zap.SugaredLogger - recorder record.EventRecorder - clock tstime.Clock - tsNamespace string + logger *zap.SugaredLogger + recorder record.EventRecorder + clock tstime.Clock + tsNamespace string + validationOpts validationOpts mu sync.Mutex // protects following svcs set.Slice[types.UID] // UIDs of all currently managed egress Services for ProxyGroup @@ -123,7 +126,7 @@ func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Re oldStatus := svc.Status.DeepCopy() defer func() { - if !apiequality.Semantic.DeepEqual(oldStatus, svc.Status) { + if !apiequality.Semantic.DeepEqual(oldStatus, &svc.Status) { err = errors.Join(err, esr.Status().Update(ctx, svc)) } }() @@ -136,9 +139,8 @@ func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Re } if !slices.Contains(svc.Finalizers, FinalizerName) { - l.Infof("configuring tailnet service") // logged exactly once svc.Finalizers = append(svc.Finalizers, FinalizerName) - if err := esr.Update(ctx, svc); err != nil { + if err := esr.updateSvcSpec(ctx, svc); err != nil { err := fmt.Errorf("failed to add finalizer: %w", err) r := svcConfiguredReason(svc, false, l) tsoperator.SetServiceCondition(svc, tsapi.EgressSvcConfigured, metav1.ConditionFalse, r, err.Error(), esr.clock, l) @@ -157,7 +159,15 @@ func (esr *egressSvcsReconciler) Reconcile(ctx context.Context, req reconcile.Re return res, err } - return res, esr.maybeProvision(ctx, svc, l) + if err := esr.maybeProvision(ctx, svc, l); err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + l.Infof("optimistic lock error, retrying: %s", err) + } else { + return reconcile.Result{}, err + } + } + + return res, nil } func (esr *egressSvcsReconciler) maybeProvision(ctx context.Context, svc *corev1.Service, l *zap.SugaredLogger) (err error) { @@ -198,7 +208,7 @@ func (esr *egressSvcsReconciler) maybeProvision(ctx context.Context, svc *corev1 if svc.Spec.ExternalName != clusterIPSvcFQDN { l.Infof("Configuring ExternalName Service to point to ClusterIP Service %s", clusterIPSvcFQDN) svc.Spec.ExternalName = clusterIPSvcFQDN - if err = esr.Update(ctx, svc); err != nil { + if err = esr.updateSvcSpec(ctx, svc); err != nil { err = fmt.Errorf("error updating ExternalName Service: %w", err) return err } @@ -222,6 +232,16 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s found := false for _, wantsPM := range svc.Spec.Ports { if wantsPM.Port == pm.Port && strings.EqualFold(string(wantsPM.Protocol), string(pm.Protocol)) { + // We want to both preserve the user set port names for ease of debugging, but also + // ensure that we name all unnamed ports as the ClusterIP Service that we create will + // always have at least two ports. + // https://kubernetes.io/docs/concepts/services-networking/service/#multi-port-services + // See also https://github.com/tailscale/tailscale/issues/13406#issuecomment-2507230388 + if wantsPM.Name != "" { + clusterIPSvc.Spec.Ports[i].Name = wantsPM.Name + } else { + clusterIPSvc.Spec.Ports[i].Name = "tailscale-unnamed" + } found = true break } @@ -236,6 +256,12 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s // ClusterIP Service produce new target port and add a portmapping to // the ClusterIP Service. for _, wantsPM := range svc.Spec.Ports { + // Because we add a healthcheck port of our own, we will always have at least two ports. That + // means that we cannot have ports with name not set. + // https://kubernetes.io/docs/concepts/services-networking/service/#multi-port-services + if wantsPM.Name == "" { + wantsPM.Name = "tailscale-unnamed" + } found := false for _, gotPM := range clusterIPSvc.Spec.Ports { if wantsPM.Port == gotPM.Port && strings.EqualFold(string(wantsPM.Protocol), string(gotPM.Protocol)) { @@ -246,7 +272,7 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s if !found { // Calculate a free port to expose on container and add // a new PortMap to the ClusterIP Service. - if usedPorts.Len() == maxPorts { + if usedPorts.Len() >= maxPorts { // TODO(irbekrm): refactor to avoid extra reconciles here. Low priority as in practice, // the limit should not be hit. return nil, false, fmt.Errorf("unable to allocate additional ports on ProxyGroup %s, %d ports already used. Create another ProxyGroup or open an issue if you believe this is unexpected.", proxyGroupName, maxPorts) @@ -262,6 +288,25 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s }) } } + var healthCheckPort int32 = defaultLocalAddrPort + + for { + if !slices.ContainsFunc(svc.Spec.Ports, func(p corev1.ServicePort) bool { + return p.Port == healthCheckPort + }) { + break + } + healthCheckPort++ + if healthCheckPort > 10002 { + return nil, false, fmt.Errorf("unable to find a free port for internal health check in range [9002, 10002]") + } + } + clusterIPSvc.Spec.Ports = append(clusterIPSvc.Spec.Ports, corev1.ServicePort{ + Name: tsHealthCheckPortName, + Port: healthCheckPort, + TargetPort: intstr.FromInt(defaultLocalAddrPort), + Protocol: "TCP", + }) if !reflect.DeepEqual(clusterIPSvc, oldClusterIPSvc) { if clusterIPSvc, err = createOrUpdate(ctx, esr.Client, esr.tsNamespace, clusterIPSvc, func(svc *corev1.Service) { svc.Labels = clusterIPSvc.Labels @@ -304,7 +349,7 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s } tailnetSvc := tailnetSvcName(svc) gotCfg := (*cfgs)[tailnetSvc] - wantsCfg := egressSvcCfg(svc, clusterIPSvc) + wantsCfg := egressSvcCfg(svc, clusterIPSvc, esr.tsNamespace, l) if !reflect.DeepEqual(gotCfg, wantsCfg) { l.Debugf("updating egress services ConfigMap %s", cm.Name) mak.Set(cfgs, tailnetSvc, wantsCfg) @@ -479,14 +524,7 @@ func (esr *egressSvcsReconciler) validateClusterResources(ctx context.Context, s tsoperator.RemoveServiceCondition(svc, tsapi.EgressSvcConfigured) return false, err } - if !tsoperator.ProxyGroupIsReady(pg) { - l.Infof("ProxyGroup %s is not ready, waiting...", proxyGroupName) - tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionUnknown, reasonProxyGroupNotReady, reasonProxyGroupNotReady, esr.clock, l) - tsoperator.RemoveServiceCondition(svc, tsapi.EgressSvcConfigured) - return false, nil - } - - if violations := validateEgressService(svc, pg); len(violations) > 0 { + if violations := validateEgressService(svc, pg, esr.validationOpts); len(violations) > 0 { msg := fmt.Sprintf("invalid egress Service: %s", strings.Join(violations, ", ")) esr.recorder.Event(svc, corev1.EventTypeWarning, "INVALIDSERVICE", msg) l.Info(msg) @@ -494,13 +532,36 @@ func (esr *egressSvcsReconciler) validateClusterResources(ctx context.Context, s tsoperator.RemoveServiceCondition(svc, tsapi.EgressSvcConfigured) return false, nil } + if !tsoperator.ProxyGroupIsReady(pg) { + tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionUnknown, reasonProxyGroupNotReady, reasonProxyGroupNotReady, esr.clock, l) + tsoperator.RemoveServiceCondition(svc, tsapi.EgressSvcConfigured) + } + l.Debugf("egress service is valid") tsoperator.SetServiceCondition(svc, tsapi.EgressSvcValid, metav1.ConditionTrue, reasonEgressSvcValid, reasonEgressSvcValid, esr.clock, l) return true, nil } -func validateEgressService(svc *corev1.Service, pg *tsapi.ProxyGroup) []string { - violations := validateService(svc) +func egressSvcCfg(externalNameSvc, clusterIPSvc *corev1.Service, ns string, l *zap.SugaredLogger) egressservices.Config { + d := retrieveClusterDomain(ns, l) + tt := tailnetTargetFromSvc(externalNameSvc) + hep := healthCheckForSvc(clusterIPSvc, d) + cfg := egressservices.Config{ + TailnetTarget: tt, + HealthCheckEndpoint: hep, + } + for _, svcPort := range clusterIPSvc.Spec.Ports { + if svcPort.Name == tsHealthCheckPortName { + continue // exclude healthcheck from egress svcs configs + } + pm := portMap(svcPort) + mak.Set(&cfg.Ports, pm, struct{}{}) + } + return cfg +} + +func validateEgressService(svc *corev1.Service, pg *tsapi.ProxyGroup, opts validationOpts) []string { + violations := validateService(svc, opts) // We check that only one of these two is set in the earlier validateService function. if svc.Annotations[AnnotationTailnetTargetFQDN] == "" && svc.Annotations[AnnotationTailnetTargetIP] == "" { @@ -540,13 +601,13 @@ func svcNameBase(s string) string { } } -// unusedPort returns a port in range [3000 - 4000). The caller must ensure that -// usedPorts does not contain all ports in range [3000 - 4000). +// unusedPort returns a port in range [10000 - 11000). The caller must ensure that +// usedPorts does not contain all ports in range [10000 - 11000). func unusedPort(usedPorts sets.Set[int32]) int32 { foundFreePort := false var suggestPort int32 for !foundFreePort { - suggestPort = rand.Int32N(maxPorts) + 3000 + suggestPort = rand.Int32N(maxPorts) + 10000 if !usedPorts.Has(suggestPort) { foundFreePort = true } @@ -568,19 +629,13 @@ func tailnetTargetFromSvc(svc *corev1.Service) egressservices.TailnetTarget { } } -func egressSvcCfg(externalNameSvc, clusterIPSvc *corev1.Service) egressservices.Config { - tt := tailnetTargetFromSvc(externalNameSvc) - cfg := egressservices.Config{TailnetTarget: tt} - for _, svcPort := range clusterIPSvc.Spec.Ports { - pm := portMap(svcPort) - mak.Set(&cfg.Ports, pm, struct{}{}) - } - return cfg -} - func portMap(p corev1.ServicePort) egressservices.PortMap { // TODO (irbekrm): out of bounds check? - return egressservices.PortMap{Protocol: string(p.Protocol), MatchPort: uint16(p.TargetPort.IntVal), TargetPort: uint16(p.Port)} + return egressservices.PortMap{ + Protocol: string(p.Protocol), + MatchPort: uint16(p.TargetPort.IntVal), + TargetPort: uint16(p.Port), + } } func isEgressSvcForProxyGroup(obj client.Object) bool { @@ -602,7 +657,11 @@ func egressSvcsConfigs(ctx context.Context, cl client.Client, proxyGroupName, ts Namespace: tsNamespace, }, } - if err := cl.Get(ctx, client.ObjectKeyFromObject(cm), cm); err != nil { + err = cl.Get(ctx, client.ObjectKeyFromObject(cm), cm) + if apierrors.IsNotFound(err) { // ProxyGroup resources have not been created (yet) + return nil, nil, nil + } + if err != nil { return nil, nil, fmt.Errorf("error retrieving egress services ConfigMap %s: %v", name, err) } cfgs = &egressservices.Configs{} @@ -622,12 +681,12 @@ func egressSvcsConfigs(ctx context.Context, cl client.Client, proxyGroupName, ts // should probably validate and truncate (?) the names is they are too long. func egressSvcChildResourceLabels(svc *corev1.Service) map[string]string { return map[string]string{ - LabelManaged: "true", - LabelParentType: "svc", - LabelParentName: svc.Name, - LabelParentNamespace: svc.Namespace, - labelProxyGroup: svc.Annotations[AnnotationProxyGroup], - labelSvcType: typeEgress, + kubetypes.LabelManaged: "true", + LabelParentType: "svc", + LabelParentName: svc.Name, + LabelParentNamespace: svc.Namespace, + labelProxyGroup: svc.Annotations[AnnotationProxyGroup], + labelSvcType: typeEgress, } } @@ -714,3 +773,27 @@ func epsPortsFromSvc(svc *corev1.Service) (ep []discoveryv1.EndpointPort) { } return ep } + +// updateSvcSpec ensures that the given Service's spec is updated in cluster, but the local Service object still retains +// the not-yet-applied status. +// TODO(irbekrm): once we do SSA for these patch updates, this will no longer be needed. +func (esr *egressSvcsReconciler) updateSvcSpec(ctx context.Context, svc *corev1.Service) error { + st := svc.Status.DeepCopy() + err := esr.Update(ctx, svc) + svc.Status = *st + return err +} + +// healthCheckForSvc return the URL of the containerboot's health check endpoint served by this Service or empty string. +func healthCheckForSvc(svc *corev1.Service, clusterDomain string) string { + // This version of the operator always sets health check port on the egress Services. However, it is possible + // that this reconcile loops runs during a proxy upgrade from a version that did not set the health check port + // and parses a Service that does not have the port set yet. + i := slices.IndexFunc(svc.Spec.Ports, func(port corev1.ServicePort) bool { + return port.Name == tsHealthCheckPortName + }) + if i == -1 { + return "" + } + return fmt.Sprintf("http://%s.%s.svc.%s:%d/healthz", svc.Name, svc.Namespace, clusterDomain, svc.Spec.Ports[i].Port) +} diff --git a/cmd/k8s-operator/egress-services_test.go b/cmd/k8s-operator/egress-services_test.go index ac77339853ebe..d8a5dfd32c1c2 100644 --- a/cmd/k8s-operator/egress-services_test.go +++ b/cmd/k8s-operator/egress-services_test.go @@ -18,6 +18,7 @@ import ( discoveryv1 "k8s.io/api/discovery/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/intstr" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" @@ -78,55 +79,41 @@ func TestTailscaleEgressServices(t *testing.T) { Selector: nil, Ports: []corev1.ServicePort{ { - Name: "http", Protocol: "TCP", Port: 80, }, - { - Name: "https", - Protocol: "TCP", - Port: 443, - }, }, }, } - t.Run("proxy_group_not_ready", func(t *testing.T) { + t.Run("service_one_unnamed_port", func(t *testing.T) { mustCreate(t, fc, svc) expectReconciled(t, esr, "default", "test") - // Service should have EgressSvcValid condition set to Unknown. - svc.Status.Conditions = []metav1.Condition{condition(tsapi.EgressSvcValid, metav1.ConditionUnknown, reasonProxyGroupNotReady, reasonProxyGroupNotReady, clock)} - expectEqual(t, fc, svc, nil) + validateReadyService(t, fc, esr, svc, clock, zl, cm) }) - - t.Run("proxy_group_ready", func(t *testing.T) { - mustUpdateStatus(t, fc, "", "foo", func(pg *tsapi.ProxyGroup) { - pg.Status.Conditions = []metav1.Condition{ - condition(tsapi.ProxyGroupReady, metav1.ConditionTrue, "", "", clock), - } + t.Run("service_add_two_named_ports", func(t *testing.T) { + svc.Spec.Ports = []corev1.ServicePort{{Protocol: "TCP", Port: 80, Name: "http"}, {Protocol: "TCP", Port: 443, Name: "https"}} + mustUpdate(t, fc, "default", "test", func(s *corev1.Service) { + s.Spec.Ports = svc.Spec.Ports }) - // Quirks of the fake client. - mustUpdateStatus(t, fc, "default", "test", func(svc *corev1.Service) { - svc.Status.Conditions = []metav1.Condition{} + expectReconciled(t, esr, "default", "test") + validateReadyService(t, fc, esr, svc, clock, zl, cm) + }) + t.Run("service_add_udp_port", func(t *testing.T) { + svc.Spec.Ports = append(svc.Spec.Ports, corev1.ServicePort{Port: 53, Protocol: "UDP", Name: "dns"}) + mustUpdate(t, fc, "default", "test", func(s *corev1.Service) { + s.Spec.Ports = svc.Spec.Ports }) expectReconciled(t, esr, "default", "test") - // Verify that a ClusterIP Service has been created. - name := findGenNameForEgressSvcResources(t, fc, svc) - expectEqual(t, fc, clusterIPSvc(name, svc), removeTargetPortsFromSvc) - clusterSvc := mustGetClusterIPSvc(t, fc, name) - // Verify that an EndpointSlice has been created. - expectEqual(t, fc, endpointSlice(name, svc, clusterSvc), nil) - // Verify that ConfigMap contains configuration for the new egress service. - mustHaveConfigForSvc(t, fc, svc, clusterSvc, cm) - r := svcConfiguredReason(svc, true, zl.Sugar()) - // Verify that the user-created ExternalName Service has Configured set to true and ExternalName pointing to the - // CluterIP Service. - svc.Status.Conditions = []metav1.Condition{ - condition(tsapi.EgressSvcConfigured, metav1.ConditionTrue, r, r, clock), - } - svc.ObjectMeta.Finalizers = []string{"tailscale.com/finalizer"} - svc.Spec.ExternalName = fmt.Sprintf("%s.operator-ns.svc.cluster.local", name) - expectEqual(t, fc, svc, nil) + validateReadyService(t, fc, esr, svc, clock, zl, cm) + }) + t.Run("service_change_protocol", func(t *testing.T) { + svc.Spec.Ports = []corev1.ServicePort{{Protocol: "TCP", Port: 80, Name: "http"}, {Protocol: "TCP", Port: 443, Name: "https"}, {Port: 53, Protocol: "TCP", Name: "tcp_dns"}} + mustUpdate(t, fc, "default", "test", func(s *corev1.Service) { + s.Spec.Ports = svc.Spec.Ports + }) + expectReconciled(t, esr, "default", "test") + validateReadyService(t, fc, esr, svc, clock, zl, cm) }) t.Run("delete_external_name_service", func(t *testing.T) { @@ -143,6 +130,29 @@ func TestTailscaleEgressServices(t *testing.T) { }) } +func validateReadyService(t *testing.T, fc client.WithWatch, esr *egressSvcsReconciler, svc *corev1.Service, clock *tstest.Clock, zl *zap.Logger, cm *corev1.ConfigMap) { + expectReconciled(t, esr, "default", "test") + // Verify that a ClusterIP Service has been created. + name := findGenNameForEgressSvcResources(t, fc, svc) + expectEqual(t, fc, clusterIPSvc(name, svc), removeTargetPortsFromSvc) + clusterSvc := mustGetClusterIPSvc(t, fc, name) + // Verify that an EndpointSlice has been created. + expectEqual(t, fc, endpointSlice(name, svc, clusterSvc)) + // Verify that ConfigMap contains configuration for the new egress service. + mustHaveConfigForSvc(t, fc, svc, clusterSvc, cm, zl) + r := svcConfiguredReason(svc, true, zl.Sugar()) + // Verify that the user-created ExternalName Service has Configured set to true and ExternalName pointing to the + // CluterIP Service. + svc.Status.Conditions = []metav1.Condition{ + condition(tsapi.EgressSvcValid, metav1.ConditionTrue, "EgressSvcValid", "EgressSvcValid", clock), + condition(tsapi.EgressSvcConfigured, metav1.ConditionTrue, r, r, clock), + } + svc.ObjectMeta.Finalizers = []string{"tailscale.com/finalizer"} + svc.Spec.ExternalName = fmt.Sprintf("%s.operator-ns.svc.cluster.local", name) + expectEqual(t, fc, svc) + +} + func condition(typ tsapi.ConditionType, st metav1.ConditionStatus, r, msg string, clock tstime.Clock) metav1.Condition { return metav1.Condition{ Type: string(typ), @@ -168,6 +178,23 @@ func findGenNameForEgressSvcResources(t *testing.T, client client.Client, svc *c func clusterIPSvc(name string, extNSvc *corev1.Service) *corev1.Service { labels := egressSvcChildResourceLabels(extNSvc) + ports := make([]corev1.ServicePort, len(extNSvc.Spec.Ports)) + for i, port := range extNSvc.Spec.Ports { + ports[i] = corev1.ServicePort{ // Copy the port to avoid modifying the original. + Name: port.Name, + Port: port.Port, + Protocol: port.Protocol, + } + if port.Name == "" { + ports[i].Name = "tailscale-unnamed" + } + } + ports = append(ports, corev1.ServicePort{ + Name: "tailscale-health-check", + Port: 9002, + TargetPort: intstr.FromInt(9002), + Protocol: "TCP", + }) return &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ Name: name, @@ -177,7 +204,7 @@ func clusterIPSvc(name string, extNSvc *corev1.Service) *corev1.Service { }, Spec: corev1.ServiceSpec{ Type: corev1.ServiceTypeClusterIP, - Ports: extNSvc.Spec.Ports, + Ports: ports, }, } } @@ -222,9 +249,9 @@ func portsForEndpointSlice(svc *corev1.Service) []discoveryv1.EndpointPort { return ports } -func mustHaveConfigForSvc(t *testing.T, cl client.Client, extNSvc, clusterIPSvc *corev1.Service, cm *corev1.ConfigMap) { +func mustHaveConfigForSvc(t *testing.T, cl client.Client, extNSvc, clusterIPSvc *corev1.Service, cm *corev1.ConfigMap, l *zap.Logger) { t.Helper() - wantsCfg := egressSvcCfg(extNSvc, clusterIPSvc) + wantsCfg := egressSvcCfg(extNSvc, clusterIPSvc, clusterIPSvc.Namespace, l.Sugar()) if err := cl.Get(context.Background(), client.ObjectKeyFromObject(cm), cm); err != nil { t.Fatalf("Error retrieving ConfigMap: %v", err) } diff --git a/cmd/k8s-operator/ingress-for-pg.go b/cmd/k8s-operator/ingress-for-pg.go new file mode 100644 index 0000000000000..5f9c549407717 --- /dev/null +++ b/cmd/k8s-operator/ingress-for-pg.go @@ -0,0 +1,1115 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/rand/v2" + "net/http" + "reflect" + "slices" + "strings" + "sync" + "time" + + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" + rbacv1 "k8s.io/api/rbac/v1" + apiequality "k8s.io/apimachinery/pkg/api/equality" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + "tailscale.com/internal/client/tailscale" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + tsoperator "tailscale.com/k8s-operator" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" + "tailscale.com/util/clientmetric" + "tailscale.com/util/dnsname" + "tailscale.com/util/mak" + "tailscale.com/util/set" +) + +const ( + serveConfigKey = "serve-config.json" + TailscaleSvcOwnerRef = "tailscale.com/k8s-operator:owned-by:%s" + // FinalizerNamePG is the finalizer used by the IngressPGReconciler + FinalizerNamePG = "tailscale.com/ingress-pg-finalizer" + + indexIngressProxyGroup = ".metadata.annotations.ingress-proxy-group" + // annotationHTTPEndpoint can be used to configure the Ingress to expose an HTTP endpoint to tailnet (as + // well as the default HTTPS endpoint). + annotationHTTPEndpoint = "tailscale.com/http-endpoint" + + labelDomain = "tailscale.com/domain" + msgFeatureFlagNotEnabled = "Tailscale Service feature flag is not enabled for this tailnet, skipping provisioning. " + + "Please contact Tailscale support through https://tailscale.com/contact/support to enable the feature flag, then recreate the operator's Pod." + + warningTailscaleServiceFeatureFlagNotEnabled = "TailscaleServiceFeatureFlagNotEnabled" + managedTSServiceComment = "This Tailscale Service is managed by the Tailscale Kubernetes Operator, do not modify" +) + +var gaugePGIngressResources = clientmetric.NewGauge(kubetypes.MetricIngressPGResourceCount) + +// HAIngressReconciler is a controller that reconciles Tailscale Ingresses +// should be exposed on an ingress ProxyGroup (in HA mode). +type HAIngressReconciler struct { + client.Client + + recorder record.EventRecorder + logger *zap.SugaredLogger + tsClient tsClient + tsnetServer tsnetServer + tsNamespace string + lc localClient + defaultTags []string + operatorID string // stableID of the operator's Tailscale device + + mu sync.Mutex // protects following + // managedIngresses is a set of all ingress resources that we're currently + // managing. This is only used for metrics. + managedIngresses set.Slice[types.UID] +} + +// Reconcile reconciles Ingresses that should be exposed over Tailscale in HA +// mode (on a ProxyGroup). It looks at all Ingresses with +// tailscale.com/proxy-group annotation. For each such Ingress, it ensures that +// a TailscaleService named after the hostname of the Ingress exists and is up to +// date. It also ensures that the serve config for the ingress ProxyGroup is +// updated to route traffic for the Tailscale Service to the Ingress's backend +// Services. Ingress hostname change also results in the Tailscale Service for the +// previous hostname being cleaned up and a new Tailscale Service being created for the +// new hostname. +// HA Ingresses support multi-cluster Ingress setup. +// Each Tailscale Service contains a list of owner references that uniquely identify +// the Ingress resource and the operator. When an Ingress that acts as a +// backend is being deleted, the corresponding Tailscale Service is only deleted if the +// only owner reference that it contains is for this Ingress. If other owner +// references are found, then cleanup operation only removes this Ingress' owner +// reference. +func (r *HAIngressReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { + logger := r.logger.With("Ingress", req.NamespacedName) + logger.Debugf("starting reconcile") + defer logger.Debugf("reconcile finished") + + ing := new(networkingv1.Ingress) + err = r.Get(ctx, req.NamespacedName, ing) + if apierrors.IsNotFound(err) { + // Request object not found, could have been deleted after reconcile request. + logger.Debugf("Ingress not found, assuming it was deleted") + return res, nil + } else if err != nil { + return res, fmt.Errorf("failed to get Ingress: %w", err) + } + + // hostname is the name of the Tailscale Service that will be created + // for this Ingress as well as the first label in the MagicDNS name of + // the Ingress. + hostname := hostnameForIngress(ing) + logger = logger.With("hostname", hostname) + + // needsRequeue is set to true if the underlying Tailscale Service has + // changed as a result of this reconcile. If that is the case, we + // reconcile the Ingress one more time to ensure that concurrent updates + // to the Tailscale Service in a multi-cluster Ingress setup have not + // resulted in another actor overwriting our Tailscale Service update. + needsRequeue := false + if !ing.DeletionTimestamp.IsZero() || !r.shouldExpose(ing) { + needsRequeue, err = r.maybeCleanup(ctx, hostname, ing, logger) + } else { + needsRequeue, err = r.maybeProvision(ctx, hostname, ing, logger) + } + if err != nil { + return res, err + } + if needsRequeue { + res = reconcile.Result{RequeueAfter: requeueInterval()} + } + return res, nil +} + +// maybeProvision ensures that a Tailscale Service for this Ingress exists and is up to date and that the serve config for the +// corresponding ProxyGroup contains the Ingress backend's definition. +// If a Tailscale Service does not exist, it will be created. +// If a Tailscale Service exists, but only with owner references from other operator instances, an owner reference for this +// operator instance is added. +// If a Tailscale Service exists, but does not have an owner reference from any operator, we error +// out assuming that this is an owner reference created by an unknown actor. +// Returns true if the operation resulted in a Tailscale Service update. +func (r *HAIngressReconciler) maybeProvision(ctx context.Context, hostname string, ing *networkingv1.Ingress, logger *zap.SugaredLogger) (svcsChanged bool, err error) { + // Currently (2025-05) Tailscale Services are behind an alpha feature flag that + // needs to be explicitly enabled for a tailnet to be able to use them. + serviceName := tailcfg.ServiceName("svc:" + hostname) + existingTSSvc, err := r.tsClient.GetVIPService(ctx, serviceName) + if isErrorFeatureFlagNotEnabled(err) { + logger.Warn(msgFeatureFlagNotEnabled) + r.recorder.Event(ing, corev1.EventTypeWarning, warningTailscaleServiceFeatureFlagNotEnabled, msgFeatureFlagNotEnabled) + return false, nil + } + if err != nil && !isErrorTailscaleServiceNotFound(err) { + return false, fmt.Errorf("error getting Tailscale Service %q: %w", hostname, err) + } + + if err := validateIngressClass(ctx, r.Client); err != nil { + logger.Infof("error validating tailscale IngressClass: %v.", err) + return false, nil + } + // Get and validate ProxyGroup readiness + pgName := ing.Annotations[AnnotationProxyGroup] + if pgName == "" { + logger.Infof("[unexpected] no ProxyGroup annotation, skipping Tailscale Service provisioning") + return false, nil + } + logger = logger.With("ProxyGroup", pgName) + + pg := &tsapi.ProxyGroup{} + if err := r.Get(ctx, client.ObjectKey{Name: pgName}, pg); err != nil { + if apierrors.IsNotFound(err) { + logger.Infof("ProxyGroup does not exist") + return false, nil + } + return false, fmt.Errorf("getting ProxyGroup %q: %w", pgName, err) + } + if !tsoperator.ProxyGroupIsReady(pg) { + logger.Infof("ProxyGroup is not (yet) ready") + return false, nil + } + + // Validate Ingress configuration + if err := r.validateIngress(ctx, ing, pg); err != nil { + logger.Infof("invalid Ingress configuration: %v", err) + r.recorder.Event(ing, corev1.EventTypeWarning, "InvalidIngressConfiguration", err.Error()) + return false, nil + } + + if !IsHTTPSEnabledOnTailnet(r.tsnetServer) { + r.recorder.Event(ing, corev1.EventTypeWarning, "HTTPSNotEnabled", "HTTPS is not enabled on the tailnet; ingress may not work") + } + + if !slices.Contains(ing.Finalizers, FinalizerNamePG) { + // This log line is printed exactly once during initial provisioning, + // because once the finalizer is in place this block gets skipped. So, + // this is a nice place to tell the operator that the high level, + // multi-reconcile operation is underway. + logger.Infof("exposing Ingress over tailscale") + ing.Finalizers = append(ing.Finalizers, FinalizerNamePG) + if err := r.Update(ctx, ing); err != nil { + return false, fmt.Errorf("failed to add finalizer: %w", err) + } + r.mu.Lock() + r.managedIngresses.Add(ing.UID) + gaugePGIngressResources.Set(int64(r.managedIngresses.Len())) + r.mu.Unlock() + } + + // 1. Ensure that if Ingress' hostname has changed, any Tailscale Service + // resources corresponding to the old hostname are cleaned up. + // In practice, this function will ensure that any Tailscale Services that are + // associated with the provided ProxyGroup and no longer owned by an + // Ingress are cleaned up. This is fine- it is not expensive and ensures + // that in edge cases (a single update changed both hostname and removed + // ProxyGroup annotation) the Tailscale Service is more likely to be + // (eventually) removed. + svcsChanged, err = r.maybeCleanupProxyGroup(ctx, pgName, logger) + if err != nil { + return false, fmt.Errorf("failed to cleanup Tailscale Service resources for ProxyGroup: %w", err) + } + + // 2. Ensure that there isn't a Tailscale Service with the same hostname + // already created and not owned by this Ingress. + // TODO(irbekrm): perhaps in future we could have record names being + // stored on Tailscale Services. I am not certain if there might not be edge + // cases (custom domains, etc?) where attempting to determine the DNS + // name of the Tailscale Service in this way won't be incorrect. + + // Generate the Tailscale Service owner annotation for a new or existing Tailscale Service. + // This checks and ensures that Tailscale Service's owner references are updated + // for this Ingress and errors if that is not possible (i.e. because it + // appears that the Tailscale Service has been created by a non-operator actor). + updatedAnnotations, err := r.ownerAnnotations(existingTSSvc) + if err != nil { + const instr = "To proceed, you can either manually delete the existing Tailscale Service or choose a different MagicDNS name at `.spec.tls.hosts[0] in the Ingress definition" + msg := fmt.Sprintf("error ensuring ownership of Tailscale Service %s: %v. %s", hostname, err, instr) + logger.Warn(msg) + r.recorder.Event(ing, corev1.EventTypeWarning, "InvalidTailscaleService", msg) + return false, nil + } + // 3. Ensure that TLS Secret and RBAC exists + tcd, err := r.tailnetCertDomain(ctx) + if err != nil { + return false, fmt.Errorf("error determining DNS name base: %w", err) + } + dnsName := hostname + "." + tcd + if err := r.ensureCertResources(ctx, pgName, dnsName, ing); err != nil { + return false, fmt.Errorf("error ensuring cert resources: %w", err) + } + + // 4. Ensure that the serve config for the ProxyGroup contains the Tailscale Service. + cm, cfg, err := r.proxyGroupServeConfig(ctx, pgName) + if err != nil { + return false, fmt.Errorf("error getting Ingress serve config: %w", err) + } + if cm == nil { + logger.Infof("no Ingress serve config ConfigMap found, unable to update serve config. Ensure that ProxyGroup is healthy.") + return svcsChanged, nil + } + ep := ipn.HostPort(fmt.Sprintf("%s:443", dnsName)) + handlers, err := handlersForIngress(ctx, ing, r.Client, r.recorder, dnsName, logger) + if err != nil { + return false, fmt.Errorf("failed to get handlers for Ingress: %w", err) + } + ingCfg := &ipn.ServiceConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 443: { + HTTPS: true, + }, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + ep: { + Handlers: handlers, + }, + }, + } + + // Add HTTP endpoint if configured. + if isHTTPEndpointEnabled(ing) { + logger.Infof("exposing Ingress over HTTP") + epHTTP := ipn.HostPort(fmt.Sprintf("%s:80", dnsName)) + ingCfg.TCP[80] = &ipn.TCPPortHandler{ + HTTP: true, + } + ingCfg.Web[epHTTP] = &ipn.WebServerConfig{ + Handlers: handlers, + } + } + + var gotCfg *ipn.ServiceConfig + if cfg != nil && cfg.Services != nil { + gotCfg = cfg.Services[serviceName] + } + if !reflect.DeepEqual(gotCfg, ingCfg) { + logger.Infof("Updating serve config") + mak.Set(&cfg.Services, serviceName, ingCfg) + cfgBytes, err := json.Marshal(cfg) + if err != nil { + return false, fmt.Errorf("error marshaling serve config: %w", err) + } + mak.Set(&cm.BinaryData, serveConfigKey, cfgBytes) + if err := r.Update(ctx, cm); err != nil { + return false, fmt.Errorf("error updating serve config: %w", err) + } + } + + // 4. Ensure that the Tailscale Service exists and is up to date. + tags := r.defaultTags + if tstr, ok := ing.Annotations[AnnotationTags]; ok { + tags = strings.Split(tstr, ",") + } + + tsSvcPorts := []string{"tcp:443"} // always 443 for Ingress + if isHTTPEndpointEnabled(ing) { + tsSvcPorts = append(tsSvcPorts, "tcp:80") + } + + tsSvc := &tailscale.VIPService{ + Name: serviceName, + Tags: tags, + Ports: tsSvcPorts, + Comment: managedTSServiceComment, + Annotations: updatedAnnotations, + } + if existingTSSvc != nil { + tsSvc.Addrs = existingTSSvc.Addrs + } + // TODO(irbekrm): right now if two Ingress resources attempt to apply different Tailscale Service configs (different + // tags, or HTTP endpoint settings) we can end up reconciling those in a loop. We should detect when an Ingress + // with the same generation number has been reconciled ~more than N times and stop attempting to apply updates. + if existingTSSvc == nil || + !reflect.DeepEqual(tsSvc.Tags, existingTSSvc.Tags) || + !reflect.DeepEqual(tsSvc.Ports, existingTSSvc.Ports) || + !ownersAreSetAndEqual(tsSvc, existingTSSvc) { + logger.Infof("Ensuring Tailscale Service exists and is up to date") + if err := r.tsClient.CreateOrUpdateVIPService(ctx, tsSvc); err != nil { + return false, fmt.Errorf("error creating Tailscale Service: %w", err) + } + } + + // 5. Update tailscaled's AdvertiseServices config, which should add the Tailscale Service + // IPs to the ProxyGroup Pods' AllowedIPs in the next netmap update if approved. + mode := serviceAdvertisementHTTPS + if isHTTPEndpointEnabled(ing) { + mode = serviceAdvertisementHTTPAndHTTPS + } + if err = r.maybeUpdateAdvertiseServicesConfig(ctx, pg.Name, serviceName, mode, logger); err != nil { + return false, fmt.Errorf("failed to update tailscaled config: %w", err) + } + + // 6. Update Ingress status if ProxyGroup Pods are ready. + count, err := r.numberPodsAdvertising(ctx, pg.Name, serviceName) + if err != nil { + return false, fmt.Errorf("failed to check if any Pods are configured: %w", err) + } + + oldStatus := ing.Status.DeepCopy() + + switch count { + case 0: + ing.Status.LoadBalancer.Ingress = nil + default: + var ports []networkingv1.IngressPortStatus + hasCerts, err := r.hasCerts(ctx, serviceName) + if err != nil { + return false, fmt.Errorf("error checking TLS credentials provisioned for Ingress: %w", err) + } + // If TLS certs have not been issued (yet), do not set port 443. + if hasCerts { + ports = append(ports, networkingv1.IngressPortStatus{ + Protocol: "TCP", + Port: 443, + }) + } + if isHTTPEndpointEnabled(ing) { + ports = append(ports, networkingv1.IngressPortStatus{ + Protocol: "TCP", + Port: 80, + }) + } + // Set Ingress status hostname only if either port 443 or 80 is advertised. + var hostname string + if len(ports) != 0 { + hostname = dnsName + } + ing.Status.LoadBalancer.Ingress = []networkingv1.IngressLoadBalancerIngress{ + { + Hostname: hostname, + Ports: ports, + }, + } + } + if apiequality.Semantic.DeepEqual(oldStatus, &ing.Status) { + return svcsChanged, nil + } + + const prefix = "Updating Ingress status" + if count == 0 { + logger.Infof("%s. No Pods are advertising Tailscale Service yet", prefix) + } else { + logger.Infof("%s. %d Pod(s) advertising Tailscale Service", prefix, count) + } + + if err := r.Status().Update(ctx, ing); err != nil { + return false, fmt.Errorf("failed to update Ingress status: %w", err) + } + return svcsChanged, nil +} + +// maybeCleanupProxyGroup ensures that any Tailscale Services that are +// associated with the provided ProxyGroup and no longer needed for any +// Ingresses exposed on this ProxyGroup are deleted, if not owned by other +// operator instances, else the owner reference is cleaned up. Returns true if +// the operation resulted in an existing Tailscale Service updates (owner +// reference removal). +func (r *HAIngressReconciler) maybeCleanupProxyGroup(ctx context.Context, proxyGroupName string, logger *zap.SugaredLogger) (svcsChanged bool, err error) { + // Get serve config for the ProxyGroup + cm, cfg, err := r.proxyGroupServeConfig(ctx, proxyGroupName) + if err != nil { + return false, fmt.Errorf("getting serve config: %w", err) + } + if cfg == nil { + // ProxyGroup does not have any Tailscale Services associated with it. + return false, nil + } + + ingList := &networkingv1.IngressList{} + if err := r.List(ctx, ingList); err != nil { + return false, fmt.Errorf("listing Ingresses: %w", err) + } + serveConfigChanged := false + // For each Tailscale Service in serve config... + for tsSvcName := range cfg.Services { + // ...check if there is currently an Ingress with this hostname + found := false + for _, i := range ingList.Items { + ingressHostname := hostnameForIngress(&i) + if ingressHostname == tsSvcName.WithoutPrefix() { + found = true + break + } + } + + if !found { + logger.Infof("Tailscale Service %q is not owned by any Ingress, cleaning up", tsSvcName) + tsService, err := r.tsClient.GetVIPService(ctx, tsSvcName) + if isErrorFeatureFlagNotEnabled(err) { + msg := fmt.Sprintf("Unable to proceed with cleanup: %s.", msgFeatureFlagNotEnabled) + logger.Warn(msg) + return false, nil + } + if isErrorTailscaleServiceNotFound(err) { + return false, nil + } + if err != nil { + return false, fmt.Errorf("getting Tailscale Service %q: %w", tsSvcName, err) + } + + // Delete the Tailscale Service from control if necessary. + svcsChanged, err = r.cleanupTailscaleService(ctx, tsService, logger) + if err != nil { + return false, fmt.Errorf("deleting Tailscale Service %q: %w", tsSvcName, err) + } + + // Make sure the Tailscale Service is not advertised in tailscaled or serve config. + if err = r.maybeUpdateAdvertiseServicesConfig(ctx, proxyGroupName, tsSvcName, serviceAdvertisementOff, logger); err != nil { + return false, fmt.Errorf("failed to update tailscaled config services: %w", err) + } + _, ok := cfg.Services[tsSvcName] + if ok { + logger.Infof("Removing Tailscale Service %q from serve config", tsSvcName) + delete(cfg.Services, tsSvcName) + serveConfigChanged = true + } + if err := r.cleanupCertResources(ctx, proxyGroupName, tsSvcName); err != nil { + return false, fmt.Errorf("failed to clean up cert resources: %w", err) + } + } + } + + if serveConfigChanged { + cfgBytes, err := json.Marshal(cfg) + if err != nil { + return false, fmt.Errorf("marshaling serve config: %w", err) + } + mak.Set(&cm.BinaryData, serveConfigKey, cfgBytes) + if err := r.Update(ctx, cm); err != nil { + return false, fmt.Errorf("updating serve config: %w", err) + } + } + return svcsChanged, nil +} + +// maybeCleanup ensures that any resources, such as a Tailscale Service created for this Ingress, are cleaned up when the +// Ingress is being deleted or is unexposed. The cleanup is safe for a multi-cluster setup- the Tailscale Service is only +// deleted if it does not contain any other owner references. If it does the cleanup only removes the owner reference +// corresponding to this Ingress. +func (r *HAIngressReconciler) maybeCleanup(ctx context.Context, hostname string, ing *networkingv1.Ingress, logger *zap.SugaredLogger) (svcChanged bool, err error) { + logger.Debugf("Ensuring any resources for Ingress are cleaned up") + ix := slices.Index(ing.Finalizers, FinalizerNamePG) + if ix < 0 { + logger.Debugf("no finalizer, nothing to do") + return false, nil + } + logger.Infof("Ensuring that Tailscale Service %q configuration is cleaned up", hostname) + serviceName := tailcfg.ServiceName("svc:" + hostname) + svc, err := r.tsClient.GetVIPService(ctx, serviceName) + if err != nil { + if isErrorFeatureFlagNotEnabled(err) { + msg := fmt.Sprintf("Unable to proceed with cleanup: %s.", msgFeatureFlagNotEnabled) + logger.Warn(msg) + r.recorder.Event(ing, corev1.EventTypeWarning, warningTailscaleServiceFeatureFlagNotEnabled, msg) + return false, nil + } + if isErrorTailscaleServiceNotFound(err) { + return false, nil + } + return false, fmt.Errorf("error getting Tailscale Service: %w", err) + } + + // Ensure that if cleanup succeeded Ingress finalizers are removed. + defer func() { + if err != nil { + return + } + err = r.deleteFinalizer(ctx, ing, logger) + }() + + // 1. Check if there is a Tailscale Service associated with this Ingress. + pg := ing.Annotations[AnnotationProxyGroup] + cm, cfg, err := r.proxyGroupServeConfig(ctx, pg) + if err != nil { + return false, fmt.Errorf("error getting ProxyGroup serve config: %w", err) + } + + // Tailscale Service is always first added to serve config and only then created in the Tailscale API, so if it is not + // found in the serve config, we can assume that there is no Tailscale Service. (If the serve config does not exist at + // all, it is possible that the ProxyGroup has been deleted before cleaning up the Ingress, so carry on with + // cleanup). + if cfg != nil && cfg.Services != nil && cfg.Services[serviceName] == nil { + return false, nil + } + + // 2. Clean up the Tailscale Service resources. + svcChanged, err = r.cleanupTailscaleService(ctx, svc, logger) + if err != nil { + return false, fmt.Errorf("error deleting Tailscale Service: %w", err) + } + + // 3. Clean up any cluster resources + if err := r.cleanupCertResources(ctx, pg, serviceName); err != nil { + return false, fmt.Errorf("failed to clean up cert resources: %w", err) + } + + if cfg == nil || cfg.Services == nil { // user probably deleted the ProxyGroup + return svcChanged, nil + } + + // 4. Unadvertise the Tailscale Service in tailscaled config. + if err = r.maybeUpdateAdvertiseServicesConfig(ctx, pg, serviceName, serviceAdvertisementOff, logger); err != nil { + return false, fmt.Errorf("failed to update tailscaled config services: %w", err) + } + + // 5. Remove the Tailscale Service from the serve config for the ProxyGroup. + logger.Infof("Removing TailscaleService %q from serve config for ProxyGroup %q", hostname, pg) + delete(cfg.Services, serviceName) + cfgBytes, err := json.Marshal(cfg) + if err != nil { + return false, fmt.Errorf("error marshaling serve config: %w", err) + } + mak.Set(&cm.BinaryData, serveConfigKey, cfgBytes) + return svcChanged, r.Update(ctx, cm) +} + +func (r *HAIngressReconciler) deleteFinalizer(ctx context.Context, ing *networkingv1.Ingress, logger *zap.SugaredLogger) error { + found := false + ing.Finalizers = slices.DeleteFunc(ing.Finalizers, func(f string) bool { + found = true + return f == FinalizerNamePG + }) + if !found { + return nil + } + logger.Debug("ensure %q finalizer is removed", FinalizerNamePG) + + if err := r.Update(ctx, ing); err != nil { + return fmt.Errorf("failed to remove finalizer %q: %w", FinalizerNamePG, err) + } + r.mu.Lock() + defer r.mu.Unlock() + r.managedIngresses.Remove(ing.UID) + gaugePGIngressResources.Set(int64(r.managedIngresses.Len())) + return nil +} + +func pgIngressCMName(pg string) string { + return fmt.Sprintf("%s-ingress-config", pg) +} + +func (r *HAIngressReconciler) proxyGroupServeConfig(ctx context.Context, pg string) (cm *corev1.ConfigMap, cfg *ipn.ServeConfig, err error) { + name := pgIngressCMName(pg) + cm = &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: r.tsNamespace, + }, + } + if err := r.Get(ctx, client.ObjectKeyFromObject(cm), cm); err != nil && !apierrors.IsNotFound(err) { + return nil, nil, fmt.Errorf("error retrieving ingress serve config ConfigMap %s: %v", name, err) + } + if apierrors.IsNotFound(err) { + return nil, nil, nil + } + cfg = &ipn.ServeConfig{} + if len(cm.BinaryData[serveConfigKey]) != 0 { + if err := json.Unmarshal(cm.BinaryData[serveConfigKey], cfg); err != nil { + return nil, nil, fmt.Errorf("error unmarshaling ingress serve config %v: %w", cm.BinaryData[serveConfigKey], err) + } + } + return cm, cfg, nil +} + +type localClient interface { + StatusWithoutPeers(ctx context.Context) (*ipnstate.Status, error) +} + +// tailnetCertDomain returns the base domain (TCD) of the current tailnet. +func (r *HAIngressReconciler) tailnetCertDomain(ctx context.Context) (string, error) { + st, err := r.lc.StatusWithoutPeers(ctx) + if err != nil { + return "", fmt.Errorf("error getting tailscale status: %w", err) + } + return st.CurrentTailnet.MagicDNSSuffix, nil +} + +// shouldExpose returns true if the Ingress should be exposed over Tailscale in HA mode (on a ProxyGroup). +func (r *HAIngressReconciler) shouldExpose(ing *networkingv1.Ingress) bool { + isTSIngress := ing != nil && + ing.Spec.IngressClassName != nil && + *ing.Spec.IngressClassName == tailscaleIngressClassName + pgAnnot := ing.Annotations[AnnotationProxyGroup] + return isTSIngress && pgAnnot != "" +} + +// validateIngress validates that the Ingress is properly configured. +// Currently validates: +// - Any tags provided via tailscale.com/tags annotation are valid Tailscale ACL tags +// - The derived hostname is a valid DNS label +// - The referenced ProxyGroup exists and is of type 'ingress' +// - Ingress' TLS block is invalid +func (r *HAIngressReconciler) validateIngress(ctx context.Context, ing *networkingv1.Ingress, pg *tsapi.ProxyGroup) error { + var errs []error + + // Validate tags if present + if tstr, ok := ing.Annotations[AnnotationTags]; ok { + tags := strings.Split(tstr, ",") + for _, tag := range tags { + tag = strings.TrimSpace(tag) + if err := tailcfg.CheckTag(tag); err != nil { + errs = append(errs, fmt.Errorf("tailscale.com/tags annotation contains invalid tag %q: %w", tag, err)) + } + } + } + + // Validate TLS configuration + if ing.Spec.TLS != nil && len(ing.Spec.TLS) > 0 && (len(ing.Spec.TLS) > 1 || len(ing.Spec.TLS[0].Hosts) > 1) { + errs = append(errs, fmt.Errorf("Ingress contains invalid TLS block %v: only a single TLS entry with a single host is allowed", ing.Spec.TLS)) + } + + // Validate that the hostname will be a valid DNS label + hostname := hostnameForIngress(ing) + if err := dnsname.ValidLabel(hostname); err != nil { + errs = append(errs, fmt.Errorf("invalid hostname %q: %w. Ensure that the hostname is a valid DNS label", hostname, err)) + } + + // Validate ProxyGroup type + if pg.Spec.Type != tsapi.ProxyGroupTypeIngress { + errs = append(errs, fmt.Errorf("ProxyGroup %q is of type %q but must be of type %q", + pg.Name, pg.Spec.Type, tsapi.ProxyGroupTypeIngress)) + } + + // Validate ProxyGroup readiness + if !tsoperator.ProxyGroupIsReady(pg) { + errs = append(errs, fmt.Errorf("ProxyGroup %q is not ready", pg.Name)) + } + + // It is invalid to have multiple Ingress resources for the same Tailscale Service in one cluster. + ingList := &networkingv1.IngressList{} + if err := r.List(ctx, ingList); err != nil { + errs = append(errs, fmt.Errorf("[unexpected] error listing Ingresses: %w", err)) + return errors.Join(errs...) + } + for _, i := range ingList.Items { + if r.shouldExpose(&i) && hostnameForIngress(&i) == hostname && i.Name != ing.Name { + errs = append(errs, fmt.Errorf("found duplicate Ingress %q for hostname %q - multiple Ingresses for the same hostname in the same cluster are not allowed", i.Name, hostname)) + } + } + return errors.Join(errs...) +} + +// cleanupTailscaleService deletes any Tailscale Service by the provided name if it is not owned by operator instances other than this one. +// If a Tailscale Service is found, but contains other owner references, only removes this operator's owner reference. +// If a Tailscale Service by the given name is not found or does not contain this operator's owner reference, do nothing. +// It returns true if an existing Tailscale Service was updated to remove owner reference, as well as any error that occurred. +func (r *HAIngressReconciler) cleanupTailscaleService(ctx context.Context, svc *tailscale.VIPService, logger *zap.SugaredLogger) (updated bool, _ error) { + if svc == nil { + return false, nil + } + o, err := parseOwnerAnnotation(svc) + if err != nil { + return false, fmt.Errorf("error parsing Tailscale Service's owner annotation") + } + if o == nil || len(o.OwnerRefs) == 0 { + return false, nil + } + // Comparing with the operatorID only means that we will not be able to + // clean up Tailscale Service in cases where the operator was deleted from the + // cluster before deleting the Ingress. Perhaps the comparison could be + // 'if or.OperatorID === r.operatorID || or.ingressUID == r.ingressUID'. + ix := slices.IndexFunc(o.OwnerRefs, func(or OwnerRef) bool { + return or.OperatorID == r.operatorID + }) + if ix == -1 { + return false, nil + } + if len(o.OwnerRefs) == 1 { + logger.Infof("Deleting Tailscale Service %q", svc.Name) + return false, r.tsClient.DeleteVIPService(ctx, svc.Name) + } + o.OwnerRefs = slices.Delete(o.OwnerRefs, ix, ix+1) + logger.Infof("Deleting Tailscale Service %q", svc.Name) + json, err := json.Marshal(o) + if err != nil { + return false, fmt.Errorf("error marshalling updated Tailscale Service owner reference: %w", err) + } + svc.Annotations[ownerAnnotation] = string(json) + return true, r.tsClient.CreateOrUpdateVIPService(ctx, svc) +} + +// isHTTPEndpointEnabled returns true if the Ingress has been configured to expose an HTTP endpoint to tailnet. +func isHTTPEndpointEnabled(ing *networkingv1.Ingress) bool { + if ing == nil { + return false + } + return ing.Annotations[annotationHTTPEndpoint] == "enabled" +} + +// serviceAdvertisementMode describes the desired state of a Tailscale Service. +type serviceAdvertisementMode int + +const ( + serviceAdvertisementOff serviceAdvertisementMode = iota // Should not be advertised + serviceAdvertisementHTTPS // Port 443 should be advertised + serviceAdvertisementHTTPAndHTTPS // Both ports 80 and 443 should be advertised +) + +func (a *HAIngressReconciler) maybeUpdateAdvertiseServicesConfig(ctx context.Context, pgName string, serviceName tailcfg.ServiceName, mode serviceAdvertisementMode, logger *zap.SugaredLogger) (err error) { + // Get all config Secrets for this ProxyGroup. + secrets := &corev1.SecretList{} + if err := a.List(ctx, secrets, client.InNamespace(a.tsNamespace), client.MatchingLabels(pgSecretLabels(pgName, "config"))); err != nil { + return fmt.Errorf("failed to list config Secrets: %w", err) + } + + // Verify that TLS cert for the Tailscale Service has been successfully issued + // before attempting to advertise the service. + // This is so that in multi-cluster setups where some Ingresses succeed + // to issue certs and some do not (rate limits), clients are not pinned + // to a backend that is not able to serve HTTPS. + // The only exception is Ingresses with an HTTP endpoint enabled - if an + // Ingress has an HTTP endpoint enabled, it will be advertised even if the + // TLS cert is not yet provisioned. + hasCert, err := a.hasCerts(ctx, serviceName) + if err != nil { + return fmt.Errorf("error checking TLS credentials provisioned for service %q: %w", serviceName, err) + } + shouldBeAdvertised := (mode == serviceAdvertisementHTTPAndHTTPS) || + (mode == serviceAdvertisementHTTPS && hasCert) // if we only expose port 443 and don't have certs (yet), do not advertise + + for _, secret := range secrets.Items { + var updated bool + for fileName, confB := range secret.Data { + var conf ipn.ConfigVAlpha + if err := json.Unmarshal(confB, &conf); err != nil { + return fmt.Errorf("error unmarshalling ProxyGroup config: %w", err) + } + + // Update the services to advertise if required. + idx := slices.Index(conf.AdvertiseServices, serviceName.String()) + isAdvertised := idx >= 0 + switch { + case isAdvertised == shouldBeAdvertised: + // Already up to date. + continue + case isAdvertised: + // Needs to be removed. + conf.AdvertiseServices = slices.Delete(conf.AdvertiseServices, idx, idx+1) + case shouldBeAdvertised: + // Needs to be added. + conf.AdvertiseServices = append(conf.AdvertiseServices, serviceName.String()) + } + + // Update the Secret. + confB, err := json.Marshal(conf) + if err != nil { + return fmt.Errorf("error marshalling ProxyGroup config: %w", err) + } + mak.Set(&secret.Data, fileName, confB) + updated = true + } + + if updated { + if err := a.Update(ctx, &secret); err != nil { + return fmt.Errorf("error updating ProxyGroup config Secret: %w", err) + } + } + } + + return nil +} + +func (a *HAIngressReconciler) numberPodsAdvertising(ctx context.Context, pgName string, serviceName tailcfg.ServiceName) (int, error) { + // Get all state Secrets for this ProxyGroup. + secrets := &corev1.SecretList{} + if err := a.List(ctx, secrets, client.InNamespace(a.tsNamespace), client.MatchingLabels(pgSecretLabels(pgName, "state"))); err != nil { + return 0, fmt.Errorf("failed to list ProxyGroup %q state Secrets: %w", pgName, err) + } + + var count int + for _, secret := range secrets.Items { + prefs, ok, err := getDevicePrefs(&secret) + if err != nil { + return 0, fmt.Errorf("error getting node metadata: %w", err) + } + if !ok { + continue + } + if slices.Contains(prefs.AdvertiseServices, serviceName.String()) { + count++ + } + } + + return count, nil +} + +const ownerAnnotation = "tailscale.com/owner-references" + +// ownerAnnotationValue is the content of the TailscaleService.Annotation[ownerAnnotation] field. +type ownerAnnotationValue struct { + // OwnerRefs is a list of owner references that identify all operator + // instances that manage this Tailscale Services. + OwnerRefs []OwnerRef `json:"ownerRefs,omitempty"` +} + +// OwnerRef is an owner reference that uniquely identifies a Tailscale +// Kubernetes operator instance. +type OwnerRef struct { + // OperatorID is the stable ID of the operator's Tailscale device. + OperatorID string `json:"operatorID,omitempty"` +} + +// ownerAnnotations returns the updated annotations required to ensure this +// instance of the operator is included as an owner. If the Tailscale Service is not +// nil, but does not contain an owner reference we return an error as this likely means +// that the Service was created by somthing other than a Tailscale +// Kubernetes operator. +func (r *HAIngressReconciler) ownerAnnotations(svc *tailscale.VIPService) (map[string]string, error) { + ref := OwnerRef{ + OperatorID: r.operatorID, + } + if svc == nil { + c := ownerAnnotationValue{OwnerRefs: []OwnerRef{ref}} + json, err := json.Marshal(c) + if err != nil { + return nil, fmt.Errorf("[unexpected] unable to marshal Tailscale Service's owner annotation contents: %w, please report this", err) + } + return map[string]string{ + ownerAnnotation: string(json), + }, nil + } + o, err := parseOwnerAnnotation(svc) + if err != nil { + return nil, err + } + if o == nil || len(o.OwnerRefs) == 0 { + return nil, fmt.Errorf("Tailscale Service %s exists, but does not contain owner annotation with owner references; not proceeding as this is likely a resource created by something other than the Tailscale Kubernetes operator", svc.Name) + } + if slices.Contains(o.OwnerRefs, ref) { // up to date + return svc.Annotations, nil + } + o.OwnerRefs = append(o.OwnerRefs, ref) + json, err := json.Marshal(o) + if err != nil { + return nil, fmt.Errorf("error marshalling updated owner references: %w", err) + } + + newAnnots := make(map[string]string, len(svc.Annotations)+1) + for k, v := range svc.Annotations { + newAnnots[k] = v + } + newAnnots[ownerAnnotation] = string(json) + return newAnnots, nil +} + +// parseOwnerAnnotation returns nil if no valid owner found. +func parseOwnerAnnotation(tsSvc *tailscale.VIPService) (*ownerAnnotationValue, error) { + if tsSvc.Annotations == nil || tsSvc.Annotations[ownerAnnotation] == "" { + return nil, nil + } + o := &ownerAnnotationValue{} + if err := json.Unmarshal([]byte(tsSvc.Annotations[ownerAnnotation]), o); err != nil { + return nil, fmt.Errorf("error parsing Tailscale Service's %s annotation %q: %w", ownerAnnotation, tsSvc.Annotations[ownerAnnotation], err) + } + return o, nil +} + +func ownersAreSetAndEqual(a, b *tailscale.VIPService) bool { + return a != nil && b != nil && + a.Annotations != nil && b.Annotations != nil && + a.Annotations[ownerAnnotation] != "" && + b.Annotations[ownerAnnotation] != "" && + strings.EqualFold(a.Annotations[ownerAnnotation], b.Annotations[ownerAnnotation]) +} + +// ensureCertResources ensures that the TLS Secret for an HA Ingress and RBAC +// resources that allow proxies to manage the Secret are created. +// Note that Tailscale Service's name validation matches Kubernetes +// resource name validation, so we can be certain that the Tailscale Service name +// (domain) is a valid Kubernetes resource name. +// https://github.com/tailscale/tailscale/blob/8b1e7f646ee4730ad06c9b70c13e7861b964949b/util/dnsname/dnsname.go#L99 +// https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-subdomain-names +func (r *HAIngressReconciler) ensureCertResources(ctx context.Context, pgName, domain string, ing *networkingv1.Ingress) error { + secret := certSecret(pgName, r.tsNamespace, domain, ing) + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, secret, nil); err != nil { + return fmt.Errorf("failed to create or update Secret %s: %w", secret.Name, err) + } + role := certSecretRole(pgName, r.tsNamespace, domain) + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, role, nil); err != nil { + return fmt.Errorf("failed to create or update Role %s: %w", role.Name, err) + } + rb := certSecretRoleBinding(pgName, r.tsNamespace, domain) + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, rb, nil); err != nil { + return fmt.Errorf("failed to create or update RoleBinding %s: %w", rb.Name, err) + } + return nil +} + +// cleanupCertResources ensures that the TLS Secret and associated RBAC +// resources that allow proxies to read/write to the Secret are deleted. +func (r *HAIngressReconciler) cleanupCertResources(ctx context.Context, pgName string, name tailcfg.ServiceName) error { + domainName, err := r.dnsNameForService(ctx, tailcfg.ServiceName(name)) + if err != nil { + return fmt.Errorf("error getting DNS name for Tailscale Service %s: %w", name, err) + } + labels := certResourceLabels(pgName, domainName) + if err := r.DeleteAllOf(ctx, &rbacv1.RoleBinding{}, client.InNamespace(r.tsNamespace), client.MatchingLabels(labels)); err != nil { + return fmt.Errorf("error deleting RoleBinding for domain name %s: %w", domainName, err) + } + if err := r.DeleteAllOf(ctx, &rbacv1.Role{}, client.InNamespace(r.tsNamespace), client.MatchingLabels(labels)); err != nil { + return fmt.Errorf("error deleting Role for domain name %s: %w", domainName, err) + } + if err := r.DeleteAllOf(ctx, &corev1.Secret{}, client.InNamespace(r.tsNamespace), client.MatchingLabels(labels)); err != nil { + return fmt.Errorf("error deleting Secret for domain name %s: %w", domainName, err) + } + return nil +} + +// requeueInterval returns a time duration between 5 and 10 minutes, which is +// the period of time after which an HA Ingress, whose Tailscale Service has been newly +// created or changed, needs to be requeued. This is to protect against +// Tailscale Service's owner references being overwritten as a result of concurrent +// updates during multi-clutster Ingress create/update operations. +func requeueInterval() time.Duration { + return time.Duration(rand.N(5)+5) * time.Minute +} + +// certSecretRole creates a Role that will allow proxies to manage the TLS +// Secret for the given domain. Domain must be a valid Kubernetes resource name. +func certSecretRole(pgName, namespace, domain string) *rbacv1.Role { + return &rbacv1.Role{ + ObjectMeta: metav1.ObjectMeta{ + Name: domain, + Namespace: namespace, + Labels: certResourceLabels(pgName, domain), + }, + Rules: []rbacv1.PolicyRule{ + { + APIGroups: []string{""}, + Resources: []string{"secrets"}, + ResourceNames: []string{domain}, + Verbs: []string{ + "get", + "list", + "patch", + "update", + }, + }, + }, + } +} + +// certSecretRoleBinding creates a RoleBinding for Role that will allow proxies +// to manage the TLS Secret for the given domain. Domain must be a valid +// Kubernetes resource name. +func certSecretRoleBinding(pgName, namespace, domain string) *rbacv1.RoleBinding { + return &rbacv1.RoleBinding{ + ObjectMeta: metav1.ObjectMeta{ + Name: domain, + Namespace: namespace, + Labels: certResourceLabels(pgName, domain), + }, + Subjects: []rbacv1.Subject{ + { + Kind: "ServiceAccount", + Name: pgName, + Namespace: namespace, + }, + }, + RoleRef: rbacv1.RoleRef{ + Kind: "Role", + Name: domain, + }, + } +} + +// certSecret creates a Secret that will store the TLS certificate and private +// key for the given domain. Domain must be a valid Kubernetes resource name. +func certSecret(pgName, namespace, domain string, ing *networkingv1.Ingress) *corev1.Secret { + labels := certResourceLabels(pgName, domain) + labels[kubetypes.LabelSecretType] = "certs" + // Labels that let us identify the Ingress resource lets us reconcile + // the Ingress when the TLS Secret is updated (for example, when TLS + // certs have been provisioned). + labels[LabelParentName] = ing.Name + labels[LabelParentNamespace] = ing.Namespace + return &corev1.Secret{ + TypeMeta: metav1.TypeMeta{ + APIVersion: "v1", + Kind: "Secret", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: domain, + Namespace: namespace, + Labels: labels, + }, + Data: map[string][]byte{ + corev1.TLSCertKey: nil, + corev1.TLSPrivateKeyKey: nil, + }, + Type: corev1.SecretTypeTLS, + } +} + +func certResourceLabels(pgName, domain string) map[string]string { + return map[string]string{ + kubetypes.LabelManaged: "true", + labelProxyGroup: pgName, + labelDomain: domain, + } +} + +// dnsNameForService returns the DNS name for the given Tailscale Service's name. +func (r *HAIngressReconciler) dnsNameForService(ctx context.Context, svc tailcfg.ServiceName) (string, error) { + s := svc.WithoutPrefix() + tcd, err := r.tailnetCertDomain(ctx) + if err != nil { + return "", fmt.Errorf("error determining DNS name base: %w", err) + } + return s + "." + tcd, nil +} + +// hasCerts checks if the TLS Secret for the given service has non-zero cert and key data. +func (r *HAIngressReconciler) hasCerts(ctx context.Context, svc tailcfg.ServiceName) (bool, error) { + domain, err := r.dnsNameForService(ctx, svc) + if err != nil { + return false, fmt.Errorf("failed to get DNS name for service: %w", err) + } + secret := &corev1.Secret{} + err = r.Get(ctx, client.ObjectKey{ + Namespace: r.tsNamespace, + Name: domain, + }, secret) + if err != nil { + if apierrors.IsNotFound(err) { + return false, nil + } + return false, fmt.Errorf("failed to get TLS Secret: %w", err) + } + + cert := secret.Data[corev1.TLSCertKey] + key := secret.Data[corev1.TLSPrivateKeyKey] + + return len(cert) > 0 && len(key) > 0, nil +} + +func isErrorFeatureFlagNotEnabled(err error) bool { + // messageFFNotEnabled is the error message returned by + // Tailscale control plane when a Tailscale Service API call is made for a + // tailnet that does not have the Tailscale Services feature flag enabled. + const messageFFNotEnabled = "feature unavailable for tailnet" + return err != nil && strings.Contains(err.Error(), messageFFNotEnabled) +} + +func isErrorTailscaleServiceNotFound(err error) bool { + var errResp tailscale.ErrResponse + ok := errors.As(err, &errResp) + return ok && errResp.Status == http.StatusNotFound +} diff --git a/cmd/k8s-operator/ingress-for-pg_test.go b/cmd/k8s-operator/ingress-for-pg_test.go new file mode 100644 index 0000000000000..f155963030b4f --- /dev/null +++ b/cmd/k8s-operator/ingress-for-pg_test.go @@ -0,0 +1,838 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "encoding/json" + "fmt" + "maps" + "reflect" + "slices" + "testing" + + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + networkingv1 "k8s.io/api/networking/v1" + rbacv1 "k8s.io/api/rbac/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "tailscale.com/internal/client/tailscale" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + tsoperator "tailscale.com/k8s-operator" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" +) + +func TestIngressPGReconciler(t *testing.T) { + ingPGR, fc, ft := setupIngressTest(t) + + ing := &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + Namespace: "default", + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + "tailscale.com/proxy-group": "test-pg", + }, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + DefaultBackend: &networkingv1.IngressBackend{ + Service: &networkingv1.IngressServiceBackend{ + Name: "test", + Port: networkingv1.ServiceBackendPort{ + Number: 8080, + }, + }, + }, + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"my-svc"}}, + }, + }, + } + mustCreate(t, fc, ing) + + // Verify initial reconciliation + expectReconciled(t, ingPGR, "default", "test-ingress") + populateTLSSecret(context.Background(), fc, "test-pg", "my-svc.ts.net") + expectReconciled(t, ingPGR, "default", "test-ingress") + verifyServeConfig(t, fc, "svc:my-svc", false) + verifyTailscaleService(t, ft, "svc:my-svc", []string{"tcp:443"}) + verifyTailscaledConfig(t, fc, []string{"svc:my-svc"}) + + // Verify that Role and RoleBinding have been created for the first Ingress. + // Do not verify the cert Secret as that was already verified implicitly above. + expectEqual(t, fc, certSecretRole("test-pg", "operator-ns", "my-svc.ts.net")) + expectEqual(t, fc, certSecretRoleBinding("test-pg", "operator-ns", "my-svc.ts.net")) + + mustUpdate(t, fc, "default", "test-ingress", func(ing *networkingv1.Ingress) { + ing.Annotations["tailscale.com/tags"] = "tag:custom,tag:test" + }) + expectReconciled(t, ingPGR, "default", "test-ingress") + + // Verify Tailscale Service uses custom tags + tsSvc, err := ft.GetVIPService(context.Background(), "svc:my-svc") + if err != nil { + t.Fatalf("getting Tailscale Service: %v", err) + } + if tsSvc == nil { + t.Fatal("Tailscale Service not created") + } + wantTags := []string{"tag:custom", "tag:test"} // custom tags only + gotTags := slices.Clone(tsSvc.Tags) + slices.Sort(gotTags) + slices.Sort(wantTags) + if !slices.Equal(gotTags, wantTags) { + t.Errorf("incorrect Tailscale Service tags: got %v, want %v", gotTags, wantTags) + } + + // Create second Ingress + ing2 := &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "my-other-ingress", + Namespace: "default", + UID: types.UID("5678-UID"), + Annotations: map[string]string{ + "tailscale.com/proxy-group": "test-pg", + }, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + DefaultBackend: &networkingv1.IngressBackend{ + Service: &networkingv1.IngressServiceBackend{ + Name: "test", + Port: networkingv1.ServiceBackendPort{ + Number: 8080, + }, + }, + }, + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"my-other-svc.tailnetxyz.ts.net"}}, + }, + }, + } + mustCreate(t, fc, ing2) + + // Verify second Ingress reconciliation + expectReconciled(t, ingPGR, "default", "my-other-ingress") + populateTLSSecret(context.Background(), fc, "test-pg", "my-other-svc.ts.net") + expectReconciled(t, ingPGR, "default", "my-other-ingress") + verifyServeConfig(t, fc, "svc:my-other-svc", false) + verifyTailscaleService(t, ft, "svc:my-other-svc", []string{"tcp:443"}) + + // Verify that Role and RoleBinding have been created for the first Ingress. + // Do not verify the cert Secret as that was already verified implicitly above. + expectEqual(t, fc, certSecretRole("test-pg", "operator-ns", "my-other-svc.ts.net")) + expectEqual(t, fc, certSecretRoleBinding("test-pg", "operator-ns", "my-other-svc.ts.net")) + + // Verify first Ingress is still working + verifyServeConfig(t, fc, "svc:my-svc", false) + verifyTailscaleService(t, ft, "svc:my-svc", []string{"tcp:443"}) + + verifyTailscaledConfig(t, fc, []string{"svc:my-svc", "svc:my-other-svc"}) + + // Delete second Ingress + if err := fc.Delete(context.Background(), ing2); err != nil { + t.Fatalf("deleting second Ingress: %v", err) + } + expectReconciled(t, ingPGR, "default", "my-other-ingress") + + // Verify second Ingress cleanup + cm := &corev1.ConfigMap{} + if err := fc.Get(context.Background(), types.NamespacedName{ + Name: "test-pg-ingress-config", + Namespace: "operator-ns", + }, cm); err != nil { + t.Fatalf("getting ConfigMap: %v", err) + } + + cfg := &ipn.ServeConfig{} + if err := json.Unmarshal(cm.BinaryData[serveConfigKey], cfg); err != nil { + t.Fatalf("unmarshaling serve config: %v", err) + } + + // Verify first Ingress is still configured + if cfg.Services["svc:my-svc"] == nil { + t.Error("first Ingress service config was incorrectly removed") + } + // Verify second Ingress was cleaned up + if cfg.Services["svc:my-other-svc"] != nil { + t.Error("second Ingress service config was not cleaned up") + } + + verifyTailscaledConfig(t, fc, []string{"svc:my-svc"}) + expectMissing[corev1.Secret](t, fc, "operator-ns", "my-other-svc.ts.net") + expectMissing[rbacv1.Role](t, fc, "operator-ns", "my-other-svc.ts.net") + expectMissing[rbacv1.RoleBinding](t, fc, "operator-ns", "my-other-svc.ts.net") + + // Delete the first Ingress and verify cleanup + if err := fc.Delete(context.Background(), ing); err != nil { + t.Fatalf("deleting Ingress: %v", err) + } + + expectReconciled(t, ingPGR, "default", "test-ingress") + + // Verify the ConfigMap was cleaned up + cm = &corev1.ConfigMap{} + if err := fc.Get(context.Background(), types.NamespacedName{ + Name: "test-pg-ingress-config", + Namespace: "operator-ns", + }, cm); err != nil { + t.Fatalf("getting ConfigMap: %v", err) + } + + cfg = &ipn.ServeConfig{} + if err := json.Unmarshal(cm.BinaryData[serveConfigKey], cfg); err != nil { + t.Fatalf("unmarshaling serve config: %v", err) + } + + if len(cfg.Services) > 0 { + t.Error("serve config not cleaned up") + } + verifyTailscaledConfig(t, fc, nil) + + // Add verification that cert resources were cleaned up + expectMissing[corev1.Secret](t, fc, "operator-ns", "my-svc.ts.net") + expectMissing[rbacv1.Role](t, fc, "operator-ns", "my-svc.ts.net") + expectMissing[rbacv1.RoleBinding](t, fc, "operator-ns", "my-svc.ts.net") +} + +func TestIngressPGReconciler_UpdateIngressHostname(t *testing.T) { + ingPGR, fc, ft := setupIngressTest(t) + + ing := &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + Namespace: "default", + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + "tailscale.com/proxy-group": "test-pg", + }, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + DefaultBackend: &networkingv1.IngressBackend{ + Service: &networkingv1.IngressServiceBackend{ + Name: "test", + Port: networkingv1.ServiceBackendPort{ + Number: 8080, + }, + }, + }, + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"my-svc"}}, + }, + }, + } + mustCreate(t, fc, ing) + + // Verify initial reconciliation + expectReconciled(t, ingPGR, "default", "test-ingress") + populateTLSSecret(context.Background(), fc, "test-pg", "my-svc.ts.net") + expectReconciled(t, ingPGR, "default", "test-ingress") + verifyServeConfig(t, fc, "svc:my-svc", false) + verifyTailscaleService(t, ft, "svc:my-svc", []string{"tcp:443"}) + verifyTailscaledConfig(t, fc, []string{"svc:my-svc"}) + + // Update the Ingress hostname and make sure the original Tailscale Service is deleted. + mustUpdate(t, fc, "default", "test-ingress", func(ing *networkingv1.Ingress) { + ing.Spec.TLS[0].Hosts[0] = "updated-svc" + }) + expectReconciled(t, ingPGR, "default", "test-ingress") + populateTLSSecret(context.Background(), fc, "test-pg", "updated-svc.ts.net") + expectReconciled(t, ingPGR, "default", "test-ingress") + verifyServeConfig(t, fc, "svc:updated-svc", false) + verifyTailscaleService(t, ft, "svc:updated-svc", []string{"tcp:443"}) + verifyTailscaledConfig(t, fc, []string{"svc:updated-svc"}) + + _, err := ft.GetVIPService(context.Background(), tailcfg.ServiceName("svc:my-svc")) + if err == nil { + t.Fatalf("svc:my-svc not cleaned up") + } + if !isErrorTailscaleServiceNotFound(err) { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidateIngress(t *testing.T) { + baseIngress := &networkingv1.Ingress{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + Namespace: "default", + Annotations: map[string]string{ + AnnotationProxyGroup: "test-pg", + }, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"test"}}, + }, + }, + } + + readyProxyGroup := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg", + Generation: 1, + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeIngress, + }, + Status: tsapi.ProxyGroupStatus{ + Conditions: []metav1.Condition{ + { + Type: string(tsapi.ProxyGroupReady), + Status: metav1.ConditionTrue, + ObservedGeneration: 1, + }, + }, + }, + } + + tests := []struct { + name string + ing *networkingv1.Ingress + pg *tsapi.ProxyGroup + existingIngs []networkingv1.Ingress + wantErr string + }{ + { + name: "valid_ingress_with_hostname", + ing: &networkingv1.Ingress{ + ObjectMeta: baseIngress.ObjectMeta, + Spec: networkingv1.IngressSpec{ + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"test.example.com"}}, + }, + }, + }, + pg: readyProxyGroup, + }, + { + name: "valid_ingress_with_default_hostname", + ing: baseIngress, + pg: readyProxyGroup, + }, + { + name: "invalid_tags", + ing: &networkingv1.Ingress{ + ObjectMeta: metav1.ObjectMeta{ + Name: baseIngress.Name, + Namespace: baseIngress.Namespace, + Annotations: map[string]string{ + AnnotationTags: "tag:invalid!", + }, + }, + }, + pg: readyProxyGroup, + wantErr: "tailscale.com/tags annotation contains invalid tag \"tag:invalid!\": tag names can only contain numbers, letters, or dashes", + }, + { + name: "multiple_TLS_entries", + ing: &networkingv1.Ingress{ + ObjectMeta: baseIngress.ObjectMeta, + Spec: networkingv1.IngressSpec{ + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"test1.example.com"}}, + {Hosts: []string{"test2.example.com"}}, + }, + }, + }, + pg: readyProxyGroup, + wantErr: "Ingress contains invalid TLS block [{[test1.example.com] } {[test2.example.com] }]: only a single TLS entry with a single host is allowed", + }, + { + name: "multiple_hosts_in_TLS_entry", + ing: &networkingv1.Ingress{ + ObjectMeta: baseIngress.ObjectMeta, + Spec: networkingv1.IngressSpec{ + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"test1.example.com", "test2.example.com"}}, + }, + }, + }, + pg: readyProxyGroup, + wantErr: "Ingress contains invalid TLS block [{[test1.example.com test2.example.com] }]: only a single TLS entry with a single host is allowed", + }, + { + name: "wrong_proxy_group_type", + ing: baseIngress, + pg: &tsapi.ProxyGroup{ + ObjectMeta: readyProxyGroup.ObjectMeta, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupType("foo"), + }, + Status: readyProxyGroup.Status, + }, + wantErr: "ProxyGroup \"test-pg\" is of type \"foo\" but must be of type \"ingress\"", + }, + { + name: "proxy_group_not_ready", + ing: baseIngress, + pg: &tsapi.ProxyGroup{ + ObjectMeta: readyProxyGroup.ObjectMeta, + Spec: readyProxyGroup.Spec, + Status: tsapi.ProxyGroupStatus{ + Conditions: []metav1.Condition{ + { + Type: string(tsapi.ProxyGroupReady), + Status: metav1.ConditionFalse, + ObservedGeneration: 1, + }, + }, + }, + }, + wantErr: "ProxyGroup \"test-pg\" is not ready", + }, + { + name: "duplicate_hostname", + ing: baseIngress, + pg: readyProxyGroup, + existingIngs: []networkingv1.Ingress{{ + ObjectMeta: metav1.ObjectMeta{ + Name: "existing-ingress", + Namespace: "default", + Annotations: map[string]string{ + AnnotationProxyGroup: "test-pg", + }, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"test"}}, + }, + }, + }}, + wantErr: `found duplicate Ingress "existing-ingress" for hostname "test" - multiple Ingresses for the same hostname in the same cluster are not allowed`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(tt.ing). + WithLists(&networkingv1.IngressList{Items: tt.existingIngs}). + Build() + r := &HAIngressReconciler{Client: fc} + err := r.validateIngress(context.Background(), tt.ing, tt.pg) + if (err == nil && tt.wantErr != "") || (err != nil && err.Error() != tt.wantErr) { + t.Errorf("validateIngress() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestIngressPGReconciler_HTTPEndpoint(t *testing.T) { + ingPGR, fc, ft := setupIngressTest(t) + + // Create test Ingress with HTTP endpoint enabled + ing := &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + Namespace: "default", + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + "tailscale.com/proxy-group": "test-pg", + "tailscale.com/http-endpoint": "enabled", + }, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + DefaultBackend: &networkingv1.IngressBackend{ + Service: &networkingv1.IngressServiceBackend{ + Name: "test", + Port: networkingv1.ServiceBackendPort{ + Number: 8080, + }, + }, + }, + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"my-svc"}}, + }, + }, + } + if err := fc.Create(context.Background(), ing); err != nil { + t.Fatal(err) + } + + // Verify initial reconciliation with HTTP enabled + expectReconciled(t, ingPGR, "default", "test-ingress") + populateTLSSecret(context.Background(), fc, "test-pg", "my-svc.ts.net") + expectReconciled(t, ingPGR, "default", "test-ingress") + verifyTailscaleService(t, ft, "svc:my-svc", []string{"tcp:80", "tcp:443"}) + verifyServeConfig(t, fc, "svc:my-svc", true) + + // Verify Ingress status + ing = &networkingv1.Ingress{} + if err := fc.Get(context.Background(), types.NamespacedName{ + Name: "test-ingress", + Namespace: "default", + }, ing); err != nil { + t.Fatal(err) + } + + // Status will be empty until the Tailscale Service shows up in prefs. + if !reflect.DeepEqual(ing.Status.LoadBalancer.Ingress, []networkingv1.IngressLoadBalancerIngress(nil)) { + t.Errorf("incorrect Ingress status: got %v, want empty", + ing.Status.LoadBalancer.Ingress) + } + + // Add the Tailscale Service to prefs to have the Ingress recognised as ready. + mustCreate(t, fc, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg-0", + Namespace: "operator-ns", + Labels: pgSecretLabels("test-pg", "state"), + }, + Data: map[string][]byte{ + "_current-profile": []byte("profile-foo"), + "profile-foo": []byte(`{"AdvertiseServices":["svc:my-svc"],"Config":{"NodeID":"node-foo"}}`), + }, + }) + + // Reconcile and re-fetch Ingress. + expectReconciled(t, ingPGR, "default", "test-ingress") + if err := fc.Get(context.Background(), client.ObjectKeyFromObject(ing), ing); err != nil { + t.Fatal(err) + } + + wantStatus := []networkingv1.IngressPortStatus{ + {Port: 443, Protocol: "TCP"}, + {Port: 80, Protocol: "TCP"}, + } + if !reflect.DeepEqual(ing.Status.LoadBalancer.Ingress[0].Ports, wantStatus) { + t.Errorf("incorrect status ports: got %v, want %v", + ing.Status.LoadBalancer.Ingress[0].Ports, wantStatus) + } + + // Remove HTTP endpoint annotation + mustUpdate(t, fc, "default", "test-ingress", func(ing *networkingv1.Ingress) { + delete(ing.Annotations, "tailscale.com/http-endpoint") + }) + + // Verify reconciliation after removing HTTP + expectReconciled(t, ingPGR, "default", "test-ingress") + verifyTailscaleService(t, ft, "svc:my-svc", []string{"tcp:443"}) + verifyServeConfig(t, fc, "svc:my-svc", false) + + // Verify Ingress status + ing = &networkingv1.Ingress{} + if err := fc.Get(context.Background(), types.NamespacedName{ + Name: "test-ingress", + Namespace: "default", + }, ing); err != nil { + t.Fatal(err) + } + + wantStatus = []networkingv1.IngressPortStatus{ + {Port: 443, Protocol: "TCP"}, + } + if !reflect.DeepEqual(ing.Status.LoadBalancer.Ingress[0].Ports, wantStatus) { + t.Errorf("incorrect status ports: got %v, want %v", + ing.Status.LoadBalancer.Ingress[0].Ports, wantStatus) + } +} + +func verifyTailscaleService(t *testing.T, ft *fakeTSClient, serviceName string, wantPorts []string) { + t.Helper() + tsSvc, err := ft.GetVIPService(context.Background(), tailcfg.ServiceName(serviceName)) + if err != nil { + t.Fatalf("getting Tailscale Service %q: %v", serviceName, err) + } + if tsSvc == nil { + t.Fatalf("Tailscale Service %q not created", serviceName) + } + gotPorts := slices.Clone(tsSvc.Ports) + slices.Sort(gotPorts) + slices.Sort(wantPorts) + if !slices.Equal(gotPorts, wantPorts) { + t.Errorf("incorrect ports for Tailscale Service %q: got %v, want %v", serviceName, gotPorts, wantPorts) + } +} + +func verifyServeConfig(t *testing.T, fc client.Client, serviceName string, wantHTTP bool) { + t.Helper() + + cm := &corev1.ConfigMap{} + if err := fc.Get(context.Background(), types.NamespacedName{ + Name: "test-pg-ingress-config", + Namespace: "operator-ns", + }, cm); err != nil { + t.Fatalf("getting ConfigMap: %v", err) + } + + cfg := &ipn.ServeConfig{} + if err := json.Unmarshal(cm.BinaryData["serve-config.json"], cfg); err != nil { + t.Fatalf("unmarshaling serve config: %v", err) + } + + t.Logf("Looking for service %q in config: %+v", serviceName, cfg) + + svc := cfg.Services[tailcfg.ServiceName(serviceName)] + if svc == nil { + t.Fatalf("service %q not found in serve config, services: %+v", serviceName, maps.Keys(cfg.Services)) + } + + wantHandlers := 1 + if wantHTTP { + wantHandlers = 2 + } + + // Check TCP handlers + if len(svc.TCP) != wantHandlers { + t.Errorf("incorrect number of TCP handlers for service %q: got %d, want %d", serviceName, len(svc.TCP), wantHandlers) + } + if wantHTTP { + if h, ok := svc.TCP[uint16(80)]; !ok { + t.Errorf("HTTP (port 80) handler not found for service %q", serviceName) + } else if !h.HTTP { + t.Errorf("HTTP not enabled for port 80 handler for service %q", serviceName) + } + } + if h, ok := svc.TCP[uint16(443)]; !ok { + t.Errorf("HTTPS (port 443) handler not found for service %q", serviceName) + } else if !h.HTTPS { + t.Errorf("HTTPS not enabled for port 443 handler for service %q", serviceName) + } + + // Check Web handlers + if len(svc.Web) != wantHandlers { + t.Errorf("incorrect number of Web handlers for service %q: got %d, want %d", serviceName, len(svc.Web), wantHandlers) + } +} + +func verifyTailscaledConfig(t *testing.T, fc client.Client, expectedServices []string) { + t.Helper() + var expected string + if expectedServices != nil && len(expectedServices) > 0 { + expectedServicesJSON, err := json.Marshal(expectedServices) + if err != nil { + t.Fatalf("marshaling expected services: %v", err) + } + expected = fmt.Sprintf(`,"AdvertiseServices":%s`, expectedServicesJSON) + } + expectEqual(t, fc, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgConfigSecretName("test-pg", 0), + Namespace: "operator-ns", + Labels: pgSecretLabels("test-pg", "config"), + }, + Data: map[string][]byte{ + tsoperator.TailscaledConfigFileName(106): []byte(fmt.Sprintf(`{"Version":""%s}`, expected)), + }, + }) +} + +func setupIngressTest(t *testing.T) (*HAIngressReconciler, client.Client, *fakeTSClient) { + tsIngressClass := &networkingv1.IngressClass{ + ObjectMeta: metav1.ObjectMeta{Name: "tailscale"}, + Spec: networkingv1.IngressClassSpec{Controller: "tailscale.com/ts-ingress"}, + } + + // Pre-create the ProxyGroup + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg", + Generation: 1, + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeIngress, + }, + } + + // Pre-create the ConfigMap for the ProxyGroup + pgConfigMap := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg-ingress-config", + Namespace: "operator-ns", + }, + BinaryData: map[string][]byte{ + "serve-config.json": []byte(`{"Services":{}}`), + }, + } + + // Pre-create a config Secret for the ProxyGroup + pgCfgSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgConfigSecretName("test-pg", 0), + Namespace: "operator-ns", + Labels: pgSecretLabels("test-pg", "config"), + }, + Data: map[string][]byte{ + tsoperator.TailscaledConfigFileName(106): []byte("{}"), + }, + } + + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(pg, pgCfgSecret, pgConfigMap, tsIngressClass). + WithStatusSubresource(pg). + Build() + + // Set ProxyGroup status to ready + pg.Status.Conditions = []metav1.Condition{ + { + Type: string(tsapi.ProxyGroupReady), + Status: metav1.ConditionTrue, + ObservedGeneration: 1, + }, + } + if err := fc.Status().Update(context.Background(), pg); err != nil { + t.Fatal(err) + } + fakeTsnetServer := &fakeTSNetServer{certDomains: []string{"foo.com"}} + + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + + lc := &fakeLocalClient{ + status: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{ + MagicDNSSuffix: "ts.net", + }, + }, + } + + ingPGR := &HAIngressReconciler{ + Client: fc, + tsClient: ft, + defaultTags: []string{"tag:k8s"}, + tsNamespace: "operator-ns", + tsnetServer: fakeTsnetServer, + logger: zl.Sugar(), + recorder: record.NewFakeRecorder(10), + lc: lc, + } + + return ingPGR, fc, ft +} + +func TestIngressPGReconciler_MultiCluster(t *testing.T) { + ingPGR, fc, ft := setupIngressTest(t) + ingPGR.operatorID = "operator-1" + + // Create initial Ingress + ing := &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + Namespace: "default", + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + "tailscale.com/proxy-group": "test-pg", + }, + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"my-svc"}}, + }, + }, + } + mustCreate(t, fc, ing) + + // Simulate existing Tailscale Service from another cluster + existingVIPSvc := &tailscale.VIPService{ + Name: "svc:my-svc", + Annotations: map[string]string{ + ownerAnnotation: `{"ownerrefs":[{"operatorID":"operator-2"}]}`, + }, + } + ft.vipServices = map[tailcfg.ServiceName]*tailscale.VIPService{ + "svc:my-svc": existingVIPSvc, + } + + // Verify reconciliation adds our operator reference + expectReconciled(t, ingPGR, "default", "test-ingress") + + tsSvc, err := ft.GetVIPService(context.Background(), "svc:my-svc") + if err != nil { + t.Fatalf("getting Tailscale Service: %v", err) + } + if tsSvc == nil { + t.Fatal("Tailscale Service not found") + } + + o, err := parseOwnerAnnotation(tsSvc) + if err != nil { + t.Fatalf("parsing owner annotation: %v", err) + } + + wantOwnerRefs := []OwnerRef{ + {OperatorID: "operator-2"}, + {OperatorID: "operator-1"}, + } + if !reflect.DeepEqual(o.OwnerRefs, wantOwnerRefs) { + t.Errorf("incorrect owner refs\ngot: %+v\nwant: %+v", o.OwnerRefs, wantOwnerRefs) + } + + // Delete the Ingress and verify Tailscale Service still exists with one owner ref + if err := fc.Delete(context.Background(), ing); err != nil { + t.Fatalf("deleting Ingress: %v", err) + } + expectRequeue(t, ingPGR, "default", "test-ingress") + + tsSvc, err = ft.GetVIPService(context.Background(), "svc:my-svc") + if err != nil { + t.Fatalf("getting Tailscale Service after deletion: %v", err) + } + if tsSvc == nil { + t.Fatal("Tailscale Service was incorrectly deleted") + } + + o, err = parseOwnerAnnotation(tsSvc) + if err != nil { + t.Fatalf("parsing owner annotation: %v", err) + } + + wantOwnerRefs = []OwnerRef{ + {OperatorID: "operator-2"}, + } + if !reflect.DeepEqual(o.OwnerRefs, wantOwnerRefs) { + t.Errorf("incorrect owner refs after deletion\ngot: %+v\nwant: %+v", o.OwnerRefs, wantOwnerRefs) + } +} + +func populateTLSSecret(ctx context.Context, c client.Client, pgName, domain string) error { + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: domain, + Namespace: "operator-ns", + Labels: map[string]string{ + kubetypes.LabelManaged: "true", + labelProxyGroup: pgName, + labelDomain: domain, + kubetypes.LabelSecretType: "certs", + }, + }, + Type: corev1.SecretTypeTLS, + Data: map[string][]byte{ + corev1.TLSCertKey: []byte("fake-cert"), + corev1.TLSPrivateKeyKey: []byte("fake-key"), + }, + } + + _, err := createOrUpdate(ctx, c, "operator-ns", secret, func(s *corev1.Secret) { + s.Data = secret.Data + }) + return err +} diff --git a/cmd/k8s-operator/ingress.go b/cmd/k8s-operator/ingress.go index acc90d465093a..6c50e10b2ba94 100644 --- a/cmd/k8s-operator/ingress.go +++ b/cmd/k8s-operator/ingress.go @@ -26,6 +26,7 @@ import ( "tailscale.com/kube/kubetypes" "tailscale.com/types/opt" "tailscale.com/util/clientmetric" + "tailscale.com/util/mak" "tailscale.com/util/set" ) @@ -58,7 +59,7 @@ var ( ) func (a *IngressReconciler) Reconcile(ctx context.Context, req reconcile.Request) (_ reconcile.Result, err error) { - logger := a.logger.With("ingress-ns", req.Namespace, "ingress-name", req.Name) + logger := a.logger.With("Ingress", req.NamespacedName) logger.Debugf("starting reconcile") defer logger.Debugf("reconcile finished") @@ -72,11 +73,20 @@ func (a *IngressReconciler) Reconcile(ctx context.Context, req reconcile.Request return reconcile.Result{}, fmt.Errorf("failed to get ing: %w", err) } if !ing.DeletionTimestamp.IsZero() || !a.shouldExpose(ing) { + // TODO(irbekrm): this message is confusing if the Ingress is an HA Ingress logger.Debugf("ingress is being deleted or should not be exposed, cleaning up") return reconcile.Result{}, a.maybeCleanup(ctx, logger, ing) } - return reconcile.Result{}, a.maybeProvision(ctx, logger, ing) + if err := a.maybeProvision(ctx, logger, ing); err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + } else { + return reconcile.Result{}, err + } + } + + return reconcile.Result{}, nil } func (a *IngressReconciler) maybeCleanup(ctx context.Context, logger *zap.SugaredLogger, ing *networkingv1.Ingress) error { @@ -90,7 +100,7 @@ func (a *IngressReconciler) maybeCleanup(ctx context.Context, logger *zap.Sugare return nil } - if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(ing.Name, ing.Namespace, "ingress")); err != nil { + if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(ing.Name, ing.Namespace, "ingress"), proxyTypeIngressResource); err != nil { return fmt.Errorf("failed to cleanup: %w", err) } else if !done { logger.Debugf("cleanup not done yet, waiting for next reconcile") @@ -120,9 +130,8 @@ func (a *IngressReconciler) maybeCleanup(ctx context.Context, logger *zap.Sugare // This function adds a finalizer to ing, ensuring that we can handle orderly // deprovisioning later. func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.SugaredLogger, ing *networkingv1.Ingress) error { - if err := a.validateIngressClass(ctx); err != nil { + if err := validateIngressClass(ctx, a.Client); err != nil { logger.Warnf("error validating tailscale IngressClass: %v. In future this might be a terminal error.", err) - } if !slices.Contains(ing.Finalizers, FinalizerName) { // This log line is printed exactly once during initial provisioning, @@ -151,7 +160,7 @@ func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.Suga gaugeIngressResources.Set(int64(a.managedIngresses.Len())) a.mu.Unlock() - if !a.ssr.IsHTTPSEnabledOnTailnet() { + if !IsHTTPSEnabledOnTailnet(a.ssr.tsnetServer) { a.recorder.Event(ing, corev1.EventTypeWarning, "HTTPSNotEnabled", "HTTPS is not enabled on the tailnet; ingress may not work") } @@ -177,73 +186,16 @@ func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.Suga } web := sc.Web[magic443] - addIngressBackend := func(b *networkingv1.IngressBackend, path string) { - if b == nil { - return - } - if b.Service == nil { - a.recorder.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "backend for path %q is missing service", path) - return - } - var svc corev1.Service - if err := a.Get(ctx, types.NamespacedName{Namespace: ing.Namespace, Name: b.Service.Name}, &svc); err != nil { - a.recorder.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "failed to get service %q for path %q: %v", b.Service.Name, path, err) - return - } - if svc.Spec.ClusterIP == "" || svc.Spec.ClusterIP == "None" { - a.recorder.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "backend for path %q has invalid ClusterIP", path) - return - } - var port int32 - if b.Service.Port.Name != "" { - for _, p := range svc.Spec.Ports { - if p.Name == b.Service.Port.Name { - port = p.Port - break - } - } - } else { - port = b.Service.Port.Number - } - if port == 0 { - a.recorder.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "backend for path %q has invalid port", path) - return - } - proto := "http://" - if port == 443 || b.Service.Port.Name == "https" { - proto = "https+insecure://" - } - web.Handlers[path] = &ipn.HTTPHandler{ - Proxy: proto + svc.Spec.ClusterIP + ":" + fmt.Sprint(port) + path, - } - } - addIngressBackend(ing.Spec.DefaultBackend, "/") var tlsHost string // hostname or FQDN or empty if ing.Spec.TLS != nil && len(ing.Spec.TLS) > 0 && len(ing.Spec.TLS[0].Hosts) > 0 { tlsHost = ing.Spec.TLS[0].Hosts[0] } - for _, rule := range ing.Spec.Rules { - // Host is optional, but if it's present it must match the TLS host - // otherwise we ignore the rule. - if rule.Host != "" && rule.Host != tlsHost { - a.recorder.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "rule with host %q ignored, unsupported", rule.Host) - continue - } - for _, p := range rule.HTTP.Paths { - // Send a warning if folks use Exact path type - to make - // it easier for us to support Exact path type matching - // in the future if needed. - // https://kubernetes.io/docs/concepts/services-networking/ingress/#path-types - if *p.PathType == networkingv1.PathTypeExact { - msg := "Exact path type strict matching is currently not supported and requests will be routed as for Prefix path type. This behaviour might change in the future." - logger.Warnf(fmt.Sprintf("Unsupported Path type exact for path %s. %s", p.Path, msg)) - a.recorder.Eventf(ing, corev1.EventTypeWarning, "UnsupportedPathTypeExact", msg) - } - addIngressBackend(&p.Backend, p.Path) - } + handlers, err := handlersForIngress(ctx, ing, a.Client, a.recorder, tlsHost, logger) + if err != nil { + return fmt.Errorf("failed to get handlers for ingress: %w", err) } - + web.Handlers = handlers if len(web.Handlers) == 0 { logger.Warn("Ingress contains no valid backends") a.recorder.Eventf(ing, corev1.EventTypeWarning, "NoValidBackends", "no valid backends") @@ -255,10 +207,7 @@ func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.Suga if tstr, ok := ing.Annotations[AnnotationTags]; ok { tags = strings.Split(tstr, ",") } - hostname := ing.Namespace + "-" + ing.Name + "-ingress" - if tlsHost != "" { - hostname, _, _ = strings.Cut(tlsHost, ".") - } + hostname := hostnameForIngress(ing) sts := &tailscaleSTSConfig{ Hostname: hostname, @@ -268,6 +217,7 @@ func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.Suga Tags: tags, ChildResourceLabels: crl, ProxyClassName: proxyClass, + proxyType: proxyTypeIngressResource, } if val := ing.GetAnnotations()[AnnotationExperimentalForwardClusterTrafficViaL7IngresProxy]; val == "true" { @@ -278,12 +228,12 @@ func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.Suga return fmt.Errorf("failed to provision: %w", err) } - _, tsHost, _, err := a.ssr.DeviceInfo(ctx, crl) + dev, err := a.ssr.DeviceInfo(ctx, crl, logger) if err != nil { - return fmt.Errorf("failed to get device ID: %w", err) + return fmt.Errorf("failed to retrieve Ingress HTTPS endpoint status: %w", err) } - if tsHost == "" { - logger.Debugf("no Tailscale hostname known yet, waiting for proxy pod to finish auth") + if dev == nil || dev.ingressDNSName == "" { + logger.Debugf("no Ingress DNS name known yet, waiting for proxy Pod initialize and start serving Ingress") // No hostname yet. Wait for the proxy pod to auth. ing.Status.LoadBalancer.Ingress = nil if err := a.Status().Update(ctx, ing); err != nil { @@ -292,10 +242,10 @@ func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.Suga return nil } - logger.Debugf("setting ingress hostname to %q", tsHost) + logger.Debugf("setting Ingress hostname to %q", dev.ingressDNSName) ing.Status.LoadBalancer.Ingress = []networkingv1.IngressLoadBalancerIngress{ { - Hostname: tsHost, + Hostname: dev.ingressDNSName, Ports: []networkingv1.IngressPortStatus{ { Protocol: "TCP", @@ -313,28 +263,112 @@ func (a *IngressReconciler) maybeProvision(ctx context.Context, logger *zap.Suga func (a *IngressReconciler) shouldExpose(ing *networkingv1.Ingress) bool { return ing != nil && ing.Spec.IngressClassName != nil && - *ing.Spec.IngressClassName == tailscaleIngressClassName + *ing.Spec.IngressClassName == tailscaleIngressClassName && + ing.Annotations[AnnotationProxyGroup] == "" } // validateIngressClass attempts to validate that 'tailscale' IngressClass // included in Tailscale installation manifests exists and has not been modified // to attempt to enable features that we do not support. -func (a *IngressReconciler) validateIngressClass(ctx context.Context) error { +func validateIngressClass(ctx context.Context, cl client.Client) error { ic := &networkingv1.IngressClass{ ObjectMeta: metav1.ObjectMeta{ Name: tailscaleIngressClassName, }, } - if err := a.Get(ctx, client.ObjectKeyFromObject(ic), ic); apierrors.IsNotFound(err) { - return errors.New("Tailscale IngressClass not found in cluster. Latest installation manifests include a tailscale IngressClass - please update") + if err := cl.Get(ctx, client.ObjectKeyFromObject(ic), ic); apierrors.IsNotFound(err) { + return errors.New("'tailscale' IngressClass not found in cluster.") } else if err != nil { return fmt.Errorf("error retrieving 'tailscale' IngressClass: %w", err) } if ic.Spec.Controller != tailscaleIngressControllerName { - return fmt.Errorf("Tailscale Ingress class controller name %s does not match tailscale Ingress controller name %s. Ensure that you are using 'tailscale' IngressClass from latest Tailscale installation manifests", ic.Spec.Controller, tailscaleIngressControllerName) + return fmt.Errorf("'tailscale' Ingress class controller name %s does not match tailscale Ingress controller name %s. Ensure that you are using 'tailscale' IngressClass from latest Tailscale installation manifests", ic.Spec.Controller, tailscaleIngressControllerName) } if ic.GetAnnotations()[ingressClassDefaultAnnotation] != "" { return fmt.Errorf("%s annotation is set on 'tailscale' IngressClass, but Tailscale Ingress controller does not support default Ingress class. Ensure that you are using 'tailscale' IngressClass from latest Tailscale installation manifests", ingressClassDefaultAnnotation) } return nil } + +func handlersForIngress(ctx context.Context, ing *networkingv1.Ingress, cl client.Client, rec record.EventRecorder, tlsHost string, logger *zap.SugaredLogger) (handlers map[string]*ipn.HTTPHandler, err error) { + addIngressBackend := func(b *networkingv1.IngressBackend, path string) { + if path == "" { + path = "/" + rec.Eventf(ing, corev1.EventTypeNormal, "PathUndefined", "configured backend is missing a path, defaulting to '/'") + } + + if b == nil { + return + } + + if b.Service == nil { + rec.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "backend for path %q is missing service", path) + return + } + var svc corev1.Service + if err := cl.Get(ctx, types.NamespacedName{Namespace: ing.Namespace, Name: b.Service.Name}, &svc); err != nil { + rec.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "failed to get service %q for path %q: %v", b.Service.Name, path, err) + return + } + if svc.Spec.ClusterIP == "" || svc.Spec.ClusterIP == "None" { + rec.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "backend for path %q has invalid ClusterIP", path) + return + } + var port int32 + if b.Service.Port.Name != "" { + for _, p := range svc.Spec.Ports { + if p.Name == b.Service.Port.Name { + port = p.Port + break + } + } + } else { + port = b.Service.Port.Number + } + if port == 0 { + rec.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "backend for path %q has invalid port", path) + return + } + proto := "http://" + if port == 443 || b.Service.Port.Name == "https" { + proto = "https+insecure://" + } + mak.Set(&handlers, path, &ipn.HTTPHandler{ + Proxy: proto + svc.Spec.ClusterIP + ":" + fmt.Sprint(port) + path, + }) + } + addIngressBackend(ing.Spec.DefaultBackend, "/") + for _, rule := range ing.Spec.Rules { + // Host is optional, but if it's present it must match the TLS host + // otherwise we ignore the rule. + if rule.Host != "" && rule.Host != tlsHost { + rec.Eventf(ing, corev1.EventTypeWarning, "InvalidIngressBackend", "rule with host %q ignored, unsupported", rule.Host) + continue + } + for _, p := range rule.HTTP.Paths { + // Send a warning if folks use Exact path type - to make + // it easier for us to support Exact path type matching + // in the future if needed. + // https://kubernetes.io/docs/concepts/services-networking/ingress/#path-types + if *p.PathType == networkingv1.PathTypeExact { + msg := "Exact path type strict matching is currently not supported and requests will be routed as for Prefix path type. This behaviour might change in the future." + logger.Warnf(fmt.Sprintf("Unsupported Path type exact for path %s. %s", p.Path, msg)) + rec.Eventf(ing, corev1.EventTypeWarning, "UnsupportedPathTypeExact", msg) + } + addIngressBackend(&p.Backend, p.Path) + } + } + return handlers, nil +} + +// hostnameForIngress returns the hostname for an Ingress resource. +// If the Ingress has TLS configured with a host, it returns the first component of that host. +// Otherwise, it returns a hostname derived from the Ingress name and namespace. +func hostnameForIngress(ing *networkingv1.Ingress) string { + if ing.Spec.TLS != nil && len(ing.Spec.TLS) > 0 && len(ing.Spec.TLS[0].Hosts) > 0 { + h := ing.Spec.TLS[0].Hosts[0] + hostname, _, _ := strings.Cut(h, ".") + return hostname + } + return ing.Namespace + "-" + ing.Name + "-ingress" +} diff --git a/cmd/k8s-operator/ingress_test.go b/cmd/k8s-operator/ingress_test.go index 38a041dde07f9..a975fec7a6329 100644 --- a/cmd/k8s-operator/ingress_test.go +++ b/cmd/k8s-operator/ingress_test.go @@ -6,25 +6,29 @@ package main import ( + "context" "testing" "go.uber.org/zap" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" networkingv1 "k8s.io/api/networking/v1" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" "tailscale.com/ipn" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" + "tailscale.com/tstest" "tailscale.com/types/ptr" "tailscale.com/util/mak" ) func TestTailscaleIngress(t *testing.T) { - tsIngressClass := &networkingv1.IngressClass{ObjectMeta: metav1.ObjectMeta{Name: "tailscale"}, Spec: networkingv1.IngressClassSpec{Controller: "tailscale.com/ts-ingress"}} - fc := fake.NewFakeClient(tsIngressClass) + fc := fake.NewFakeClient(ingressClass()) ft := &fakeTSClient{} fakeTsnetServer := &fakeTSNetServer{certDomains: []string{"foo.com"}} zl, err := zap.NewDevelopment() @@ -45,45 +49,8 @@ func TestTailscaleIngress(t *testing.T) { } // 1. Resources get created for regular Ingress - ing := &networkingv1.Ingress{ - TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - Namespace: "default", - // The apiserver is supposed to set the UID, but the fake client - // doesn't. So, set it explicitly because other code later depends - // on it being set. - UID: types.UID("1234-UID"), - }, - Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), - DefaultBackend: &networkingv1.IngressBackend{ - Service: &networkingv1.IngressServiceBackend{ - Name: "test", - Port: networkingv1.ServiceBackendPort{ - Number: 8080, - }, - }, - }, - TLS: []networkingv1.IngressTLS{ - {Hosts: []string{"default-test"}}, - }, - }, - } - mustCreate(t, fc, ing) - mustCreate(t, fc, &corev1.Service{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - Namespace: "default", - }, - Spec: corev1.ServiceSpec{ - ClusterIP: "1.2.3.4", - Ports: []corev1.ServicePort{{ - Port: 8080, - Name: "http"}, - }, - }, - }) + mustCreate(t, fc, ingress()) + mustCreate(t, fc, service()) expectReconciled(t, ingR, "default", "test") @@ -102,9 +69,9 @@ func TestTailscaleIngress(t *testing.T) { } opts.serveConfig = serveConfig - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "ingress"), nil) - expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "ingress")) + expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation, removeResourceReqs) // 2. Ingress status gets updated with ingress proxy's MagicDNS name // once that becomes available. @@ -113,13 +80,16 @@ func TestTailscaleIngress(t *testing.T) { mak.Set(&secret.Data, "device_fqdn", []byte("foo.tailnetxyz.ts.net")) }) expectReconciled(t, ingR, "default", "test") + + // Get the ingress and update it with expected changes + ing := ingress() ing.Finalizers = append(ing.Finalizers, "tailscale.com/finalizer") ing.Status.LoadBalancer = networkingv1.IngressLoadBalancerStatus{ Ingress: []networkingv1.IngressLoadBalancerIngress{ {Hostname: "foo.tailnetxyz.ts.net", Ports: []networkingv1.IngressPortStatus{{Port: 443, Protocol: "TCP"}}}, }, } - expectEqual(t, fc, ing, nil) + expectEqual(t, fc, ing) // 3. Resources get created for Ingress that should allow forwarding // cluster traffic @@ -128,7 +98,7 @@ func TestTailscaleIngress(t *testing.T) { }) opts.shouldEnableForwardingClusterTrafficViaIngress = true expectReconciled(t, ingR, "default", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) // 4. Resources get cleaned up when Ingress class is unset mustUpdate(t, fc, "default", "test", func(ing *networkingv1.Ingress) { @@ -141,19 +111,130 @@ func TestTailscaleIngress(t *testing.T) { expectMissing[corev1.Secret](t, fc, "operator-ns", fullName) } +func TestTailscaleIngressHostname(t *testing.T) { + fc := fake.NewFakeClient(ingressClass()) + ft := &fakeTSClient{} + fakeTsnetServer := &fakeTSNetServer{certDomains: []string{"foo.com"}} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + ingR := &IngressReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + tsnetServer: fakeTsnetServer, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + } + + // 1. Resources get created for regular Ingress + mustCreate(t, fc, ingress()) + mustCreate(t, fc, service()) + + expectReconciled(t, ingR, "default", "test") + + fullName, shortName := findGenName(t, fc, "default", "test", "ingress") + mustCreate(t, fc, &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: fullName, + Namespace: "operator-ns", + UID: "test-uid", + }, + }) + opts := configOpts{ + stsName: shortName, + secretName: fullName, + namespace: "default", + parentType: "ingress", + hostname: "default-test", + app: kubetypes.AppIngressResource, + } + serveConfig := &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{443: {HTTPS: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{"${TS_CERT_DOMAIN}:443": {Handlers: map[string]*ipn.HTTPHandler{"/": {Proxy: "http://1.2.3.4:8080/"}}}}, + } + opts.serveConfig = serveConfig + + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "ingress")) + expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation, removeResourceReqs) + + // 2. Ingress proxy with capability version >= 110 does not have an HTTPS endpoint set + mustUpdate(t, fc, "operator-ns", opts.secretName, func(secret *corev1.Secret) { + mak.Set(&secret.Data, "device_id", []byte("1234")) + mak.Set(&secret.Data, "tailscale_capver", []byte("110")) + mak.Set(&secret.Data, "pod_uid", []byte("test-uid")) + mak.Set(&secret.Data, "device_fqdn", []byte("foo.tailnetxyz.ts.net")) + }) + expectReconciled(t, ingR, "default", "test") + + // Get the ingress and update it with expected changes + ing := ingress() + ing.Finalizers = append(ing.Finalizers, "tailscale.com/finalizer") + expectEqual(t, fc, ing) + + // 3. Ingress proxy with capability version >= 110 advertises HTTPS endpoint + mustUpdate(t, fc, "operator-ns", opts.secretName, func(secret *corev1.Secret) { + mak.Set(&secret.Data, "device_id", []byte("1234")) + mak.Set(&secret.Data, "tailscale_capver", []byte("110")) + mak.Set(&secret.Data, "pod_uid", []byte("test-uid")) + mak.Set(&secret.Data, "device_fqdn", []byte("foo.tailnetxyz.ts.net")) + mak.Set(&secret.Data, "https_endpoint", []byte("foo.tailnetxyz.ts.net")) + }) + expectReconciled(t, ingR, "default", "test") + ing.Status.LoadBalancer = networkingv1.IngressLoadBalancerStatus{ + Ingress: []networkingv1.IngressLoadBalancerIngress{ + {Hostname: "foo.tailnetxyz.ts.net", Ports: []networkingv1.IngressPortStatus{{Port: 443, Protocol: "TCP"}}}, + }, + } + expectEqual(t, fc, ing) + + // 4. Ingress proxy with capability version >= 110 does not have an HTTPS endpoint ready + mustUpdate(t, fc, "operator-ns", opts.secretName, func(secret *corev1.Secret) { + mak.Set(&secret.Data, "device_id", []byte("1234")) + mak.Set(&secret.Data, "tailscale_capver", []byte("110")) + mak.Set(&secret.Data, "pod_uid", []byte("test-uid")) + mak.Set(&secret.Data, "device_fqdn", []byte("foo.tailnetxyz.ts.net")) + mak.Set(&secret.Data, "https_endpoint", []byte("no-https")) + }) + expectReconciled(t, ingR, "default", "test") + ing.Status.LoadBalancer.Ingress = nil + expectEqual(t, fc, ing) + + // 5. Ingress proxy's state has https_endpoints set, but its capver is not matching Pod UID (downgrade) + mustUpdate(t, fc, "operator-ns", opts.secretName, func(secret *corev1.Secret) { + mak.Set(&secret.Data, "device_id", []byte("1234")) + mak.Set(&secret.Data, "tailscale_capver", []byte("110")) + mak.Set(&secret.Data, "pod_uid", []byte("not-the-right-uid")) + mak.Set(&secret.Data, "device_fqdn", []byte("foo.tailnetxyz.ts.net")) + mak.Set(&secret.Data, "https_endpoint", []byte("bar.tailnetxyz.ts.net")) + }) + ing.Status.LoadBalancer = networkingv1.IngressLoadBalancerStatus{ + Ingress: []networkingv1.IngressLoadBalancerIngress{ + {Hostname: "foo.tailnetxyz.ts.net", Ports: []networkingv1.IngressPortStatus{{Port: 443, Protocol: "TCP"}}}, + }, + } + expectReconciled(t, ingR, "default", "test") + expectEqual(t, fc, ing) +} + func TestTailscaleIngressWithProxyClass(t *testing.T) { // Setup pc := &tsapi.ProxyClass{ ObjectMeta: metav1.ObjectMeta{Name: "custom-metadata"}, Spec: tsapi.ProxyClassSpec{StatefulSet: &tsapi.StatefulSet{ - Labels: map[string]string{"foo": "bar"}, + Labels: tsapi.Labels{"foo": "bar"}, Annotations: map[string]string{"bar.io/foo": "some-val"}, Pod: &tsapi.Pod{Annotations: map[string]string{"foo.io/bar": "some-val"}}}}, } - tsIngressClass := &networkingv1.IngressClass{ObjectMeta: metav1.ObjectMeta{Name: "tailscale"}, Spec: networkingv1.IngressClassSpec{Controller: "tailscale.com/ts-ingress"}} fc := fake.NewClientBuilder(). WithScheme(tsapi.GlobalScheme). - WithObjects(pc, tsIngressClass). + WithObjects(pc, ingressClass()). WithStatusSubresource(pc). Build() ft := &fakeTSClient{} @@ -177,45 +258,8 @@ func TestTailscaleIngressWithProxyClass(t *testing.T) { // 1. Ingress is created with no ProxyClass specified, default proxy // resources get configured. - ing := &networkingv1.Ingress{ - TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - Namespace: "default", - // The apiserver is supposed to set the UID, but the fake client - // doesn't. So, set it explicitly because other code later depends - // on it being set. - UID: types.UID("1234-UID"), - }, - Spec: networkingv1.IngressSpec{ - IngressClassName: ptr.To("tailscale"), - DefaultBackend: &networkingv1.IngressBackend{ - Service: &networkingv1.IngressServiceBackend{ - Name: "test", - Port: networkingv1.ServiceBackendPort{ - Number: 8080, - }, - }, - }, - TLS: []networkingv1.IngressTLS{ - {Hosts: []string{"default-test"}}, - }, - }, - } - mustCreate(t, fc, ing) - mustCreate(t, fc, &corev1.Service{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - Namespace: "default", - }, - Spec: corev1.ServiceSpec{ - ClusterIP: "1.2.3.4", - Ports: []corev1.ServicePort{{ - Port: 8080, - Name: "http"}, - }, - }, - }) + mustCreate(t, fc, ingress()) + mustCreate(t, fc, service()) expectReconciled(t, ingR, "default", "test") @@ -234,9 +278,9 @@ func TestTailscaleIngressWithProxyClass(t *testing.T) { } opts.serveConfig = serveConfig - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "ingress"), nil) - expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "ingress")) + expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation, removeResourceReqs) // 2. Ingress is updated to specify a ProxyClass, ProxyClass is not yet // ready, so proxy resource configuration does not change. @@ -244,7 +288,7 @@ func TestTailscaleIngressWithProxyClass(t *testing.T) { mak.Set(&ing.ObjectMeta.Labels, LabelProxyClass, "custom-metadata") }) expectReconciled(t, ingR, "default", "test") - expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation, removeResourceReqs) // 3. ProxyClass is set to Ready by proxy-class reconciler. Ingress get // reconciled and configuration from the ProxyClass is applied to the @@ -259,7 +303,7 @@ func TestTailscaleIngressWithProxyClass(t *testing.T) { }) expectReconciled(t, ingR, "default", "test") opts.proxyClass = pc.Name - expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation, removeResourceReqs) // 4. tailscale.com/proxy-class label is removed from the Ingress, the // Ingress gets reconciled and the custom ProxyClass configuration is @@ -269,5 +313,388 @@ func TestTailscaleIngressWithProxyClass(t *testing.T) { }) expectReconciled(t, ingR, "default", "test") opts.proxyClass = "" - expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation, removeResourceReqs) +} + +func TestTailscaleIngressWithServiceMonitor(t *testing.T) { + pc := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{Name: "metrics", Generation: 1}, + Spec: tsapi.ProxyClassSpec{}, + Status: tsapi.ProxyClassStatus{ + Conditions: []metav1.Condition{{ + Status: metav1.ConditionTrue, + Type: string(tsapi.ProxyClassReady), + ObservedGeneration: 1, + }}}, + } + crd := &apiextensionsv1.CustomResourceDefinition{ObjectMeta: metav1.ObjectMeta{Name: serviceMonitorCRD}} + + // Create fake client with ProxyClass, IngressClass, Ingress with metrics ProxyClass, and Service + ing := ingress() + ing.Labels = map[string]string{ + LabelProxyClass: "metrics", + } + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(pc, ingressClass(), ing, service()). + WithStatusSubresource(pc). + Build() + + ft := &fakeTSClient{} + fakeTsnetServer := &fakeTSNetServer{certDomains: []string{"foo.com"}} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + ingR := &IngressReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + tsnetServer: fakeTsnetServer, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + } + expectReconciled(t, ingR, "default", "test") + fullName, shortName := findGenName(t, fc, "default", "test", "ingress") + serveConfig := &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{443: {HTTPS: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{"${TS_CERT_DOMAIN}:443": {Handlers: map[string]*ipn.HTTPHandler{"/": {Proxy: "http://1.2.3.4:8080/"}}}}, + } + opts := configOpts{ + stsName: shortName, + secretName: fullName, + namespace: "default", + tailscaleNamespace: "operator-ns", + parentType: "ingress", + hostname: "default-test", + app: kubetypes.AppIngressResource, + namespaced: true, + proxyType: proxyTypeIngressResource, + serveConfig: serveConfig, + resourceVersion: "1", + } + + // 1. Enable metrics- expect metrics Service to be created + mustUpdate(t, fc, "", "metrics", func(proxyClass *tsapi.ProxyClass) { + proxyClass.Spec.Metrics = &tsapi.Metrics{Enable: true} + }) + opts.enableMetrics = true + + expectReconciled(t, ingR, "default", "test") + + expectEqual(t, fc, expectedMetricsService(opts)) + + // 2. Enable ServiceMonitor - should not error when there is no ServiceMonitor CRD in cluster + mustUpdate(t, fc, "", "metrics", func(pc *tsapi.ProxyClass) { + pc.Spec.Metrics.ServiceMonitor = &tsapi.ServiceMonitor{Enable: true, Labels: tsapi.Labels{"foo": "bar"}} + }) + expectReconciled(t, ingR, "default", "test") + expectEqual(t, fc, expectedMetricsService(opts)) + + // 3. Create ServiceMonitor CRD and reconcile- ServiceMonitor should get created + mustCreate(t, fc, crd) + expectReconciled(t, ingR, "default", "test") + opts.serviceMonitorLabels = tsapi.Labels{"foo": "bar"} + expectEqual(t, fc, expectedMetricsService(opts)) + expectEqualUnstructured(t, fc, expectedServiceMonitor(t, opts)) + + // 4. Update ServiceMonitor CRD and reconcile- ServiceMonitor should get updated + mustUpdate(t, fc, pc.Namespace, pc.Name, func(proxyClass *tsapi.ProxyClass) { + proxyClass.Spec.Metrics.ServiceMonitor.Labels = nil + }) + expectReconciled(t, ingR, "default", "test") + opts.serviceMonitorLabels = nil + opts.resourceVersion = "2" + expectEqual(t, fc, expectedMetricsService(opts)) + expectEqualUnstructured(t, fc, expectedServiceMonitor(t, opts)) + + // 5. Disable metrics - metrics resources should get deleted. + mustUpdate(t, fc, pc.Namespace, pc.Name, func(proxyClass *tsapi.ProxyClass) { + proxyClass.Spec.Metrics = nil + }) + expectReconciled(t, ingR, "default", "test") + expectMissing[corev1.Service](t, fc, "operator-ns", metricsResourceName(shortName)) + // ServiceMonitor gets garbage collected when the Service is deleted - we cannot test that here. +} + +func TestIngressLetsEncryptStaging(t *testing.T) { + cl := tstest.NewClock(tstest.ClockOpts{}) + zl := zap.Must(zap.NewDevelopment()) + + pcLEStaging, pcLEStagingFalse, pcOther := proxyClassesForLEStagingTest() + + testCases := testCasesForLEStagingTests(pcLEStaging, pcLEStagingFalse, pcOther) + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + builder := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme) + + builder = builder.WithObjects(pcLEStaging, pcLEStagingFalse, pcOther). + WithStatusSubresource(pcLEStaging, pcLEStagingFalse, pcOther) + + fc := builder.Build() + + if tt.proxyClassPerResource != "" || tt.defaultProxyClass != "" { + name := tt.proxyClassPerResource + if name == "" { + name = tt.defaultProxyClass + } + setProxyClassReady(t, fc, cl, name) + } + + mustCreate(t, fc, ingressClass()) + mustCreate(t, fc, service()) + ing := ingress() + if tt.proxyClassPerResource != "" { + ing.Labels = map[string]string{ + LabelProxyClass: tt.proxyClassPerResource, + } + } + mustCreate(t, fc, ing) + + ingR := &IngressReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: &fakeTSClient{}, + tsnetServer: &fakeTSNetServer{certDomains: []string{"test-host"}}, + defaultTags: []string{"tag:test"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale:test", + }, + logger: zl.Sugar(), + defaultProxyClass: tt.defaultProxyClass, + } + + expectReconciled(t, ingR, "default", "test") + + _, shortName := findGenName(t, fc, "default", "test", "ingress") + sts := &appsv1.StatefulSet{} + if err := fc.Get(context.Background(), client.ObjectKey{Namespace: "operator-ns", Name: shortName}, sts); err != nil { + t.Fatalf("failed to get StatefulSet: %v", err) + } + + if tt.useLEStagingEndpoint { + verifyEnvVar(t, sts, "TS_DEBUG_ACME_DIRECTORY_URL", letsEncryptStagingEndpoint) + } else { + verifyEnvVarNotPresent(t, sts, "TS_DEBUG_ACME_DIRECTORY_URL") + } + }) + } +} + +func TestEmptyPath(t *testing.T) { + testCases := []struct { + name string + paths []networkingv1.HTTPIngressPath + expectedEvents []string + }{ + { + name: "empty_path_with_prefix_type", + paths: []networkingv1.HTTPIngressPath{ + { + PathType: ptrPathType(networkingv1.PathTypePrefix), + Path: "", + Backend: *backend(), + }, + }, + expectedEvents: []string{ + "Normal PathUndefined configured backend is missing a path, defaulting to '/'", + }, + }, + { + name: "empty_path_with_implementation_specific_type", + paths: []networkingv1.HTTPIngressPath{ + { + PathType: ptrPathType(networkingv1.PathTypeImplementationSpecific), + Path: "", + Backend: *backend(), + }, + }, + expectedEvents: []string{ + "Normal PathUndefined configured backend is missing a path, defaulting to '/'", + }, + }, + { + name: "empty_path_with_exact_type", + paths: []networkingv1.HTTPIngressPath{ + { + PathType: ptrPathType(networkingv1.PathTypeExact), + Path: "", + Backend: *backend(), + }, + }, + expectedEvents: []string{ + "Warning UnsupportedPathTypeExact Exact path type strict matching is currently not supported and requests will be routed as for Prefix path type. This behaviour might change in the future.", + "Normal PathUndefined configured backend is missing a path, defaulting to '/'", + }, + }, + { + name: "two_competing_but_not_identical_paths_including_one_empty", + paths: []networkingv1.HTTPIngressPath{ + { + PathType: ptrPathType(networkingv1.PathTypeImplementationSpecific), + Path: "", + Backend: *backend(), + }, + { + PathType: ptrPathType(networkingv1.PathTypeImplementationSpecific), + Path: "/", + Backend: *backend(), + }, + }, + expectedEvents: []string{ + "Normal PathUndefined configured backend is missing a path, defaulting to '/'", + }, + }, + } + + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + fc := fake.NewFakeClient(ingressClass()) + ft := &fakeTSClient{} + fr := record.NewFakeRecorder(3) // bump this if you expect a test case to throw more events + fakeTsnetServer := &fakeTSNetServer{certDomains: []string{"foo.com"}} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + ingR := &IngressReconciler{ + recorder: fr, + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + tsnetServer: fakeTsnetServer, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + } + + // 1. Resources get created for regular Ingress + mustCreate(t, fc, ingressWithPaths(tt.paths)) + mustCreate(t, fc, service()) + + expectReconciled(t, ingR, "default", "test") + + fullName, shortName := findGenName(t, fc, "default", "test", "ingress") + mustCreate(t, fc, &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: fullName, + Namespace: "operator-ns", + UID: "test-uid", + }, + }) + opts := configOpts{ + stsName: shortName, + secretName: fullName, + namespace: "default", + parentType: "ingress", + hostname: "foo", + app: kubetypes.AppIngressResource, + } + serveConfig := &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{443: {HTTPS: true}}, + Web: map[ipn.HostPort]*ipn.WebServerConfig{"${TS_CERT_DOMAIN}:443": {Handlers: map[string]*ipn.HTTPHandler{"/": {Proxy: "http://1.2.3.4:8080/"}}}}, + } + opts.serveConfig = serveConfig + + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "ingress")) + expectEqual(t, fc, expectedSTSUserspace(t, fc, opts), removeHashAnnotation, removeResourceReqs) + + expectEvents(t, fr, tt.expectedEvents) + }) + } +} + +// ptrPathType is a helper function to return a pointer to the pathtype string (required for TestEmptyPath) +func ptrPathType(p networkingv1.PathType) *networkingv1.PathType { + return &p +} + +func ingressClass() *networkingv1.IngressClass { + return &networkingv1.IngressClass{ + ObjectMeta: metav1.ObjectMeta{Name: "tailscale"}, + Spec: networkingv1.IngressClassSpec{Controller: "tailscale.com/ts-ingress"}, + } +} + +func service() *corev1.Service { + return &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "1.2.3.4", + Ports: []corev1.ServicePort{{ + Port: 8080, + Name: "http"}, + }, + }, + } +} + +func ingress() *networkingv1.Ingress { + return &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + UID: types.UID("1234-UID"), + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + DefaultBackend: backend(), + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"default-test"}}, + }, + }, + } +} + +func ingressWithPaths(paths []networkingv1.HTTPIngressPath) *networkingv1.Ingress { + return &networkingv1.Ingress{ + TypeMeta: metav1.TypeMeta{Kind: "Ingress", APIVersion: "networking.k8s.io/v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + UID: types.UID("1234-UID"), + }, + Spec: networkingv1.IngressSpec{ + IngressClassName: ptr.To("tailscale"), + Rules: []networkingv1.IngressRule{ + { + Host: "foo.tailnetxyz.ts.net", + IngressRuleValue: networkingv1.IngressRuleValue{ + HTTP: &networkingv1.HTTPIngressRuleValue{ + Paths: paths, + }, + }, + }, + }, + TLS: []networkingv1.IngressTLS{ + {Hosts: []string{"foo.tailnetxyz.ts.net"}}, + }, + }, + } +} + +func backend() *networkingv1.IngressBackend { + return &networkingv1.IngressBackend{ + Service: &networkingv1.IngressServiceBackend{ + Name: "test", + Port: networkingv1.ServiceBackendPort{ + Number: 8080, + }, + }, + } } diff --git a/cmd/k8s-operator/metrics_resources.go b/cmd/k8s-operator/metrics_resources.go new file mode 100644 index 0000000000000..0579e34661a11 --- /dev/null +++ b/cmd/k8s-operator/metrics_resources.go @@ -0,0 +1,296 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "fmt" + "reflect" + + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" +) + +const ( + labelMetricsTarget = "tailscale.com/metrics-target" + + // These labels get transferred from the metrics Service to the ingested Prometheus metrics. + labelPromProxyType = "ts_proxy_type" + labelPromProxyParentName = "ts_proxy_parent_name" + labelPromProxyParentNamespace = "ts_proxy_parent_namespace" + labelPromJob = "ts_prom_job" + + serviceMonitorCRD = "servicemonitors.monitoring.coreos.com" +) + +// ServiceMonitor contains a subset of fields of servicemonitors.monitoring.coreos.com Custom Resource Definition. +// Duplicating it here allows us to avoid importing prometheus-operator library. +// https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L40 +type ServiceMonitor struct { + metav1.TypeMeta `json:",inline"` + metav1.ObjectMeta `json:"metadata"` + Spec ServiceMonitorSpec `json:"spec"` +} + +// https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L55 +type ServiceMonitorSpec struct { + // Endpoints defines the endpoints to be scraped on the selected Service(s). + // https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L82 + Endpoints []ServiceMonitorEndpoint `json:"endpoints"` + // JobLabel is the label on the Service whose value will become the value of the Prometheus job label for the metrics ingested via this ServiceMonitor. + // https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L66 + JobLabel string `json:"jobLabel"` + // NamespaceSelector selects the namespace of Service(s) that this ServiceMonitor allows to scrape. + // https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L88 + NamespaceSelector ServiceMonitorNamespaceSelector `json:"namespaceSelector,omitempty"` + // Selector is the label selector for Service(s) that this ServiceMonitor allows to scrape. + // https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L85 + Selector metav1.LabelSelector `json:"selector"` + // TargetLabels are labels on the selected Service that should be applied as Prometheus labels to the ingested metrics. + // https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L72 + TargetLabels []string `json:"targetLabels"` +} + +// ServiceMonitorNamespaceSelector selects namespaces in which Prometheus operator will attempt to find Services for +// this ServiceMonitor. +// https://github.com/prometheus-operator/prometheus-operator/blob/bb4514e0d5d69f20270e29cfd4ad39b87865ccdf/pkg/apis/monitoring/v1/servicemonitor_types.go#L88 +type ServiceMonitorNamespaceSelector struct { + MatchNames []string `json:"matchNames,omitempty"` +} + +// ServiceMonitorEndpoint defines an endpoint of Service to scrape. We only define port here. Prometheus by default +// scrapes /metrics path, which is what we want. +type ServiceMonitorEndpoint struct { + // Port is the name of the Service port that Prometheus will scrape. + Port string `json:"port,omitempty"` +} + +func reconcileMetricsResources(ctx context.Context, logger *zap.SugaredLogger, opts *metricsOpts, pc *tsapi.ProxyClass, cl client.Client) error { + if opts.proxyType == proxyTypeEgress { + // Metrics are currently not being enabled for standalone egress proxies. + return nil + } + if pc == nil || pc.Spec.Metrics == nil || !pc.Spec.Metrics.Enable { + return maybeCleanupMetricsResources(ctx, opts, cl) + } + metricsSvc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: metricsResourceName(opts.proxyStsName), + Namespace: opts.tsNamespace, + Labels: metricsResourceLabels(opts), + }, + Spec: corev1.ServiceSpec{ + Selector: opts.proxyLabels, + Type: corev1.ServiceTypeClusterIP, + Ports: []corev1.ServicePort{{Protocol: "TCP", Port: 9002, Name: "metrics"}}, + }, + } + var err error + metricsSvc, err = createOrUpdate(ctx, cl, opts.tsNamespace, metricsSvc, func(svc *corev1.Service) { + svc.Spec.Ports = metricsSvc.Spec.Ports + svc.Spec.Selector = metricsSvc.Spec.Selector + }) + if err != nil { + return fmt.Errorf("error ensuring metrics Service: %w", err) + } + + crdExists, err := hasServiceMonitorCRD(ctx, cl) + if err != nil { + return fmt.Errorf("error verifying that %q CRD exists: %w", serviceMonitorCRD, err) + } + if !crdExists { + return nil + } + + if pc.Spec.Metrics.ServiceMonitor == nil || !pc.Spec.Metrics.ServiceMonitor.Enable { + return maybeCleanupServiceMonitor(ctx, cl, opts.proxyStsName, opts.tsNamespace) + } + + logger.Infof("ensuring ServiceMonitor for metrics Service %s/%s", metricsSvc.Namespace, metricsSvc.Name) + svcMonitor, err := newServiceMonitor(metricsSvc, pc.Spec.Metrics.ServiceMonitor) + if err != nil { + return fmt.Errorf("error creating ServiceMonitor: %w", err) + } + + // We don't use createOrUpdate here because that does not work with unstructured types. + existing := svcMonitor.DeepCopy() + err = cl.Get(ctx, client.ObjectKeyFromObject(metricsSvc), existing) + if apierrors.IsNotFound(err) { + if err := cl.Create(ctx, svcMonitor); err != nil { + return fmt.Errorf("error creating ServiceMonitor: %w", err) + } + return nil + } + if err != nil { + return fmt.Errorf("error getting ServiceMonitor: %w", err) + } + // Currently, we only update labels on the ServiceMonitor as those are the only values that can change. + if !reflect.DeepEqual(existing.GetLabels(), svcMonitor.GetLabels()) { + existing.SetLabels(svcMonitor.GetLabels()) + if err := cl.Update(ctx, existing); err != nil { + return fmt.Errorf("error updating ServiceMonitor: %w", err) + } + } + return nil +} + +// maybeCleanupMetricsResources ensures that any metrics resources created for a proxy are deleted. Only metrics Service +// gets deleted explicitly because the ServiceMonitor has Service's owner reference, so gets garbage collected +// automatically. +func maybeCleanupMetricsResources(ctx context.Context, opts *metricsOpts, cl client.Client) error { + sel := metricsSvcSelector(opts.proxyLabels, opts.proxyType) + return cl.DeleteAllOf(ctx, &corev1.Service{}, client.InNamespace(opts.tsNamespace), client.MatchingLabels(sel)) +} + +// maybeCleanupServiceMonitor cleans up any ServiceMonitor created for the named proxy StatefulSet. +func maybeCleanupServiceMonitor(ctx context.Context, cl client.Client, stsName, ns string) error { + smName := metricsResourceName(stsName) + sm := serviceMonitorTemplate(smName, ns) + u, err := serviceMonitorToUnstructured(sm) + if err != nil { + return fmt.Errorf("error building ServiceMonitor: %w", err) + } + err = cl.Get(ctx, types.NamespacedName{Name: smName, Namespace: ns}, u) + if apierrors.IsNotFound(err) { + return nil // nothing to do + } + if err != nil { + return fmt.Errorf("error verifying if ServiceMonitor %s/%s exists: %w", ns, stsName, err) + } + return cl.Delete(ctx, u) +} + +// newServiceMonitor takes a metrics Service created for a proxy and constructs and returns a ServiceMonitor for that +// proxy that can be applied to the kube API server. +// The ServiceMonitor is returned as Unstructured type - this allows us to avoid importing prometheus-operator API server client/schema. +func newServiceMonitor(metricsSvc *corev1.Service, spec *tsapi.ServiceMonitor) (*unstructured.Unstructured, error) { + sm := serviceMonitorTemplate(metricsSvc.Name, metricsSvc.Namespace) + sm.ObjectMeta.Labels = metricsSvc.Labels + if spec != nil && len(spec.Labels) > 0 { + sm.ObjectMeta.Labels = mergeMapKeys(sm.ObjectMeta.Labels, spec.Labels.Parse()) + } + + sm.ObjectMeta.OwnerReferences = []metav1.OwnerReference{*metav1.NewControllerRef(metricsSvc, corev1.SchemeGroupVersion.WithKind("Service"))} + sm.Spec = ServiceMonitorSpec{ + Selector: metav1.LabelSelector{MatchLabels: metricsSvc.Labels}, + Endpoints: []ServiceMonitorEndpoint{{ + Port: "metrics", + }}, + NamespaceSelector: ServiceMonitorNamespaceSelector{ + MatchNames: []string{metricsSvc.Namespace}, + }, + JobLabel: labelPromJob, + TargetLabels: []string{ + labelPromProxyParentName, + labelPromProxyParentNamespace, + labelPromProxyType, + }, + } + return serviceMonitorToUnstructured(sm) +} + +// serviceMonitorToUnstructured takes a ServiceMonitor and converts it to Unstructured type that can be used by the c/r +// client in Kubernetes API server calls. +func serviceMonitorToUnstructured(sm *ServiceMonitor) (*unstructured.Unstructured, error) { + contents, err := runtime.DefaultUnstructuredConverter.ToUnstructured(sm) + if err != nil { + return nil, fmt.Errorf("error converting ServiceMonitor to Unstructured: %w", err) + } + u := &unstructured.Unstructured{} + u.SetUnstructuredContent(contents) + u.SetGroupVersionKind(sm.GroupVersionKind()) + return u, nil +} + +// metricsResourceName returns name for metrics Service and ServiceMonitor for a proxy StatefulSet. +func metricsResourceName(stsName string) string { + // Maximum length of StatefulSet name if 52 chars, so this is fine. + return fmt.Sprintf("%s-metrics", stsName) +} + +// metricsResourceLabels constructs labels that will be applied to metrics Service and metrics ServiceMonitor for a +// proxy. +func metricsResourceLabels(opts *metricsOpts) map[string]string { + lbls := map[string]string{ + kubetypes.LabelManaged: "true", + labelMetricsTarget: opts.proxyStsName, + labelPromProxyType: opts.proxyType, + labelPromProxyParentName: opts.proxyLabels[LabelParentName], + } + // Include namespace label for proxies created for a namespaced type. + if isNamespacedProxyType(opts.proxyType) { + lbls[labelPromProxyParentNamespace] = opts.proxyLabels[LabelParentNamespace] + } + lbls[labelPromJob] = promJobName(opts) + return lbls +} + +// promJobName constructs the value of the Prometheus job label that will apply to all metrics for a ServiceMonitor. +func promJobName(opts *metricsOpts) string { + // Include parent resource namespace for proxies created for namespaced types. + if opts.proxyType == proxyTypeIngressResource || opts.proxyType == proxyTypeIngressService { + return fmt.Sprintf("ts_%s_%s_%s", opts.proxyType, opts.proxyLabels[LabelParentNamespace], opts.proxyLabels[LabelParentName]) + } + return fmt.Sprintf("ts_%s_%s", opts.proxyType, opts.proxyLabels[LabelParentName]) +} + +// metricsSvcSelector returns the minimum label set to uniquely identify a metrics Service for a proxy. +func metricsSvcSelector(proxyLabels map[string]string, proxyType string) map[string]string { + sel := map[string]string{ + labelPromProxyType: proxyType, + labelPromProxyParentName: proxyLabels[LabelParentName], + } + // Include namespace label for proxies created for a namespaced type. + if isNamespacedProxyType(proxyType) { + sel[labelPromProxyParentNamespace] = proxyLabels[LabelParentNamespace] + } + return sel +} + +// serviceMonitorTemplate returns a base ServiceMonitor type that, when converted to Unstructured, is a valid type that +// can be used in kube API server calls via the c/r client. +func serviceMonitorTemplate(name, ns string) *ServiceMonitor { + return &ServiceMonitor{ + TypeMeta: metav1.TypeMeta{ + Kind: "ServiceMonitor", + APIVersion: "monitoring.coreos.com/v1", + }, + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: ns, + }, + } +} + +type metricsOpts struct { + proxyStsName string // name of StatefulSet for proxy + tsNamespace string // namespace in which Tailscale is installed + proxyLabels map[string]string // labels of the proxy StatefulSet + proxyType string +} + +func isNamespacedProxyType(typ string) bool { + return typ == proxyTypeIngressResource || typ == proxyTypeIngressService +} + +func mergeMapKeys(a, b map[string]string) map[string]string { + m := make(map[string]string, len(a)+len(b)) + for key, val := range b { + m[key] = val + } + for key, val := range a { + m[key] = val + } + return m +} diff --git a/cmd/k8s-operator/nameserver.go b/cmd/k8s-operator/nameserver.go index 52577c929acea..d042aff0c7de0 100644 --- a/cmd/k8s-operator/nameserver.go +++ b/cmd/k8s-operator/nameserver.go @@ -9,6 +9,7 @@ import ( "context" "fmt" "slices" + "strings" "sync" _ "embed" @@ -18,6 +19,7 @@ import ( xslices "golang.org/x/exp/slices" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" apiequality "k8s.io/apimachinery/pkg/api/equality" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -86,7 +88,7 @@ func (a *NameserverReconciler) Reconcile(ctx context.Context, req reconcile.Requ return reconcile.Result{}, nil } logger.Info("Cleaning up DNSConfig resources") - if err := a.maybeCleanup(ctx, &dnsCfg, logger); err != nil { + if err := a.maybeCleanup(&dnsCfg); err != nil { logger.Errorf("error cleaning up reconciler resource: %v", err) return res, err } @@ -100,9 +102,9 @@ func (a *NameserverReconciler) Reconcile(ctx context.Context, req reconcile.Requ } oldCnStatus := dnsCfg.Status.DeepCopy() - setStatus := func(dnsCfg *tsapi.DNSConfig, conditionType tsapi.ConditionType, status metav1.ConditionStatus, reason, message string) (reconcile.Result, error) { + setStatus := func(dnsCfg *tsapi.DNSConfig, status metav1.ConditionStatus, reason, message string) (reconcile.Result, error) { tsoperator.SetDNSConfigCondition(dnsCfg, tsapi.NameserverReady, status, reason, message, dnsCfg.Generation, a.clock, logger) - if !apiequality.Semantic.DeepEqual(oldCnStatus, dnsCfg.Status) { + if !apiequality.Semantic.DeepEqual(oldCnStatus, &dnsCfg.Status) { // An error encountered here should get returned by the Reconcile function. if updateErr := a.Client.Status().Update(ctx, dnsCfg); updateErr != nil { err = errors.Wrap(err, updateErr.Error()) @@ -118,7 +120,7 @@ func (a *NameserverReconciler) Reconcile(ctx context.Context, req reconcile.Requ msg := "invalid cluster configuration: more than one tailscale.com/dnsconfigs found. Please ensure that no more than one is created." logger.Error(msg) a.recorder.Event(&dnsCfg, corev1.EventTypeWarning, reasonMultipleDNSConfigsPresent, messageMultipleDNSConfigsPresent) - setStatus(&dnsCfg, tsapi.NameserverReady, metav1.ConditionFalse, reasonMultipleDNSConfigsPresent, messageMultipleDNSConfigsPresent) + setStatus(&dnsCfg, metav1.ConditionFalse, reasonMultipleDNSConfigsPresent, messageMultipleDNSConfigsPresent) } if !slices.Contains(dnsCfg.Finalizers, FinalizerName) { @@ -127,11 +129,16 @@ func (a *NameserverReconciler) Reconcile(ctx context.Context, req reconcile.Requ if err := a.Update(ctx, &dnsCfg); err != nil { msg := fmt.Sprintf(messageNameserverCreationFailed, err) logger.Error(msg) - return setStatus(&dnsCfg, tsapi.NameserverReady, metav1.ConditionFalse, reasonNameserverCreationFailed, msg) + return setStatus(&dnsCfg, metav1.ConditionFalse, reasonNameserverCreationFailed, msg) } } if err := a.maybeProvision(ctx, &dnsCfg, logger); err != nil { - return reconcile.Result{}, fmt.Errorf("error provisioning nameserver resources: %w", err) + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + return reconcile.Result{}, nil + } else { + return reconcile.Result{}, fmt.Errorf("error provisioning nameserver resources: %w", err) + } } a.mu.Lock() @@ -149,7 +156,7 @@ func (a *NameserverReconciler) Reconcile(ctx context.Context, req reconcile.Requ dnsCfg.Status.Nameserver = &tsapi.NameserverStatus{ IP: ip, } - return setStatus(&dnsCfg, tsapi.NameserverReady, metav1.ConditionTrue, reasonNameserverCreated, reasonNameserverCreated) + return setStatus(&dnsCfg, metav1.ConditionTrue, reasonNameserverCreated, reasonNameserverCreated) } logger.Info("nameserver Service does not have an IP address allocated, waiting...") return reconcile.Result{}, nil @@ -162,12 +169,33 @@ func nameserverResourceLabels(name, namespace string) map[string]string { return labels } -func (a *NameserverReconciler) maybeProvision(ctx context.Context, tsDNSCfg *tsapi.DNSConfig, logger *zap.SugaredLogger) error { - labels := nameserverResourceLabels(tsDNSCfg.Name, a.tsNamespace) +// mergeEnvVars merges `source` with `other` while prioritizing the values from +// `other` if there a duplicate environment variables found. +func mergeEnvVars(source []corev1.EnvVar, other []corev1.EnvVar) []corev1.EnvVar { + merged := make([]corev1.EnvVar, len(other)) + copy(merged, source) + + // create a map to track existing env var names in `other` + existing := make(map[string]bool, len(other)) + for _, env := range other { + existing[env.Name] = true + } + // now we add the missing env variable names from source if they do not + // already exist + for _, env := range source { + if !existing[env.Name] { + merged = append(merged, env) + } + } + return merged +} + +func (a *NameserverReconciler) maybeProvision(ctx context.Context, tsDNSCfg *tsapi.DNSConfig, _ *zap.SugaredLogger) error { + resourceLabels := nameserverResourceLabels(tsDNSCfg.Name, a.tsNamespace) dCfg := &deployConfig{ ownerRefs: []metav1.OwnerReference{*metav1.NewControllerRef(tsDNSCfg, tsapi.SchemeGroupVersion.WithKind("DNSConfig"))}, namespace: a.tsNamespace, - labels: labels, + labels: resourceLabels, imageRepo: defaultNameserverImageRepo, imageTag: defaultNameserverImageTag, } @@ -177,7 +205,13 @@ func (a *NameserverReconciler) maybeProvision(ctx context.Context, tsDNSCfg *tsa if tsDNSCfg.Spec.Nameserver.Image != nil && tsDNSCfg.Spec.Nameserver.Image.Tag != "" { dCfg.imageTag = tsDNSCfg.Spec.Nameserver.Image.Tag } - for _, deployable := range []deployable{saDeployable, deployDeployable, svcDeployable, cmDeployable} { + if len(tsDNSCfg.Spec.Nameserver.Cmd) > 0 { + dCfg.cmd = tsDNSCfg.Spec.Nameserver.Cmd + } + dCfg.env = tsDNSCfg.Spec.Nameserver.Env + dCfg.podLabels = tsDNSCfg.Spec.Nameserver.PodLabels + + for _, deployable := range []deployable{saDeployable, roleDeployable, rolebindingDeployable, deployDeployable, svcDeployable, cmDeployable} { if err := deployable.updateObj(ctx, dCfg, a.Client); err != nil { return fmt.Errorf("error reconciling %s: %w", deployable.kind, err) } @@ -188,7 +222,7 @@ func (a *NameserverReconciler) maybeProvision(ctx context.Context, tsDNSCfg *tsa // maybeCleanup removes DNSConfig from being tracked. The cluster resources // created, will be automatically garbage collected as they are owned by the // DNSConfig. -func (a *NameserverReconciler) maybeCleanup(ctx context.Context, dnsCfg *tsapi.DNSConfig, logger *zap.SugaredLogger) error { +func (a *NameserverReconciler) maybeCleanup(dnsCfg *tsapi.DNSConfig) error { a.mu.Lock() a.managedNameservers.Remove(dnsCfg.UID) a.mu.Unlock() @@ -205,6 +239,9 @@ type deployConfig struct { imageRepo string imageTag string labels map[string]string + podLabels map[string]string + cmd []string + env []corev1.EnvVar ownerRefs []metav1.OwnerReference namespace string } @@ -218,6 +255,10 @@ var ( saYaml []byte //go:embed deploy/manifests/nameserver/svc.yaml svcYaml []byte + //go:embed deploy/manifests/nameserver/role.yaml + roleYaml []byte + //go:embed deploy/manifests/nameserver/rolebinding.yaml + roleBindingYaml []byte deployDeployable = deployable{ kind: "Deployment", @@ -230,6 +271,17 @@ var ( d.ObjectMeta.Namespace = cfg.namespace d.ObjectMeta.Labels = cfg.labels d.ObjectMeta.OwnerReferences = cfg.ownerRefs + if d.Spec.Template.Labels == nil { + d.Spec.Template.Labels = make(map[string]string) + } + for key, value := range cfg.podLabels { + d.Spec.Template.Labels[key] = value + } + if len(cfg.cmd) > 0 { + d.Spec.Template.Spec.Containers[0].Command = cfg.cmd + } + d.Spec.Template.Spec.Containers[0].Env = mergeEnvVars(d.Spec.Template.Spec.Containers[0].Env, cfg.env) + updateF := func(oldD *appsv1.Deployment) { oldD.Spec = d.Spec } @@ -279,4 +331,35 @@ var ( return err }, } + roleDeployable = deployable{ + kind: "Role", + updateObj: func(ctx context.Context, cfg *deployConfig, kubeClient client.Client) error { + role := new(rbacv1.Role) + if err := yaml.Unmarshal(roleYaml, &role); err != nil { + return fmt.Errorf("error unmarshalling role yaml: %w", err) + } + role.ObjectMeta.Labels = cfg.labels + role.ObjectMeta.OwnerReferences = cfg.ownerRefs + role.ObjectMeta.Namespace = cfg.namespace + _, err := createOrUpdate[rbacv1.Role](ctx, kubeClient, cfg.namespace, role, func(*rbacv1.Role) {}) + return err + }, + } + rolebindingDeployable = deployable{ + kind: "RoleBinding", + updateObj: func(ctx context.Context, cfg *deployConfig, kubeClient client.Client) error { + roleBinding := new(rbacv1.RoleBinding) + if err := yaml.Unmarshal(roleBindingYaml, &roleBinding); err != nil { + return fmt.Errorf("error unmarshalling rolebinding yaml: %w", err) + } + roleBinding.ObjectMeta.Labels = cfg.labels + roleBinding.ObjectMeta.OwnerReferences = cfg.ownerRefs + roleBinding.ObjectMeta.Namespace = cfg.namespace + if len(roleBinding.Subjects) > 0 { + roleBinding.Subjects[0].Namespace = cfg.namespace + } + _, err := createOrUpdate[rbacv1.RoleBinding](ctx, kubeClient, cfg.namespace, roleBinding, func(*rbacv1.RoleBinding) {}) + return err + }, + } ) diff --git a/cmd/k8s-operator/nameserver_test.go b/cmd/k8s-operator/nameserver_test.go index 695710212e57b..cec95b84ee719 100644 --- a/cmd/k8s-operator/nameserver_test.go +++ b/cmd/k8s-operator/nameserver_test.go @@ -69,7 +69,7 @@ func TestNameserverReconciler(t *testing.T) { wantsDeploy.Namespace = "tailscale" labels := nameserverResourceLabels("test", "tailscale") wantsDeploy.ObjectMeta.Labels = labels - expectEqual(t, fc, wantsDeploy, nil) + expectEqual(t, fc, wantsDeploy) // Verify that DNSConfig advertizes the nameserver's Service IP address, // has the ready status condition and tailscale finalizer. @@ -88,7 +88,7 @@ func TestNameserverReconciler(t *testing.T) { Message: reasonNameserverCreated, LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, }) - expectEqual(t, fc, dnsCfg, nil) + expectEqual(t, fc, dnsCfg) // // Verify that nameserver image gets updated to match DNSConfig spec. mustUpdate(t, fc, "", "test", func(dnsCfg *tsapi.DNSConfig) { @@ -96,7 +96,7 @@ func TestNameserverReconciler(t *testing.T) { }) expectReconciled(t, nr, "", "test") wantsDeploy.Spec.Template.Spec.Containers[0].Image = "test:v0.0.2" - expectEqual(t, fc, wantsDeploy, nil) + expectEqual(t, fc, wantsDeploy) // Verify that when another actor sets ConfigMap data, it does not get // overwritten by nameserver reconciler. @@ -114,7 +114,7 @@ func TestNameserverReconciler(t *testing.T) { TypeMeta: metav1.TypeMeta{Kind: "ConfigMap", APIVersion: "v1"}, Data: map[string]string{"records.json": string(bs)}, } - expectEqual(t, fc, wantCm, nil) + expectEqual(t, fc, wantCm) // Verify that if dnsconfig.spec.nameserver.image.{repo,tag} are unset, // the nameserver image defaults to tailscale/k8s-nameserver:unstable. @@ -123,5 +123,5 @@ func TestNameserverReconciler(t *testing.T) { }) expectReconciled(t, nr, "", "test") wantsDeploy.Spec.Template.Spec.Containers[0].Image = "tailscale/k8s-nameserver:unstable" - expectEqual(t, fc, wantsDeploy, nil) + expectEqual(t, fc, wantsDeploy) } diff --git a/cmd/k8s-operator/operator.go b/cmd/k8s-operator/operator.go index d8dd403cc6097..dd97ce66728fb 100644 --- a/cmd/k8s-operator/operator.go +++ b/cmd/k8s-operator/operator.go @@ -9,22 +9,27 @@ package main import ( "context" + "fmt" + "net/http" "os" "regexp" + "strconv" "strings" "time" "github.com/go-logr/zapr" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "golang.org/x/oauth2/clientcredentials" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" discoveryv1 "k8s.io/api/discovery/v1" networkingv1 "k8s.io/api/networking/v1" rbacv1 "k8s.io/api/rbac/v1" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" + "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/rest" + toolscache "k8s.io/client-go/tools/cache" "sigs.k8s.io/controller-runtime/pkg/builder" "sigs.k8s.io/controller-runtime/pkg/cache" "sigs.k8s.io/controller-runtime/pkg/client" @@ -35,10 +40,12 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/manager/signals" "sigs.k8s.io/controller-runtime/pkg/reconcile" + "tailscale.com/client/local" "tailscale.com/client/tailscale" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/store/kubestore" + apiproxy "tailscale.com/k8s-operator/api-proxy" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" "tailscale.com/tsnet" @@ -56,6 +63,10 @@ import ( // Generate CRD API docs. //go:generate go run github.com/elastic/crd-ref-docs --renderer=markdown --source-path=../../k8s-operator/apis/ --config=../../k8s-operator/api-docs-config.yaml --output-path=../../k8s-operator/api.md +const ( + defaultDomain = "ts.net" +) + func main() { // Required to use our client API. We're fine with the instability since the // client lives in the same repo as this code. @@ -70,6 +81,15 @@ func main() { tsFirewallMode = defaultEnv("PROXY_FIREWALL_MODE", "") defaultProxyClass = defaultEnv("PROXY_DEFAULT_CLASS", "") isDefaultLoadBalancer = defaultBool("OPERATOR_DEFAULT_LOAD_BALANCER", false) + baseDomain = defaultEnv("OPERATOR_DOMAIN", defaultDomain) + // relaxedDomainValidation can be used to only validate that the + // base domain and at least 1 sub domain in the FQDN of + // services exists. The default is to make sure that there are + // exactly 2 sub-domains in the FQDN. + relaxedDomainValidation = defaultBool("OPERATOR_RELAXED_DOMAIN_VALIDATION", false) + // noFqdnDotAppend prevents the dot ('.') appending to the FQDN + // of egress proxy services + noFqdnDotAppend = defaultBool("OPERATOR_NO_FQDN_DOT_APPEND", false) ) var opts []kzap.Opts @@ -84,24 +104,33 @@ func main() { zlog := kzap.NewRaw(opts...).Sugar() logf.SetLogger(zapr.NewLogger(zlog.Desugar())) + if tsNamespace == "" { + const namespaceFile = "/var/run/secrets/kubernetes.io/serviceaccount/namespace" + b, err := os.ReadFile(namespaceFile) + if err != nil { + zlog.Fatalf("Could not get operator namespace from OPERATOR_NAMESPACE environment variable or default projected volume: %v", err) + } + tsNamespace = strings.TrimSpace(string(b)) + } + // The operator can run either as a plain operator or it can // additionally act as api-server proxy // https://tailscale.com/kb/1236/kubernetes-operator/?q=kubernetes#accessing-the-kubernetes-control-plane-using-an-api-server-proxy. - mode := parseAPIProxyMode() - if mode == apiserverProxyModeDisabled { + mode := apiproxy.ParseAPIProxyMode() + if mode == apiproxy.APIServerProxyModeDisabled { hostinfo.SetApp(kubetypes.AppOperator) } else { hostinfo.SetApp(kubetypes.AppAPIServerProxy) } - s, tsClient := initTSNet(zlog) + s, tsc := initTSNet(zlog) defer s.Close() restConfig := config.GetConfigOrDie() - maybeLaunchAPIServerProxy(zlog, restConfig, s, mode) + apiproxy.MaybeLaunchAPIServerProxy(zlog, restConfig, s, mode) rOpts := reconcilerOpts{ log: zlog, tsServer: s, - tsClient: tsClient, + tsClient: tsc, tailscaleNamespace: tsNamespace, restConfig: restConfig, proxyImage: image, @@ -110,6 +139,11 @@ func main() { proxyTags: tags, proxyFirewallMode: tsFirewallMode, defaultProxyClass: defaultProxyClass, + validationOpts: validationOpts{ + baseDomain: baseDomain, + relaxedDomainValidation: relaxedDomainValidation, + }, + noFqdnDotAppend: noFqdnDotAppend, } runReconcilers(rOpts) } @@ -117,10 +151,11 @@ func main() { // initTSNet initializes the tsnet.Server and logs in to Tailscale. It uses the // CLIENT_ID_FILE and CLIENT_SECRET_FILE environment variables to authenticate // with Tailscale. -func initTSNet(zlog *zap.SugaredLogger) (*tsnet.Server, *tailscale.Client) { +func initTSNet(zlog *zap.SugaredLogger) (*tsnet.Server, tsClient) { var ( clientIDPath = defaultEnv("CLIENT_ID_FILE", "") clientSecretPath = defaultEnv("CLIENT_SECRET_FILE", "") + controlURL = defaultEnv("CONTROL_URL", "") hostname = defaultEnv("OPERATOR_HOSTNAME", "tailscale-operator") kubeSecret = defaultEnv("OPERATOR_SECRET", "") operatorTags = defaultEnv("OPERATOR_INITIAL_TAGS", "tag:k8s-operator") @@ -129,26 +164,22 @@ func initTSNet(zlog *zap.SugaredLogger) (*tsnet.Server, *tailscale.Client) { if clientIDPath == "" || clientSecretPath == "" { startlog.Fatalf("CLIENT_ID_FILE and CLIENT_SECRET_FILE must be set") } - clientID, err := os.ReadFile(clientIDPath) - if err != nil { - startlog.Fatalf("reading client ID %q: %v", clientIDPath, err) - } - clientSecret, err := os.ReadFile(clientSecretPath) + tsc, err := newTSClient(context.Background(), clientIDPath, clientSecretPath) if err != nil { - startlog.Fatalf("reading client secret %q: %v", clientSecretPath, err) + startlog.Fatalf("error creating Tailscale client: %v", err) } - credentials := clientcredentials.Config{ - ClientID: string(clientID), - ClientSecret: string(clientSecret), - TokenURL: "https://login.tailscale.com/api/v2/oauth/token", - } - tsClient := tailscale.NewClient("-", nil) - tsClient.UserAgent = "tailscale-k8s-operator" - tsClient.HTTPClient = credentials.Client(context.Background()) s := &tsnet.Server{ - Hostname: hostname, - Logf: zlog.Named("tailscaled").Debugf, + ControlURL: controlURL, + Hostname: hostname, + Logf: zlog.Named("tailscaled").Debugf, + } + if p := os.Getenv("TS_PORT"); p != "" { + port, err := strconv.ParseUint(p, 10, 16) + if err != nil { + startlog.Fatalf("TS_PORT %q cannot be parsed as uint16: %v", p, err) + } + s.Port = uint16(port) } if kubeSecret != "" { st, err := kubestore.New(logger.Discard, kubeSecret) @@ -191,7 +222,7 @@ waitOnline: }, }, } - authkey, _, err := tsClient.CreateKey(ctx, caps) + authkey, _, err := tsc.CreateKey(ctx, caps) if err != nil { startlog.Fatalf("creating operator authkey: %v", err) } @@ -215,7 +246,7 @@ waitOnline: } time.Sleep(time.Second) } - return s, tsClient + return s, tsc } // runReconcilers starts the controller-runtime manager and registers the @@ -231,21 +262,32 @@ func runReconcilers(opts reconcilerOpts) { nsFilter := cache.ByObject{ Field: client.InNamespace(opts.tailscaleNamespace).AsSelector(), } + + // We watch the ServiceMonitor CRD to ensure that reconcilers are re-triggered if user's workflows result in the + // ServiceMonitor CRD applied after some of our resources that define ServiceMonitor creation. This selector + // ensures that we only watch the ServiceMonitor CRD and that we don't cache full contents of it. + serviceMonitorSelector := cache.ByObject{ + Field: fields.SelectorFromSet(fields.Set{"metadata.name": serviceMonitorCRD}), + Transform: crdTransformer(startlog), + } + + // TODO (irbekrm): stricter filtering what we watch/cache/call + // reconcilers on. c/r by default starts a watch on any + // resources that we GET via the controller manager's client. mgrOpts := manager.Options{ - // TODO (irbekrm): stricter filtering what we watch/cache/call - // reconcilers on. c/r by default starts a watch on any - // resources that we GET via the controller manager's client. + // The cache will apply the specified filters only to the object types listed below via ByObject. + // Other object types (e.g., EndpointSlices) can still be fetched or watched using the cached client, but they will not have any filtering applied. Cache: cache.Options{ ByObject: map[client.Object]cache.ByObject{ - &corev1.Secret{}: nsFilter, - &corev1.ServiceAccount{}: nsFilter, - &corev1.Pod{}: nsFilter, - &corev1.ConfigMap{}: nsFilter, - &appsv1.StatefulSet{}: nsFilter, - &appsv1.Deployment{}: nsFilter, - &discoveryv1.EndpointSlice{}: nsFilter, - &rbacv1.Role{}: nsFilter, - &rbacv1.RoleBinding{}: nsFilter, + &corev1.Secret{}: nsFilter, + &corev1.ServiceAccount{}: nsFilter, + &corev1.Pod{}: nsFilter, + &corev1.ConfigMap{}: nsFilter, + &appsv1.StatefulSet{}: nsFilter, + &appsv1.Deployment{}: nsFilter, + &rbacv1.Role{}: nsFilter, + &rbacv1.RoleBinding{}: nsFilter, + &apiextensionsv1.CustomResourceDefinition{}: serviceMonitorSelector, }, }, Scheme: tsapi.GlobalScheme, @@ -271,6 +313,7 @@ func runReconcilers(opts reconcilerOpts) { proxyImage: opts.proxyImage, proxyPriorityClassName: opts.proxyPriorityClassName, tsFirewallMode: opts.proxyFirewallMode, + controlUrl: opts.tsServer.ControlURL, } err = builder. ControllerManagedBy(mgr). @@ -288,6 +331,8 @@ func runReconcilers(opts reconcilerOpts) { tsNamespace: opts.tailscaleNamespace, clock: tstime.DefaultClock{}, defaultProxyClass: opts.defaultProxyClass, + validationOpts: opts.validationOpts, + noFqdnDotAppend: opts.noFqdnDotAppend, }) if err != nil { startlog.Fatalf("could not create service reconciler: %v", err) @@ -301,6 +346,7 @@ func runReconcilers(opts reconcilerOpts) { err = builder. ControllerManagedBy(mgr). For(&networkingv1.Ingress{}). + Named("ingress-reconciler"). Watches(&appsv1.StatefulSet{}, ingressChildFilter). Watches(&corev1.Secret{}, ingressChildFilter). Watches(&corev1.Service{}, svcHandlerForIngress). @@ -315,6 +361,67 @@ func runReconcilers(opts reconcilerOpts) { if err != nil { startlog.Fatalf("could not create ingress reconciler: %v", err) } + lc, err := opts.tsServer.LocalClient() + if err != nil { + startlog.Fatalf("could not get local client: %v", err) + } + id, err := id(context.Background(), lc) + if err != nil { + startlog.Fatalf("error determining stable ID of the operator's Tailscale device: %v", err) + } + ingressProxyGroupFilter := handler.EnqueueRequestsFromMapFunc(ingressesFromIngressProxyGroup(mgr.GetClient(), opts.log)) + err = builder. + ControllerManagedBy(mgr). + For(&networkingv1.Ingress{}). + Named("ingress-pg-reconciler"). + Watches(&corev1.Service{}, handler.EnqueueRequestsFromMapFunc(serviceHandlerForIngressPG(mgr.GetClient(), startlog))). + Watches(&corev1.Secret{}, handler.EnqueueRequestsFromMapFunc(HAIngressesFromSecret(mgr.GetClient(), startlog))). + Watches(&tsapi.ProxyGroup{}, ingressProxyGroupFilter). + Complete(&HAIngressReconciler{ + recorder: eventRecorder, + tsClient: opts.tsClient, + tsnetServer: opts.tsServer, + defaultTags: strings.Split(opts.proxyTags, ","), + Client: mgr.GetClient(), + logger: opts.log.Named("ingress-pg-reconciler"), + lc: lc, + operatorID: id, + tsNamespace: opts.tailscaleNamespace, + }) + if err != nil { + startlog.Fatalf("could not create ingress-pg-reconciler: %v", err) + } + if err := mgr.GetFieldIndexer().IndexField(context.Background(), new(networkingv1.Ingress), indexIngressProxyGroup, indexPGIngresses); err != nil { + startlog.Fatalf("failed setting up indexer for HA Ingresses: %v", err) + } + + ingressSvcFromEpsFilter := handler.EnqueueRequestsFromMapFunc(ingressSvcFromEps(mgr.GetClient(), opts.log.Named("service-pg-reconciler"))) + err = builder. + ControllerManagedBy(mgr). + For(&corev1.Service{}). + Named("service-pg-reconciler"). + Watches(&corev1.Secret{}, handler.EnqueueRequestsFromMapFunc(HAServicesFromSecret(mgr.GetClient(), startlog))). + Watches(&tsapi.ProxyGroup{}, ingressProxyGroupFilter). + Watches(&discoveryv1.EndpointSlice{}, ingressSvcFromEpsFilter). + Complete(&HAServiceReconciler{ + recorder: eventRecorder, + tsClient: opts.tsClient, + tsnetServer: opts.tsServer, + defaultTags: strings.Split(opts.proxyTags, ","), + Client: mgr.GetClient(), + logger: opts.log.Named("service-pg-reconciler"), + lc: lc, + clock: tstime.DefaultClock{}, + operatorID: id, + tsNamespace: opts.tailscaleNamespace, + validationOpts: opts.validationOpts, + }) + if err != nil { + startlog.Fatalf("could not create service-pg-reconciler: %v", err) + } + if err := mgr.GetFieldIndexer().IndexField(context.Background(), new(corev1.Service), indexIngressProxyGroup, indexPGIngresses); err != nil { + startlog.Fatalf("failed setting up indexer for HA Services: %v", err) + } connectorFilter := handler.EnqueueRequestsFromMapFunc(managedResourceHandlerForType("connector")) // If a ProxyClassChanges, enqueue all Connectors that have @@ -322,6 +429,7 @@ func runReconcilers(opts reconcilerOpts) { proxyClassFilterForConnector := handler.EnqueueRequestsFromMapFunc(proxyClassHandlerForConnector(mgr.GetClient(), startlog)) err = builder.ControllerManagedBy(mgr). For(&tsapi.Connector{}). + Named("connector-reconciler"). Watches(&appsv1.StatefulSet{}, connectorFilter). Watches(&corev1.Secret{}, connectorFilter). Watches(&tsapi.ProxyClass{}, proxyClassFilterForConnector). @@ -341,6 +449,7 @@ func runReconcilers(opts reconcilerOpts) { nameserverFilter := handler.EnqueueRequestsFromMapFunc(managedResourceHandlerForType("nameserver")) err = builder.ControllerManagedBy(mgr). For(&tsapi.DNSConfig{}). + Named("nameserver-reconciler"). Watches(&appsv1.Deployment{}, nameserverFilter). Watches(&corev1.ConfigMap{}, nameserverFilter). Watches(&corev1.Service{}, nameserverFilter). @@ -364,11 +473,12 @@ func runReconcilers(opts reconcilerOpts) { Watches(&corev1.Service{}, egressSvcFilter). Watches(&tsapi.ProxyGroup{}, egressProxyGroupFilter). Complete(&egressSvcsReconciler{ - Client: mgr.GetClient(), - tsNamespace: opts.tailscaleNamespace, - recorder: eventRecorder, - clock: tstime.DefaultClock{}, - logger: opts.log.Named("egress-svcs-reconciler"), + Client: mgr.GetClient(), + tsNamespace: opts.tailscaleNamespace, + recorder: eventRecorder, + clock: tstime.DefaultClock{}, + logger: opts.log.Named("egress-svcs-reconciler"), + validationOpts: opts.validationOpts, }) if err != nil { startlog.Fatalf("could not create egress Services reconciler: %v", err) @@ -414,8 +524,32 @@ func runReconcilers(opts reconcilerOpts) { startlog.Fatalf("could not create egress EndpointSlices reconciler: %v", err) } + podsForEps := handler.EnqueueRequestsFromMapFunc(podsFromEgressEps(mgr.GetClient(), opts.log, opts.tailscaleNamespace)) + podsER := handler.EnqueueRequestsFromMapFunc(egressPodsHandler) + err = builder. + ControllerManagedBy(mgr). + Named("egress-pods-readiness-reconciler"). + Watches(&discoveryv1.EndpointSlice{}, podsForEps). + Watches(&corev1.Pod{}, podsER). + Complete(&egressPodsReconciler{ + Client: mgr.GetClient(), + tsNamespace: opts.tailscaleNamespace, + clock: tstime.DefaultClock{}, + logger: opts.log.Named("egress-pods-readiness-reconciler"), + httpClient: http.DefaultClient, + }) + if err != nil { + startlog.Fatalf("could not create egress Pods readiness reconciler: %v", err) + } + + // ProxyClass reconciler gets triggered on ServiceMonitor CRD changes to ensure that any ProxyClasses, that + // define that a ServiceMonitor should be created, were set to invalid because the CRD did not exist get + // reconciled if the CRD is applied at a later point. + serviceMonitorFilter := handler.EnqueueRequestsFromMapFunc(proxyClassesWithServiceMonitor(mgr.GetClient(), opts.log)) err = builder.ControllerManagedBy(mgr). For(&tsapi.ProxyClass{}). + Named("proxyclass-reconciler"). + Watches(&apiextensionsv1.CustomResourceDefinition{}, serviceMonitorFilter). Complete(&ProxyClassReconciler{ Client: mgr.GetClient(), recorder: eventRecorder, @@ -458,6 +592,7 @@ func runReconcilers(opts reconcilerOpts) { recorderFilter := handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &tsapi.Recorder{}) err = builder.ControllerManagedBy(mgr). For(&tsapi.Recorder{}). + Named("recorder-reconciler"). Watches(&appsv1.StatefulSet{}, recorderFilter). Watches(&corev1.ServiceAccount{}, recorderFilter). Watches(&corev1.Secret{}, recorderFilter). @@ -475,12 +610,14 @@ func runReconcilers(opts reconcilerOpts) { startlog.Fatalf("could not create Recorder reconciler: %v", err) } - // Recorder reconciler. + // ProxyGroup reconciler. ownedByProxyGroupFilter := handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &tsapi.ProxyGroup{}) proxyClassFilterForProxyGroup := handler.EnqueueRequestsFromMapFunc(proxyClassHandlerForProxyGroup(mgr.GetClient(), startlog)) err = builder.ControllerManagedBy(mgr). For(&tsapi.ProxyGroup{}). + Named("proxygroup-reconciler"). Watches(&appsv1.StatefulSet{}, ownedByProxyGroupFilter). + Watches(&corev1.ConfigMap{}, ownedByProxyGroupFilter). Watches(&corev1.ServiceAccount{}, ownedByProxyGroupFilter). Watches(&corev1.Secret{}, ownedByProxyGroupFilter). Watches(&rbacv1.Role{}, ownedByProxyGroupFilter). @@ -509,10 +646,22 @@ func runReconcilers(opts reconcilerOpts) { } } +type validationOpts struct { + baseDomain string + relaxedDomainValidation bool +} + +func (v validationOpts) domain() string { + if v.baseDomain == "" { + return defaultDomain + } + return v.baseDomain +} + type reconcilerOpts struct { log *zap.SugaredLogger tsServer *tsnet.Server - tsClient *tailscale.Client + tsClient tsClient tailscaleNamespace string // namespace in which operator resources will be deployed restConfig *rest.Config // config for connecting to the kube API server proxyImage string // : @@ -547,6 +696,12 @@ type reconcilerOpts struct { // class for proxies that do not have a ProxyClass set. // this is defined by an operator env variable. defaultProxyClass string + // validationOpts are used to control validations happening in the + // reconcilers + validationOpts validationOpts + // noFqdnDotAppend prevents the addition of a dot ('.") to the end of + // destination FQDNs in egress proxies + noFqdnDotAppend bool } // enqueueAllIngressEgressProxySvcsinNS returns a reconcile request for each @@ -557,8 +712,8 @@ func enqueueAllIngressEgressProxySvcsInNS(ns string, cl client.Client, logger *z // Get all headless Services for proxies configured using Service. svcProxyLabels := map[string]string{ - LabelManaged: "true", - LabelParentType: "svc", + kubetypes.LabelManaged: "true", + LabelParentType: "svc", } svcHeadlessSvcList := &corev1.ServiceList{} if err := cl.List(ctx, svcHeadlessSvcList, client.InNamespace(ns), client.MatchingLabels(svcProxyLabels)); err != nil { @@ -571,8 +726,8 @@ func enqueueAllIngressEgressProxySvcsInNS(ns string, cl client.Client, logger *z // Get all headless Services for proxies configured using Ingress. ingProxyLabels := map[string]string{ - LabelManaged: "true", - LabelParentType: "ingress", + kubetypes.LabelManaged: "true", + LabelParentType: "ingress", } ingHeadlessSvcList := &corev1.ServiceList{} if err := cl.List(ctx, ingHeadlessSvcList, client.InNamespace(ns), client.MatchingLabels(ingProxyLabels)); err != nil { @@ -637,15 +792,9 @@ func dnsRecordsReconcilerIngressHandler(ns string, isDefaultLoadBalancer bool, c } } -type tsClient interface { - CreateKey(ctx context.Context, caps tailscale.KeyCapabilities) (string, *tailscale.Key, error) - Device(ctx context.Context, deviceID string, fields *tailscale.DeviceFieldsOpts) (*tailscale.Device, error) - DeleteDevice(ctx context.Context, nodeStableID string) error -} - func isManagedResource(o client.Object) bool { ls := o.GetLabels() - return ls[LabelManaged] == "true" + return ls[kubetypes.LabelManaged] == "true" } func isManagedByType(o client.Object, typ string) bool { @@ -735,7 +884,7 @@ func proxyClassHandlerForConnector(cl client.Client, logger *zap.SugaredLogger) } } -// proxyClassHandlerForConnector returns a handler that, for a given ProxyClass, +// proxyClassHandlerForProxyGroup returns a handler that, for a given ProxyClass, // returns a list of reconcile requests for all Connectors that have // .spec.proxyClass set. func proxyClassHandlerForProxyGroup(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { @@ -778,6 +927,10 @@ func serviceHandlerForIngress(cl client.Client, logger *zap.SugaredLogger) handl if ing.Spec.IngressClassName == nil || *ing.Spec.IngressClassName != tailscaleIngressClassName { return nil } + if hasProxyGroupAnnotation(&ing) { + // We don't want to reconcile backend Services for Ingresses for ProxyGroups. + continue + } if ing.Spec.DefaultBackend != nil && ing.Spec.DefaultBackend.Service != nil && ing.Spec.DefaultBackend.Service.Name == o.GetName() { reqs = append(reqs, reconcile.Request{NamespacedName: client.ObjectKeyFromObject(&ing)}) } @@ -821,9 +974,13 @@ func serviceHandler(_ context.Context, o client.Object) []reconcile.Request { } // isMagicDNSName reports whether name is a full tailnet node FQDN (with or -// without final dot). -func isMagicDNSName(name string) bool { - validMagicDNSName := regexp.MustCompile(`^[a-zA-Z0-9-]+\.[a-zA-Z0-9-]+\.ts\.net\.?$`) +// without final dot). The behaviour can be controlled with the given opts. +func isMagicDNSName(name string, opts validationOpts) bool { + baseDomainEscaped := strings.ReplaceAll(opts.domain(), `.`, `\.`) + validMagicDNSName := regexp.MustCompile(`^[a-zA-Z0-9-]+\.[a-zA-Z0-9-]+\.` + baseDomainEscaped + `\.?$`) + if opts.relaxedDomainValidation { + validMagicDNSName = regexp.MustCompile(`^([a-zA-Z0-9-]+\.)+` + baseDomainEscaped + `\.?$`) + } return validMagicDNSName.MatchString(name) } @@ -860,11 +1017,25 @@ func egressEpsHandler(_ context.Context, o client.Object) []reconcile.Request { } } +func egressPodsHandler(_ context.Context, o client.Object) []reconcile.Request { + if typ := o.GetLabels()[LabelParentType]; typ != proxyTypeProxyGroup { + return nil + } + return []reconcile.Request{ + { + NamespacedName: types.NamespacedName{ + Namespace: o.GetNamespace(), + Name: o.GetName(), + }, + }, + } +} + // egressEpsFromEgressPods returns a Pod event handler that checks if Pod is a replica for a ProxyGroup and if it is, // returns reconciler requests for all egress EndpointSlices for that ProxyGroup. func egressEpsFromPGPods(cl client.Client, ns string) handler.MapFunc { return func(_ context.Context, o client.Object) []reconcile.Request { - if v, ok := o.GetLabels()[LabelManaged]; !ok || v != "true" { + if v, ok := o.GetLabels()[kubetypes.LabelManaged]; !ok || v != "true" { return nil } // TODO(irbekrm): for now this is good enough as all ProxyGroups are egress. Add a type check once we @@ -884,15 +1055,13 @@ func egressEpsFromPGPods(cl client.Client, ns string) handler.MapFunc { // returns reconciler requests for all egress EndpointSlices for that ProxyGroup. func egressEpsFromPGStateSecrets(cl client.Client, ns string) handler.MapFunc { return func(_ context.Context, o client.Object) []reconcile.Request { - if v, ok := o.GetLabels()[LabelManaged]; !ok || v != "true" { + if v, ok := o.GetLabels()[kubetypes.LabelManaged]; !ok || v != "true" { return nil } - // TODO(irbekrm): for now this is good enough as all ProxyGroups are egress. Add a type check once we - // have ingress ProxyGroups. if parentType := o.GetLabels()[LabelParentType]; parentType != "proxygroup" { return nil } - if secretType := o.GetLabels()[labelSecretType]; secretType != "state" { + if secretType := o.GetLabels()[kubetypes.LabelSecretType]; secretType != "state" { return nil } pg, ok := o.GetLabels()[LabelParentName] @@ -903,13 +1072,43 @@ func egressEpsFromPGStateSecrets(cl client.Client, ns string) handler.MapFunc { } } +func ingressSvcFromEps(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + svcName := o.GetLabels()[discoveryv1.LabelServiceName] + if svcName == "" { + return nil + } + + svc := &corev1.Service{} + ns := o.GetNamespace() + if err := cl.Get(ctx, types.NamespacedName{Name: svcName, Namespace: ns}, svc); err != nil { + logger.Errorf("failed to get service: %v", err) + return nil + } + + pgName := svc.Annotations[AnnotationProxyGroup] + if pgName == "" { + return nil + } + + return []reconcile.Request{ + { + NamespacedName: types.NamespacedName{ + Namespace: ns, + Name: svcName, + }, + }, + } + } +} + // egressSvcFromEps is an event handler for EndpointSlices. If an EndpointSlice is for an egress ExternalName Service // meant to be exposed on a ProxyGroup, returns a reconcile request for the Service. func egressSvcFromEps(_ context.Context, o client.Object) []reconcile.Request { if typ := o.GetLabels()[labelSvcType]; typ != typeEgress { return nil } - if v, ok := o.GetLabels()[LabelManaged]; !ok || v != "true" { + if v, ok := o.GetLabels()[kubetypes.LabelManaged]; !ok || v != "true" { return nil } svcName, ok := o.GetLabels()[LabelParentName] @@ -949,10 +1148,103 @@ func reconcileRequestsForPG(pg string, cl client.Client, ns string) []reconcile. return reqs } +func isTLSSecret(secret *corev1.Secret) bool { + return secret.Type == corev1.SecretTypeTLS && + secret.ObjectMeta.Labels[kubetypes.LabelManaged] == "true" && + secret.ObjectMeta.Labels[kubetypes.LabelSecretType] == "certs" && + secret.ObjectMeta.Labels[labelDomain] != "" && + secret.ObjectMeta.Labels[labelProxyGroup] != "" +} + +func isPGStateSecret(secret *corev1.Secret) bool { + return secret.ObjectMeta.Labels[kubetypes.LabelManaged] == "true" && + secret.ObjectMeta.Labels[LabelParentType] == "proxygroup" && + secret.ObjectMeta.Labels[kubetypes.LabelSecretType] == "state" +} + +// HAIngressesFromSecret returns a handler that returns reconcile requests for +// all HA Ingresses that should be reconciled in response to a Secret event. +func HAIngressesFromSecret(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + secret, ok := o.(*corev1.Secret) + if !ok { + logger.Infof("[unexpected] Secret handler triggered for an object that is not a Secret") + return nil + } + if isTLSSecret(secret) { + return []reconcile.Request{ + { + NamespacedName: types.NamespacedName{ + Namespace: secret.ObjectMeta.Labels[LabelParentNamespace], + Name: secret.ObjectMeta.Labels[LabelParentName], + }, + }, + } + } + if !isPGStateSecret(secret) { + return nil + } + pgName, ok := secret.ObjectMeta.Labels[LabelParentName] + if !ok { + return nil + } + + ingList := &networkingv1.IngressList{} + if err := cl.List(ctx, ingList, client.MatchingFields{indexIngressProxyGroup: pgName}); err != nil { + logger.Infof("error listing Ingresses, skipping a reconcile for event on Secret %s: %v", secret.Name, err) + return nil + } + reqs := make([]reconcile.Request, 0) + for _, ing := range ingList.Items { + reqs = append(reqs, reconcile.Request{ + NamespacedName: types.NamespacedName{ + Namespace: ing.Namespace, + Name: ing.Name, + }, + }) + } + return reqs + } +} + +// HAServiceFromSecret returns a handler that returns reconcile requests for +// all HA Services that should be reconciled in response to a Secret event. +func HAServicesFromSecret(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + secret, ok := o.(*corev1.Secret) + if !ok { + logger.Infof("[unexpected] Secret handler triggered for an object that is not a Secret") + return nil + } + if !isPGStateSecret(secret) { + return nil + } + pgName, ok := secret.ObjectMeta.Labels[LabelParentName] + if !ok { + return nil + } + svcList := &corev1.ServiceList{} + if err := cl.List(ctx, svcList, client.MatchingFields{indexIngressProxyGroup: pgName}); err != nil { + logger.Infof("error listing Services, skipping a reconcile for event on Secret %s: %v", secret.Name, err) + return nil + } + reqs := make([]reconcile.Request, 0) + for _, svc := range svcList.Items { + reqs = append(reqs, reconcile.Request{ + NamespacedName: types.NamespacedName{ + Namespace: svc.Namespace, + Name: svc.Name, + }, + }) + } + return reqs + } +} + // egressSvcsFromEgressProxyGroup is an event handler for egress ProxyGroups. It returns reconcile requests for all // user-created ExternalName Services that should be exposed on this ProxyGroup. func egressSvcsFromEgressProxyGroup(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { - return func(_ context.Context, o client.Object) []reconcile.Request { + return func(ctx context.Context, o client.Object) []reconcile.Request { pg, ok := o.(*tsapi.ProxyGroup) if !ok { logger.Infof("[unexpected] ProxyGroup handler triggered for an object that is not a ProxyGroup") @@ -962,7 +1254,7 @@ func egressSvcsFromEgressProxyGroup(cl client.Client, logger *zap.SugaredLogger) return nil } svcList := &corev1.ServiceList{} - if err := cl.List(context.Background(), svcList, client.MatchingFields{indexEgressProxyGroup: pg.Name}); err != nil { + if err := cl.List(ctx, svcList, client.MatchingFields{indexEgressProxyGroup: pg.Name}); err != nil { logger.Infof("error listing Services: %v, skipping a reconcile for event on ProxyGroup %s", err, pg.Name) return nil } @@ -979,10 +1271,40 @@ func egressSvcsFromEgressProxyGroup(cl client.Client, logger *zap.SugaredLogger) } } +// ingressesFromIngressProxyGroup is an event handler for ingress ProxyGroups. It returns reconcile requests for all +// user-created Ingresses that should be exposed on this ProxyGroup. +func ingressesFromIngressProxyGroup(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + pg, ok := o.(*tsapi.ProxyGroup) + if !ok { + logger.Infof("[unexpected] ProxyGroup handler triggered for an object that is not a ProxyGroup") + return nil + } + if pg.Spec.Type != tsapi.ProxyGroupTypeIngress { + return nil + } + ingList := &networkingv1.IngressList{} + if err := cl.List(ctx, ingList, client.MatchingFields{indexIngressProxyGroup: pg.Name}); err != nil { + logger.Infof("error listing Ingresses: %v, skipping a reconcile for event on ProxyGroup %s", err, pg.Name) + return nil + } + reqs := make([]reconcile.Request, 0) + for _, svc := range ingList.Items { + reqs = append(reqs, reconcile.Request{ + NamespacedName: types.NamespacedName{ + Namespace: svc.Namespace, + Name: svc.Name, + }, + }) + } + return reqs + } +} + // epsFromExternalNameService is an event handler for ExternalName Services that define a Tailscale egress service that // should be exposed on a ProxyGroup. It returns reconcile requests for EndpointSlices created for this Service. func epsFromExternalNameService(cl client.Client, logger *zap.SugaredLogger, ns string) handler.MapFunc { - return func(_ context.Context, o client.Object) []reconcile.Request { + return func(ctx context.Context, o client.Object) []reconcile.Request { svc, ok := o.(*corev1.Service) if !ok { logger.Infof("[unexpected] Service handler triggered for an object that is not a Service") @@ -992,7 +1314,7 @@ func epsFromExternalNameService(cl client.Client, logger *zap.SugaredLogger, ns return nil } epsList := &discoveryv1.EndpointSliceList{} - if err := cl.List(context.Background(), epsList, client.InNamespace(ns), + if err := cl.List(ctx, epsList, client.InNamespace(ns), client.MatchingLabels(egressSvcChildResourceLabels(svc))); err != nil { logger.Infof("error listing EndpointSlices: %v, skipping a reconcile for event on Service %s", err, svc.Name) return nil @@ -1010,7 +1332,87 @@ func epsFromExternalNameService(cl client.Client, logger *zap.SugaredLogger, ns } } -// indexEgressServices adds a local index to a cached Tailscale egress Services meant to be exposed on a ProxyGroup. The +func podsFromEgressEps(cl client.Client, logger *zap.SugaredLogger, ns string) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + eps, ok := o.(*discoveryv1.EndpointSlice) + if !ok { + logger.Infof("[unexpected] EndpointSlice handler triggered for an object that is not a EndpointSlice") + return nil + } + if eps.Labels[labelProxyGroup] == "" { + return nil + } + if eps.Labels[labelSvcType] != "egress" { + return nil + } + podLabels := map[string]string{ + kubetypes.LabelManaged: "true", + LabelParentType: "proxygroup", + LabelParentName: eps.Labels[labelProxyGroup], + } + podList := &corev1.PodList{} + if err := cl.List(ctx, podList, client.InNamespace(ns), + client.MatchingLabels(podLabels)); err != nil { + logger.Infof("error listing EndpointSlices: %v, skipping a reconcile for event on EndpointSlice %s", err, eps.Name) + return nil + } + reqs := make([]reconcile.Request, 0) + for _, pod := range podList.Items { + reqs = append(reqs, reconcile.Request{ + NamespacedName: types.NamespacedName{ + Namespace: pod.Namespace, + Name: pod.Name, + }, + }) + } + return reqs + } +} + +// proxyClassesWithServiceMonitor returns an event handler that, given that the event is for the Prometheus +// ServiceMonitor CRD, returns all ProxyClasses that define that a ServiceMonitor should be created. +func proxyClassesWithServiceMonitor(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + crd, ok := o.(*apiextensionsv1.CustomResourceDefinition) + if !ok { + logger.Debugf("[unexpected] ServiceMonitor CRD handler received an object that is not a CustomResourceDefinition") + return nil + } + if crd.Name != serviceMonitorCRD { + logger.Debugf("[unexpected] ServiceMonitor CRD handler received an unexpected CRD %q", crd.Name) + return nil + } + pcl := &tsapi.ProxyClassList{} + if err := cl.List(ctx, pcl); err != nil { + logger.Debugf("[unexpected] error listing ProxyClasses: %v", err) + return nil + } + reqs := make([]reconcile.Request, 0) + for _, pc := range pcl.Items { + if pc.Spec.Metrics != nil && pc.Spec.Metrics.ServiceMonitor != nil && pc.Spec.Metrics.ServiceMonitor.Enable { + reqs = append(reqs, reconcile.Request{ + NamespacedName: types.NamespacedName{Namespace: pc.Namespace, Name: pc.Name}, + }) + } + } + return reqs + } +} + +// crdTransformer gets called before a CRD is stored to c/r cache, it removes the CRD spec to reduce memory consumption. +func crdTransformer(log *zap.SugaredLogger) toolscache.TransformFunc { + return func(o any) (any, error) { + crd, ok := o.(*apiextensionsv1.CustomResourceDefinition) + if !ok { + log.Infof("[unexpected] CRD transformer called for a non-CRD type") + return crd, nil + } + crd.Spec = apiextensionsv1.CustomResourceDefinitionSpec{} + return crd, nil + } +} + +// indexEgressServices adds a local index to cached Tailscale egress Services meant to be exposed on a ProxyGroup. The // index is used a list filter. func indexEgressServices(o client.Object) []string { if !isEgressSvcForProxyGroup(o) { @@ -1018,3 +1420,63 @@ func indexEgressServices(o client.Object) []string { } return []string{o.GetAnnotations()[AnnotationProxyGroup]} } + +// indexPGIngresses is used to select ProxyGroup-backed Services which are +// locally indexed in the cache for efficient listing without requiring labels. +func indexPGIngresses(o client.Object) []string { + if !hasProxyGroupAnnotation(o) { + return nil + } + return []string{o.GetAnnotations()[AnnotationProxyGroup]} +} + +// serviceHandlerForIngressPG returns a handler for Service events that ensures that if the Service +// associated with an event is a backend Service for a tailscale Ingress with ProxyGroup annotation, +// the associated Ingress gets reconciled. +func serviceHandlerForIngressPG(cl client.Client, logger *zap.SugaredLogger) handler.MapFunc { + return func(ctx context.Context, o client.Object) []reconcile.Request { + ingList := networkingv1.IngressList{} + if err := cl.List(ctx, &ingList, client.InNamespace(o.GetNamespace())); err != nil { + logger.Debugf("error listing Ingresses: %v", err) + return nil + } + reqs := make([]reconcile.Request, 0) + for _, ing := range ingList.Items { + if ing.Spec.IngressClassName == nil || *ing.Spec.IngressClassName != tailscaleIngressClassName { + continue + } + if !hasProxyGroupAnnotation(&ing) { + continue + } + if ing.Spec.DefaultBackend != nil && ing.Spec.DefaultBackend.Service != nil && ing.Spec.DefaultBackend.Service.Name == o.GetName() { + reqs = append(reqs, reconcile.Request{NamespacedName: client.ObjectKeyFromObject(&ing)}) + } + for _, rule := range ing.Spec.Rules { + if rule.HTTP == nil { + continue + } + for _, path := range rule.HTTP.Paths { + if path.Backend.Service != nil && path.Backend.Service.Name == o.GetName() { + reqs = append(reqs, reconcile.Request{NamespacedName: client.ObjectKeyFromObject(&ing)}) + } + } + } + } + return reqs + } +} + +func hasProxyGroupAnnotation(obj client.Object) bool { + return obj.GetAnnotations()[AnnotationProxyGroup] != "" +} + +func id(ctx context.Context, lc *local.Client) (string, error) { + st, err := lc.StatusWithoutPeers(ctx) + if err != nil { + return "", fmt.Errorf("error getting tailscale status: %w", err) + } + if st.Self == nil { + return "", fmt.Errorf("unexpected: device's status does not contain self status") + } + return string(st.Self.ID), nil +} diff --git a/cmd/k8s-operator/operator_test.go b/cmd/k8s-operator/operator_test.go index 21e1d4313749e..85cff14e483dd 100644 --- a/cmd/k8s-operator/operator_test.go +++ b/cmd/k8s-operator/operator_test.go @@ -16,6 +16,7 @@ import ( appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" networkingv1 "k8s.io/api/networking/v1" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" @@ -105,7 +106,7 @@ func TestLoadBalancerClass(t *testing.T) { }}, }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) // Delete the misconfiguration so the proxy starts getting created on the // next reconcile. @@ -127,9 +128,9 @@ func TestLoadBalancerClass(t *testing.T) { app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) want.Annotations = nil want.ObjectMeta.Finalizers = []string{"tailscale.com/finalizer"} @@ -142,7 +143,7 @@ func TestLoadBalancerClass(t *testing.T) { Message: "no Tailscale hostname known yet, waiting for proxy pod to finish auth", }}, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) // Normally the Tailscale proxy pod would come up here and write its info // into the secret. Simulate that, then verify reconcile again and verify @@ -168,7 +169,7 @@ func TestLoadBalancerClass(t *testing.T) { }, }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) // Turn the service back into a ClusterIP service, which should make the // operator clean up. @@ -205,7 +206,7 @@ func TestLoadBalancerClass(t *testing.T) { Type: corev1.ServiceTypeClusterIP, }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) } func TestTailnetTargetFQDNAnnotation(t *testing.T) { @@ -265,9 +266,9 @@ func TestTailnetTargetFQDNAnnotation(t *testing.T) { app: kubetypes.AppEgressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, o), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, o)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation, removeResourceReqs) want := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ Name: "test", @@ -287,10 +288,10 @@ func TestTailnetTargetFQDNAnnotation(t *testing.T) { Conditions: proxyCreatedCondition(clock), }, } - expectEqual(t, fc, want, nil) - expectEqual(t, fc, expectedSecret(t, fc, o), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, want) + expectEqual(t, fc, expectedSecret(t, fc, o)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation, removeResourceReqs) // Change the tailscale-target-fqdn annotation which should update the // StatefulSet @@ -377,9 +378,9 @@ func TestTailnetTargetIPAnnotation(t *testing.T) { app: kubetypes.AppEgressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, o), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, o)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation, removeResourceReqs) want := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ Name: "test", @@ -399,10 +400,10 @@ func TestTailnetTargetIPAnnotation(t *testing.T) { Conditions: proxyCreatedCondition(clock), }, } - expectEqual(t, fc, want, nil) - expectEqual(t, fc, expectedSecret(t, fc, o), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, want) + expectEqual(t, fc, expectedSecret(t, fc, o)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation, removeResourceReqs) // Change the tailscale-target-ip annotation which should update the // StatefulSet @@ -432,6 +433,148 @@ func TestTailnetTargetIPAnnotation(t *testing.T) { expectMissing[corev1.Secret](t, fc, "operator-ns", fullName) } +func TestTailnetTargetIPAnnotation_IPCouldNotBeParsed(t *testing.T) { + fc := fake.NewFakeClient() + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + clock := tstest.NewClock(tstest.ClockOpts{}) + sr := &ServiceReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + clock: clock, + recorder: record.NewFakeRecorder(100), + } + tailnetTargetIP := "invalid-ip" + mustCreate(t, fc, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + AnnotationTailnetTargetIP: tailnetTargetIP, + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + }, + }) + + expectReconciled(t, sr, "default", "test") + + t0 := conditionTime(clock) + + want := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + AnnotationTailnetTargetIP: tailnetTargetIP, + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + }, + Status: corev1.ServiceStatus{ + Conditions: []metav1.Condition{{ + Type: string(tsapi.ProxyReady), + Status: metav1.ConditionFalse, + LastTransitionTime: t0, + Reason: reasonProxyInvalid, + Message: `unable to provision proxy resources: invalid Service: invalid value of annotation tailscale.com/tailnet-ip: "invalid-ip" could not be parsed as a valid IP Address, error: ParseAddr("invalid-ip"): unable to parse IP`, + }}, + }, + } + + expectEqual(t, fc, want) +} + +func TestTailnetTargetIPAnnotation_InvalidIP(t *testing.T) { + fc := fake.NewFakeClient() + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + clock := tstest.NewClock(tstest.ClockOpts{}) + sr := &ServiceReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + clock: clock, + recorder: record.NewFakeRecorder(100), + } + tailnetTargetIP := "999.999.999.999" + mustCreate(t, fc, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + AnnotationTailnetTargetIP: tailnetTargetIP, + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + }, + }) + + expectReconciled(t, sr, "default", "test") + + t0 := conditionTime(clock) + + want := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + AnnotationTailnetTargetIP: tailnetTargetIP, + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + }, + Status: corev1.ServiceStatus{ + Conditions: []metav1.Condition{{ + Type: string(tsapi.ProxyReady), + Status: metav1.ConditionFalse, + LastTransitionTime: t0, + Reason: reasonProxyInvalid, + Message: `unable to provision proxy resources: invalid Service: invalid value of annotation tailscale.com/tailnet-ip: "999.999.999.999" could not be parsed as a valid IP Address, error: ParseAddr("999.999.999.999"): IPv4 field has value >255`, + }}, + }, + } + + expectEqual(t, fc, want) +} + func TestAnnotations(t *testing.T) { fc := fake.NewFakeClient() ft := &fakeTSClient{} @@ -486,9 +629,9 @@ func TestAnnotations(t *testing.T) { app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, o), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, o)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation, removeResourceReqs) want := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ Name: "test", @@ -507,7 +650,7 @@ func TestAnnotations(t *testing.T) { Conditions: proxyCreatedCondition(clock), }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) // Turn the service back into a ClusterIP service, which should make the // operator clean up. @@ -535,7 +678,7 @@ func TestAnnotations(t *testing.T) { Type: corev1.ServiceTypeClusterIP, }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) } func TestAnnotationIntoLB(t *testing.T) { @@ -592,9 +735,9 @@ func TestAnnotationIntoLB(t *testing.T) { app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, o), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, o)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation, removeResourceReqs) // Normally the Tailscale proxy pod would come up here and write its info // into the secret. Simulate that, since it would have normally happened at @@ -626,7 +769,7 @@ func TestAnnotationIntoLB(t *testing.T) { Conditions: proxyCreatedCondition(clock), }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) // Remove Tailscale's annotation, and at the same time convert the service // into a tailscale LoadBalancer. @@ -637,8 +780,8 @@ func TestAnnotationIntoLB(t *testing.T) { }) expectReconciled(t, sr, "default", "test") // None of the proxy machinery should have changed... - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation, removeResourceReqs) // ... but the service should have a LoadBalancer status. want = &corev1.Service{ @@ -667,7 +810,7 @@ func TestAnnotationIntoLB(t *testing.T) { Conditions: proxyCreatedCondition(clock), }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) } func TestLBIntoAnnotation(t *testing.T) { @@ -722,9 +865,9 @@ func TestLBIntoAnnotation(t *testing.T) { app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, o), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, o)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation, removeResourceReqs) // Normally the Tailscale proxy pod would come up here and write its info // into the secret. Simulate that, then verify reconcile again and verify @@ -764,7 +907,7 @@ func TestLBIntoAnnotation(t *testing.T) { Conditions: proxyCreatedCondition(clock), }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) // Turn the service back into a ClusterIP service, but also add the // tailscale annotation. @@ -783,8 +926,8 @@ func TestLBIntoAnnotation(t *testing.T) { }) expectReconciled(t, sr, "default", "test") - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation, removeResourceReqs) want = &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ @@ -804,7 +947,7 @@ func TestLBIntoAnnotation(t *testing.T) { Conditions: proxyCreatedCondition(clock), }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) } func TestCustomHostname(t *testing.T) { @@ -862,9 +1005,9 @@ func TestCustomHostname(t *testing.T) { app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, o), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, o)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation, removeResourceReqs) want := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ Name: "test", @@ -884,7 +1027,7 @@ func TestCustomHostname(t *testing.T) { Conditions: proxyCreatedCondition(clock), }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) // Turn the service back into a ClusterIP service, which should make the // operator clean up. @@ -915,7 +1058,7 @@ func TestCustomHostname(t *testing.T) { Type: corev1.ServiceTypeClusterIP, }, } - expectEqual(t, fc, want, nil) + expectEqual(t, fc, want) } func TestCustomPriorityClassName(t *testing.T) { @@ -975,7 +1118,7 @@ func TestCustomPriorityClassName(t *testing.T) { app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation, removeResourceReqs) } func TestProxyClassForService(t *testing.T) { @@ -987,7 +1130,7 @@ func TestProxyClassForService(t *testing.T) { AcceptRoutes: true, }, StatefulSet: &tsapi.StatefulSet{ - Labels: map[string]string{"foo": "bar"}, + Labels: tsapi.Labels{"foo": "bar"}, Annotations: map[string]string{"bar.io/foo": "some-val"}, Pod: &tsapi.Pod{Annotations: map[string]string{"foo.io/bar": "some-val"}}}}, } @@ -1043,9 +1186,9 @@ func TestProxyClassForService(t *testing.T) { clusterTargetIP: "10.20.30.40", app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) // 2. The Service gets updated with tailscale.com/proxy-class label // pointing at the 'custom-metadata' ProxyClass. The ProxyClass is not @@ -1054,8 +1197,8 @@ func TestProxyClassForService(t *testing.T) { mak.Set(&svc.Labels, LabelProxyClass, "custom-metadata") }) expectReconciled(t, sr, "default", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) + expectEqual(t, fc, expectedSecret(t, fc, opts)) // 3. ProxyClass is set to Ready, the Service gets reconciled by the // services-reconciler and the customization from the ProxyClass is @@ -1070,7 +1213,7 @@ func TestProxyClassForService(t *testing.T) { }) opts.proxyClass = pc.Name expectReconciled(t, sr, "default", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) expectEqual(t, fc, expectedSecret(t, fc, opts), removeAuthKeyIfExistsModifier(t)) // 4. tailscale.com/proxy-class label is removed from the Service, the @@ -1081,7 +1224,7 @@ func TestProxyClassForService(t *testing.T) { }) opts.proxyClass = "" expectReconciled(t, sr, "default", "test") - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) } func TestDefaultLoadBalancer(t *testing.T) { @@ -1127,7 +1270,7 @@ func TestDefaultLoadBalancer(t *testing.T) { fullName, shortName := findGenName(t, fc, "default", "test", "svc") - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) o := configOpts{ stsName: shortName, secretName: fullName, @@ -1137,8 +1280,7 @@ func TestDefaultLoadBalancer(t *testing.T) { clusterTargetIP: "10.20.30.40", app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) - + expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation, removeResourceReqs) } func TestProxyFirewallMode(t *testing.T) { @@ -1194,77 +1336,14 @@ func TestProxyFirewallMode(t *testing.T) { clusterTargetIP: "10.20.30.40", app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, o), removeHashAnnotation, removeResourceReqs) } -func TestTailscaledConfigfileHash(t *testing.T) { - fc := fake.NewFakeClient() - ft := &fakeTSClient{} - zl, err := zap.NewDevelopment() - if err != nil { - t.Fatal(err) - } - clock := tstest.NewClock(tstest.ClockOpts{}) - sr := &ServiceReconciler{ - Client: fc, - ssr: &tailscaleSTSReconciler{ - Client: fc, - tsClient: ft, - defaultTags: []string{"tag:k8s"}, - operatorNamespace: "operator-ns", - proxyImage: "tailscale/tailscale", - }, - logger: zl.Sugar(), - clock: clock, - isDefaultLoadBalancer: true, - } - - // Create a service that we should manage, and check that the initial round - // of objects looks right. - mustCreate(t, fc, &corev1.Service{ - ObjectMeta: metav1.ObjectMeta{ - Name: "test", - Namespace: "default", - // The apiserver is supposed to set the UID, but the fake client - // doesn't. So, set it explicitly because other code later depends - // on it being set. - UID: types.UID("1234-UID"), - }, - Spec: corev1.ServiceSpec{ - ClusterIP: "10.20.30.40", - Type: corev1.ServiceTypeLoadBalancer, - }, - }) - - expectReconciled(t, sr, "default", "test") - - fullName, shortName := findGenName(t, fc, "default", "test", "svc") - o := configOpts{ - stsName: shortName, - secretName: fullName, - namespace: "default", - parentType: "svc", - hostname: "default-test", - clusterTargetIP: "10.20.30.40", - confFileHash: "e09bededa0379920141cbd0b0dbdf9b8b66545877f9e8397423f5ce3e1ba439e", - app: kubetypes.AppIngressProxy, - } - expectEqual(t, fc, expectedSTS(t, fc, o), nil) - - // 2. Hostname gets changed, configfile is updated and a new hash value - // is produced. - mustUpdate(t, fc, "default", "test", func(svc *corev1.Service) { - mak.Set(&svc.Annotations, AnnotationHostname, "another-test") - }) - o.hostname = "another-test" - o.confFileHash = "5d754cf55463135ee34aa9821f2fd8483b53eb0570c3740c84a086304f427684" - expectReconciled(t, sr, "default", "test") - expectEqual(t, fc, expectedSTS(t, fc, o), nil) -} func Test_isMagicDNSName(t *testing.T) { tests := []struct { - in string - want bool + in string + want bool + validationOpts *validationOpts }{ { in: "foo.tail4567.ts.net", @@ -1278,10 +1357,33 @@ func Test_isMagicDNSName(t *testing.T) { in: "foo.tail4567", want: false, }, + { + in: "foo.ts.net", + want: false, + }, + { + in: "foo.tail4567.example.com.", + want: true, + validationOpts: &validationOpts{baseDomain: "example.com"}, + }, + { + in: "foo.example.com.", + want: false, + validationOpts: &validationOpts{baseDomain: "example.com"}, + }, + { + in: "foo.example.com.", + want: true, + validationOpts: &validationOpts{baseDomain: "example.com", relaxedDomainValidation: true}, + }, } for _, tt := range tests { t.Run(tt.in, func(t *testing.T) { - if got := isMagicDNSName(tt.in); got != tt.want { + opts := validationOpts{} + if tt.validationOpts != nil { + opts = *tt.validationOpts + } + if got := isMagicDNSName(tt.in, opts); got != tt.want { t.Errorf("isMagicDNSName(%q) = %v, want %v", tt.in, got, tt.want) } }) @@ -1309,10 +1411,10 @@ func Test_serviceHandlerForIngress(t *testing.T) { Name: "headless-1", Namespace: "tailscale", Labels: map[string]string{ - LabelManaged: "true", - LabelParentName: "ing-1", - LabelParentNamespace: "ns-1", - LabelParentType: "ingress", + kubetypes.LabelManaged: "true", + LabelParentName: "ing-1", + LabelParentNamespace: "ns-1", + LabelParentType: "ingress", }, }, } @@ -1537,9 +1639,9 @@ func Test_authKeyRemoval(t *testing.T) { app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) // 2. Apply update to the Secret that imitates the proxy setting device_id. s := expectedSecret(t, fc, opts) @@ -1551,7 +1653,7 @@ func Test_authKeyRemoval(t *testing.T) { expectReconciled(t, sr, "default", "test") opts.shouldRemoveAuthKey = true opts.secretExtraData = map[string][]byte{"device_id": []byte("dkkdi4CNTRL")} - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) + expectEqual(t, fc, expectedSecret(t, fc, opts)) } func Test_externalNameService(t *testing.T) { @@ -1611,9 +1713,9 @@ func Test_externalNameService(t *testing.T) { app: kubetypes.AppIngressProxy, } - expectEqual(t, fc, expectedSecret(t, fc, opts), nil) - expectEqual(t, fc, expectedHeadlessService(shortName, "svc"), nil) - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSecret(t, fc, opts)) + expectEqual(t, fc, expectedHeadlessService(shortName, "svc")) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) // 2. Change the ExternalName and verify that changes get propagated. mustUpdate(t, sr, "default", "test", func(s *corev1.Service) { @@ -1621,7 +1723,155 @@ func Test_externalNameService(t *testing.T) { }) expectReconciled(t, sr, "default", "test") opts.clusterTargetDNS = "bar.com" - expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation) + expectEqual(t, fc, expectedSTS(t, fc, opts), removeHashAnnotation, removeResourceReqs) +} + +func Test_metricsResourceCreation(t *testing.T) { + pc := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{Name: "metrics", Generation: 1}, + Spec: tsapi.ProxyClassSpec{}, + Status: tsapi.ProxyClassStatus{ + Conditions: []metav1.Condition{{ + Status: metav1.ConditionTrue, + Type: string(tsapi.ProxyClassReady), + ObservedGeneration: 1, + }}}, + } + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + UID: types.UID("1234-UID"), + Labels: map[string]string{LabelProxyClass: "metrics"}, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + }, + } + crd := &apiextensionsv1.CustomResourceDefinition{ObjectMeta: metav1.ObjectMeta{Name: serviceMonitorCRD}} + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(pc, svc). + WithStatusSubresource(pc). + Build() + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + clock := tstest.NewClock(tstest.ClockOpts{}) + sr := &ServiceReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + operatorNamespace: "operator-ns", + }, + logger: zl.Sugar(), + clock: clock, + } + expectReconciled(t, sr, "default", "test") + fullName, shortName := findGenName(t, fc, "default", "test", "svc") + opts := configOpts{ + stsName: shortName, + secretName: fullName, + namespace: "default", + parentType: "svc", + tailscaleNamespace: "operator-ns", + hostname: "default-test", + namespaced: true, + proxyType: proxyTypeIngressService, + app: kubetypes.AppIngressProxy, + resourceVersion: "1", + } + + // 1. Enable metrics- expect metrics Service to be created + mustUpdate(t, fc, "", "metrics", func(pc *tsapi.ProxyClass) { + pc.Spec = tsapi.ProxyClassSpec{Metrics: &tsapi.Metrics{Enable: true}} + }) + expectReconciled(t, sr, "default", "test") + opts.enableMetrics = true + expectEqual(t, fc, expectedMetricsService(opts)) + + // 2. Enable ServiceMonitor - should not error when there is no ServiceMonitor CRD in cluster + mustUpdate(t, fc, "", "metrics", func(pc *tsapi.ProxyClass) { + pc.Spec.Metrics.ServiceMonitor = &tsapi.ServiceMonitor{Enable: true} + }) + expectReconciled(t, sr, "default", "test") + + // 3. Create ServiceMonitor CRD and reconcile- ServiceMonitor should get created + mustCreate(t, fc, crd) + expectReconciled(t, sr, "default", "test") + expectEqualUnstructured(t, fc, expectedServiceMonitor(t, opts)) + + // 4. A change to ServiceMonitor config gets reflected in the ServiceMonitor resource + mustUpdate(t, fc, "", "metrics", func(pc *tsapi.ProxyClass) { + pc.Spec.Metrics.ServiceMonitor.Labels = tsapi.Labels{"foo": "bar"} + }) + expectReconciled(t, sr, "default", "test") + opts.serviceMonitorLabels = tsapi.Labels{"foo": "bar"} + opts.resourceVersion = "2" + expectEqual(t, fc, expectedMetricsService(opts)) + expectEqualUnstructured(t, fc, expectedServiceMonitor(t, opts)) + + // 5. Disable metrics- expect metrics Service to be deleted + mustUpdate(t, fc, "", "metrics", func(pc *tsapi.ProxyClass) { + pc.Spec.Metrics = nil + }) + expectReconciled(t, sr, "default", "test") + expectMissing[corev1.Service](t, fc, "operator-ns", metricsResourceName(opts.stsName)) + // ServiceMonitor gets garbage collected when Service gets deleted (it has OwnerReference of the Service + // object). We cannot test this using the fake client. +} + +func TestIgnorePGService(t *testing.T) { + // NOTE: creating proxygroup stuff just to be sure that it's all ignored + _, _, fc, _ := setupServiceTest(t) + + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + clock := tstest.NewClock(tstest.ClockOpts{}) + sr := &ServiceReconciler{ + Client: fc, + ssr: &tailscaleSTSReconciler{ + Client: fc, + tsClient: ft, + defaultTags: []string{"tag:k8s"}, + operatorNamespace: "operator-ns", + proxyImage: "tailscale/tailscale", + }, + logger: zl.Sugar(), + clock: clock, + } + + // Create a service that we should manage, and check that the initial round + // of objects looks right. + mustCreate(t, fc, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + // The apiserver is supposed to set the UID, but the fake client + // doesn't. So, set it explicitly because other code later depends + // on it being set. + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + "tailscale.com/proxygroup": "test-pg", + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeClusterIP, + }, + }) + + expectReconciled(t, sr, "default", "test") + + findNoGenName(t, fc, "default", "test", "svc") } func toFQDN(t *testing.T, s string) dnsname.FQDN { diff --git a/cmd/k8s-operator/proxyclass.go b/cmd/k8s-operator/proxyclass.go index 882a9030fa75d..5ec9897d0a8b7 100644 --- a/cmd/k8s-operator/proxyclass.go +++ b/cmd/k8s-operator/proxyclass.go @@ -15,6 +15,7 @@ import ( dockerref "github.com/distribution/reference" "go.uber.org/zap" corev1 "k8s.io/api/core/v1" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" apiequality "k8s.io/apimachinery/pkg/api/equality" apierrors "k8s.io/apimachinery/pkg/api/errors" apivalidation "k8s.io/apimachinery/pkg/api/validation" @@ -95,14 +96,14 @@ func (pcr *ProxyClassReconciler) Reconcile(ctx context.Context, req reconcile.Re pcr.mu.Unlock() oldPCStatus := pc.Status.DeepCopy() - if errs := pcr.validate(pc); errs != nil { + if errs := pcr.validate(ctx, pc); errs != nil { msg := fmt.Sprintf(messageProxyClassInvalid, errs.ToAggregate().Error()) pcr.recorder.Event(pc, corev1.EventTypeWarning, reasonProxyClassInvalid, msg) tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionFalse, reasonProxyClassInvalid, msg, pc.Generation, pcr.clock, logger) } else { tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionTrue, reasonProxyClassValid, reasonProxyClassValid, pc.Generation, pcr.clock, logger) } - if !apiequality.Semantic.DeepEqual(oldPCStatus, pc.Status) { + if !apiequality.Semantic.DeepEqual(oldPCStatus, &pc.Status) { if err := pcr.Client.Status().Update(ctx, pc); err != nil { logger.Errorf("error updating ProxyClass status: %v", err) return reconcile.Result{}, err @@ -111,10 +112,10 @@ func (pcr *ProxyClassReconciler) Reconcile(ctx context.Context, req reconcile.Re return reconcile.Result{}, nil } -func (pcr *ProxyClassReconciler) validate(pc *tsapi.ProxyClass) (violations field.ErrorList) { +func (pcr *ProxyClassReconciler) validate(ctx context.Context, pc *tsapi.ProxyClass) (violations field.ErrorList) { if sts := pc.Spec.StatefulSet; sts != nil { if len(sts.Labels) > 0 { - if errs := metavalidation.ValidateLabels(sts.Labels, field.NewPath(".spec.statefulSet.labels")); errs != nil { + if errs := metavalidation.ValidateLabels(sts.Labels.Parse(), field.NewPath(".spec.statefulSet.labels")); errs != nil { violations = append(violations, errs...) } } @@ -125,7 +126,7 @@ func (pcr *ProxyClassReconciler) validate(pc *tsapi.ProxyClass) (violations fiel } if pod := sts.Pod; pod != nil { if len(pod.Labels) > 0 { - if errs := metavalidation.ValidateLabels(pod.Labels, field.NewPath(".spec.statefulSet.pod.labels")); errs != nil { + if errs := metavalidation.ValidateLabels(pod.Labels.Parse(), field.NewPath(".spec.statefulSet.pod.labels")); errs != nil { violations = append(violations, errs...) } } @@ -160,9 +161,28 @@ func (pcr *ProxyClassReconciler) validate(pc *tsapi.ProxyClass) (violations fiel violations = append(violations, field.TypeInvalid(field.NewPath("spec", "statefulSet", "pod", "tailscaleInitContainer", "image"), tc.Image, err.Error())) } } + + if tc.Debug != nil { + violations = append(violations, field.TypeInvalid(field.NewPath("spec", "statefulSet", "pod", "tailscaleInitContainer", "debug"), tc.Debug, "debug settings cannot be configured on the init container")) + } } } } + if pc.Spec.Metrics != nil && pc.Spec.Metrics.ServiceMonitor != nil && pc.Spec.Metrics.ServiceMonitor.Enable { + found, err := hasServiceMonitorCRD(ctx, pcr.Client) + if err != nil { + pcr.logger.Infof("[unexpected]: error retrieving %q CRD: %v", serviceMonitorCRD, err) + // best effort validation - don't error out here + } else if !found { + msg := fmt.Sprintf("ProxyClass defines that a ServiceMonitor custom resource should be created, but %q CRD was not found", serviceMonitorCRD) + violations = append(violations, field.TypeInvalid(field.NewPath("spec", "metrics", "serviceMonitor"), "enable", msg)) + } + } + if pc.Spec.Metrics != nil && pc.Spec.Metrics.ServiceMonitor != nil && len(pc.Spec.Metrics.ServiceMonitor.Labels) > 0 { + if errs := metavalidation.ValidateLabels(pc.Spec.Metrics.ServiceMonitor.Labels.Parse(), field.NewPath(".spec.metrics.serviceMonitor.labels")); errs != nil { + violations = append(violations, errs...) + } + } // We do not validate embedded fields (security context, resource // requirements etc) as we inherit upstream validation for those fields. // Invalid values would get rejected by upstream validations at apply @@ -170,6 +190,16 @@ func (pcr *ProxyClassReconciler) validate(pc *tsapi.ProxyClass) (violations fiel return violations } +func hasServiceMonitorCRD(ctx context.Context, cl client.Client) (bool, error) { + sm := &apiextensionsv1.CustomResourceDefinition{} + if err := cl.Get(ctx, types.NamespacedName{Name: serviceMonitorCRD}, sm); apierrors.IsNotFound(err) { + return false, nil + } else if err != nil { + return false, err + } + return true, nil +} + // maybeCleanup removes tailscale.com finalizer and ensures that the ProxyClass // is no longer counted towards k8s_proxyclass_resources. func (pcr *ProxyClassReconciler) maybeCleanup(ctx context.Context, logger *zap.SugaredLogger, pc *tsapi.ProxyClass) error { diff --git a/cmd/k8s-operator/proxyclass_test.go b/cmd/k8s-operator/proxyclass_test.go index eb68811fc6b94..48290eea782b5 100644 --- a/cmd/k8s-operator/proxyclass_test.go +++ b/cmd/k8s-operator/proxyclass_test.go @@ -8,10 +8,12 @@ package main import ( + "context" "testing" "time" "go.uber.org/zap" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" @@ -34,10 +36,10 @@ func TestProxyClass(t *testing.T) { }, Spec: tsapi.ProxyClassSpec{ StatefulSet: &tsapi.StatefulSet{ - Labels: map[string]string{"foo": "bar", "xyz1234": "abc567"}, + Labels: tsapi.Labels{"foo": "bar", "xyz1234": "abc567"}, Annotations: map[string]string{"foo.io/bar": "{'key': 'val1232'}"}, Pod: &tsapi.Pod{ - Labels: map[string]string{"foo": "bar", "xyz1234": "abc567"}, + Labels: tsapi.Labels{"foo": "bar", "xyz1234": "abc567"}, Annotations: map[string]string{"foo.io/bar": "{'key': 'val1232'}"}, TailscaleContainer: &tsapi.Container{ Env: []tsapi.Env{{Name: "FOO", Value: "BAR"}}, @@ -76,7 +78,7 @@ func TestProxyClass(t *testing.T) { LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, }) - expectEqual(t, fc, pc, nil) + expectEqual(t, fc, pc) // 2. A ProxyClass resource with invalid labels gets its status updated to Invalid with an error message. pc.Spec.StatefulSet.Labels["foo"] = "?!someVal" @@ -86,7 +88,7 @@ func TestProxyClass(t *testing.T) { expectReconciled(t, pcr, "", "test") msg := `ProxyClass is not valid: .spec.statefulSet.labels: Invalid value: "?!someVal": a valid label must be an empty string or consist of alphanumeric characters, '-', '_' or '.', and must start and end with an alphanumeric character (e.g. 'MyValue', or 'my_value', or '12345', regex used for validation is '(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?')` tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionFalse, reasonProxyClassInvalid, msg, 0, cl, zl.Sugar()) - expectEqual(t, fc, pc, nil) + expectEqual(t, fc, pc) expectedEvent := "Warning ProxyClassInvalid ProxyClass is not valid: .spec.statefulSet.labels: Invalid value: \"?!someVal\": a valid label must be an empty string or consist of alphanumeric characters, '-', '_' or '.', and must start and end with an alphanumeric character (e.g. 'MyValue', or 'my_value', or '12345', regex used for validation is '(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?')" expectEvents(t, fr, []string{expectedEvent}) @@ -100,7 +102,7 @@ func TestProxyClass(t *testing.T) { expectReconciled(t, pcr, "", "test") msg = `ProxyClass is not valid: spec.statefulSet.pod.tailscaleContainer.image: Invalid value: "FOO bar": invalid reference format: repository name (library/FOO bar) must be lowercase` tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionFalse, reasonProxyClassInvalid, msg, 0, cl, zl.Sugar()) - expectEqual(t, fc, pc, nil) + expectEqual(t, fc, pc) expectedEvent = `Warning ProxyClassInvalid ProxyClass is not valid: spec.statefulSet.pod.tailscaleContainer.image: Invalid value: "FOO bar": invalid reference format: repository name (library/FOO bar) must be lowercase` expectEvents(t, fr, []string{expectedEvent}) @@ -119,7 +121,7 @@ func TestProxyClass(t *testing.T) { expectReconciled(t, pcr, "", "test") msg = `ProxyClass is not valid: spec.statefulSet.pod.tailscaleInitContainer.image: Invalid value: "FOO bar": invalid reference format: repository name (library/FOO bar) must be lowercase` tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionFalse, reasonProxyClassInvalid, msg, 0, cl, zl.Sugar()) - expectEqual(t, fc, pc, nil) + expectEqual(t, fc, pc) expectedEvent = `Warning ProxyClassInvalid ProxyClass is not valid: spec.statefulSet.pod.tailscaleInitContainer.image: Invalid value: "FOO bar": invalid reference format: repository name (library/FOO bar) must be lowercase` expectEvents(t, fr, []string{expectedEvent}) @@ -134,4 +136,95 @@ func TestProxyClass(t *testing.T) { "Warning CustomTSEnvVar ProxyClass overrides the default value for EXPERIMENTAL_ALLOW_PROXYING_CLUSTER_TRAFFIC_VIA_INGRESS env var for tailscale container. Running with custom values for Tailscale env vars is not recommended and might break in the future."} expectReconciled(t, pcr, "", "test") expectEvents(t, fr, expectedEvents) + + // 6. A ProxyClass with ServiceMonitor enabled and in a cluster that has not ServiceMonitor CRD is invalid + pc.Spec.Metrics = &tsapi.Metrics{Enable: true, ServiceMonitor: &tsapi.ServiceMonitor{Enable: true}} + mustUpdate(t, fc, "", "test", func(proxyClass *tsapi.ProxyClass) { + proxyClass.Spec = pc.Spec + }) + expectReconciled(t, pcr, "", "test") + msg = `ProxyClass is not valid: spec.metrics.serviceMonitor: Invalid value: "enable": ProxyClass defines that a ServiceMonitor custom resource should be created, but "servicemonitors.monitoring.coreos.com" CRD was not found` + tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionFalse, reasonProxyClassInvalid, msg, 0, cl, zl.Sugar()) + expectEqual(t, fc, pc) + expectedEvent = "Warning ProxyClassInvalid " + msg + expectEvents(t, fr, []string{expectedEvent}) + + // 7. A ProxyClass with ServiceMonitor enabled and in a cluster that does have the ServiceMonitor CRD is valid + crd := &apiextensionsv1.CustomResourceDefinition{ObjectMeta: metav1.ObjectMeta{Name: serviceMonitorCRD}} + mustCreate(t, fc, crd) + expectReconciled(t, pcr, "", "test") + tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionTrue, reasonProxyClassValid, reasonProxyClassValid, 0, cl, zl.Sugar()) + expectEqual(t, fc, pc) + + // 7. A ProxyClass with invalid ServiceMonitor labels gets its status updated to Invalid with an error message. + pc.Spec.Metrics.ServiceMonitor.Labels = tsapi.Labels{"foo": "bar!"} + mustUpdate(t, fc, "", "test", func(proxyClass *tsapi.ProxyClass) { + proxyClass.Spec.Metrics.ServiceMonitor.Labels = pc.Spec.Metrics.ServiceMonitor.Labels + }) + expectReconciled(t, pcr, "", "test") + msg = `ProxyClass is not valid: .spec.metrics.serviceMonitor.labels: Invalid value: "bar!": a valid label must be an empty string or consist of alphanumeric characters, '-', '_' or '.', and must start and end with an alphanumeric character (e.g. 'MyValue', or 'my_value', or '12345', regex used for validation is '(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?')` + tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionFalse, reasonProxyClassInvalid, msg, 0, cl, zl.Sugar()) + expectEqual(t, fc, pc) + + // 8. A ProxyClass with valid ServiceMonitor labels gets its status updated to Valid. + pc.Spec.Metrics.ServiceMonitor.Labels = tsapi.Labels{"foo": "bar", "xyz1234": "abc567", "empty": "", "onechar": "a"} + mustUpdate(t, fc, "", "test", func(proxyClass *tsapi.ProxyClass) { + proxyClass.Spec.Metrics.ServiceMonitor.Labels = pc.Spec.Metrics.ServiceMonitor.Labels + }) + expectReconciled(t, pcr, "", "test") + tsoperator.SetProxyClassCondition(pc, tsapi.ProxyClassReady, metav1.ConditionTrue, reasonProxyClassValid, reasonProxyClassValid, 0, cl, zl.Sugar()) + expectEqual(t, fc, pc) +} + +func TestValidateProxyClass(t *testing.T) { + for name, tc := range map[string]struct { + pc *tsapi.ProxyClass + valid bool + }{ + "empty": { + valid: true, + pc: &tsapi.ProxyClass{}, + }, + "debug_enabled_for_main_container": { + valid: true, + pc: &tsapi.ProxyClass{ + Spec: tsapi.ProxyClassSpec{ + StatefulSet: &tsapi.StatefulSet{ + Pod: &tsapi.Pod{ + TailscaleContainer: &tsapi.Container{ + Debug: &tsapi.Debug{ + Enable: true, + }, + }, + }, + }, + }, + }, + }, + "debug_enabled_for_init_container": { + valid: false, + pc: &tsapi.ProxyClass{ + Spec: tsapi.ProxyClassSpec{ + StatefulSet: &tsapi.StatefulSet{ + Pod: &tsapi.Pod{ + TailscaleInitContainer: &tsapi.Container{ + Debug: &tsapi.Debug{ + Enable: true, + }, + }, + }, + }, + }, + }, + }, + } { + t.Run(name, func(t *testing.T) { + pcr := &ProxyClassReconciler{} + err := pcr.validate(context.Background(), tc.pc) + valid := err == nil + if valid != tc.valid { + t.Errorf("expected valid=%v, got valid=%v, err=%v", tc.valid, valid, err) + } + }) + } } diff --git a/cmd/k8s-operator/proxygroup.go b/cmd/k8s-operator/proxygroup.go index 1f9983aa98962..f263829d73963 100644 --- a/cmd/k8s-operator/proxygroup.go +++ b/cmd/k8s-operator/proxygroup.go @@ -12,6 +12,7 @@ import ( "fmt" "net/http" "slices" + "strings" "sync" "github.com/pkg/errors" @@ -31,6 +32,7 @@ import ( "tailscale.com/ipn" tsoperator "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/egressservices" "tailscale.com/kube/kubetypes" "tailscale.com/tailcfg" "tailscale.com/tstime" @@ -45,9 +47,15 @@ const ( reasonProxyGroupReady = "ProxyGroupReady" reasonProxyGroupCreating = "ProxyGroupCreating" reasonProxyGroupInvalid = "ProxyGroupInvalid" + + // Copied from k8s.io/apiserver/pkg/registry/generic/registry/store.go@cccad306d649184bf2a0e319ba830c53f65c445c + optimisticLockErrorMsg = "the object has been modified; please apply your changes to the latest version and try again" ) -var gaugeProxyGroupResources = clientmetric.NewGauge(kubetypes.MetricProxyGroupCount) +var ( + gaugeEgressProxyGroupResources = clientmetric.NewGauge(kubetypes.MetricProxyGroupEgressCount) + gaugeIngressProxyGroupResources = clientmetric.NewGauge(kubetypes.MetricProxyGroupIngressCount) +) // ProxyGroupReconciler ensures cluster resources for a ProxyGroup definition. type ProxyGroupReconciler struct { @@ -64,8 +72,9 @@ type ProxyGroupReconciler struct { tsFirewallMode string defaultProxyClass string - mu sync.Mutex // protects following - proxyGroups set.Slice[types.UID] // for proxygroups gauge + mu sync.Mutex // protects following + egressProxyGroups set.Slice[types.UID] // for egress proxygroups gauge + ingressProxyGroups set.Slice[types.UID] // for ingress proxygroups gauge } func (r *ProxyGroupReconciler) logger(name string) *zap.SugaredLogger { @@ -110,7 +119,7 @@ func (r *ProxyGroupReconciler) Reconcile(ctx context.Context, req reconcile.Requ oldPGStatus := pg.Status.DeepCopy() setStatusReady := func(pg *tsapi.ProxyGroup, status metav1.ConditionStatus, reason, message string) (reconcile.Result, error) { tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, status, reason, message, pg.Generation, r.clock, logger) - if !apiequality.Semantic.DeepEqual(oldPGStatus, pg.Status) { + if !apiequality.Semantic.DeepEqual(oldPGStatus, &pg.Status) { // An error encountered here should get returned by the Reconcile function. if updateErr := r.Client.Status().Update(ctx, pg); updateErr != nil { err = errors.Wrap(err, updateErr.Error()) @@ -158,6 +167,7 @@ func (r *ProxyGroupReconciler) Reconcile(ctx context.Context, req reconcile.Requ r.recorder.Eventf(pg, corev1.EventTypeWarning, reasonProxyGroupCreationFailed, err.Error()) return setStatusReady(pg, metav1.ConditionFalse, reasonProxyGroupCreationFailed, err.Error()) } + validateProxyClassForPG(logger, pg, proxyClass) if !tsoperator.ProxyClassIsReady(proxyClass) { message := fmt.Sprintf("the ProxyGroup's ProxyClass %s is not yet in a ready state, waiting...", proxyClassName) logger.Info(message) @@ -166,9 +176,17 @@ func (r *ProxyGroupReconciler) Reconcile(ctx context.Context, req reconcile.Requ } if err = r.maybeProvision(ctx, pg, proxyClass); err != nil { - err = fmt.Errorf("error provisioning ProxyGroup resources: %w", err) - r.recorder.Eventf(pg, corev1.EventTypeWarning, reasonProxyGroupCreationFailed, err.Error()) - return setStatusReady(pg, metav1.ConditionFalse, reasonProxyGroupCreationFailed, err.Error()) + reason := reasonProxyGroupCreationFailed + msg := fmt.Sprintf("error provisioning ProxyGroup resources: %s", err) + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + reason = reasonProxyGroupCreating + msg = fmt.Sprintf("optimistic lock error, retrying: %s", err) + err = nil + logger.Info(msg) + } else { + r.recorder.Eventf(pg, corev1.EventTypeWarning, reason, msg) + } + return setStatusReady(pg, metav1.ConditionFalse, reason, msg) } desiredReplicas := int(pgReplicas(pg)) @@ -188,11 +206,35 @@ func (r *ProxyGroupReconciler) Reconcile(ctx context.Context, req reconcile.Requ return setStatusReady(pg, metav1.ConditionTrue, reasonProxyGroupReady, reasonProxyGroupReady) } +// validateProxyClassForPG applies custom validation logic for ProxyClass applied to ProxyGroup. +func validateProxyClassForPG(logger *zap.SugaredLogger, pg *tsapi.ProxyGroup, pc *tsapi.ProxyClass) { + if pg.Spec.Type == tsapi.ProxyGroupTypeIngress { + return + } + // Our custom logic for ensuring minimum downtime ProxyGroup update rollouts relies on the local health check + // beig accessible on the replica Pod IP:9002. This address can also be modified by users, via + // TS_LOCAL_ADDR_PORT env var. + // + // Currently TS_LOCAL_ADDR_PORT controls Pod's health check and metrics address. _Probably_ there is no need for + // users to set this to a custom value. Users who want to consume metrics, should integrate with the metrics + // Service and/or ServiceMonitor, rather than Pods directly. The health check is likely not useful to integrate + // directly with for operator proxies (and we should aim for unified lifecycle logic in the operator, users + // shouldn't need to set their own). + // + // TODO(irbekrm): maybe disallow configuring this env var in future (in Tailscale 1.84 or later). + if hasLocalAddrPortSet(pc) { + msg := fmt.Sprintf("ProxyClass %s applied to an egress ProxyGroup has TS_LOCAL_ADDR_PORT env var set to a custom value."+ + "This will disable the ProxyGroup graceful failover mechanism, so you might experience downtime when ProxyGroup pods are restarted."+ + "In future we will remove the ability to set custom TS_LOCAL_ADDR_PORT for egress ProxyGroups."+ + "Please raise an issue if you expect that this will cause issues for your workflow.", pc.Name) + logger.Warn(msg) + } +} + func (r *ProxyGroupReconciler) maybeProvision(ctx context.Context, pg *tsapi.ProxyGroup, proxyClass *tsapi.ProxyClass) error { logger := r.logger(pg.Name) r.mu.Lock() - r.proxyGroups.Add(pg.UID) - gaugeProxyGroupResources.Set(int64(r.proxyGroups.Len())) + r.ensureAddedToGaugeForProxyGroup(pg) r.mu.Unlock() cfgHash, err := r.ensureConfigSecretsCreated(ctx, pg, proxyClass) @@ -238,27 +280,76 @@ func (r *ProxyGroupReconciler) maybeProvision(ctx context.Context, pg *tsapi.Pro return fmt.Errorf("error provisioning RoleBinding: %w", err) } if pg.Spec.Type == tsapi.ProxyGroupTypeEgress { - cm := pgEgressCM(pg, r.tsNamespace) + cm, hp := pgEgressCM(pg, r.tsNamespace) if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, cm, func(existing *corev1.ConfigMap) { existing.ObjectMeta.Labels = cm.ObjectMeta.Labels existing.ObjectMeta.OwnerReferences = cm.ObjectMeta.OwnerReferences + mak.Set(&existing.BinaryData, egressservices.KeyHEPPings, hp) }); err != nil { - return fmt.Errorf("error provisioning ConfigMap: %w", err) + return fmt.Errorf("error provisioning egress ConfigMap %q: %w", cm.Name, err) } } - ss, err := pgStatefulSet(pg, r.tsNamespace, r.proxyImage, r.tsFirewallMode, cfgHash) + if pg.Spec.Type == tsapi.ProxyGroupTypeIngress { + cm := pgIngressCM(pg, r.tsNamespace) + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, cm, func(existing *corev1.ConfigMap) { + existing.ObjectMeta.Labels = cm.ObjectMeta.Labels + existing.ObjectMeta.OwnerReferences = cm.ObjectMeta.OwnerReferences + }); err != nil { + return fmt.Errorf("error provisioning ingress ConfigMap %q: %w", cm.Name, err) + } + } + ss, err := pgStatefulSet(pg, r.tsNamespace, r.proxyImage, r.tsFirewallMode, proxyClass) if err != nil { return fmt.Errorf("error generating StatefulSet spec: %w", err) } - ss = applyProxyClassToStatefulSet(proxyClass, ss, nil, logger) - if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, ss, func(s *appsv1.StatefulSet) { + cfg := &tailscaleSTSConfig{ + proxyType: string(pg.Spec.Type), + } + ss = applyProxyClassToStatefulSet(proxyClass, ss, cfg, logger) + capver, err := r.capVerForPG(ctx, pg, logger) + if err != nil { + return fmt.Errorf("error getting device info: %w", err) + } + + updateSS := func(s *appsv1.StatefulSet) { + + // This is a temporary workaround to ensure that egress ProxyGroup proxies with capver older than 110 + // are restarted when tailscaled configfile contents have changed. + // This workaround ensures that: + // 1. The hash mechanism is used to trigger pod restarts for proxies below capver 110. + // 2. Proxies above capver are not unnecessarily restarted when the configfile contents change. + // 3. If the hash has alreay been set, but the capver is above 110, the old hash is preserved to avoid + // unnecessary pod restarts that could result in an update loop where capver cannot be determined for a + // restarting Pod and the hash is re-added again. + // Note that this workaround is only applied to egress ProxyGroups, because ingress ProxyGroup was added after capver 110. + // Note also that the hash annotation is only set on updates, not creation, because if the StatefulSet is + // being created, there is no need for a restart. + // TODO(irbekrm): remove this in 1.84. + hash := cfgHash + if capver >= 110 { + hash = s.Spec.Template.GetAnnotations()[podAnnotationLastSetConfigFileHash] + } + s.Spec = ss.Spec + if hash != "" && pg.Spec.Type == tsapi.ProxyGroupTypeEgress { + mak.Set(&s.Spec.Template.Annotations, podAnnotationLastSetConfigFileHash, hash) + } + s.ObjectMeta.Labels = ss.ObjectMeta.Labels s.ObjectMeta.Annotations = ss.ObjectMeta.Annotations s.ObjectMeta.OwnerReferences = ss.ObjectMeta.OwnerReferences - s.Spec = ss.Spec - }); err != nil { + } + if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, ss, updateSS); err != nil { return fmt.Errorf("error provisioning StatefulSet: %w", err) } + mo := &metricsOpts{ + tsNamespace: r.tsNamespace, + proxyStsName: pg.Name, + proxyLabels: pgLabels(pg.Name, nil), + proxyType: "proxygroup", + } + if err := reconcileMetricsResources(ctx, logger, mo, proxyClass, r.Client); err != nil { + return fmt.Errorf("error reconciling metrics resources: %w", err) + } if err := r.cleanupDanglingResources(ctx, pg); err != nil { return fmt.Errorf("error cleaning up dangling resources: %w", err) @@ -327,10 +418,17 @@ func (r *ProxyGroupReconciler) maybeCleanup(ctx context.Context, pg *tsapi.Proxy } } + mo := &metricsOpts{ + proxyLabels: pgLabels(pg.Name, nil), + tsNamespace: r.tsNamespace, + proxyType: "proxygroup"} + if err := maybeCleanupMetricsResources(ctx, mo, r.Client); err != nil { + return false, fmt.Errorf("error cleaning up metrics resources: %w", err) + } + logger.Infof("cleaned up ProxyGroup resources") r.mu.Lock() - r.proxyGroups.Remove(pg.UID) - gaugeProxyGroupResources.Set(int64(r.proxyGroups.Len())) + r.ensureRemovedFromGaugeForProxyGroup(pg) r.mu.Unlock() return true, nil } @@ -353,11 +451,11 @@ func (r *ProxyGroupReconciler) deleteTailnetDevice(ctx context.Context, id tailc func (r *ProxyGroupReconciler) ensureConfigSecretsCreated(ctx context.Context, pg *tsapi.ProxyGroup, proxyClass *tsapi.ProxyClass) (hash string, err error) { logger := r.logger(pg.Name) - var allConfigs []tailscaledConfigs + var configSHA256Sum string for i := range pgReplicas(pg) { cfgSecret := &corev1.Secret{ ObjectMeta: metav1.ObjectMeta{ - Name: fmt.Sprintf("%s-%d-config", pg.Name, i), + Name: pgConfigSecretName(pg.Name, i), Namespace: r.tsNamespace, Labels: pgSecretLabels(pg.Name, "config"), OwnerReferences: pgOwnerReference(pg), @@ -366,7 +464,7 @@ func (r *ProxyGroupReconciler) ensureConfigSecretsCreated(ctx context.Context, p var existingCfgSecret *corev1.Secret // unmodified copy of secret if err := r.Get(ctx, client.ObjectKeyFromObject(cfgSecret), cfgSecret); err == nil { - logger.Debugf("secret %s/%s already exists", cfgSecret.GetNamespace(), cfgSecret.GetName()) + logger.Debugf("Secret %s/%s already exists", cfgSecret.GetNamespace(), cfgSecret.GetName()) existingCfgSecret = cfgSecret.DeepCopy() } else if !apierrors.IsNotFound(err) { return "", err @@ -374,7 +472,7 @@ func (r *ProxyGroupReconciler) ensureConfigSecretsCreated(ctx context.Context, p var authKey string if existingCfgSecret == nil { - logger.Debugf("creating authkey for new ProxyGroup proxy") + logger.Debugf("Creating authkey for new ProxyGroup proxy") tags := pg.Spec.Tags.Stringify() if len(tags) == 0 { tags = r.defaultTags @@ -389,39 +487,83 @@ func (r *ProxyGroupReconciler) ensureConfigSecretsCreated(ctx context.Context, p if err != nil { return "", fmt.Errorf("error creating tailscaled config: %w", err) } - allConfigs = append(allConfigs, configs) for cap, cfg := range configs { cfgJSON, err := json.Marshal(cfg) if err != nil { return "", fmt.Errorf("error marshalling tailscaled config: %w", err) } - mak.Set(&cfgSecret.StringData, tsoperator.TailscaledConfigFileName(cap), string(cfgJSON)) + mak.Set(&cfgSecret.Data, tsoperator.TailscaledConfigFileName(cap), cfgJSON) + } + + // The config sha256 sum is a value for a hash annotation used to trigger + // pod restarts when tailscaled config changes. Any config changes apply + // to all replicas, so it is sufficient to only hash the config for the + // first replica. + // + // In future, we're aiming to eliminate restarts altogether and have + // pods dynamically reload their config when it changes. + if i == 0 { + sum := sha256.New() + for _, cfg := range configs { + // Zero out the auth key so it doesn't affect the sha256 hash when we + // remove it from the config after the pods have all authed. Otherwise + // all the pods will need to restart immediately after authing. + cfg.AuthKey = nil + b, err := json.Marshal(cfg) + if err != nil { + return "", err + } + if _, err := sum.Write(b); err != nil { + return "", err + } + } + + configSHA256Sum = fmt.Sprintf("%x", sum.Sum(nil)) } if existingCfgSecret != nil { - logger.Debugf("patching the existing ProxyGroup config Secret %s", cfgSecret.Name) - if err := r.Patch(ctx, cfgSecret, client.MergeFrom(existingCfgSecret)); err != nil { - return "", err + if !apiequality.Semantic.DeepEqual(existingCfgSecret, cfgSecret) { + logger.Debugf("Updating the existing ProxyGroup config Secret %s", cfgSecret.Name) + if err := r.Update(ctx, cfgSecret); err != nil { + return "", err + } } } else { - logger.Debugf("creating a new config Secret %s for the ProxyGroup", cfgSecret.Name) + logger.Debugf("Creating a new config Secret %s for the ProxyGroup", cfgSecret.Name) if err := r.Create(ctx, cfgSecret); err != nil { return "", err } } } - sum := sha256.New() - b, err := json.Marshal(allConfigs) - if err != nil { - return "", err - } - if _, err := sum.Write(b); err != nil { - return "", err - } + return configSHA256Sum, nil +} + +// ensureAddedToGaugeForProxyGroup ensures the gauge metric for the ProxyGroup resource is updated when the ProxyGroup +// is created. r.mu must be held. +func (r *ProxyGroupReconciler) ensureAddedToGaugeForProxyGroup(pg *tsapi.ProxyGroup) { + switch pg.Spec.Type { + case tsapi.ProxyGroupTypeEgress: + r.egressProxyGroups.Add(pg.UID) + case tsapi.ProxyGroupTypeIngress: + r.ingressProxyGroups.Add(pg.UID) + } + gaugeEgressProxyGroupResources.Set(int64(r.egressProxyGroups.Len())) + gaugeIngressProxyGroupResources.Set(int64(r.ingressProxyGroups.Len())) +} - return fmt.Sprintf("%x", sum.Sum(nil)), nil +// ensureRemovedFromGaugeForProxyGroup ensures the gauge metric for the ProxyGroup resource type is updated when the +// ProxyGroup is deleted. r.mu must be held. +func (r *ProxyGroupReconciler) ensureRemovedFromGaugeForProxyGroup(pg *tsapi.ProxyGroup) { + switch pg.Spec.Type { + case tsapi.ProxyGroupTypeEgress: + r.egressProxyGroups.Remove(pg.UID) + case tsapi.ProxyGroupTypeIngress: + r.ingressProxyGroups.Remove(pg.UID) + } + gaugeEgressProxyGroupResources.Set(int64(r.egressProxyGroups.Len())) + gaugeIngressProxyGroupResources.Set(int64(r.ingressProxyGroups.Len())) } func pgTailscaledConfig(pg *tsapi.ProxyGroup, class *tsapi.ProxyClass, idx int32, authKey string, oldSecret *corev1.Secret) (tailscaledConfigs, error) { @@ -434,7 +576,7 @@ func pgTailscaledConfig(pg *tsapi.ProxyGroup, class *tsapi.ProxyClass, idx int32 } if pg.Spec.HostnamePrefix != "" { - conf.Hostname = ptr.To(fmt.Sprintf("%s%d", pg.Spec.HostnamePrefix, idx)) + conf.Hostname = ptr.To(fmt.Sprintf("%s-%d", pg.Spec.HostnamePrefix, idx)) } if shouldAcceptRoutes(class) { @@ -459,10 +601,35 @@ func pgTailscaledConfig(pg *tsapi.ProxyGroup, class *tsapi.ProxyClass, idx int32 conf.AuthKey = key } capVerConfigs := make(map[tailcfg.CapabilityVersion]ipn.ConfigVAlpha) + + // AdvertiseServices config is set by ingress-pg-reconciler, so make sure we + // don't overwrite it here. + if err := copyAdvertiseServicesConfig(conf, oldSecret, 106); err != nil { + return nil, err + } capVerConfigs[106] = *conf return capVerConfigs, nil } +func copyAdvertiseServicesConfig(conf *ipn.ConfigVAlpha, oldSecret *corev1.Secret, capVer tailcfg.CapabilityVersion) error { + if oldSecret == nil { + return nil + } + + oldConfB := oldSecret.Data[tsoperator.TailscaledConfigFileName(capVer)] + if len(oldConfB) == 0 { + return nil + } + + var oldConf ipn.ConfigVAlpha + if err := json.Unmarshal(oldConfB, &oldConf); err != nil { + return fmt.Errorf("error unmarshalling existing config: %w", err) + } + conf.AdvertiseServices = oldConf.AdvertiseServices + + return nil +} + func (r *ProxyGroupReconciler) validate(_ *tsapi.ProxyGroup) error { return nil } @@ -483,7 +650,7 @@ func (r *ProxyGroupReconciler) getNodeMetadata(ctx context.Context, pg *tsapi.Pr return nil, fmt.Errorf("unexpected secret %s was labelled as owned by the ProxyGroup %s: %w", secret.Name, pg.Name, err) } - id, dnsName, ok, err := getNodeMetadata(ctx, &secret) + prefs, ok, err := getDevicePrefs(&secret) if err != nil { return nil, err } @@ -491,12 +658,19 @@ func (r *ProxyGroupReconciler) getNodeMetadata(ctx context.Context, pg *tsapi.Pr continue } - metadata = append(metadata, nodeMetadata{ + nm := nodeMetadata{ ordinal: ordinal, stateSecret: &secret, - tsID: id, - dnsName: dnsName, - }) + tsID: prefs.Config.NodeID, + dnsName: prefs.Config.UserProfile.LoginName, + } + pod := &corev1.Pod{} + if err := r.Get(ctx, client.ObjectKey{Namespace: r.tsNamespace, Name: secret.Name}, pod); err != nil && !apierrors.IsNotFound(err) { + return nil, err + } else if err == nil { + nm.podUID = string(pod.UID) + } + metadata = append(metadata, nm) } return metadata, nil @@ -528,6 +702,29 @@ func (r *ProxyGroupReconciler) getDeviceInfo(ctx context.Context, pg *tsapi.Prox type nodeMetadata struct { ordinal int stateSecret *corev1.Secret - tsID tailcfg.StableNodeID - dnsName string + // podUID is the UID of the current Pod or empty if the Pod does not exist. + podUID string + tsID tailcfg.StableNodeID + dnsName string +} + +// capVerForPG returns best effort capability version for the given ProxyGroup. It attempts to find it by looking at the +// Secret + Pod for the replica with ordinal 0. Returns -1 if it is not possible to determine the capability version +// (i.e there is no Pod yet). +func (r *ProxyGroupReconciler) capVerForPG(ctx context.Context, pg *tsapi.ProxyGroup, logger *zap.SugaredLogger) (tailcfg.CapabilityVersion, error) { + metas, err := r.getNodeMetadata(ctx, pg) + if err != nil { + return -1, fmt.Errorf("error getting node metadata: %w", err) + } + if len(metas) == 0 { + return -1, nil + } + dev, err := deviceInfo(metas[0].stateSecret, metas[0].podUID, logger) + if err != nil { + return -1, fmt.Errorf("error getting device info: %w", err) + } + if dev == nil { + return -1, nil + } + return dev.capver, nil } diff --git a/cmd/k8s-operator/proxygroup_specs.go b/cmd/k8s-operator/proxygroup_specs.go index 9aa7ac3b008a3..1d12c39e0241e 100644 --- a/cmd/k8s-operator/proxygroup_specs.go +++ b/cmd/k8s-operator/proxygroup_specs.go @@ -7,20 +7,28 @@ package main import ( "fmt" + "slices" + "strconv" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/intstr" "sigs.k8s.io/yaml" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/egressservices" + "tailscale.com/kube/ingressservices" + "tailscale.com/kube/kubetypes" "tailscale.com/types/ptr" ) +// deletionGracePeriodSeconds is set to 6 minutes to ensure that the pre-stop hook of these proxies have enough chance to terminate gracefully. +const deletionGracePeriodSeconds int64 = 360 + // Returns the base StatefulSet definition for a ProxyGroup. A ProxyClass may be // applied over the top after. -func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode, cfgHash string) (*appsv1.StatefulSet, error) { +func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode string, proxyClass *tsapi.ProxyClass) (*appsv1.StatefulSet, error) { ss := new(appsv1.StatefulSet) if err := yaml.Unmarshal(proxyYaml, &ss); err != nil { return nil, fmt.Errorf("failed to unmarshal proxy spec: %w", err) @@ -52,12 +60,13 @@ func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode, cfgHa Namespace: namespace, Labels: pgLabels(pg.Name, nil), DeletionGracePeriodSeconds: ptr.To[int64](10), - Annotations: map[string]string{ - podAnnotationLastSetConfigFileHash: cfgHash, - }, } tmpl.Spec.ServiceAccountName = pg.Name tmpl.Spec.InitContainers[0].Image = image + proxyConfigVolName := pgEgressCMName(pg.Name) + if pg.Spec.Type == tsapi.ProxyGroupTypeIngress { + proxyConfigVolName = pgIngressCMName(pg.Name) + } tmpl.Spec.Volumes = func() []corev1.Volume { var volumes []corev1.Volume for i := range pgReplicas(pg) { @@ -65,24 +74,22 @@ func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode, cfgHa Name: fmt.Sprintf("tailscaledconfig-%d", i), VolumeSource: corev1.VolumeSource{ Secret: &corev1.SecretVolumeSource{ - SecretName: fmt.Sprintf("%s-%d-config", pg.Name, i), + SecretName: pgConfigSecretName(pg.Name, i), }, }, }) } - if pg.Spec.Type == tsapi.ProxyGroupTypeEgress { - volumes = append(volumes, corev1.Volume{ - Name: pgEgressCMName(pg.Name), - VolumeSource: corev1.VolumeSource{ - ConfigMap: &corev1.ConfigMapVolumeSource{ - LocalObjectReference: corev1.LocalObjectReference{ - Name: pgEgressCMName(pg.Name), - }, + volumes = append(volumes, corev1.Volume{ + Name: proxyConfigVolName, + VolumeSource: corev1.VolumeSource{ + ConfigMap: &corev1.ConfigMapVolumeSource{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: proxyConfigVolName, }, }, - }) - } + }, + }) return volumes }() @@ -92,6 +99,10 @@ func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode, cfgHa c.Image = image c.VolumeMounts = func() []corev1.VolumeMount { var mounts []corev1.VolumeMount + + // TODO(tomhjp): Read config directly from the secret instead. The + // mounts change on scaling up/down which causes unnecessary restarts + // for pods that haven't meaningfully changed. for i := range pgReplicas(pg) { mounts = append(mounts, corev1.VolumeMount{ Name: fmt.Sprintf("tailscaledconfig-%d", i), @@ -100,13 +111,11 @@ func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode, cfgHa }) } - if pg.Spec.Type == tsapi.ProxyGroupTypeEgress { - mounts = append(mounts, corev1.VolumeMount{ - Name: pgEgressCMName(pg.Name), - MountPath: "/etc/proxies", - ReadOnly: true, - }) - } + mounts = append(mounts, corev1.VolumeMount{ + Name: proxyConfigVolName, + MountPath: "/etc/proxies", + ReadOnly: true, + }) return mounts }() @@ -121,15 +130,6 @@ func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode, cfgHa }, }, }, - { - Name: "POD_NAME", - ValueFrom: &corev1.EnvVarSource{ - FieldRef: &corev1.ObjectFieldSelector{ - // Secret is named after the pod. - FieldPath: "metadata.name", - }, - }, - }, { Name: "TS_KUBE_SECRET", Value: "$(POD_NAME)", @@ -142,10 +142,6 @@ func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode, cfgHa Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig/$(POD_NAME)", }, - { - Name: "TS_USERSPACE", - Value: "false", - }, } if tsFirewallMode != "" { @@ -156,15 +152,79 @@ func pgStatefulSet(pg *tsapi.ProxyGroup, namespace, image, tsFirewallMode, cfgHa } if pg.Spec.Type == tsapi.ProxyGroupTypeEgress { + envs = append(envs, + // TODO(irbekrm): in 1.80 we deprecated TS_EGRESS_SERVICES_CONFIG_PATH in favour of + // TS_EGRESS_PROXIES_CONFIG_PATH. Remove it in 1.84. + corev1.EnvVar{ + Name: "TS_EGRESS_SERVICES_CONFIG_PATH", + Value: fmt.Sprintf("/etc/proxies/%s", egressservices.KeyEgressServices), + }, + corev1.EnvVar{ + Name: "TS_EGRESS_PROXIES_CONFIG_PATH", + Value: "/etc/proxies", + }, + corev1.EnvVar{ + Name: "TS_INTERNAL_APP", + Value: kubetypes.AppProxyGroupEgress, + }, + corev1.EnvVar{ + Name: "TS_ENABLE_HEALTH_CHECK", + Value: "true", + }) + } else { // ingress envs = append(envs, corev1.EnvVar{ - Name: "TS_EGRESS_SERVICES_CONFIG_PATH", - Value: fmt.Sprintf("/etc/proxies/%s", egressservices.KeyEgressServices), - }) + Name: "TS_INTERNAL_APP", + Value: kubetypes.AppProxyGroupIngress, + }, + corev1.EnvVar{ + Name: "TS_INGRESS_PROXIES_CONFIG_PATH", + Value: fmt.Sprintf("/etc/proxies/%s", ingressservices.IngressConfigKey), + }, + corev1.EnvVar{ + Name: "TS_SERVE_CONFIG", + Value: fmt.Sprintf("/etc/proxies/%s", serveConfigKey), + }, + corev1.EnvVar{ + // Run proxies in cert share mode to + // ensure that only one TLS cert is + // issued for an HA Ingress. + Name: "TS_EXPERIMENTAL_CERT_SHARE", + Value: "true", + }, + ) } - - return envs + return append(c.Env, envs...) }() + // The pre-stop hook is used to ensure that a replica does not get terminated while cluster traffic for egress + // services is still being routed to it. + // + // This mechanism currently (2025-01-26) rely on the local health check being accessible on the Pod's + // IP, so they are not supported for ProxyGroups where users have configured TS_LOCAL_ADDR_PORT to a custom + // value. + // + // NB: For _Ingress_ ProxyGroups, we run shutdown logic within containerboot + // in reaction to a SIGTERM signal instead of using a pre-stop hook. This is + // because Ingress pods need to unadvertise services, and it's preferable to + // avoid triggering those side-effects from a GET request that would be + // accessible to the whole cluster network (in the absence of NetworkPolicy + // rules). + // + // TODO(tomhjp): add a readiness probe or gate to Ingress Pods. There is a + // small window where the Pod is marked ready but routing can still fail. + if pg.Spec.Type == tsapi.ProxyGroupTypeEgress && !hasLocalAddrPortSet(proxyClass) { + c.Lifecycle = &corev1.Lifecycle{ + PreStop: &corev1.LifecycleHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Path: kubetypes.EgessServicesPreshutdownEP, + Port: intstr.FromInt(defaultLocalAddrPort), + }, + }, + } + // Set the deletion grace period to 6 minutes to ensure that the pre-stop hook has enough time to terminate + // gracefully. + ss.Spec.Template.DeletionGracePeriodSeconds = ptr.To(deletionGracePeriodSeconds) + } return ss, nil } @@ -188,6 +248,13 @@ func pgRole(pg *tsapi.ProxyGroup, namespace string) *rbacv1.Role { OwnerReferences: pgOwnerReference(pg), }, Rules: []rbacv1.PolicyRule{ + { + APIGroups: []string{""}, + Resources: []string{"secrets"}, + Verbs: []string{ + "list", + }, + }, { APIGroups: []string{""}, Resources: []string{"secrets"}, @@ -199,13 +266,22 @@ func pgRole(pg *tsapi.ProxyGroup, namespace string) *rbacv1.Role { ResourceNames: func() (secrets []string) { for i := range pgReplicas(pg) { secrets = append(secrets, - fmt.Sprintf("%s-%d-config", pg.Name, i), // Config with auth key. - fmt.Sprintf("%s-%d", pg.Name, i), // State. + pgConfigSecretName(pg.Name, i), // Config with auth key. + fmt.Sprintf("%s-%d", pg.Name, i), // State. ) } return secrets }(), }, + { + APIGroups: []string{""}, + Resources: []string{"events"}, + Verbs: []string{ + "create", + "patch", + "get", + }, + }, }, } } @@ -247,7 +323,9 @@ func pgStateSecrets(pg *tsapi.ProxyGroup, namespace string) (secrets []*corev1.S return secrets } -func pgEgressCM(pg *tsapi.ProxyGroup, namespace string) *corev1.ConfigMap { +func pgEgressCM(pg *tsapi.ProxyGroup, namespace string) (*corev1.ConfigMap, []byte) { + hp := hepPings(pg) + hpBs := []byte(strconv.Itoa(hp)) return &corev1.ConfigMap{ ObjectMeta: metav1.ObjectMeta{ Name: pgEgressCMName(pg.Name), @@ -255,12 +333,24 @@ func pgEgressCM(pg *tsapi.ProxyGroup, namespace string) *corev1.ConfigMap { Labels: pgLabels(pg.Name, nil), OwnerReferences: pgOwnerReference(pg), }, + BinaryData: map[string][]byte{egressservices.KeyHEPPings: hpBs}, + }, hpBs +} + +func pgIngressCM(pg *tsapi.ProxyGroup, namespace string) *corev1.ConfigMap { + return &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgIngressCMName(pg.Name), + Namespace: namespace, + Labels: pgLabels(pg.Name, nil), + OwnerReferences: pgOwnerReference(pg), + }, } } -func pgSecretLabels(pgName, typ string) map[string]string { +func pgSecretLabels(pgName, secretType string) map[string]string { return pgLabels(pgName, map[string]string{ - labelSecretType: typ, // "config" or "state". + kubetypes.LabelSecretType: secretType, // "config" or "state". }) } @@ -270,7 +360,7 @@ func pgLabels(pgName string, customLabels map[string]string) map[string]string { l[k] = v } - l[LabelManaged] = "true" + l[kubetypes.LabelManaged] = "true" l[LabelParentType] = "proxygroup" l[LabelParentName] = pgName @@ -289,6 +379,30 @@ func pgReplicas(pg *tsapi.ProxyGroup) int32 { return 2 } +func pgConfigSecretName(pgName string, i int32) string { + return fmt.Sprintf("%s-%d-config", pgName, i) +} + func pgEgressCMName(pg string) string { return fmt.Sprintf("%s-egress-config", pg) } + +// hasLocalAddrPortSet returns true if the proxyclass has the TS_LOCAL_ADDR_PORT env var set. For egress ProxyGroups, +// currently (2025-01-26) this means that the ProxyGroup does not support graceful failover. +func hasLocalAddrPortSet(proxyClass *tsapi.ProxyClass) bool { + if proxyClass == nil || proxyClass.Spec.StatefulSet == nil || proxyClass.Spec.StatefulSet.Pod == nil || proxyClass.Spec.StatefulSet.Pod.TailscaleContainer == nil { + return false + } + return slices.ContainsFunc(proxyClass.Spec.StatefulSet.Pod.TailscaleContainer.Env, func(env tsapi.Env) bool { + return env.Name == envVarTSLocalAddrPort + }) +} + +// hepPings returns the number of times a health check endpoint exposed by a Service fronting ProxyGroup replicas should +// be pinged to ensure that all currently configured backend replicas are hit. +func hepPings(pg *tsapi.ProxyGroup) int { + rc := pgReplicas(pg) + // Assuming a Service implemented using round robin load balancing, number-of-replica-times should be enough, but in + // practice, we cannot assume that the requests will be load balanced perfectly. + return int(rc) * 3 +} diff --git a/cmd/k8s-operator/proxygroup_test.go b/cmd/k8s-operator/proxygroup_test.go index 445db7537ddb6..159329eda2335 100644 --- a/cmd/k8s-operator/proxygroup_test.go +++ b/cmd/k8s-operator/proxygroup_test.go @@ -17,15 +17,20 @@ import ( appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" "tailscale.com/client/tailscale" + "tailscale.com/ipn" tsoperator "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" "tailscale.com/tstest" "tailscale.com/types/ptr" + "tailscale.com/util/mak" ) const testProxyImage = "tailscale/tailscale:test" @@ -35,6 +40,8 @@ var defaultProxyClassAnnotations = map[string]string{ } func TestProxyGroup(t *testing.T) { + const initialCfgHash = "6632726be70cf224049580deb4d317bba065915b5fd415461d60ed621c91b196" + pc := &tsapi.ProxyClass{ ObjectMeta: metav1.ObjectMeta{ Name: "default-pc", @@ -50,6 +57,9 @@ func TestProxyGroup(t *testing.T) { Name: "test", Finalizers: []string{"tailscale.com/finalizer"}, }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeEgress, + }, } fc := fake.NewClientBuilder(). @@ -74,12 +84,21 @@ func TestProxyGroup(t *testing.T) { l: zl.Sugar(), clock: cl, } + crd := &apiextensionsv1.CustomResourceDefinition{ObjectMeta: metav1.ObjectMeta{Name: serviceMonitorCRD}} + opts := configOpts{ + proxyType: "proxygroup", + stsName: pg.Name, + parentType: "proxygroup", + tailscaleNamespace: "tailscale", + resourceVersion: "1", + } t.Run("proxyclass_not_ready", func(t *testing.T) { expectReconciled(t, reconciler, "", pg.Name) tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "the ProxyGroup's ProxyClass default-pc is not yet in a ready state, waiting...", 0, cl, zl.Sugar()) - expectEqual(t, fc, pg, nil) + expectEqual(t, fc, pg) + expectProxyGroupResources(t, fc, pg, false, "", pc) }) t.Run("observe_ProxyGroupCreating_status_reason", func(t *testing.T) { @@ -99,11 +118,12 @@ func TestProxyGroup(t *testing.T) { expectReconciled(t, reconciler, "", pg.Name) tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "0/2 ProxyGroup pods running", 0, cl, zl.Sugar()) - expectEqual(t, fc, pg, nil) - if expected := 1; reconciler.proxyGroups.Len() != expected { - t.Fatalf("expected %d recorders, got %d", expected, reconciler.proxyGroups.Len()) + expectEqual(t, fc, pg) + expectProxyGroupResources(t, fc, pg, true, "", pc) + if expected := 1; reconciler.egressProxyGroups.Len() != expected { + t.Fatalf("expected %d egress ProxyGroups, got %d", expected, reconciler.egressProxyGroups.Len()) } - expectProxyGroupResources(t, fc, pg, true) + expectProxyGroupResources(t, fc, pg, true, "", pc) keyReq := tailscale.KeyCapabilities{ Devices: tailscale.KeyDeviceCapabilities{ Create: tailscale.KeyDeviceCreateCapabilities{ @@ -134,8 +154,8 @@ func TestProxyGroup(t *testing.T) { }, } tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionTrue, reasonProxyGroupReady, reasonProxyGroupReady, 0, cl, zl.Sugar()) - expectEqual(t, fc, pg, nil) - expectProxyGroupResources(t, fc, pg, true) + expectEqual(t, fc, pg) + expectProxyGroupResources(t, fc, pg, true, initialCfgHash, pc) }) t.Run("scale_up_to_3", func(t *testing.T) { @@ -145,7 +165,8 @@ func TestProxyGroup(t *testing.T) { }) expectReconciled(t, reconciler, "", pg.Name) tsoperator.SetProxyGroupCondition(pg, tsapi.ProxyGroupReady, metav1.ConditionFalse, reasonProxyGroupCreating, "2/3 ProxyGroup pods running", 0, cl, zl.Sugar()) - expectEqual(t, fc, pg, nil) + expectEqual(t, fc, pg) + expectProxyGroupResources(t, fc, pg, true, initialCfgHash, pc) addNodeIDToStateSecrets(t, fc, pg) expectReconciled(t, reconciler, "", pg.Name) @@ -154,8 +175,8 @@ func TestProxyGroup(t *testing.T) { Hostname: "hostname-nodeid-2", TailnetIPs: []string{"1.2.3.4", "::1"}, }) - expectEqual(t, fc, pg, nil) - expectProxyGroupResources(t, fc, pg, true) + expectEqual(t, fc, pg) + expectProxyGroupResources(t, fc, pg, true, initialCfgHash, pc) }) t.Run("scale_down_to_1", func(t *testing.T) { @@ -163,11 +184,47 @@ func TestProxyGroup(t *testing.T) { mustUpdate(t, fc, "", pg.Name, func(p *tsapi.ProxyGroup) { p.Spec = pg.Spec }) + expectReconciled(t, reconciler, "", pg.Name) + pg.Status.Devices = pg.Status.Devices[:1] // truncate to only the first device. - expectEqual(t, fc, pg, nil) + expectEqual(t, fc, pg) + expectProxyGroupResources(t, fc, pg, true, initialCfgHash, pc) + }) + + t.Run("trigger_config_change_and_observe_new_config_hash", func(t *testing.T) { + pc.Spec.TailscaleConfig = &tsapi.TailscaleConfig{ + AcceptRoutes: true, + } + mustUpdate(t, fc, "", pc.Name, func(p *tsapi.ProxyClass) { + p.Spec = pc.Spec + }) + + expectReconciled(t, reconciler, "", pg.Name) - expectProxyGroupResources(t, fc, pg, true) + expectEqual(t, fc, pg) + expectProxyGroupResources(t, fc, pg, true, "518a86e9fae64f270f8e0ec2a2ea6ca06c10f725035d3d6caca132cd61e42a74", pc) + }) + + t.Run("enable_metrics", func(t *testing.T) { + pc.Spec.Metrics = &tsapi.Metrics{Enable: true} + mustUpdate(t, fc, "", pc.Name, func(p *tsapi.ProxyClass) { + p.Spec = pc.Spec + }) + expectReconciled(t, reconciler, "", pg.Name) + expectEqual(t, fc, expectedMetricsService(opts)) + }) + t.Run("enable_service_monitor_no_crd", func(t *testing.T) { + pc.Spec.Metrics.ServiceMonitor = &tsapi.ServiceMonitor{Enable: true} + mustUpdate(t, fc, "", pc.Name, func(p *tsapi.ProxyClass) { + p.Spec.Metrics = pc.Spec.Metrics + }) + expectReconciled(t, reconciler, "", pg.Name) + }) + t.Run("create_crd_expect_service_monitor", func(t *testing.T) { + mustCreate(t, fc, crd) + expectReconciled(t, reconciler, "", pg.Name) + expectEqualUnstructured(t, fc, expectedServiceMonitor(t, opts)) }) t.Run("delete_and_cleanup", func(t *testing.T) { @@ -177,39 +234,397 @@ func TestProxyGroup(t *testing.T) { expectReconciled(t, reconciler, "", pg.Name) - expectMissing[tsapi.Recorder](t, fc, "", pg.Name) - if expected := 0; reconciler.proxyGroups.Len() != expected { - t.Fatalf("expected %d ProxyGroups, got %d", expected, reconciler.proxyGroups.Len()) + expectMissing[tsapi.ProxyGroup](t, fc, "", pg.Name) + if expected := 0; reconciler.egressProxyGroups.Len() != expected { + t.Fatalf("expected %d ProxyGroups, got %d", expected, reconciler.egressProxyGroups.Len()) } // 2 nodes should get deleted as part of the scale down, and then finally // the first node gets deleted with the ProxyGroup cleanup. if diff := cmp.Diff(tsClient.deleted, []string{"nodeid-1", "nodeid-2", "nodeid-0"}); diff != "" { t.Fatalf("unexpected deleted devices (-got +want):\n%s", diff) } + expectMissing[corev1.Service](t, reconciler, "tailscale", metricsResourceName(pg.Name)) // The fake client does not clean up objects whose owner has been // deleted, so we can't test for the owned resources getting deleted. }) } -func expectProxyGroupResources(t *testing.T, fc client.WithWatch, pg *tsapi.ProxyGroup, shouldExist bool) { +func TestProxyGroupTypes(t *testing.T) { + pc := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Generation: 1, + }, + Spec: tsapi.ProxyClassSpec{}, + } + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(pc). + WithStatusSubresource(pc). + Build() + mustUpdateStatus(t, fc, "", pc.Name, func(p *tsapi.ProxyClass) { + p.Status.Conditions = []metav1.Condition{{ + Type: string(tsapi.ProxyClassReady), + Status: metav1.ConditionTrue, + ObservedGeneration: 1, + }} + }) + + zl, _ := zap.NewDevelopment() + reconciler := &ProxyGroupReconciler{ + tsNamespace: tsNamespace, + proxyImage: testProxyImage, + Client: fc, + l: zl.Sugar(), + tsClient: &fakeTSClient{}, + clock: tstest.NewClock(tstest.ClockOpts{}), + } + + t.Run("egress_type", func(t *testing.T) { + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-egress", + UID: "test-egress-uid", + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeEgress, + Replicas: ptr.To[int32](0), + }, + } + mustCreate(t, fc, pg) + + expectReconciled(t, reconciler, "", pg.Name) + verifyProxyGroupCounts(t, reconciler, 0, 1) + + sts := &appsv1.StatefulSet{} + if err := fc.Get(context.Background(), client.ObjectKey{Namespace: tsNamespace, Name: pg.Name}, sts); err != nil { + t.Fatalf("failed to get StatefulSet: %v", err) + } + verifyEnvVar(t, sts, "TS_INTERNAL_APP", kubetypes.AppProxyGroupEgress) + verifyEnvVar(t, sts, "TS_EGRESS_PROXIES_CONFIG_PATH", "/etc/proxies") + verifyEnvVar(t, sts, "TS_ENABLE_HEALTH_CHECK", "true") + + // Verify that egress configuration has been set up. + cm := &corev1.ConfigMap{} + cmName := fmt.Sprintf("%s-egress-config", pg.Name) + if err := fc.Get(context.Background(), client.ObjectKey{Namespace: tsNamespace, Name: cmName}, cm); err != nil { + t.Fatalf("failed to get ConfigMap: %v", err) + } + + expectedVolumes := []corev1.Volume{ + { + Name: cmName, + VolumeSource: corev1.VolumeSource{ + ConfigMap: &corev1.ConfigMapVolumeSource{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: cmName, + }, + }, + }, + }, + } + + expectedVolumeMounts := []corev1.VolumeMount{ + { + Name: cmName, + MountPath: "/etc/proxies", + ReadOnly: true, + }, + } + + if diff := cmp.Diff(expectedVolumes, sts.Spec.Template.Spec.Volumes); diff != "" { + t.Errorf("unexpected volumes (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(expectedVolumeMounts, sts.Spec.Template.Spec.Containers[0].VolumeMounts); diff != "" { + t.Errorf("unexpected volume mounts (-want +got):\n%s", diff) + } + + expectedLifecycle := corev1.Lifecycle{ + PreStop: &corev1.LifecycleHandler{ + HTTPGet: &corev1.HTTPGetAction{ + Path: kubetypes.EgessServicesPreshutdownEP, + Port: intstr.FromInt(defaultLocalAddrPort), + }, + }, + } + if diff := cmp.Diff(expectedLifecycle, *sts.Spec.Template.Spec.Containers[0].Lifecycle); diff != "" { + t.Errorf("unexpected lifecycle (-want +got):\n%s", diff) + } + if *sts.Spec.Template.DeletionGracePeriodSeconds != deletionGracePeriodSeconds { + t.Errorf("unexpected deletion grace period seconds %d, want %d", *sts.Spec.Template.DeletionGracePeriodSeconds, deletionGracePeriodSeconds) + } + }) + t.Run("egress_type_no_lifecycle_hook_when_local_addr_port_set", func(t *testing.T) { + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-egress-no-lifecycle", + UID: "test-egress-no-lifecycle-uid", + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeEgress, + Replicas: ptr.To[int32](0), + ProxyClass: "test", + }, + } + mustCreate(t, fc, pg) + mustUpdate(t, fc, "", pc.Name, func(p *tsapi.ProxyClass) { + p.Spec.StatefulSet = &tsapi.StatefulSet{ + Pod: &tsapi.Pod{ + TailscaleContainer: &tsapi.Container{ + Env: []tsapi.Env{{ + Name: "TS_LOCAL_ADDR_PORT", + Value: "127.0.0.1:8080", + }}, + }, + }, + } + }) + expectReconciled(t, reconciler, "", pg.Name) + + sts := &appsv1.StatefulSet{} + if err := fc.Get(context.Background(), client.ObjectKey{Namespace: tsNamespace, Name: pg.Name}, sts); err != nil { + t.Fatalf("failed to get StatefulSet: %v", err) + } + + if sts.Spec.Template.Spec.Containers[0].Lifecycle != nil { + t.Error("lifecycle hook was set when TS_LOCAL_ADDR_PORT was configured via ProxyClass") + } + }) + + t.Run("ingress_type", func(t *testing.T) { + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-ingress", + UID: "test-ingress-uid", + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeIngress, + Replicas: ptr.To[int32](0), + }, + } + if err := fc.Create(context.Background(), pg); err != nil { + t.Fatal(err) + } + + expectReconciled(t, reconciler, "", pg.Name) + verifyProxyGroupCounts(t, reconciler, 1, 2) + + sts := &appsv1.StatefulSet{} + if err := fc.Get(context.Background(), client.ObjectKey{Namespace: tsNamespace, Name: pg.Name}, sts); err != nil { + t.Fatalf("failed to get StatefulSet: %v", err) + } + verifyEnvVar(t, sts, "TS_INTERNAL_APP", kubetypes.AppProxyGroupIngress) + verifyEnvVar(t, sts, "TS_SERVE_CONFIG", "/etc/proxies/serve-config.json") + verifyEnvVar(t, sts, "TS_EXPERIMENTAL_CERT_SHARE", "true") + + // Verify ConfigMap volume mount + cmName := fmt.Sprintf("%s-ingress-config", pg.Name) + expectedVolume := corev1.Volume{ + Name: cmName, + VolumeSource: corev1.VolumeSource{ + ConfigMap: &corev1.ConfigMapVolumeSource{ + LocalObjectReference: corev1.LocalObjectReference{ + Name: cmName, + }, + }, + }, + } + + expectedVolumeMount := corev1.VolumeMount{ + Name: cmName, + MountPath: "/etc/proxies", + ReadOnly: true, + } + + if diff := cmp.Diff([]corev1.Volume{expectedVolume}, sts.Spec.Template.Spec.Volumes); diff != "" { + t.Errorf("unexpected volumes (-want +got):\n%s", diff) + } + + if diff := cmp.Diff([]corev1.VolumeMount{expectedVolumeMount}, sts.Spec.Template.Spec.Containers[0].VolumeMounts); diff != "" { + t.Errorf("unexpected volume mounts (-want +got):\n%s", diff) + } + }) +} + +func TestIngressAdvertiseServicesConfigPreserved(t *testing.T) { + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + Build() + reconciler := &ProxyGroupReconciler{ + tsNamespace: tsNamespace, + proxyImage: testProxyImage, + Client: fc, + l: zap.Must(zap.NewDevelopment()).Sugar(), + tsClient: &fakeTSClient{}, + clock: tstest.NewClock(tstest.ClockOpts{}), + } + + existingServices := []string{"svc1", "svc2"} + existingConfigBytes, err := json.Marshal(ipn.ConfigVAlpha{ + AdvertiseServices: existingServices, + Version: "should-get-overwritten", + }) + if err != nil { + t.Fatal(err) + } + + const pgName = "test-ingress" + mustCreate(t, fc, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgConfigSecretName(pgName, 0), + Namespace: tsNamespace, + }, + Data: map[string][]byte{ + tsoperator.TailscaledConfigFileName(106): existingConfigBytes, + }, + }) + + mustCreate(t, fc, &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgName, + UID: "test-ingress-uid", + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeIngress, + Replicas: ptr.To[int32](1), + }, + }) + expectReconciled(t, reconciler, "", pgName) + + expectedConfigBytes, err := json.Marshal(ipn.ConfigVAlpha{ + // Preserved. + AdvertiseServices: existingServices, + + // Everything else got updated in the reconcile: + Version: "alpha0", + AcceptDNS: "false", + AcceptRoutes: "false", + Locked: "false", + Hostname: ptr.To(fmt.Sprintf("%s-%d", pgName, 0)), + }) + if err != nil { + t.Fatal(err) + } + expectEqual(t, fc, &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgConfigSecretName(pgName, 0), + Namespace: tsNamespace, + ResourceVersion: "2", + }, + Data: map[string][]byte{ + tsoperator.TailscaledConfigFileName(106): expectedConfigBytes, + }, + }) +} + +func proxyClassesForLEStagingTest() (*tsapi.ProxyClass, *tsapi.ProxyClass, *tsapi.ProxyClass) { + pcLEStaging := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "le-staging", + Generation: 1, + }, + Spec: tsapi.ProxyClassSpec{ + UseLetsEncryptStagingEnvironment: true, + }, + } + + pcLEStagingFalse := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "le-staging-false", + Generation: 1, + }, + Spec: tsapi.ProxyClassSpec{ + UseLetsEncryptStagingEnvironment: false, + }, + } + + pcOther := &tsapi.ProxyClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: "other", + Generation: 1, + }, + Spec: tsapi.ProxyClassSpec{}, + } + + return pcLEStaging, pcLEStagingFalse, pcOther +} + +func setProxyClassReady(t *testing.T, fc client.Client, cl *tstest.Clock, name string) *tsapi.ProxyClass { + t.Helper() + pc := &tsapi.ProxyClass{} + if err := fc.Get(context.Background(), client.ObjectKey{Name: name}, pc); err != nil { + t.Fatal(err) + } + pc.Status = tsapi.ProxyClassStatus{ + Conditions: []metav1.Condition{{ + Type: string(tsapi.ProxyClassReady), + Status: metav1.ConditionTrue, + Reason: reasonProxyClassValid, + Message: reasonProxyClassValid, + LastTransitionTime: metav1.Time{Time: cl.Now().Truncate(time.Second)}, + ObservedGeneration: pc.Generation, + }}, + } + if err := fc.Status().Update(context.Background(), pc); err != nil { + t.Fatal(err) + } + return pc +} + +func verifyProxyGroupCounts(t *testing.T, r *ProxyGroupReconciler, wantIngress, wantEgress int) { + t.Helper() + if r.ingressProxyGroups.Len() != wantIngress { + t.Errorf("expected %d ingress proxy groups, got %d", wantIngress, r.ingressProxyGroups.Len()) + } + if r.egressProxyGroups.Len() != wantEgress { + t.Errorf("expected %d egress proxy groups, got %d", wantEgress, r.egressProxyGroups.Len()) + } +} + +func verifyEnvVar(t *testing.T, sts *appsv1.StatefulSet, name, expectedValue string) { + t.Helper() + for _, env := range sts.Spec.Template.Spec.Containers[0].Env { + if env.Name == name { + if env.Value != expectedValue { + t.Errorf("expected %s=%s, got %s", name, expectedValue, env.Value) + } + return + } + } + t.Errorf("%s environment variable not found", name) +} + +func verifyEnvVarNotPresent(t *testing.T, sts *appsv1.StatefulSet, name string) { + t.Helper() + for _, env := range sts.Spec.Template.Spec.Containers[0].Env { + if env.Name == name { + t.Errorf("environment variable %s should not be present", name) + return + } + } +} + +func expectProxyGroupResources(t *testing.T, fc client.WithWatch, pg *tsapi.ProxyGroup, shouldExist bool, cfgHash string, proxyClass *tsapi.ProxyClass) { t.Helper() role := pgRole(pg, tsNamespace) roleBinding := pgRoleBinding(pg, tsNamespace) serviceAccount := pgServiceAccount(pg, tsNamespace) - statefulSet, err := pgStatefulSet(pg, tsNamespace, testProxyImage, "auto", "") + statefulSet, err := pgStatefulSet(pg, tsNamespace, testProxyImage, "auto", proxyClass) if err != nil { t.Fatal(err) } statefulSet.Annotations = defaultProxyClassAnnotations + if cfgHash != "" { + mak.Set(&statefulSet.Spec.Template.Annotations, podAnnotationLastSetConfigFileHash, cfgHash) + } if shouldExist { - expectEqual(t, fc, role, nil) - expectEqual(t, fc, roleBinding, nil) - expectEqual(t, fc, serviceAccount, nil) - expectEqual(t, fc, statefulSet, func(ss *appsv1.StatefulSet) { - ss.Spec.Template.Annotations[podAnnotationLastSetConfigFileHash] = "" - }) + expectEqual(t, fc, role) + expectEqual(t, fc, roleBinding) + expectEqual(t, fc, serviceAccount) + expectEqual(t, fc, statefulSet, removeResourceReqs) } else { expectMissing[rbacv1.Role](t, fc, role.Namespace, role.Name) expectMissing[rbacv1.RoleBinding](t, fc, roleBinding.Namespace, roleBinding.Name) @@ -218,11 +633,13 @@ func expectProxyGroupResources(t *testing.T, fc client.WithWatch, pg *tsapi.Prox } var expectedSecrets []string - for i := range pgReplicas(pg) { - expectedSecrets = append(expectedSecrets, - fmt.Sprintf("%s-%d", pg.Name, i), - fmt.Sprintf("%s-%d-config", pg.Name, i), - ) + if shouldExist { + for i := range pgReplicas(pg) { + expectedSecrets = append(expectedSecrets, + fmt.Sprintf("%s-%d", pg.Name, i), + pgConfigSecretName(pg.Name, i), + ) + } } expectSecrets(t, fc, expectedSecrets) } @@ -265,3 +682,146 @@ func addNodeIDToStateSecrets(t *testing.T, fc client.WithWatch, pg *tsapi.ProxyG }) } } + +func TestProxyGroupLetsEncryptStaging(t *testing.T) { + cl := tstest.NewClock(tstest.ClockOpts{}) + zl := zap.Must(zap.NewDevelopment()) + + // Set up test cases- most are shared with non-HA Ingress. + type proxyGroupLETestCase struct { + leStagingTestCase + pgType tsapi.ProxyGroupType + } + pcLEStaging, pcLEStagingFalse, pcOther := proxyClassesForLEStagingTest() + sharedTestCases := testCasesForLEStagingTests(pcLEStaging, pcLEStagingFalse, pcOther) + var tests []proxyGroupLETestCase + for _, tt := range sharedTestCases { + tests = append(tests, proxyGroupLETestCase{ + leStagingTestCase: tt, + pgType: tsapi.ProxyGroupTypeIngress, + }) + } + tests = append(tests, proxyGroupLETestCase{ + leStagingTestCase: leStagingTestCase{ + name: "egress_pg_with_staging_proxyclass", + proxyClassPerResource: "le-staging", + useLEStagingEndpoint: false, + }, + pgType: tsapi.ProxyGroupTypeEgress, + }) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + builder := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme) + + // Pre-populate the fake client with ProxyClasses. + builder = builder.WithObjects(pcLEStaging, pcLEStagingFalse, pcOther). + WithStatusSubresource(pcLEStaging, pcLEStagingFalse, pcOther) + + fc := builder.Build() + + // If the test case needs a ProxyClass to exist, ensure it is set to Ready. + if tt.proxyClassPerResource != "" || tt.defaultProxyClass != "" { + name := tt.proxyClassPerResource + if name == "" { + name = tt.defaultProxyClass + } + setProxyClassReady(t, fc, cl, name) + } + + // Create ProxyGroup + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tt.pgType, + Replicas: ptr.To[int32](1), + ProxyClass: tt.proxyClassPerResource, + }, + } + mustCreate(t, fc, pg) + + reconciler := &ProxyGroupReconciler{ + tsNamespace: tsNamespace, + proxyImage: testProxyImage, + defaultTags: []string{"tag:test"}, + defaultProxyClass: tt.defaultProxyClass, + Client: fc, + tsClient: &fakeTSClient{}, + l: zl.Sugar(), + clock: cl, + } + + expectReconciled(t, reconciler, "", pg.Name) + + // Verify that the StatefulSet created for ProxyGrup has + // the expected setting for the staging endpoint. + sts := &appsv1.StatefulSet{} + if err := fc.Get(context.Background(), client.ObjectKey{Namespace: tsNamespace, Name: pg.Name}, sts); err != nil { + t.Fatalf("failed to get StatefulSet: %v", err) + } + + if tt.useLEStagingEndpoint { + verifyEnvVar(t, sts, "TS_DEBUG_ACME_DIRECTORY_URL", letsEncryptStagingEndpoint) + } else { + verifyEnvVarNotPresent(t, sts, "TS_DEBUG_ACME_DIRECTORY_URL") + } + }) + } +} + +type leStagingTestCase struct { + name string + // ProxyClass set on ProxyGroup or Ingress resource. + proxyClassPerResource string + // Default ProxyClass. + defaultProxyClass string + useLEStagingEndpoint bool +} + +// Shared test cases for LE staging endpoint configuration for ProxyGroup and +// non-HA Ingress. +func testCasesForLEStagingTests(pcLEStaging, pcLEStagingFalse, pcOther *tsapi.ProxyClass) []leStagingTestCase { + return []leStagingTestCase{ + { + name: "with_staging_proxyclass", + proxyClassPerResource: "le-staging", + useLEStagingEndpoint: true, + }, + { + name: "with_staging_proxyclass_false", + proxyClassPerResource: "le-staging-false", + useLEStagingEndpoint: false, + }, + { + name: "with_other_proxyclass", + proxyClassPerResource: "other", + useLEStagingEndpoint: false, + }, + { + name: "no_proxyclass", + proxyClassPerResource: "", + useLEStagingEndpoint: false, + }, + { + name: "with_default_staging_proxyclass", + proxyClassPerResource: "", + defaultProxyClass: "le-staging", + useLEStagingEndpoint: true, + }, + { + name: "with_default_other_proxyclass", + proxyClassPerResource: "", + defaultProxyClass: "other", + useLEStagingEndpoint: false, + }, + { + name: "with_default_staging_proxyclass_false", + proxyClassPerResource: "", + defaultProxyClass: "le-staging-false", + useLEStagingEndpoint: false, + }, + } +} diff --git a/cmd/k8s-operator/sts.go b/cmd/k8s-operator/sts.go index 6378a82636939..b8dc79dacf8c4 100644 --- a/cmd/k8s-operator/sts.go +++ b/cmd/k8s-operator/sts.go @@ -15,6 +15,7 @@ import ( "net/http" "os" "slices" + "strconv" "strings" "go.uber.org/zap" @@ -43,11 +44,9 @@ const ( // Labels that the operator sets on StatefulSets and Pods. If you add a // new label here, do also add it to tailscaleManagedLabels var to // ensure that it does not get overwritten by ProxyClass configuration. - LabelManaged = "tailscale.com/managed" LabelParentType = "tailscale.com/parent-resource-type" LabelParentName = "tailscale.com/parent-resource" LabelParentNamespace = "tailscale.com/parent-resource-ns" - labelSecretType = "tailscale.com/secret-type" // "config" or "state". // LabelProxyClass can be set by users on tailscale Ingresses and Services that define cluster ingress or // cluster egress, to specify that configuration in this ProxyClass should be applied to resources created for @@ -94,11 +93,22 @@ const ( podAnnotationLastSetTailnetTargetFQDN = "tailscale.com/operator-last-set-ts-tailnet-target-fqdn" // podAnnotationLastSetConfigFileHash is sha256 hash of the current tailscaled configuration contents. podAnnotationLastSetConfigFileHash = "tailscale.com/operator-last-set-config-file-hash" + + proxyTypeEgress = "egress_service" + proxyTypeIngressService = "ingress_service" + proxyTypeIngressResource = "ingress_resource" + proxyTypeConnector = "connector" + proxyTypeProxyGroup = "proxygroup" + + envVarTSLocalAddrPort = "TS_LOCAL_ADDR_PORT" + defaultLocalAddrPort = 9002 // metrics and health check port + + letsEncryptStagingEndpoint = "https://acme-staging-v02.api.letsencrypt.org/directory" ) var ( // tailscaleManagedLabels are label keys that tailscale operator sets on StatefulSets and Pods. - tailscaleManagedLabels = []string{LabelManaged, LabelParentType, LabelParentName, LabelParentNamespace, "app"} + tailscaleManagedLabels = []string{kubetypes.LabelManaged, LabelParentType, LabelParentName, LabelParentNamespace, "app"} // tailscaleManagedAnnotations are annotation keys that tailscale operator sets on StatefulSets and Pods. tailscaleManagedAnnotations = []string{podAnnotationLastSetClusterIP, podAnnotationLastSetTailnetTargetIP, podAnnotationLastSetTailnetTargetFQDN, podAnnotationLastSetConfigFileHash} ) @@ -122,6 +132,10 @@ type tailscaleSTSConfig struct { Hostname string Tags []string // if empty, use defaultTags + ControlURL string + + proxyType string + // Connector specifies a configuration of a Connector instance if that's // what this StatefulSet should be created for. Connector *connector @@ -132,10 +146,13 @@ type tailscaleSTSConfig struct { } type connector struct { - // routes is a list of subnet routes that this Connector should expose. + // routes is a list of routes that this Connector should advertise either as a subnet router or as an app + // connector. routes string // isExitNode defines whether this Connector should act as an exit node. isExitNode bool + // isAppConnector defines whether this Connector should act as an app connector. + isAppConnector bool } type tsnetServer interface { CertDomains() []string @@ -150,6 +167,7 @@ type tailscaleSTSReconciler struct { proxyImage string proxyPriorityClassName string tsFirewallMode string + controlUrl string } func (sts tailscaleSTSReconciler) validate() error { @@ -160,8 +178,8 @@ func (sts tailscaleSTSReconciler) validate() error { } // IsHTTPSEnabledOnTailnet reports whether HTTPS is enabled on the tailnet. -func (a *tailscaleSTSReconciler) IsHTTPSEnabledOnTailnet() bool { - return len(a.tsnetServer.CertDomains()) > 0 +func IsHTTPSEnabledOnTailnet(tsnetServer tsnetServer) bool { + return len(tsnetServer.CertDomains()) > 0 } // Provision ensures that the StatefulSet for the given service is running and @@ -186,22 +204,38 @@ func (a *tailscaleSTSReconciler) Provision(ctx context.Context, logger *zap.Suga } sts.ProxyClass = proxyClass - secretName, tsConfigHash, configs, err := a.createOrGetSecret(ctx, logger, sts, hsvc) + if a.controlUrl != "" { + sts.ControlURL = a.controlUrl + } + + secretName, tsConfigHash, _, err := a.createOrGetSecret(ctx, logger, sts, hsvc) + if a.controlUrl != "" { + sts.ControlURL = a.controlUrl + } + if err != nil { return nil, fmt.Errorf("failed to create or get API key secret: %w", err) } - _, err = a.reconcileSTS(ctx, logger, sts, hsvc, secretName, tsConfigHash, configs) + _, err = a.reconcileSTS(ctx, logger, sts, hsvc, secretName, tsConfigHash) if err != nil { return nil, fmt.Errorf("failed to reconcile statefulset: %w", err) } - + mo := &metricsOpts{ + proxyStsName: hsvc.Name, + tsNamespace: hsvc.Namespace, + proxyLabels: hsvc.Labels, + proxyType: sts.proxyType, + } + if err = reconcileMetricsResources(ctx, logger, mo, sts.ProxyClass, a.Client); err != nil { + return nil, fmt.Errorf("failed to ensure metrics resources: %w", err) + } return hsvc, nil } // Cleanup removes all resources associated that were created by Provision with // the given labels. It returns true when all resources have been removed, // otherwise it returns false and the caller should retry later. -func (a *tailscaleSTSReconciler) Cleanup(ctx context.Context, logger *zap.SugaredLogger, labels map[string]string) (done bool, _ error) { +func (a *tailscaleSTSReconciler) Cleanup(ctx context.Context, logger *zap.SugaredLogger, labels map[string]string, typ string) (done bool, _ error) { // Need to delete the StatefulSet first, and delete it with foreground // cascading deletion. That way, the pod that's writing to the Secret will // stop running before we start looking at the Secret's contents, and @@ -227,21 +261,21 @@ func (a *tailscaleSTSReconciler) Cleanup(ctx context.Context, logger *zap.Sugare return false, nil } - id, _, _, err := a.DeviceInfo(ctx, labels) + dev, err := a.DeviceInfo(ctx, labels, logger) if err != nil { return false, fmt.Errorf("getting device info: %w", err) } - if id != "" { - logger.Debugf("deleting device %s from control", string(id)) - if err := a.tsClient.DeleteDevice(ctx, string(id)); err != nil { + if dev != nil && dev.id != "" { + logger.Debugf("deleting device %s from control", string(dev.id)) + if err := a.tsClient.DeleteDevice(ctx, string(dev.id)); err != nil { errResp := &tailscale.ErrResponse{} if ok := errors.As(err, errResp); ok && errResp.Status == http.StatusNotFound { - logger.Debugf("device %s not found, likely because it has already been deleted from control", string(id)) + logger.Debugf("device %s not found, likely because it has already been deleted from control", string(dev.id)) } else { return false, fmt.Errorf("deleting device: %w", err) } } else { - logger.Debugf("device %s deleted from control", string(id)) + logger.Debugf("device %s deleted from control", string(dev.id)) } } @@ -254,6 +288,14 @@ func (a *tailscaleSTSReconciler) Cleanup(ctx context.Context, logger *zap.Sugare return false, err } } + mo := &metricsOpts{ + proxyLabels: labels, + tsNamespace: a.operatorNamespace, + proxyType: typ, + } + if err := maybeCleanupMetricsResources(ctx, mo, a.Client); err != nil { + return false, fmt.Errorf("error cleaning up metrics resources: %w", err) + } return true, nil } @@ -409,44 +451,75 @@ func sanitizeConfigBytes(c ipn.ConfigVAlpha) string { return string(sanitizedBytes) } -// DeviceInfo returns the device ID, hostname and IPs for the Tailscale device -// that acts as an operator proxy. It retrieves info from a Kubernetes Secret -// labeled with the provided labels. -// Either of device ID, hostname and IPs can be empty string if not found in the Secret. -func (a *tailscaleSTSReconciler) DeviceInfo(ctx context.Context, childLabels map[string]string) (id tailcfg.StableNodeID, hostname string, ips []string, err error) { +// DeviceInfo returns the device ID, hostname, IPs and capver for the Tailscale device that acts as an operator proxy. +// It retrieves info from a Kubernetes Secret labeled with the provided labels. Capver is cross-validated against the +// Pod to ensure that it is the currently running Pod that set the capver. If the Pod or the Secret does not exist, the +// returned capver is -1. Either of device ID, hostname and IPs can be empty string if not found in the Secret. +func (a *tailscaleSTSReconciler) DeviceInfo(ctx context.Context, childLabels map[string]string, logger *zap.SugaredLogger) (dev *device, err error) { sec, err := getSingleObject[corev1.Secret](ctx, a.Client, a.operatorNamespace, childLabels) if err != nil { - return "", "", nil, err + return dev, err } if sec == nil { - return "", "", nil, nil + return dev, nil } + podUID := "" + pod := new(corev1.Pod) + if err := a.Get(ctx, types.NamespacedName{Namespace: sec.Namespace, Name: sec.Name}, pod); err != nil && !apierrors.IsNotFound(err) { + return dev, err + } else if err == nil { + podUID = string(pod.ObjectMeta.UID) + } + return deviceInfo(sec, podUID, logger) +} - return deviceInfo(sec) +// device contains tailscale state of a proxy device as gathered from its tailscale state Secret. +type device struct { + id tailcfg.StableNodeID // device's stable ID + hostname string // MagicDNS name of the device + ips []string // Tailscale IPs of the device + // ingressDNSName is the L7 Ingress DNS name. In practice this will be the same value as hostname, but only set + // when the device has been configured to serve traffic on it via 'tailscale serve'. + ingressDNSName string + capver tailcfg.CapabilityVersion } -func deviceInfo(sec *corev1.Secret) (id tailcfg.StableNodeID, hostname string, ips []string, err error) { - id = tailcfg.StableNodeID(sec.Data["device_id"]) +func deviceInfo(sec *corev1.Secret, podUID string, log *zap.SugaredLogger) (dev *device, err error) { + id := tailcfg.StableNodeID(sec.Data[kubetypes.KeyDeviceID]) if id == "" { - return "", "", nil, nil + return dev, nil } + dev = &device{id: id} // Kubernetes chokes on well-formed FQDNs with the trailing dot, so we have // to remove it. - hostname = strings.TrimSuffix(string(sec.Data["device_fqdn"]), ".") - if hostname == "" { + dev.hostname = strings.TrimSuffix(string(sec.Data[kubetypes.KeyDeviceFQDN]), ".") + if dev.hostname == "" { // Device ID gets stored and retrieved in a different flow than // FQDN and IPs. A device that acts as Kubernetes operator - // proxy, but whose route setup has failed might have an device + // proxy, but whose route setup has failed might have a device // ID, but no FQDN/IPs. If so, return the ID, to allow the // operator to clean up such devices. - return id, "", nil, nil + return dev, nil + } + dev.ingressDNSName = dev.hostname + pcv := proxyCapVer(sec, podUID, log) + dev.capver = pcv + // TODO(irbekrm): we fall back to using the hostname field to determine Ingress's hostname to ensure backwards + // compatibility. In 1.82 we can remove this fallback mechanism. + if pcv >= 109 { + dev.ingressDNSName = strings.TrimSuffix(string(sec.Data[kubetypes.KeyHTTPSEndpoint]), ".") + if strings.EqualFold(dev.ingressDNSName, kubetypes.ValueNoHTTPS) { + dev.ingressDNSName = "" + } } - if rawDeviceIPs, ok := sec.Data["device_ips"]; ok { + if rawDeviceIPs, ok := sec.Data[kubetypes.KeyDeviceIPs]; ok { + ips := make([]string, 0) if err := json.Unmarshal(rawDeviceIPs, &ips); err != nil { - return "", "", nil, err + return nil, err } + dev.ips = ips } - return id, hostname, ips, nil + return dev, nil } func newAuthKey(ctx context.Context, tsClient tsClient, tags []string) (string, error) { @@ -473,7 +546,7 @@ var proxyYaml []byte //go:embed deploy/manifests/userspace-proxy.yaml var userspaceProxyYaml []byte -func (a *tailscaleSTSReconciler) reconcileSTS(ctx context.Context, logger *zap.SugaredLogger, sts *tailscaleSTSConfig, headlessSvc *corev1.Service, proxySecret, tsConfigHash string, configs map[tailcfg.CapabilityVersion]ipn.ConfigVAlpha) (*appsv1.StatefulSet, error) { +func (a *tailscaleSTSReconciler) reconcileSTS(ctx context.Context, logger *zap.SugaredLogger, sts *tailscaleSTSConfig, headlessSvc *corev1.Service, proxySecret, tsConfigHash string) (*appsv1.StatefulSet, error) { ss := new(appsv1.StatefulSet) if sts.ServeConfig != nil && sts.ForwardClusterTrafficViaL7IngressProxy != true { // If forwarding cluster traffic via is required we need non-userspace + NET_ADMIN + forwarding if err := yaml.Unmarshal(userspaceProxyYaml, &ss); err != nil { @@ -518,11 +591,6 @@ func (a *tailscaleSTSReconciler) reconcileSTS(ctx context.Context, logger *zap.S Name: "TS_KUBE_SECRET", Value: proxySecret, }, - corev1.EnvVar{ - // Old tailscaled config key is still used for backwards compatibility. - Name: "EXPERIMENTAL_TS_CONFIGFILE_PATH", - Value: "/etc/tsconfig/tailscaled", - }, corev1.EnvVar{ // New style is in the form of cap-.hujson. Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", @@ -535,8 +603,6 @@ func (a *tailscaleSTSReconciler) reconcileSTS(ctx context.Context, logger *zap.S Value: "true", }) } - // Configure containeboot to run tailscaled with a configfile read from the state Secret. - mak.Set(&ss.Spec.Template.Annotations, podAnnotationLastSetConfigFileHash, tsConfigHash) configVolume := corev1.Volume{ Name: "tailscaledconfig", @@ -606,6 +672,12 @@ func (a *tailscaleSTSReconciler) reconcileSTS(ctx context.Context, logger *zap.S }, }) } + + dev, err := a.DeviceInfo(ctx, sts.ChildResourceLabels, logger) + if err != nil { + return nil, fmt.Errorf("failed to get device info: %w", err) + } + app, err := appInfoForProxy(sts) if err != nil { // No need to error out if now or in future we end up in a @@ -624,7 +696,25 @@ func (a *tailscaleSTSReconciler) reconcileSTS(ctx context.Context, logger *zap.S ss = applyProxyClassToStatefulSet(sts.ProxyClass, ss, sts, logger) } updateSS := func(s *appsv1.StatefulSet) { + // This is a temporary workaround to ensure that proxies with capver older than 110 + // are restarted when tailscaled configfile contents have changed. + // This workaround ensures that: + // 1. The hash mechanism is used to trigger pod restarts for proxies below capver 110. + // 2. Proxies above capver are not unnecessarily restarted when the configfile contents change. + // 3. If the hash has alreay been set, but the capver is above 110, the old hash is preserved to avoid + // unnecessary pod restarts that could result in an update loop where capver cannot be determined for a + // restarting Pod and the hash is re-added again. + // Note that the hash annotation is only set on updates not creation, because if the StatefulSet is + // being created, there is no need for a restart. + // TODO(irbekrm): remove this in 1.84. + hash := tsConfigHash + if dev == nil || dev.capver >= 110 { + hash = s.Spec.Template.GetAnnotations()[podAnnotationLastSetConfigFileHash] + } s.Spec = ss.Spec + if hash != "" { + mak.Set(&s.Spec.Template.Annotations, podAnnotationLastSetConfigFileHash, hash) + } s.ObjectMeta.Labels = ss.Labels s.ObjectMeta.Annotations = ss.Annotations } @@ -668,24 +758,53 @@ func mergeStatefulSetLabelsOrAnnots(current, custom map[string]string, managed [ return custom } +func debugSetting(pc *tsapi.ProxyClass) bool { + if pc == nil || + pc.Spec.StatefulSet == nil || + pc.Spec.StatefulSet.Pod == nil || + pc.Spec.StatefulSet.Pod.TailscaleContainer == nil || + pc.Spec.StatefulSet.Pod.TailscaleContainer.Debug == nil { + // This default will change to false in 1.82.0. + return pc.Spec.Metrics != nil && pc.Spec.Metrics.Enable + } + + return pc.Spec.StatefulSet.Pod.TailscaleContainer.Debug.Enable +} + func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet, stsCfg *tailscaleSTSConfig, logger *zap.SugaredLogger) *appsv1.StatefulSet { if pc == nil || ss == nil { return ss } - if stsCfg != nil && pc.Spec.Metrics != nil && pc.Spec.Metrics.Enable { - if stsCfg.TailnetTargetFQDN == "" && stsCfg.TailnetTargetIP == "" && !stsCfg.ForwardClusterTrafficViaL7IngressProxy { - enableMetrics(ss, pc) - } else if stsCfg.ForwardClusterTrafficViaL7IngressProxy { + + metricsEnabled := pc.Spec.Metrics != nil && pc.Spec.Metrics.Enable + debugEnabled := debugSetting(pc) + if metricsEnabled || debugEnabled { + isEgress := stsCfg != nil && (stsCfg.TailnetTargetFQDN != "" || stsCfg.TailnetTargetIP != "") + isForwardingL7Ingress := stsCfg != nil && stsCfg.ForwardClusterTrafficViaL7IngressProxy + if isEgress { // TODO (irbekrm): fix this // For Ingress proxies that have been configured with // tailscale.com/experimental-forward-cluster-traffic-via-ingress // annotation, all cluster traffic is forwarded to the // Ingress backend(s). - logger.Info("ProxyClass specifies that metrics should be enabled, but this is currently not supported for Ingress proxies that accept cluster traffic.") - } else { + logger.Info("ProxyClass specifies that metrics should be enabled, but this is currently not supported for egress proxies.") + } else if isForwardingL7Ingress { // TODO (irbekrm): fix this // For egress proxies, currently all cluster traffic is forwarded to the tailnet target. logger.Info("ProxyClass specifies that metrics should be enabled, but this is currently not supported for Ingress proxies that accept cluster traffic.") + } else { + enableEndpoints(ss, metricsEnabled, debugEnabled) + } + } + if pc.Spec.UseLetsEncryptStagingEnvironment && (stsCfg.proxyType == proxyTypeIngressResource || stsCfg.proxyType == string(tsapi.ProxyGroupTypeIngress)) { + for i, c := range ss.Spec.Template.Spec.Containers { + if c.Name == "tailscale" { + ss.Spec.Template.Spec.Containers[i].Env = append(ss.Spec.Template.Spec.Containers[i].Env, corev1.EnvVar{ + Name: "TS_DEBUG_ACME_DIRECTORY_URL", + Value: letsEncryptStagingEndpoint, + }) + break + } } } @@ -694,7 +813,7 @@ func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet, } // Update StatefulSet metadata. - if wantsSSLabels := pc.Spec.StatefulSet.Labels; len(wantsSSLabels) > 0 { + if wantsSSLabels := pc.Spec.StatefulSet.Labels.Parse(); len(wantsSSLabels) > 0 { ss.ObjectMeta.Labels = mergeStatefulSetLabelsOrAnnots(ss.ObjectMeta.Labels, wantsSSLabels, tailscaleManagedLabels) } if wantsSSAnnots := pc.Spec.StatefulSet.Annotations; len(wantsSSAnnots) > 0 { @@ -706,7 +825,7 @@ func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet, return ss } wantsPod := pc.Spec.StatefulSet.Pod - if wantsPodLabels := wantsPod.Labels; len(wantsPodLabels) > 0 { + if wantsPodLabels := wantsPod.Labels.Parse(); len(wantsPodLabels) > 0 { ss.Spec.Template.ObjectMeta.Labels = mergeStatefulSetLabelsOrAnnots(ss.Spec.Template.ObjectMeta.Labels, wantsPodLabels, tailscaleManagedLabels) } if wantsPodAnnots := wantsPod.Annotations; len(wantsPodAnnots) > 0 { @@ -718,6 +837,7 @@ func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet, ss.Spec.Template.Spec.NodeSelector = wantsPod.NodeSelector ss.Spec.Template.Spec.Affinity = wantsPod.Affinity ss.Spec.Template.Spec.Tolerations = wantsPod.Tolerations + ss.Spec.Template.Spec.TopologySpreadConstraints = wantsPod.TopologySpreadConstraints // Update containers. updateContainer := func(overlay *tsapi.Container, base corev1.Container) corev1.Container { @@ -762,16 +882,58 @@ func applyProxyClassToStatefulSet(pc *tsapi.ProxyClass, ss *appsv1.StatefulSet, return ss } -func enableMetrics(ss *appsv1.StatefulSet, pc *tsapi.ProxyClass) { +func enableEndpoints(ss *appsv1.StatefulSet, metrics, debug bool) { for i, c := range ss.Spec.Template.Spec.Containers { if c.Name == "tailscale" { - // Serve metrics on on :9001/debug/metrics. If - // we didn't specify Pod IP here, the proxy would, in - // some cases, also listen to its Tailscale IP- we don't - // want folks to start relying on this side-effect as a - // feature. - ss.Spec.Template.Spec.Containers[i].Env = append(ss.Spec.Template.Spec.Containers[i].Env, corev1.EnvVar{Name: "TS_TAILSCALED_EXTRA_ARGS", Value: "--debug=$(POD_IP):9001"}) - ss.Spec.Template.Spec.Containers[i].Ports = append(ss.Spec.Template.Spec.Containers[i].Ports, corev1.ContainerPort{Name: "metrics", Protocol: "TCP", HostPort: 9001, ContainerPort: 9001}) + if debug { + ss.Spec.Template.Spec.Containers[i].Env = append(ss.Spec.Template.Spec.Containers[i].Env, + // Serve tailscaled's debug metrics on on + // :9001/debug/metrics. If we didn't specify Pod IP + // here, the proxy would, in some cases, also listen to its + // Tailscale IP- we don't want folks to start relying on this + // side-effect as a feature. + corev1.EnvVar{ + Name: "TS_DEBUG_ADDR_PORT", + Value: "$(POD_IP):9001", + }, + // TODO(tomhjp): Can remove this env var once 1.76.x is no + // longer supported. + corev1.EnvVar{ + Name: "TS_TAILSCALED_EXTRA_ARGS", + Value: "--debug=$(TS_DEBUG_ADDR_PORT)", + }, + ) + + ss.Spec.Template.Spec.Containers[i].Ports = append(ss.Spec.Template.Spec.Containers[i].Ports, + corev1.ContainerPort{ + Name: "debug", + Protocol: "TCP", + ContainerPort: 9001, + }, + ) + } + + if metrics { + ss.Spec.Template.Spec.Containers[i].Env = append(ss.Spec.Template.Spec.Containers[i].Env, + // Serve client metrics on :9002/metrics. + corev1.EnvVar{ + Name: "TS_LOCAL_ADDR_PORT", + Value: "$(POD_IP):9002", + }, + corev1.EnvVar{ + Name: "TS_ENABLE_METRICS", + Value: "true", + }, + ) + ss.Spec.Template.Spec.Containers[i].Ports = append(ss.Spec.Template.Spec.Containers[i].Ports, + corev1.ContainerPort{ + Name: "metrics", + Protocol: "TCP", + ContainerPort: 9002, + }, + ) + } + break } } @@ -785,15 +947,9 @@ func readAuthKey(secret *corev1.Secret, key string) (*string, error) { return origConf.AuthKey, nil } -// tailscaledConfig takes a proxy config, a newly generated auth key if -// generated and a Secret with the previous proxy state and auth key and -// returns tailscaled configuration and a hash of that configuration. -// -// As of 2024-05-09 it also returns legacy tailscaled config without the -// later added NoStatefulFilter field to support proxies older than cap95. -// TODO (irbekrm): remove the legacy config once we no longer need to support -// versions older than cap94, -// https://tailscale.com/kb/1236/kubernetes-operator#operator-and-proxies +// tailscaledConfig takes a proxy config, a newly generated auth key if generated and a Secret with the previous proxy +// state and auth key and returns tailscaled config files for currently supported proxy versions and a hash of that +// configuration. func tailscaledConfig(stsC *tailscaleSTSConfig, newAuthkey string, oldSecret *corev1.Secret) (tailscaledConfigs, error) { conf := &ipn.ConfigVAlpha{ Version: "alpha0", @@ -801,21 +957,19 @@ func tailscaledConfig(stsC *tailscaleSTSConfig, newAuthkey string, oldSecret *co AcceptRoutes: "false", // AcceptRoutes defaults to true Locked: "false", Hostname: &stsC.Hostname, - NoStatefulFiltering: "false", + NoStatefulFiltering: "true", // Explicitly enforce default value, see #14216 + AppConnector: &ipn.AppConnectorPrefs{Advertise: false}, } - // For egress proxies only, we need to ensure that stateful filtering is - // not in place so that traffic from cluster can be forwarded via - // Tailscale IPs. - if stsC.TailnetTargetFQDN != "" || stsC.TailnetTargetIP != "" { - conf.NoStatefulFiltering = "true" - } if stsC.Connector != nil { routes, err := netutil.CalcAdvertiseRoutes(stsC.Connector.routes, stsC.Connector.isExitNode) if err != nil { return nil, fmt.Errorf("error calculating routes: %w", err) } conf.AdvertiseRoutes = routes + if stsC.Connector.isAppConnector { + conf.AppConnector.Advertise = true + } } if shouldAcceptRoutes(stsC.ProxyClass) { conf.AcceptRoutes = "true" @@ -830,11 +984,17 @@ func tailscaledConfig(stsC *tailscaleSTSConfig, newAuthkey string, oldSecret *co } conf.AuthKey = key } + + if stsC.ControlURL != "" { + conf.ServerURL = &stsC.ControlURL + } + capVerConfigs := make(map[tailcfg.CapabilityVersion]ipn.ConfigVAlpha) + capVerConfigs[107] = *conf + + // AppConnector config option is only understood by clients of capver 107 and newer. + conf.AppConnector = nil capVerConfigs[95] = *conf - // legacy config should not contain NoStatefulFiltering field. - conf.NoStatefulFiltering.Clear() - capVerConfigs[94] = *conf return capVerConfigs, nil } @@ -907,13 +1067,13 @@ func tailscaledConfigHash(c tailscaledConfigs) (string, error) { return fmt.Sprintf("%x", h.Sum(nil)), nil } -// createOrUpdate adds obj to the k8s cluster, unless the object already exists, -// in which case update is called to make changes to it. If update is nil, the -// existing object is returned unmodified. +// createOrMaybeUpdate adds obj to the k8s cluster, unless the object already exists, +// in which case update is called to make changes to it. If update is nil or returns +// an error, the object is returned unmodified. // // obj is looked up by its Name and Namespace if Name is set, otherwise it's // looked up by labels. -func createOrUpdate[T any, O ptrObject[T]](ctx context.Context, c client.Client, ns string, obj O, update func(O)) (O, error) { +func createOrMaybeUpdate[T any, O ptrObject[T]](ctx context.Context, c client.Client, ns string, obj O, update func(O) error) (O, error) { var ( existing O err error @@ -928,7 +1088,9 @@ func createOrUpdate[T any, O ptrObject[T]](ctx context.Context, c client.Client, } if err == nil && existing != nil { if update != nil { - update(existing) + if err := update(existing); err != nil { + return nil, err + } if err := c.Update(ctx, existing); err != nil { return nil, err } @@ -944,6 +1106,21 @@ func createOrUpdate[T any, O ptrObject[T]](ctx context.Context, c client.Client, return obj, nil } +// createOrUpdate adds obj to the k8s cluster, unless the object already exists, +// in which case update is called to make changes to it. If update is nil, the +// existing object is returned unmodified. +// +// obj is looked up by its Name and Namespace if Name is set, otherwise it's +// looked up by labels. +func createOrUpdate[T any, O ptrObject[T]](ctx context.Context, c client.Client, ns string, obj O, update func(O)) (O, error) { + return createOrMaybeUpdate(ctx, c, ns, obj, func(o O) error { + if update != nil { + update(o) + } + return nil + }) +} + // getSingleObject searches for k8s objects of type T // (e.g. corev1.Service) with the given labels, and returns // it. Returns nil if no objects match the labels, and an error if @@ -1007,3 +1184,24 @@ func nameForService(svc *corev1.Service) string { func isValidFirewallMode(m string) bool { return m == "auto" || m == "nftables" || m == "iptables" } + +// proxyCapVer accepts a proxy state Secret and UID of the current proxy Pod returns the capability version of the +// tailscale running in that Pod. This is best effort - if the capability version can not (currently) be determined, it +// returns -1. +func proxyCapVer(sec *corev1.Secret, podUID string, log *zap.SugaredLogger) tailcfg.CapabilityVersion { + if sec == nil || podUID == "" { + return tailcfg.CapabilityVersion(-1) + } + if len(sec.Data[kubetypes.KeyCapVer]) == 0 || len(sec.Data[kubetypes.KeyPodUID]) == 0 { + return tailcfg.CapabilityVersion(-1) + } + capVer, err := strconv.Atoi(string(sec.Data[kubetypes.KeyCapVer])) + if err != nil { + log.Infof("[unexpected]: unexpected capability version in proxy's state Secret, expected an integer, got %q", string(sec.Data[kubetypes.KeyCapVer])) + return tailcfg.CapabilityVersion(-1) + } + if !strings.EqualFold(podUID, string(sec.Data[kubetypes.KeyPodUID])) { + return tailcfg.CapabilityVersion(-1) + } + return tailcfg.CapabilityVersion(capVer) +} diff --git a/cmd/k8s-operator/sts_test.go b/cmd/k8s-operator/sts_test.go index b2b2c8b93a2d7..35c512c8cd05b 100644 --- a/cmd/k8s-operator/sts_test.go +++ b/cmd/k8s-operator/sts_test.go @@ -18,8 +18,10 @@ import ( appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/yaml" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" "tailscale.com/types/ptr" ) @@ -60,10 +62,10 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { proxyClassAllOpts := &tsapi.ProxyClass{ Spec: tsapi.ProxyClassSpec{ StatefulSet: &tsapi.StatefulSet{ - Labels: map[string]string{"foo": "bar"}, + Labels: tsapi.Labels{"foo": "bar"}, Annotations: map[string]string{"foo.io/bar": "foo"}, Pod: &tsapi.Pod{ - Labels: map[string]string{"bar": "foo"}, + Labels: tsapi.Labels{"bar": "foo"}, Annotations: map[string]string{"bar.io/foo": "foo"}, SecurityContext: &corev1.PodSecurityContext{ RunAsUser: ptr.To(int64(0)), @@ -73,6 +75,16 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { NodeSelector: map[string]string{"beta.kubernetes.io/os": "linux"}, Affinity: &corev1.Affinity{NodeAffinity: &corev1.NodeAffinity{RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{}}}, Tolerations: []corev1.Toleration{{Key: "", Operator: "Exists"}}, + TopologySpreadConstraints: []corev1.TopologySpreadConstraint{ + { + WhenUnsatisfiable: "DoNotSchedule", + TopologyKey: "kubernetes.io/hostname", + MaxSkew: 3, + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{"foo": "bar"}, + }, + }, + }, TailscaleContainer: &tsapi.Container{ SecurityContext: &corev1.SecurityContext{ Privileged: ptr.To(true), @@ -105,21 +117,36 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { proxyClassJustLabels := &tsapi.ProxyClass{ Spec: tsapi.ProxyClassSpec{ StatefulSet: &tsapi.StatefulSet{ - Labels: map[string]string{"foo": "bar"}, + Labels: tsapi.Labels{"foo": "bar"}, Annotations: map[string]string{"foo.io/bar": "foo"}, Pod: &tsapi.Pod{ - Labels: map[string]string{"bar": "foo"}, + Labels: tsapi.Labels{"bar": "foo"}, Annotations: map[string]string{"bar.io/foo": "foo"}, }, }, }, } - proxyClassMetrics := &tsapi.ProxyClass{ - Spec: tsapi.ProxyClassSpec{ - Metrics: &tsapi.Metrics{Enable: true}, - }, - } + proxyClassWithMetricsDebug := func(metrics bool, debug *bool) *tsapi.ProxyClass { + return &tsapi.ProxyClass{ + Spec: tsapi.ProxyClassSpec{ + Metrics: &tsapi.Metrics{Enable: metrics}, + StatefulSet: func() *tsapi.StatefulSet { + if debug == nil { + return nil + } + + return &tsapi.StatefulSet{ + Pod: &tsapi.Pod{ + TailscaleContainer: &tsapi.Container{ + Debug: &tsapi.Debug{Enable: *debug}, + }, + }, + } + }(), + }, + } + } var userspaceProxySS, nonUserspaceProxySS appsv1.StatefulSet if err := yaml.Unmarshal(userspaceProxyYaml, &userspaceProxySS); err != nil { t.Fatalf("unmarshaling userspace proxy template: %v", err) @@ -130,8 +157,8 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { // Set a couple additional fields so we can test that we don't // mistakenly override those. labels := map[string]string{ - LabelManaged: "true", - LabelParentName: "foo", + kubetypes.LabelManaged: "true", + LabelParentName: "foo", } annots := map[string]string{ podAnnotationLastSetClusterIP: "1.2.3.4", @@ -149,9 +176,9 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { // 1. Test that a ProxyClass with all fields set gets correctly applied // to a Statefulset built from non-userspace proxy template. wantSS := nonUserspaceProxySS.DeepCopy() - wantSS.ObjectMeta.Labels = mergeMapKeys(wantSS.ObjectMeta.Labels, proxyClassAllOpts.Spec.StatefulSet.Labels) - wantSS.ObjectMeta.Annotations = mergeMapKeys(wantSS.ObjectMeta.Annotations, proxyClassAllOpts.Spec.StatefulSet.Annotations) - wantSS.Spec.Template.Labels = proxyClassAllOpts.Spec.StatefulSet.Pod.Labels + updateMap(wantSS.ObjectMeta.Labels, proxyClassAllOpts.Spec.StatefulSet.Labels.Parse()) + updateMap(wantSS.ObjectMeta.Annotations, proxyClassAllOpts.Spec.StatefulSet.Annotations) + wantSS.Spec.Template.Labels = proxyClassAllOpts.Spec.StatefulSet.Pod.Labels.Parse() wantSS.Spec.Template.Annotations = proxyClassAllOpts.Spec.StatefulSet.Pod.Annotations wantSS.Spec.Template.Spec.SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.SecurityContext wantSS.Spec.Template.Spec.ImagePullSecrets = proxyClassAllOpts.Spec.StatefulSet.Pod.ImagePullSecrets @@ -159,6 +186,7 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { wantSS.Spec.Template.Spec.NodeSelector = proxyClassAllOpts.Spec.StatefulSet.Pod.NodeSelector wantSS.Spec.Template.Spec.Affinity = proxyClassAllOpts.Spec.StatefulSet.Pod.Affinity wantSS.Spec.Template.Spec.Tolerations = proxyClassAllOpts.Spec.StatefulSet.Pod.Tolerations + wantSS.Spec.Template.Spec.TopologySpreadConstraints = proxyClassAllOpts.Spec.StatefulSet.Pod.TopologySpreadConstraints wantSS.Spec.Template.Spec.Containers[0].SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.SecurityContext wantSS.Spec.Template.Spec.InitContainers[0].SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleInitContainer.SecurityContext wantSS.Spec.Template.Spec.Containers[0].Resources = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.Resources @@ -172,28 +200,28 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { gotSS := applyProxyClassToStatefulSet(proxyClassAllOpts, nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { - t.Fatalf("Unexpected result applying ProxyClass with all fields set to a StatefulSet for non-userspace proxy (-got +want):\n%s", diff) + t.Errorf("Unexpected result applying ProxyClass with all fields set to a StatefulSet for non-userspace proxy (-got +want):\n%s", diff) } // 2. Test that a ProxyClass with custom labels and annotations for // StatefulSet and Pod set gets correctly applied to a Statefulset built // from non-userspace proxy template. wantSS = nonUserspaceProxySS.DeepCopy() - wantSS.ObjectMeta.Labels = mergeMapKeys(wantSS.ObjectMeta.Labels, proxyClassJustLabels.Spec.StatefulSet.Labels) - wantSS.ObjectMeta.Annotations = mergeMapKeys(wantSS.ObjectMeta.Annotations, proxyClassJustLabels.Spec.StatefulSet.Annotations) - wantSS.Spec.Template.Labels = proxyClassJustLabels.Spec.StatefulSet.Pod.Labels + updateMap(wantSS.ObjectMeta.Labels, proxyClassJustLabels.Spec.StatefulSet.Labels.Parse()) + updateMap(wantSS.ObjectMeta.Annotations, proxyClassJustLabels.Spec.StatefulSet.Annotations) + wantSS.Spec.Template.Labels = proxyClassJustLabels.Spec.StatefulSet.Pod.Labels.Parse() wantSS.Spec.Template.Annotations = proxyClassJustLabels.Spec.StatefulSet.Pod.Annotations gotSS = applyProxyClassToStatefulSet(proxyClassJustLabels, nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { - t.Fatalf("Unexpected result applying ProxyClass with custom labels and annotations to a StatefulSet for non-userspace proxy (-got +want):\n%s", diff) + t.Errorf("Unexpected result applying ProxyClass with custom labels and annotations to a StatefulSet for non-userspace proxy (-got +want):\n%s", diff) } // 3. Test that a ProxyClass with all fields set gets correctly applied // to a Statefulset built from a userspace proxy template. wantSS = userspaceProxySS.DeepCopy() - wantSS.ObjectMeta.Labels = mergeMapKeys(wantSS.ObjectMeta.Labels, proxyClassAllOpts.Spec.StatefulSet.Labels) - wantSS.ObjectMeta.Annotations = mergeMapKeys(wantSS.ObjectMeta.Annotations, proxyClassAllOpts.Spec.StatefulSet.Annotations) - wantSS.Spec.Template.Labels = proxyClassAllOpts.Spec.StatefulSet.Pod.Labels + updateMap(wantSS.ObjectMeta.Labels, proxyClassAllOpts.Spec.StatefulSet.Labels.Parse()) + updateMap(wantSS.ObjectMeta.Annotations, proxyClassAllOpts.Spec.StatefulSet.Annotations) + wantSS.Spec.Template.Labels = proxyClassAllOpts.Spec.StatefulSet.Pod.Labels.Parse() wantSS.Spec.Template.Annotations = proxyClassAllOpts.Spec.StatefulSet.Pod.Annotations wantSS.Spec.Template.Spec.SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.SecurityContext wantSS.Spec.Template.Spec.ImagePullSecrets = proxyClassAllOpts.Spec.StatefulSet.Pod.ImagePullSecrets @@ -201,6 +229,7 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { wantSS.Spec.Template.Spec.NodeSelector = proxyClassAllOpts.Spec.StatefulSet.Pod.NodeSelector wantSS.Spec.Template.Spec.Affinity = proxyClassAllOpts.Spec.StatefulSet.Pod.Affinity wantSS.Spec.Template.Spec.Tolerations = proxyClassAllOpts.Spec.StatefulSet.Pod.Tolerations + wantSS.Spec.Template.Spec.TopologySpreadConstraints = proxyClassAllOpts.Spec.StatefulSet.Pod.TopologySpreadConstraints wantSS.Spec.Template.Spec.Containers[0].SecurityContext = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.SecurityContext wantSS.Spec.Template.Spec.Containers[0].Resources = proxyClassAllOpts.Spec.StatefulSet.Pod.TailscaleContainer.Resources wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, []corev1.EnvVar{{Name: "foo", Value: "bar"}, {Name: "TS_USERSPACE", Value: "true"}, {Name: "bar"}}...) @@ -208,36 +237,61 @@ func Test_applyProxyClassToStatefulSet(t *testing.T) { wantSS.Spec.Template.Spec.Containers[0].Image = "ghcr.io/my-repo/tailscale:v0.01testsomething" gotSS = applyProxyClassToStatefulSet(proxyClassAllOpts, userspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { - t.Fatalf("Unexpected result applying ProxyClass with all options to a StatefulSet for a userspace proxy (-got +want):\n%s", diff) + t.Errorf("Unexpected result applying ProxyClass with all options to a StatefulSet for a userspace proxy (-got +want):\n%s", diff) } // 4. Test that a ProxyClass with custom labels and annotations gets correctly applied // to a Statefulset built from a userspace proxy template. wantSS = userspaceProxySS.DeepCopy() - wantSS.ObjectMeta.Labels = mergeMapKeys(wantSS.ObjectMeta.Labels, proxyClassJustLabels.Spec.StatefulSet.Labels) - wantSS.ObjectMeta.Annotations = mergeMapKeys(wantSS.ObjectMeta.Annotations, proxyClassJustLabels.Spec.StatefulSet.Annotations) - wantSS.Spec.Template.Labels = proxyClassJustLabels.Spec.StatefulSet.Pod.Labels + updateMap(wantSS.ObjectMeta.Labels, proxyClassJustLabels.Spec.StatefulSet.Labels.Parse()) + updateMap(wantSS.ObjectMeta.Annotations, proxyClassJustLabels.Spec.StatefulSet.Annotations) + wantSS.Spec.Template.Labels = proxyClassJustLabels.Spec.StatefulSet.Pod.Labels.Parse() wantSS.Spec.Template.Annotations = proxyClassJustLabels.Spec.StatefulSet.Pod.Annotations gotSS = applyProxyClassToStatefulSet(proxyClassJustLabels, userspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { - t.Fatalf("Unexpected result applying ProxyClass with custom labels and annotations to a StatefulSet for a userspace proxy (-got +want):\n%s", diff) + t.Errorf("Unexpected result applying ProxyClass with custom labels and annotations to a StatefulSet for a userspace proxy (-got +want):\n%s", diff) } - // 5. Test that a ProxyClass with metrics enabled gets correctly applied to a StatefulSet. + // 5. Metrics enabled defaults to enabling both metrics and debug. wantSS = nonUserspaceProxySS.DeepCopy() - wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, corev1.EnvVar{Name: "TS_TAILSCALED_EXTRA_ARGS", Value: "--debug=$(POD_IP):9001"}) - wantSS.Spec.Template.Spec.Containers[0].Ports = []corev1.ContainerPort{{Name: "metrics", Protocol: "TCP", ContainerPort: 9001, HostPort: 9001}} - gotSS = applyProxyClassToStatefulSet(proxyClassMetrics, nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) + wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, + corev1.EnvVar{Name: "TS_DEBUG_ADDR_PORT", Value: "$(POD_IP):9001"}, + corev1.EnvVar{Name: "TS_TAILSCALED_EXTRA_ARGS", Value: "--debug=$(TS_DEBUG_ADDR_PORT)"}, + corev1.EnvVar{Name: "TS_LOCAL_ADDR_PORT", Value: "$(POD_IP):9002"}, + corev1.EnvVar{Name: "TS_ENABLE_METRICS", Value: "true"}, + ) + wantSS.Spec.Template.Spec.Containers[0].Ports = []corev1.ContainerPort{ + {Name: "debug", Protocol: "TCP", ContainerPort: 9001}, + {Name: "metrics", Protocol: "TCP", ContainerPort: 9002}, + } + gotSS = applyProxyClassToStatefulSet(proxyClassWithMetricsDebug(true, nil), nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) if diff := cmp.Diff(gotSS, wantSS); diff != "" { - t.Fatalf("Unexpected result applying ProxyClass with metrics enabled to a StatefulSet (-got +want):\n%s", diff) + t.Errorf("Unexpected result applying ProxyClass with metrics enabled to a StatefulSet (-got +want):\n%s", diff) } -} -func mergeMapKeys(a, b map[string]string) map[string]string { - for key, val := range b { - a[key] = val + // 6. Enable _just_ metrics by explicitly disabling debug. + wantSS = nonUserspaceProxySS.DeepCopy() + wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, + corev1.EnvVar{Name: "TS_LOCAL_ADDR_PORT", Value: "$(POD_IP):9002"}, + corev1.EnvVar{Name: "TS_ENABLE_METRICS", Value: "true"}, + ) + wantSS.Spec.Template.Spec.Containers[0].Ports = []corev1.ContainerPort{{Name: "metrics", Protocol: "TCP", ContainerPort: 9002}} + gotSS = applyProxyClassToStatefulSet(proxyClassWithMetricsDebug(true, ptr.To(false)), nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) + if diff := cmp.Diff(gotSS, wantSS); diff != "" { + t.Errorf("Unexpected result applying ProxyClass with metrics enabled to a StatefulSet (-got +want):\n%s", diff) + } + + // 7. Enable _just_ debug without metrics. + wantSS = nonUserspaceProxySS.DeepCopy() + wantSS.Spec.Template.Spec.Containers[0].Env = append(wantSS.Spec.Template.Spec.Containers[0].Env, + corev1.EnvVar{Name: "TS_DEBUG_ADDR_PORT", Value: "$(POD_IP):9001"}, + corev1.EnvVar{Name: "TS_TAILSCALED_EXTRA_ARGS", Value: "--debug=$(TS_DEBUG_ADDR_PORT)"}, + ) + wantSS.Spec.Template.Spec.Containers[0].Ports = []corev1.ContainerPort{{Name: "debug", Protocol: "TCP", ContainerPort: 9001}} + gotSS = applyProxyClassToStatefulSet(proxyClassWithMetricsDebug(false, ptr.To(true)), nonUserspaceProxySS.DeepCopy(), new(tailscaleSTSConfig), zl.Sugar()) + if diff := cmp.Diff(gotSS, wantSS); diff != "" { + t.Errorf("Unexpected result applying ProxyClass with metrics enabled to a StatefulSet (-got +want):\n%s", diff) } - return a } func Test_mergeStatefulSetLabelsOrAnnots(t *testing.T) { @@ -250,28 +304,28 @@ func Test_mergeStatefulSetLabelsOrAnnots(t *testing.T) { }{ { name: "no custom labels specified and none present in current labels, return current labels", - current: map[string]string{LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, - want: map[string]string{LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, + current: map[string]string{kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, + want: map[string]string{kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, managed: tailscaleManagedLabels, }, { name: "no custom labels specified, but some present in current labels, return tailscale managed labels only from the current labels", - current: map[string]string{"foo": "bar", "something.io/foo": "bar", LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, - want: map[string]string{LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, + current: map[string]string{"foo": "bar", "something.io/foo": "bar", kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, + want: map[string]string{kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, managed: tailscaleManagedLabels, }, { name: "custom labels specified, current labels only contain tailscale managed labels, return a union of both", - current: map[string]string{LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, + current: map[string]string{kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, custom: map[string]string{"foo": "bar", "something.io/foo": "bar"}, - want: map[string]string{"foo": "bar", "something.io/foo": "bar", LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, + want: map[string]string{"foo": "bar", "something.io/foo": "bar", kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, managed: tailscaleManagedLabels, }, { name: "custom labels specified, current labels contain tailscale managed labels and custom labels, some of which re not present in the new custom labels, return a union of managed labels and the desired custom labels", - current: map[string]string{"foo": "bar", "bar": "baz", "app": "1234", LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, + current: map[string]string{"foo": "bar", "bar": "baz", "app": "1234", kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, custom: map[string]string{"foo": "bar", "something.io/foo": "bar"}, - want: map[string]string{"foo": "bar", "something.io/foo": "bar", "app": "1234", LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, + want: map[string]string{"foo": "bar", "something.io/foo": "bar", "app": "1234", kubetypes.LabelManaged: "true", LabelParentName: "foo", LabelParentType: "svc", LabelParentNamespace: "foo"}, managed: tailscaleManagedLabels, }, { @@ -331,3 +385,10 @@ func Test_mergeStatefulSetLabelsOrAnnots(t *testing.T) { }) } } + +// updateMap updates map a with the values from map b. +func updateMap(a, b map[string]string) { + for key, val := range b { + a[key] = val + } +} diff --git a/cmd/k8s-operator/svc-for-pg.go b/cmd/k8s-operator/svc-for-pg.go new file mode 100644 index 0000000000000..a7255b73e8f2c --- /dev/null +++ b/cmd/k8s-operator/svc-for-pg.go @@ -0,0 +1,861 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/netip" + "reflect" + "slices" + "strings" + "sync" + + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + discoveryv1 "k8s.io/api/discovery/v1" + apiequality "k8s.io/apimachinery/pkg/api/equality" + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/reconcile" + "tailscale.com/internal/client/tailscale" + "tailscale.com/ipn" + tsoperator "tailscale.com/k8s-operator" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/ingressservices" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" + "tailscale.com/tstime" + "tailscale.com/util/clientmetric" + "tailscale.com/util/mak" + "tailscale.com/util/set" +) + +const ( + finalizerName = "tailscale.com/service-pg-finalizer" + + reasonIngressSvcInvalid = "IngressSvcInvalid" + reasonIngressSvcValid = "IngressSvcValid" + reasonIngressSvcConfigured = "IngressSvcConfigured" + reasonIngressSvcNoBackendsConfigured = "IngressSvcNoBackendsConfigured" + reasonIngressSvcCreationFailed = "IngressSvcCreationFailed" +) + +var gaugePGServiceResources = clientmetric.NewGauge(kubetypes.MetricServicePGResourceCount) + +// HAServiceReconciler is a controller that reconciles Tailscale Kubernetes +// Services that should be exposed on an ingress ProxyGroup (in HA mode). +type HAServiceReconciler struct { + client.Client + isDefaultLoadBalancer bool + recorder record.EventRecorder + logger *zap.SugaredLogger + tsClient tsClient + tsnetServer tsnetServer + tsNamespace string + lc localClient + defaultTags []string + operatorID string // stableID of the operator's Tailscale device + + validationOpts validationOpts + + clock tstime.Clock + + mu sync.Mutex // protects following + // managedServices is a set of all Service resources that we're currently + // managing. This is only used for metrics. + managedServices set.Slice[types.UID] +} + +// Reconcile reconciles Services that should be exposed over Tailscale in HA +// mode (on a ProxyGroup). It looks at all Services with +// tailscale.com/proxy-group annotation. For each such Service, it ensures that +// a Tailscale Service named after the hostname of the Service exists and is up to +// date. +// HA Servicees support multi-cluster Service setup. +// Each Tailscale Service contains a list of owner references that uniquely identify +// the operator. When an Service that acts as a +// backend is being deleted, the corresponding Tailscale Service is only deleted if the +// only owner reference that it contains is for this operator. If other owner +// references are found, then cleanup operation only removes this operator's owner +// reference. +func (r *HAServiceReconciler) Reconcile(ctx context.Context, req reconcile.Request) (res reconcile.Result, err error) { + logger := r.logger.With("Service", req.NamespacedName) + logger.Debugf("starting reconcile") + defer logger.Debugf("reconcile finished") + + svc := new(corev1.Service) + err = r.Get(ctx, req.NamespacedName, svc) + if apierrors.IsNotFound(err) { + // Request object not found, could have been deleted after reconcile request. + logger.Debugf("Service not found, assuming it was deleted") + return res, nil + } else if err != nil { + return res, fmt.Errorf("failed to get Service: %w", err) + } + + hostname := nameForService(svc) + logger = logger.With("hostname", hostname) + + if !svc.DeletionTimestamp.IsZero() || !r.isTailscaleService(svc) { + logger.Debugf("Service is being deleted or is (no longer) referring to Tailscale ingress/egress, ensuring any created resources are cleaned up") + _, err = r.maybeCleanup(ctx, hostname, svc, logger) + return res, err + } + + // needsRequeue is set to true if the underlying Tailscale Service has changed as a result of this reconcile. If that + // is the case, we reconcile the Ingress one more time to ensure that concurrent updates to the Tailscale Service in a + // multi-cluster Ingress setup have not resulted in another actor overwriting our Tailscale Service update. + needsRequeue := false + needsRequeue, err = r.maybeProvision(ctx, hostname, svc, logger) + if err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + } else { + return reconcile.Result{}, err + } + } + if needsRequeue { + res = reconcile.Result{RequeueAfter: requeueInterval()} + } + + return reconcile.Result{}, nil +} + +// maybeProvision ensures that a Tailscale Service for this Ingress exists and is up to date and that the serve config for the +// corresponding ProxyGroup contains the Ingress backend's definition. +// If a Tailscale Service does not exist, it will be created. +// If a Tailscale Service exists, but only with owner references from other operator instances, an owner reference for this +// operator instance is added. +// If a Tailscale Service exists, but does not have an owner reference from any operator, we error +// out assuming that this is an owner reference created by an unknown actor. +// Returns true if the operation resulted in a Tailscale Service update. +func (r *HAServiceReconciler) maybeProvision(ctx context.Context, hostname string, svc *corev1.Service, logger *zap.SugaredLogger) (svcsChanged bool, err error) { + oldSvcStatus := svc.Status.DeepCopy() + defer func() { + if !apiequality.Semantic.DeepEqual(oldSvcStatus, &svc.Status) { + // An error encountered here should get returned by the Reconcile function. + err = errors.Join(err, r.Client.Status().Update(ctx, svc)) + } + }() + + pgName := svc.Annotations[AnnotationProxyGroup] + if pgName == "" { + logger.Infof("[unexpected] no ProxyGroup annotation, skipping Tailscale Service provisioning") + return false, nil + } + + logger = logger.With("ProxyGroup", pgName) + + pg := &tsapi.ProxyGroup{} + if err := r.Get(ctx, client.ObjectKey{Name: pgName}, pg); err != nil { + if apierrors.IsNotFound(err) { + msg := fmt.Sprintf("ProxyGroup %q does not exist", pgName) + logger.Warnf(msg) + r.recorder.Event(svc, corev1.EventTypeWarning, "ProxyGroupNotFound", msg) + return false, nil + } + return false, fmt.Errorf("getting ProxyGroup %q: %w", pgName, err) + } + if !tsoperator.ProxyGroupIsReady(pg) { + logger.Infof("ProxyGroup is not (yet) ready") + return false, nil + } + + // Validate Service configuration + if violations := validateService(svc, r.validationOpts); len(violations) > 0 { + msg := fmt.Sprintf("unable to provision proxy resources: invalid Service: %s", strings.Join(violations, ", ")) + r.recorder.Event(svc, corev1.EventTypeWarning, "INVALIDSERVICE", msg) + r.logger.Error(msg) + tsoperator.SetServiceCondition(svc, tsapi.IngressSvcValid, metav1.ConditionFalse, reasonIngressSvcInvalid, msg, r.clock, logger) + return false, nil + } + + if !slices.Contains(svc.Finalizers, finalizerName) { + // This log line is printed exactly once during initial provisioning, + // because once the finalizer is in place this block gets skipped. So, + // this is a nice place to tell the operator that the high level, + // multi-reconcile operation is underway. + logger.Infof("exposing Service over tailscale") + svc.Finalizers = append(svc.Finalizers, finalizerName) + if err := r.Update(ctx, svc); err != nil { + return false, fmt.Errorf("failed to add finalizer: %w", err) + } + r.mu.Lock() + r.managedServices.Add(svc.UID) + gaugePGServiceResources.Set(int64(r.managedServices.Len())) + r.mu.Unlock() + } + + // 1. Ensure that if Service's hostname/name has changed, any Tailscale Service + // resources corresponding to the old hostname are cleaned up. + // In practice, this function will ensure that any Tailscale Services that are + // associated with the provided ProxyGroup and no longer owned by a + // Service are cleaned up. This is fine- it is not expensive and ensures + // that in edge cases (a single update changed both hostname and removed + // ProxyGroup annotation) the Tailscale Service is more likely to be + // (eventually) removed. + svcsChanged, err = r.maybeCleanupProxyGroup(ctx, pgName, logger) + if err != nil { + return false, fmt.Errorf("failed to cleanup Tailscale Service resources for ProxyGroup: %w", err) + } + + // 2. Ensure that there isn't a Tailscale Service with the same hostname + // already created and not owned by this Service. + serviceName := tailcfg.ServiceName("svc:" + hostname) + existingTSSvc, err := r.tsClient.GetVIPService(ctx, serviceName) + if isErrorFeatureFlagNotEnabled(err) { + logger.Warn(msgFeatureFlagNotEnabled) + r.recorder.Event(svc, corev1.EventTypeWarning, warningTailscaleServiceFeatureFlagNotEnabled, msgFeatureFlagNotEnabled) + return false, nil + } + if err != nil && !isErrorTailscaleServiceNotFound(err) { + return false, fmt.Errorf("error getting Tailscale Service %q: %w", hostname, err) + } + + // 3. Generate the Tailscale Service owner annotation for new or existing Tailscale Service. + // This checks and ensures that Tailscale Service's owner references are updated + // for this Service and errors if that is not possible (i.e. because it + // appears that the Tailscale Service has been created by a non-operator actor). + updatedAnnotations, err := r.ownerAnnotations(existingTSSvc) + if err != nil { + instr := fmt.Sprintf("To proceed, you can either manually delete the existing Tailscale Service or choose a different hostname with the '%s' annotaion", AnnotationHostname) + msg := fmt.Sprintf("error ensuring ownership of Tailscale Service %s: %v. %s", hostname, err, instr) + logger.Warn(msg) + r.recorder.Event(svc, corev1.EventTypeWarning, "InvalidTailscaleService", msg) + tsoperator.SetServiceCondition(svc, tsapi.IngressSvcValid, metav1.ConditionFalse, reasonIngressSvcInvalid, msg, r.clock, logger) + return false, nil + } + + tags := r.defaultTags + if tstr, ok := svc.Annotations[AnnotationTags]; ok && tstr != "" { + tags = strings.Split(tstr, ",") + } + + tsSvc := &tailscale.VIPService{ + Name: serviceName, + Tags: tags, + Ports: []string{"do-not-validate"}, // we don't want to validate ports + Comment: managedTSServiceComment, + Annotations: updatedAnnotations, + } + if existingTSSvc != nil { + tsSvc.Addrs = existingTSSvc.Addrs + } + + // TODO(irbekrm): right now if two Service resources attempt to apply different Tailscale Service configs (different + // tags) we can end up reconciling those in a loop. We should detect when a Service + // with the same generation number has been reconciled ~more than N times and stop attempting to apply updates. + if existingTSSvc == nil || + !reflect.DeepEqual(tsSvc.Tags, existingTSSvc.Tags) || + !ownersAreSetAndEqual(tsSvc, existingTSSvc) { + logger.Infof("Ensuring Tailscale Service exists and is up to date") + if err := r.tsClient.CreateOrUpdateVIPService(ctx, tsSvc); err != nil { + return false, fmt.Errorf("error creating Tailscale Service: %w", err) + } + existingTSSvc = tsSvc + } + + cm, cfgs, err := ingressSvcsConfigs(ctx, r.Client, pgName, r.tsNamespace) + if err != nil { + return false, fmt.Errorf("error retrieving ingress services configuration: %w", err) + } + if cm == nil { + logger.Info("ConfigMap not yet created, waiting..") + return false, nil + } + + if existingTSSvc.Addrs == nil { + existingTSSvc, err = r.tsClient.GetVIPService(ctx, tsSvc.Name) + if err != nil { + return false, fmt.Errorf("error getting Tailscale Service: %w", err) + } + if existingTSSvc.Addrs == nil { + // TODO(irbekrm): this should be a retry + return false, fmt.Errorf("unexpected: Tailscale Service addresses not populated") + } + } + + var tsSvcIPv4 netip.Addr + var tsSvcIPv6 netip.Addr + for _, tsip := range existingTSSvc.Addrs { + ip, err := netip.ParseAddr(tsip) + if err != nil { + return false, fmt.Errorf("error parsing Tailscale Service address: %w", err) + } + + if ip.Is4() { + tsSvcIPv4 = ip + } else if ip.Is6() { + tsSvcIPv6 = ip + } + } + + cfg := ingressservices.Config{} + for _, cip := range svc.Spec.ClusterIPs { + ip, err := netip.ParseAddr(cip) + if err != nil { + return false, fmt.Errorf("error parsing Kubernetes Service address: %w", err) + } + + if ip.Is4() { + cfg.IPv4Mapping = &ingressservices.Mapping{ + ClusterIP: ip, + TailscaleServiceIP: tsSvcIPv4, + } + } else if ip.Is6() { + cfg.IPv6Mapping = &ingressservices.Mapping{ + ClusterIP: ip, + TailscaleServiceIP: tsSvcIPv6, + } + } + } + + existingCfg := cfgs[serviceName.String()] + if !reflect.DeepEqual(existingCfg, cfg) { + mak.Set(&cfgs, serviceName.String(), cfg) + cfgBytes, err := json.Marshal(cfgs) + if err != nil { + return false, fmt.Errorf("error marshaling ingress config: %w", err) + } + mak.Set(&cm.BinaryData, ingressservices.IngressConfigKey, cfgBytes) + if err := r.Update(ctx, cm); err != nil { + return false, fmt.Errorf("error updating ingress config: %w", err) + } + } + + logger.Infof("updating AdvertiseServices config") + // 4. Update tailscaled's AdvertiseServices config, which should add the Tailscale Service + // IPs to the ProxyGroup Pods' AllowedIPs in the next netmap update if approved. + if err = r.maybeUpdateAdvertiseServicesConfig(ctx, svc, pg.Name, serviceName, &cfg, true, logger); err != nil { + return false, fmt.Errorf("failed to update tailscaled config: %w", err) + } + + count, err := r.numberPodsAdvertising(ctx, pgName, serviceName) + if err != nil { + return false, fmt.Errorf("failed to get number of advertised Pods: %w", err) + } + + // TODO(irbekrm): here and when creating the Tailscale Service, verify if the + // error is not terminal (and therefore should not be reconciled). For + // example, if the hostname is already a hostname of a Tailscale node, + // the GET here will fail. + // If there are no Pods advertising the Tailscale Service (yet), we want to set 'svc.Status.LoadBalancer.Ingress' to nil" + var lbs []corev1.LoadBalancerIngress + conditionStatus := metav1.ConditionFalse + conditionType := tsapi.IngressSvcConfigured + conditionReason := reasonIngressSvcNoBackendsConfigured + conditionMessage := fmt.Sprintf("%d/%d proxy backends ready and advertising", count, pgReplicas(pg)) + if count != 0 { + dnsName, err := r.dnsNameForService(ctx, serviceName) + if err != nil { + return false, fmt.Errorf("error getting DNS name for Service: %w", err) + } + + lbs = []corev1.LoadBalancerIngress{ + { + Hostname: dnsName, + IP: tsSvcIPv4.String(), + }, + } + + conditionStatus = metav1.ConditionTrue + conditionReason = reasonIngressSvcConfigured + } + + tsoperator.SetServiceCondition(svc, conditionType, conditionStatus, conditionReason, conditionMessage, r.clock, logger) + svc.Status.LoadBalancer.Ingress = lbs + + return svcsChanged, nil +} + +// maybeCleanup ensures that any resources, such as a Tailscale Service created for this Service, are cleaned up when the +// Service is being deleted or is unexposed. The cleanup is safe for a multi-cluster setup- the Tailscale Service is only +// deleted if it does not contain any other owner references. If it does the cleanup only removes the owner reference +// corresponding to this Service. +func (r *HAServiceReconciler) maybeCleanup(ctx context.Context, hostname string, svc *corev1.Service, logger *zap.SugaredLogger) (svcChanged bool, err error) { + logger.Debugf("Ensuring any resources for Service are cleaned up") + ix := slices.Index(svc.Finalizers, finalizerName) + if ix < 0 { + logger.Debugf("no finalizer, nothing to do") + return false, nil + } + logger.Infof("Ensuring that Tailscale Service %q configuration is cleaned up", hostname) + + defer func() { + if err != nil { + return + } + err = r.deleteFinalizer(ctx, svc, logger) + }() + + serviceName := tailcfg.ServiceName("svc:" + hostname) + // 1. Clean up the Tailscale Service. + svcChanged, err = r.cleanupTailscaleService(ctx, serviceName, logger) + if err != nil { + return false, fmt.Errorf("error deleting Tailscale Service: %w", err) + } + + // 2. Unadvertise the Tailscale Service. + pgName := svc.Annotations[AnnotationProxyGroup] + if err = r.maybeUpdateAdvertiseServicesConfig(ctx, svc, pgName, serviceName, nil, false, logger); err != nil { + return false, fmt.Errorf("failed to update tailscaled config services: %w", err) + } + + // TODO: maybe wait for the service to be unadvertised, only then remove the backend routing + + // 3. Clean up ingress config (routing rules). + cm, cfgs, err := ingressSvcsConfigs(ctx, r.Client, pgName, r.tsNamespace) + if err != nil { + return false, fmt.Errorf("error retrieving ingress services configuration: %w", err) + } + if cm == nil || cfgs == nil { + return true, nil + } + logger.Infof("Removing Tailscale Service %q from ingress config for ProxyGroup %q", hostname, pgName) + delete(cfgs, serviceName.String()) + cfgBytes, err := json.Marshal(cfgs) + if err != nil { + return false, fmt.Errorf("error marshaling ingress config: %w", err) + } + mak.Set(&cm.BinaryData, ingressservices.IngressConfigKey, cfgBytes) + return true, r.Update(ctx, cm) +} + +// Tailscale Services that are associated with the provided ProxyGroup and no longer managed this operator's instance are deleted, if not owned by other operator instances, else the owner reference is cleaned up. +// Returns true if the operation resulted in existing Tailscale Service updates (owner reference removal). +func (r *HAServiceReconciler) maybeCleanupProxyGroup(ctx context.Context, proxyGroupName string, logger *zap.SugaredLogger) (svcsChanged bool, err error) { + cm, config, err := ingressSvcsConfigs(ctx, r.Client, proxyGroupName, r.tsNamespace) + if err != nil { + return false, fmt.Errorf("failed to get ingress service config: %s", err) + } + + svcList := &corev1.ServiceList{} + if err := r.Client.List(ctx, svcList, client.MatchingFields{indexIngressProxyGroup: proxyGroupName}); err != nil { + return false, fmt.Errorf("failed to find Services for ProxyGroup %q: %w", proxyGroupName, err) + } + + ingressConfigChanged := false + for tsSvcName, cfg := range config { + found := false + for _, svc := range svcList.Items { + if strings.EqualFold(fmt.Sprintf("svc:%s", nameForService(&svc)), tsSvcName) { + found = true + break + } + } + if !found { + logger.Infof("Tailscale Service %q is not owned by any Service, cleaning up", tsSvcName) + + // Make sure the Tailscale Service is not advertised in tailscaled or serve config. + if err = r.maybeUpdateAdvertiseServicesConfig(ctx, nil, proxyGroupName, tailcfg.ServiceName(tsSvcName), &cfg, false, logger); err != nil { + return false, fmt.Errorf("failed to update tailscaled config services: %w", err) + } + + svcsChanged, err = r.cleanupTailscaleService(ctx, tailcfg.ServiceName(tsSvcName), logger) + if err != nil { + return false, fmt.Errorf("deleting Tailscale Service %q: %w", tsSvcName, err) + } + + _, ok := config[tsSvcName] + if ok { + logger.Infof("Removing Tailscale Service %q from serve config", tsSvcName) + delete(config, tsSvcName) + ingressConfigChanged = true + } + } + } + + if ingressConfigChanged { + configBytes, err := json.Marshal(config) + if err != nil { + return false, fmt.Errorf("marshaling serve config: %w", err) + } + mak.Set(&cm.BinaryData, ingressservices.IngressConfigKey, configBytes) + if err := r.Update(ctx, cm); err != nil { + return false, fmt.Errorf("updating serve config: %w", err) + } + } + + return svcsChanged, nil +} + +func (r *HAServiceReconciler) deleteFinalizer(ctx context.Context, svc *corev1.Service, logger *zap.SugaredLogger) error { + svc.Finalizers = slices.DeleteFunc(svc.Finalizers, func(f string) bool { + return f == finalizerName + }) + logger.Debugf("ensure %q finalizer is removed", finalizerName) + + if err := r.Update(ctx, svc); err != nil { + return fmt.Errorf("failed to remove finalizer %q: %w", finalizerName, err) + } + r.mu.Lock() + defer r.mu.Unlock() + r.managedServices.Remove(svc.UID) + gaugePGServiceResources.Set(int64(r.managedServices.Len())) + return nil +} + +func (r *HAServiceReconciler) isTailscaleService(svc *corev1.Service) bool { + proxyGroup := svc.Annotations[AnnotationProxyGroup] + return r.shouldExpose(svc) && proxyGroup != "" +} + +func (r *HAServiceReconciler) shouldExpose(svc *corev1.Service) bool { + return r.shouldExposeClusterIP(svc) +} + +func (r *HAServiceReconciler) shouldExposeClusterIP(svc *corev1.Service) bool { + if svc.Spec.ClusterIP == "" || svc.Spec.ClusterIP == "None" { + return false + } + return isTailscaleLoadBalancerService(svc, r.isDefaultLoadBalancer) || hasExposeAnnotation(svc) +} + +// tailnetCertDomain returns the base domain (TCD) of the current tailnet. +func (r *HAServiceReconciler) tailnetCertDomain(ctx context.Context) (string, error) { + st, err := r.lc.StatusWithoutPeers(ctx) + if err != nil { + return "", fmt.Errorf("error getting tailscale status: %w", err) + } + return st.CurrentTailnet.MagicDNSSuffix, nil +} + +// cleanupTailscaleService deletes any Tailscale Service by the provided name if it is not owned by operator instances other than this one. +// If a Tailscale Service is found, but contains other owner references, only removes this operator's owner reference. +// If a Tailscale Service by the given name is not found or does not contain this operator's owner reference, do nothing. +// It returns true if an existing Tailscale Service was updated to remove owner reference, as well as any error that occurred. +func (r *HAServiceReconciler) cleanupTailscaleService(ctx context.Context, name tailcfg.ServiceName, logger *zap.SugaredLogger) (updated bool, err error) { + svc, err := r.tsClient.GetVIPService(ctx, name) + if isErrorFeatureFlagNotEnabled(err) { + msg := fmt.Sprintf("Unable to proceed with cleanup: %s.", msgFeatureFlagNotEnabled) + logger.Warn(msg) + return false, nil + } + if err != nil { + errResp := &tailscale.ErrResponse{} + ok := errors.As(err, errResp) + if ok && errResp.Status == http.StatusNotFound { + return false, nil + } + if !ok { + return false, fmt.Errorf("unexpected error getting Tailscale Service %q: %w", name.String(), err) + } + + return false, fmt.Errorf("error getting Tailscale Service: %w", err) + } + if svc == nil { + return false, nil + } + o, err := parseOwnerAnnotation(svc) + if err != nil { + return false, fmt.Errorf("error parsing Tailscale Service owner annotation: %w", err) + } + if o == nil || len(o.OwnerRefs) == 0 { + return false, nil + } + // Comparing with the operatorID only means that we will not be able to + // clean up Tailscale Services in cases where the operator was deleted from the + // cluster before deleting the Ingress. Perhaps the comparison could be + // 'if or.OperatorID == r.operatorID || or.ingressUID == r.ingressUID'. + ix := slices.IndexFunc(o.OwnerRefs, func(or OwnerRef) bool { + return or.OperatorID == r.operatorID + }) + if ix == -1 { + return false, nil + } + if len(o.OwnerRefs) == 1 { + logger.Infof("Deleting Tailscale Service %q", name) + return false, r.tsClient.DeleteVIPService(ctx, name) + } + o.OwnerRefs = slices.Delete(o.OwnerRefs, ix, ix+1) + logger.Infof("Updating Tailscale Service %q", name) + json, err := json.Marshal(o) + if err != nil { + return false, fmt.Errorf("error marshalling updated Tailscale Service owner reference: %w", err) + } + svc.Annotations[ownerAnnotation] = string(json) + return true, r.tsClient.CreateOrUpdateVIPService(ctx, svc) +} + +func (a *HAServiceReconciler) backendRoutesSetup(ctx context.Context, serviceName, replicaName, pgName string, wantsCfg *ingressservices.Config, logger *zap.SugaredLogger) (bool, error) { + logger.Debugf("checking backend routes for service '%s'", serviceName) + pod := &corev1.Pod{} + err := a.Get(ctx, client.ObjectKey{Namespace: a.tsNamespace, Name: replicaName}, pod) + if apierrors.IsNotFound(err) { + logger.Debugf("Pod %q not found", replicaName) + return false, nil + } + if err != nil { + return false, fmt.Errorf("failed to get Pod: %w", err) + } + secret := &corev1.Secret{} + err = a.Get(ctx, client.ObjectKey{Namespace: a.tsNamespace, Name: replicaName}, secret) + if apierrors.IsNotFound(err) { + logger.Debugf("Secret %q not found", replicaName) + return false, nil + } + if err != nil { + return false, fmt.Errorf("failed to get Secret: %w", err) + } + if len(secret.Data) == 0 || secret.Data[ingressservices.IngressConfigKey] == nil { + return false, nil + } + gotCfgB := secret.Data[ingressservices.IngressConfigKey] + var gotCfgs ingressservices.Status + if err := json.Unmarshal(gotCfgB, &gotCfgs); err != nil { + return false, fmt.Errorf("error unmarshalling ingress config: %w", err) + } + statusUpToDate, err := isCurrentStatus(gotCfgs, pod, logger) + if err != nil { + return false, fmt.Errorf("error checking ingress config status: %w", err) + } + if !statusUpToDate || !reflect.DeepEqual(gotCfgs.Configs.GetConfig(serviceName), wantsCfg) { + logger.Debugf("Pod %q is not ready to advertise Tailscale Service", pod.Name) + return false, nil + } + return true, nil +} + +func isCurrentStatus(gotCfgs ingressservices.Status, pod *corev1.Pod, logger *zap.SugaredLogger) (bool, error) { + ips := pod.Status.PodIPs + if len(ips) == 0 { + logger.Debugf("Pod %q does not yet have IPs, unable to determine if status is up to date", pod.Name) + return false, nil + } + + if len(ips) > 2 { + return false, fmt.Errorf("pod 'status.PodIPs' can contain at most 2 IPs, got %d (%v)", len(ips), ips) + } + var podIPv4, podIPv6 string + for _, ip := range ips { + parsed, err := netip.ParseAddr(ip.IP) + if err != nil { + return false, fmt.Errorf("error parsing IP address %s: %w", ip.IP, err) + } + if parsed.Is4() { + podIPv4 = parsed.String() + continue + } + podIPv6 = parsed.String() + } + if podIPv4 != gotCfgs.PodIPv4 || podIPv6 != gotCfgs.PodIPv6 { + return false, nil + } + return true, nil +} + +func (a *HAServiceReconciler) maybeUpdateAdvertiseServicesConfig(ctx context.Context, svc *corev1.Service, pgName string, serviceName tailcfg.ServiceName, cfg *ingressservices.Config, shouldBeAdvertised bool, logger *zap.SugaredLogger) (err error) { + logger.Debugf("checking advertisement for service '%s'", serviceName) + // Get all config Secrets for this ProxyGroup. + // Get all Pods + secrets := &corev1.SecretList{} + if err := a.List(ctx, secrets, client.InNamespace(a.tsNamespace), client.MatchingLabels(pgSecretLabels(pgName, "config"))); err != nil { + return fmt.Errorf("failed to list config Secrets: %w", err) + } + + if svc != nil && shouldBeAdvertised { + shouldBeAdvertised, err = a.checkEndpointsReady(ctx, svc, logger) + if err != nil { + return fmt.Errorf("failed to check readiness of Service '%s' endpoints: %w", svc.Name, err) + } + } + + for _, secret := range secrets.Items { + var updated bool + for fileName, confB := range secret.Data { + var conf ipn.ConfigVAlpha + if err := json.Unmarshal(confB, &conf); err != nil { + return fmt.Errorf("error unmarshalling ProxyGroup config: %w", err) + } + + idx := slices.Index(conf.AdvertiseServices, serviceName.String()) + isAdvertised := idx >= 0 + switch { + case !isAdvertised && !shouldBeAdvertised: + logger.Debugf("service %q shouldn't be advertised", serviceName) + continue + case isAdvertised && shouldBeAdvertised: + logger.Debugf("service %q is already advertised", serviceName) + continue + case isAdvertised && !shouldBeAdvertised: + logger.Debugf("deleting advertisement for service %q", serviceName) + conf.AdvertiseServices = slices.Delete(conf.AdvertiseServices, idx, idx+1) + case shouldBeAdvertised: + replicaName, ok := strings.CutSuffix(secret.Name, "-config") + if !ok { + logger.Infof("[unexpected] unable to determine replica name from config Secret name %q, unable to determine if backend routing has been configured", secret.Name) + return nil + } + ready, err := a.backendRoutesSetup(ctx, serviceName.String(), replicaName, pgName, cfg, logger) + if err != nil { + return fmt.Errorf("error checking backend routes: %w", err) + } + if !ready { + logger.Debugf("service %q is not ready to be advertised", serviceName) + continue + } + + conf.AdvertiseServices = append(conf.AdvertiseServices, serviceName.String()) + } + confB, err := json.Marshal(conf) + if err != nil { + return fmt.Errorf("error marshalling ProxyGroup config: %w", err) + } + mak.Set(&secret.Data, fileName, confB) + updated = true + } + if updated { + if err := a.Update(ctx, &secret); err != nil { + return fmt.Errorf("error updating ProxyGroup config Secret: %w", err) + } + } + } + return nil +} + +func (a *HAServiceReconciler) numberPodsAdvertising(ctx context.Context, pgName string, serviceName tailcfg.ServiceName) (int, error) { + // Get all state Secrets for this ProxyGroup. + secrets := &corev1.SecretList{} + if err := a.List(ctx, secrets, client.InNamespace(a.tsNamespace), client.MatchingLabels(pgSecretLabels(pgName, "state"))); err != nil { + return 0, fmt.Errorf("failed to list ProxyGroup %q state Secrets: %w", pgName, err) + } + + var count int + for _, secret := range secrets.Items { + prefs, ok, err := getDevicePrefs(&secret) + if err != nil { + return 0, fmt.Errorf("error getting node metadata: %w", err) + } + if !ok { + continue + } + if slices.Contains(prefs.AdvertiseServices, serviceName.String()) { + count++ + } + } + + return count, nil +} + +// ownerAnnotations returns the updated annotations required to ensure this +// instance of the operator is included as an owner. If the Tailscale Service is not +// nil, but does not contain an owner we return an error as this likely means +// that the Tailscale Service was created by something other than a Tailscale +// Kubernetes operator. +func (r *HAServiceReconciler) ownerAnnotations(svc *tailscale.VIPService) (map[string]string, error) { + ref := OwnerRef{ + OperatorID: r.operatorID, + } + if svc == nil { + c := ownerAnnotationValue{OwnerRefs: []OwnerRef{ref}} + json, err := json.Marshal(c) + if err != nil { + return nil, fmt.Errorf("[unexpected] unable to marshal Tailscale Service owner annotation contents: %w, please report this", err) + } + return map[string]string{ + ownerAnnotation: string(json), + }, nil + } + o, err := parseOwnerAnnotation(svc) + if err != nil { + return nil, err + } + if o == nil || len(o.OwnerRefs) == 0 { + return nil, fmt.Errorf("Tailscale Service %s exists, but does not contain owner annotation with owner references; not proceeding as this is likely a resource created by something other than the Tailscale Kubernetes operator", svc.Name) + } + if slices.Contains(o.OwnerRefs, ref) { // up to date + return svc.Annotations, nil + } + o.OwnerRefs = append(o.OwnerRefs, ref) + json, err := json.Marshal(o) + if err != nil { + return nil, fmt.Errorf("error marshalling updated owner references: %w", err) + } + + newAnnots := make(map[string]string, len(svc.Annotations)+1) + for k, v := range svc.Annotations { + newAnnots[k] = v + } + newAnnots[ownerAnnotation] = string(json) + return newAnnots, nil +} + +// dnsNameForService returns the DNS name for the given Tailscale Service name. +func (r *HAServiceReconciler) dnsNameForService(ctx context.Context, svc tailcfg.ServiceName) (string, error) { + s := svc.WithoutPrefix() + tcd, err := r.tailnetCertDomain(ctx) + if err != nil { + return "", fmt.Errorf("error determining DNS name base: %w", err) + } + return s + "." + tcd, nil +} + +// ingressSvcsConfig returns a ConfigMap that contains ingress services configuration for the provided ProxyGroup as well +// as unmarshalled configuration from the ConfigMap. +func ingressSvcsConfigs(ctx context.Context, cl client.Client, proxyGroupName, tsNamespace string) (cm *corev1.ConfigMap, cfgs ingressservices.Configs, err error) { + name := pgIngressCMName(proxyGroupName) + cm = &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: tsNamespace, + }, + } + err = cl.Get(ctx, client.ObjectKeyFromObject(cm), cm) + if apierrors.IsNotFound(err) { // ProxyGroup resources have not been created (yet) + return nil, nil, nil + } + if err != nil { + return nil, nil, fmt.Errorf("error retrieving ingress services ConfigMap %s: %v", name, err) + } + cfgs = ingressservices.Configs{} + if len(cm.BinaryData[ingressservices.IngressConfigKey]) != 0 { + if err := json.Unmarshal(cm.BinaryData[ingressservices.IngressConfigKey], &cfgs); err != nil { + return nil, nil, fmt.Errorf("error unmarshaling ingress services config %v: %w", cm.BinaryData[ingressservices.IngressConfigKey], err) + } + } + return cm, cfgs, nil +} + +func (r *HAServiceReconciler) getEndpointSlicesForService(ctx context.Context, svc *corev1.Service, logger *zap.SugaredLogger) ([]discoveryv1.EndpointSlice, error) { + logger.Debugf("looking for endpoint slices for svc with name '%s' in namespace '%s' matching label '%s=%s'", svc.Name, svc.Namespace, discoveryv1.LabelServiceName, svc.Name) + // https://kubernetes.io/docs/concepts/services-networking/endpoint-slices/#ownership + labels := map[string]string{discoveryv1.LabelServiceName: svc.Name} + eps := new(discoveryv1.EndpointSliceList) + if err := r.List(ctx, eps, client.InNamespace(svc.Namespace), client.MatchingLabels(labels)); err != nil { + return nil, fmt.Errorf("error listing EndpointSlices: %w", err) + } + + if len(eps.Items) == 0 { + logger.Debugf("Service '%s' EndpointSlice does not yet exist. We will reconcile again once it's created", svc.Name) + return nil, nil + } + + return eps.Items, nil +} + +func (r *HAServiceReconciler) checkEndpointsReady(ctx context.Context, svc *corev1.Service, logger *zap.SugaredLogger) (bool, error) { + epss, err := r.getEndpointSlicesForService(ctx, svc, logger) + if err != nil { + return false, fmt.Errorf("failed to list EndpointSlices for Service %q: %w", svc.Name, err) + } + for _, eps := range epss { + for _, ep := range eps.Endpoints { + if *ep.Conditions.Ready { + return true, nil + } + } + } + + logger.Debugf("could not find any ready Endpoints in EndpointSlice") + return false, nil +} diff --git a/cmd/k8s-operator/svc-for-pg_test.go b/cmd/k8s-operator/svc-for-pg_test.go new file mode 100644 index 0000000000000..4bb633cb8ff66 --- /dev/null +++ b/cmd/k8s-operator/svc-for-pg_test.go @@ -0,0 +1,367 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "encoding/json" + "fmt" + "math/rand/v2" + "net/netip" + "testing" + + "go.uber.org/zap" + corev1 "k8s.io/api/core/v1" + discoveryv1 "k8s.io/api/discovery/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/client-go/tools/record" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/fake" + "tailscale.com/ipn/ipnstate" + tsoperator "tailscale.com/k8s-operator" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/ingressservices" + "tailscale.com/tstest" + "tailscale.com/types/ptr" + "tailscale.com/util/mak" + + "tailscale.com/tailcfg" +) + +func TestServicePGReconciler(t *testing.T) { + svcPGR, stateSecret, fc, ft := setupServiceTest(t) + svcs := []*corev1.Service{} + config := []string{} + for i := range 4 { + svc, _ := setupTestService(t, fmt.Sprintf("test-svc-%d", i), "", fmt.Sprintf("1.2.3.%d", i), fc, stateSecret) + svcs = append(svcs, svc) + + // Verify initial reconciliation + expectReconciled(t, svcPGR, "default", svc.Name) + + config = append(config, fmt.Sprintf("svc:default-%s", svc.Name)) + verifyTailscaleService(t, ft, fmt.Sprintf("svc:default-%s", svc.Name), []string{"do-not-validate"}) + verifyTailscaledConfig(t, fc, config) + } + + for i, svc := range svcs { + if err := fc.Delete(context.Background(), svc); err != nil { + t.Fatalf("deleting Service: %v", err) + } + + expectReconciled(t, svcPGR, "default", svc.Name) + + // Verify the ConfigMap was cleaned up + cm := &corev1.ConfigMap{} + if err := fc.Get(context.Background(), types.NamespacedName{ + Name: "test-pg-ingress-config", + Namespace: "operator-ns", + }, cm); err != nil { + t.Fatalf("getting ConfigMap: %v", err) + } + + cfgs := ingressservices.Configs{} + if err := json.Unmarshal(cm.BinaryData[ingressservices.IngressConfigKey], &cfgs); err != nil { + t.Fatalf("unmarshaling serve config: %v", err) + } + + if len(cfgs) > len(svcs)-(i+1) { + t.Error("serve config not cleaned up") + } + + config = removeEl(config, fmt.Sprintf("svc:default-%s", svc.Name)) + verifyTailscaledConfig(t, fc, config) + } +} + +func TestServicePGReconciler_UpdateHostname(t *testing.T) { + svcPGR, stateSecret, fc, ft := setupServiceTest(t) + + cip := "4.1.6.7" + svc, _ := setupTestService(t, "test-service", "", cip, fc, stateSecret) + + expectReconciled(t, svcPGR, "default", svc.Name) + + verifyTailscaleService(t, ft, fmt.Sprintf("svc:default-%s", svc.Name), []string{"do-not-validate"}) + verifyTailscaledConfig(t, fc, []string{fmt.Sprintf("svc:default-%s", svc.Name)}) + + hostname := "foobarbaz" + mustUpdate(t, fc, svc.Namespace, svc.Name, func(s *corev1.Service) { + mak.Set(&s.Annotations, AnnotationHostname, hostname) + }) + + // NOTE: we need to update the ingress config Secret because there is no containerboot in the fake proxy Pod + updateIngressConfigSecret(t, fc, stateSecret, hostname, cip) + expectReconciled(t, svcPGR, "default", svc.Name) + + verifyTailscaleService(t, ft, fmt.Sprintf("svc:%s", hostname), []string{"do-not-validate"}) + verifyTailscaledConfig(t, fc, []string{fmt.Sprintf("svc:%s", hostname)}) + + _, err := ft.GetVIPService(context.Background(), tailcfg.ServiceName(fmt.Sprintf("svc:default-%s", svc.Name))) + if err == nil { + t.Fatalf("svc:default-%s not cleaned up", svc.Name) + } + if !isErrorTailscaleServiceNotFound(err) { + t.Fatalf("unexpected error: %v", err) + } +} + +func setupServiceTest(t *testing.T) (*HAServiceReconciler, *corev1.Secret, client.Client, *fakeTSClient) { + // Pre-create the ProxyGroup + pg := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg", + Generation: 1, + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeIngress, + }, + } + + // Pre-create the ConfigMap for the ProxyGroup + pgConfigMap := &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg-ingress-config", + Namespace: "operator-ns", + }, + BinaryData: map[string][]byte{ + "serve-config.json": []byte(`{"Services":{}}`), + }, + } + + // Pre-create a config Secret for the ProxyGroup + pgCfgSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: pgConfigSecretName("test-pg", 0), + Namespace: "operator-ns", + Labels: pgSecretLabels("test-pg", "config"), + }, + Data: map[string][]byte{ + tsoperator.TailscaledConfigFileName(106): []byte(`{"Version":""}`), + }, + } + + pgStateSecret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg-0", + Namespace: "operator-ns", + }, + Data: map[string][]byte{}, + } + + pgPod := &corev1.Pod{ + TypeMeta: metav1.TypeMeta{Kind: "Pod", APIVersion: "v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pg-0", + Namespace: "operator-ns", + }, + Status: corev1.PodStatus{ + PodIPs: []corev1.PodIP{ + { + IP: "4.3.2.1", + }, + }, + }, + } + + fc := fake.NewClientBuilder(). + WithScheme(tsapi.GlobalScheme). + WithObjects(pg, pgCfgSecret, pgConfigMap, pgPod, pgStateSecret). + WithStatusSubresource(pg). + WithIndex(new(corev1.Service), indexIngressProxyGroup, indexPGIngresses). + Build() + + // Set ProxyGroup status to ready + pg.Status.Conditions = []metav1.Condition{ + { + Type: string(tsapi.ProxyGroupReady), + Status: metav1.ConditionTrue, + ObservedGeneration: 1, + }, + } + if err := fc.Status().Update(context.Background(), pg); err != nil { + t.Fatal(err) + } + fakeTsnetServer := &fakeTSNetServer{certDomains: []string{"foo.com"}} + + ft := &fakeTSClient{} + zl, err := zap.NewDevelopment() + if err != nil { + t.Fatal(err) + } + + lc := &fakeLocalClient{ + status: &ipnstate.Status{ + CurrentTailnet: &ipnstate.TailnetStatus{ + MagicDNSSuffix: "ts.net", + }, + }, + } + + cl := tstest.NewClock(tstest.ClockOpts{}) + svcPGR := &HAServiceReconciler{ + Client: fc, + tsClient: ft, + clock: cl, + defaultTags: []string{"tag:k8s"}, + tsNamespace: "operator-ns", + tsnetServer: fakeTsnetServer, + logger: zl.Sugar(), + recorder: record.NewFakeRecorder(10), + lc: lc, + } + + return svcPGR, pgStateSecret, fc, ft +} + +func TestServicePGReconciler_MultiCluster(t *testing.T) { + var ft *fakeTSClient + var lc localClient + for i := 0; i <= 10; i++ { + pgr, stateSecret, fc, fti := setupServiceTest(t) + if i == 0 { + ft = fti + lc = pgr.lc + } else { + pgr.tsClient = ft + pgr.lc = lc + } + + svc, _ := setupTestService(t, "test-multi-cluster", "", "4.3.2.1", fc, stateSecret) + expectReconciled(t, pgr, "default", svc.Name) + + tsSvcs, err := ft.ListVIPServices(context.Background()) + if err != nil { + t.Fatalf("getting Tailscale Service: %v", err) + } + + if len(tsSvcs) != 1 { + t.Fatalf("unexpected number of Tailscale Services (%d)", len(tsSvcs)) + } + + for name := range tsSvcs { + t.Logf("found Tailscale Service with name %q", name.String()) + } + } +} + +func TestIgnoreRegularService(t *testing.T) { + pgr, _, fc, ft := setupServiceTest(t) + + svc := &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test", + Namespace: "default", + // The apiserver is supposed to set the UID, but the fake client + // doesn't. So, set it explicitly because other code later depends + // on it being set. + UID: types.UID("1234-UID"), + Annotations: map[string]string{ + "tailscale.com/expose": "true", + }, + }, + Spec: corev1.ServiceSpec{ + ClusterIP: "10.20.30.40", + Type: corev1.ServiceTypeClusterIP, + }, + } + + mustCreate(t, fc, svc) + expectReconciled(t, pgr, "default", "test") + + verifyTailscaledConfig(t, fc, nil) + + tsSvcs, err := ft.ListVIPServices(context.Background()) + if err == nil { + if len(tsSvcs) > 0 { + t.Fatal("unexpected Tailscale Services found") + } + } +} + +func removeEl(s []string, value string) []string { + result := s[:0] + for _, v := range s { + if v != value { + result = append(result, v) + } + } + return result +} + +func updateIngressConfigSecret(t *testing.T, fc client.Client, stateSecret *corev1.Secret, serviceName string, clusterIP string) { + ingressConfig := ingressservices.Configs{ + fmt.Sprintf("svc:%s", serviceName): ingressservices.Config{ + IPv4Mapping: &ingressservices.Mapping{ + TailscaleServiceIP: netip.MustParseAddr(vipTestIP), + ClusterIP: netip.MustParseAddr(clusterIP), + }, + }, + } + + ingressStatus := ingressservices.Status{ + Configs: ingressConfig, + PodIPv4: "4.3.2.1", + } + + icJson, err := json.Marshal(ingressStatus) + if err != nil { + t.Fatalf("failed to json marshal ingress config: %s", err.Error()) + } + + mustUpdate(t, fc, stateSecret.Namespace, stateSecret.Name, func(sec *corev1.Secret) { + mak.Set(&sec.Data, ingressservices.IngressConfigKey, icJson) + }) +} + +func setupTestService(t *testing.T, svcName string, hostname string, clusterIP string, fc client.Client, stateSecret *corev1.Secret) (svc *corev1.Service, eps *discoveryv1.EndpointSlice) { + uid := rand.IntN(100) + svc = &corev1.Service{ + TypeMeta: metav1.TypeMeta{Kind: "Service", APIVersion: "v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: svcName, + Namespace: "default", + UID: types.UID(fmt.Sprintf("%d-UID", uid)), + Annotations: map[string]string{ + "tailscale.com/proxy-group": "test-pg", + }, + }, + Spec: corev1.ServiceSpec{ + Type: corev1.ServiceTypeLoadBalancer, + LoadBalancerClass: ptr.To("tailscale"), + ClusterIP: clusterIP, + ClusterIPs: []string{clusterIP}, + }, + } + + eps = &discoveryv1.EndpointSlice{ + TypeMeta: metav1.TypeMeta{Kind: "EndpointSlice", APIVersion: "v1"}, + ObjectMeta: metav1.ObjectMeta{ + Name: svcName, + Namespace: "default", + Labels: map[string]string{ + discoveryv1.LabelServiceName: svcName, + }, + }, + AddressType: discoveryv1.AddressTypeIPv4, + Endpoints: []discoveryv1.Endpoint{ + { + Addresses: []string{"4.3.2.1"}, + Conditions: discoveryv1.EndpointConditions{ + Ready: ptr.To(true), + }, + }, + }, + } + + updateIngressConfigSecret(t, fc, stateSecret, fmt.Sprintf("default-%s", svcName), clusterIP) + + mustCreate(t, fc, svc) + mustCreate(t, fc, eps) + + return svc, eps +} diff --git a/cmd/k8s-operator/svc.go b/cmd/k8s-operator/svc.go index f45f922463113..7cd19bc591c59 100644 --- a/cmd/k8s-operator/svc.go +++ b/cmd/k8s-operator/svc.go @@ -65,6 +65,10 @@ type ServiceReconciler struct { clock tstime.Clock defaultProxyClass string + + validationOpts validationOpts + + noFqdnDotAppend bool } var ( @@ -84,10 +88,10 @@ func childResourceLabels(name, ns, typ string) map[string]string { // proxying. Instead, we have to do our own filtering and tracking with // labels. return map[string]string{ - LabelManaged: "true", - LabelParentName: name, - LabelParentNamespace: ns, - LabelParentType: typ, + kubetypes.LabelManaged: "true", + LabelParentName: name, + LabelParentNamespace: ns, + LabelParentType: typ, } } @@ -121,7 +125,15 @@ func (a *ServiceReconciler) Reconcile(ctx context.Context, req reconcile.Request return reconcile.Result{}, a.maybeCleanup(ctx, logger, svc) } - return reconcile.Result{}, a.maybeProvision(ctx, logger, svc) + if err := a.maybeProvision(ctx, logger, svc); err != nil { + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + logger.Infof("optimistic lock error, retrying: %s", err) + } else { + return reconcile.Result{}, err + } + } + + return reconcile.Result{}, nil } // maybeCleanup removes any existing resources related to serving svc over tailscale. @@ -131,7 +143,7 @@ func (a *ServiceReconciler) Reconcile(ctx context.Context, req reconcile.Request func (a *ServiceReconciler) maybeCleanup(ctx context.Context, logger *zap.SugaredLogger, svc *corev1.Service) (err error) { oldSvcStatus := svc.Status.DeepCopy() defer func() { - if !apiequality.Semantic.DeepEqual(oldSvcStatus, svc.Status) { + if !apiequality.Semantic.DeepEqual(oldSvcStatus, &svc.Status) { // An error encountered here should get returned by the Reconcile function. err = errors.Join(err, a.Client.Status().Update(ctx, svc)) } @@ -152,7 +164,12 @@ func (a *ServiceReconciler) maybeCleanup(ctx context.Context, logger *zap.Sugare return nil } - if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(svc.Name, svc.Namespace, "svc")); err != nil { + proxyTyp := proxyTypeEgress + if a.shouldExpose(svc) { + proxyTyp = proxyTypeIngressService + } + + if done, err := a.ssr.Cleanup(ctx, logger, childResourceLabels(svc.Name, svc.Namespace, "svc"), proxyTyp); err != nil { return fmt.Errorf("failed to cleanup: %w", err) } else if !done { logger.Debugf("cleanup not done yet, waiting for next reconcile") @@ -191,7 +208,7 @@ func (a *ServiceReconciler) maybeCleanup(ctx context.Context, logger *zap.Sugare func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.SugaredLogger, svc *corev1.Service) (err error) { oldSvcStatus := svc.Status.DeepCopy() defer func() { - if !apiequality.Semantic.DeepEqual(oldSvcStatus, svc.Status) { + if !apiequality.Semantic.DeepEqual(oldSvcStatus, &svc.Status) { // An error encountered here should get returned by the Reconcile function. err = errors.Join(err, a.Client.Status().Update(ctx, svc)) } @@ -207,7 +224,7 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga tsoperator.SetServiceCondition(svc, tsapi.ProxyReady, metav1.ConditionFalse, reasonProxyInvalid, msg, a.clock, logger) return nil } - if violations := validateService(svc); len(violations) > 0 { + if violations := validateService(svc, a.validationOpts); len(violations) > 0 { msg := fmt.Sprintf("unable to provision proxy resources: invalid Service: %s", strings.Join(violations, ", ")) a.recorder.Event(svc, corev1.EventTypeWarning, "INVALIDSERVICE", msg) a.logger.Error(msg) @@ -256,6 +273,10 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga ChildResourceLabels: crl, ProxyClassName: proxyClass, } + sts.proxyType = proxyTypeEgress + if a.shouldExpose(svc) { + sts.proxyType = proxyTypeIngressService + } a.mu.Lock() if a.shouldExposeClusterIP(svc) { @@ -272,7 +293,7 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga gaugeEgressProxies.Set(int64(a.managedEgressProxies.Len())) } else if fqdn := svc.Annotations[AnnotationTailnetTargetFQDN]; fqdn != "" { fqdn := svc.Annotations[AnnotationTailnetTargetFQDN] - if !strings.HasSuffix(fqdn, ".") { + if !strings.HasSuffix(fqdn, ".") && !a.noFqdnDotAppend { fqdn = fqdn + "." } sts.TailnetTargetFQDN = fqdn @@ -311,11 +332,11 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga return nil } - _, tsHost, tsIPs, err := a.ssr.DeviceInfo(ctx, crl) + dev, err := a.ssr.DeviceInfo(ctx, crl, logger) if err != nil { return fmt.Errorf("failed to get device ID: %w", err) } - if tsHost == "" { + if dev == nil || dev.hostname == "" { msg := "no Tailscale hostname known yet, waiting for proxy pod to finish auth" logger.Debug(msg) // No hostname yet. Wait for the proxy pod to auth. @@ -324,9 +345,9 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga return nil } - logger.Debugf("setting Service LoadBalancer status to %q, %s", tsHost, strings.Join(tsIPs, ", ")) + logger.Debugf("setting Service LoadBalancer status to %q, %s", dev.hostname, strings.Join(dev.ips, ", ")) ingress := []corev1.LoadBalancerIngress{ - {Hostname: tsHost}, + {Hostname: dev.hostname}, } clusterIPAddr, err := netip.ParseAddr(svc.Spec.ClusterIP) if err != nil { @@ -334,7 +355,7 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga tsoperator.SetServiceCondition(svc, tsapi.ProxyReady, metav1.ConditionFalse, reasonProxyFailed, msg, a.clock, logger) return errors.New(msg) } - for _, ip := range tsIPs { + for _, ip := range dev.ips { addr, err := netip.ParseAddr(ip) if err != nil { continue @@ -348,19 +369,24 @@ func (a *ServiceReconciler) maybeProvision(ctx context.Context, logger *zap.Suga return nil } -func validateService(svc *corev1.Service) []string { +func validateService(svc *corev1.Service, opts validationOpts) []string { violations := make([]string, 0) if svc.Annotations[AnnotationTailnetTargetFQDN] != "" && svc.Annotations[AnnotationTailnetTargetIP] != "" { violations = append(violations, fmt.Sprintf("only one of annotations %s and %s can be set", AnnotationTailnetTargetIP, AnnotationTailnetTargetFQDN)) } if fqdn := svc.Annotations[AnnotationTailnetTargetFQDN]; fqdn != "" { - if !isMagicDNSName(fqdn) { + if !isMagicDNSName(fqdn, opts) { violations = append(violations, fmt.Sprintf("invalid value of annotation %s: %q does not appear to be a valid MagicDNS name", AnnotationTailnetTargetFQDN, fqdn)) } } - - // TODO(irbekrm): validate that tailscale.com/tailnet-ip annotation is a - // valid IP address (tailscale/tailscale#13671). + if ipStr := svc.Annotations[AnnotationTailnetTargetIP]; ipStr != "" { + ip, err := netip.ParseAddr(ipStr) + if err != nil { + violations = append(violations, fmt.Sprintf("invalid value of annotation %s: %q could not be parsed as a valid IP Address, error: %s", AnnotationTailnetTargetIP, ipStr, err)) + } else if !ip.IsValid() { + violations = append(violations, fmt.Sprintf("parsed IP address in annotation %s: %q is not valid", AnnotationTailnetTargetIP, ipStr)) + } + } svcName := nameForService(svc) if err := dnsname.ValidLabel(svcName); err != nil { diff --git a/cmd/k8s-operator/testutils_test.go b/cmd/k8s-operator/testutils_test.go index 6b6297cbdd4fe..619aecc56816e 100644 --- a/cmd/k8s-operator/testutils_test.go +++ b/cmd/k8s-operator/testutils_test.go @@ -8,6 +8,8 @@ package main import ( "context" "encoding/json" + "fmt" + "net/http" "net/netip" "reflect" "strings" @@ -21,17 +23,25 @@ import ( corev1 "k8s.io/api/core/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/types" "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" - "tailscale.com/client/tailscale" + "tailscale.com/internal/client/tailscale" "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" "tailscale.com/types/ptr" "tailscale.com/util/mak" ) +const ( + vipTestIP = "5.6.7.8" +) + // confgOpts contains configuration options for creating cluster resources for // Tailscale proxies. type configOpts struct { @@ -39,7 +49,10 @@ type configOpts struct { secretName string hostname string namespace string + tailscaleNamespace string + namespaced bool parentType string + proxyType string priorityClassName string firewallMode string tailnetTargetIP string @@ -48,6 +61,7 @@ type configOpts struct { clusterTargetDNS string subnetRoutes string isExitNode bool + isAppConnector bool confFileHash string serveConfig *ipn.ServeConfig shouldEnableForwardingClusterTrafficViaIngress bool @@ -55,6 +69,10 @@ type configOpts struct { app string shouldRemoveAuthKey bool secretExtraData map[string][]byte + resourceVersion string + + enableMetrics bool + serviceMonitorLabels tsapi.Labels } func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.StatefulSet { @@ -69,14 +87,13 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef Env: []corev1.EnvVar{ {Name: "TS_USERSPACE", Value: "false"}, {Name: "POD_IP", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "status.podIP"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, + {Name: "POD_NAME", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "metadata.name"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, + {Name: "POD_UID", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "metadata.uid"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, {Name: "TS_KUBE_SECRET", Value: opts.secretName}, - {Name: "EXPERIMENTAL_TS_CONFIGFILE_PATH", Value: "/etc/tsconfig/tailscaled"}, {Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig"}, }, SecurityContext: &corev1.SecurityContext{ - Capabilities: &corev1.Capabilities{ - Add: []corev1.Capability{"NET_ADMIN"}, - }, + Privileged: ptr.To(true), }, ImagePullPolicy: "Always", } @@ -86,7 +103,7 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef Value: "true", }) } - annots := make(map[string]string) + var annots map[string]string var volumes []corev1.Volume volumes = []corev1.Volume{ { @@ -104,7 +121,7 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef MountPath: "/etc/tsconfig", }} if opts.confFileHash != "" { - annots["tailscale.com/operator-last-set-config-file-hash"] = opts.confFileHash + mak.Set(&annots, "tailscale.com/operator-last-set-config-file-hash", opts.confFileHash) } if opts.firewallMode != "" { tsContainer.Env = append(tsContainer.Env, corev1.EnvVar{ @@ -113,13 +130,13 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef }) } if opts.tailnetTargetIP != "" { - annots["tailscale.com/operator-last-set-ts-tailnet-target-ip"] = opts.tailnetTargetIP + mak.Set(&annots, "tailscale.com/operator-last-set-ts-tailnet-target-ip", opts.tailnetTargetIP) tsContainer.Env = append(tsContainer.Env, corev1.EnvVar{ Name: "TS_TAILNET_TARGET_IP", Value: opts.tailnetTargetIP, }) } else if opts.tailnetTargetFQDN != "" { - annots["tailscale.com/operator-last-set-ts-tailnet-target-fqdn"] = opts.tailnetTargetFQDN + mak.Set(&annots, "tailscale.com/operator-last-set-ts-tailnet-target-fqdn", opts.tailnetTargetFQDN) tsContainer.Env = append(tsContainer.Env, corev1.EnvVar{ Name: "TS_TAILNET_TARGET_FQDN", Value: opts.tailnetTargetFQDN, @@ -130,13 +147,13 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef Name: "TS_DEST_IP", Value: opts.clusterTargetIP, }) - annots["tailscale.com/operator-last-set-cluster-ip"] = opts.clusterTargetIP + mak.Set(&annots, "tailscale.com/operator-last-set-cluster-ip", opts.clusterTargetIP) } else if opts.clusterTargetDNS != "" { tsContainer.Env = append(tsContainer.Env, corev1.EnvVar{ Name: "TS_EXPERIMENTAL_DEST_DNS_NAME", Value: opts.clusterTargetDNS, }) - annots["tailscale.com/operator-last-set-cluster-dns-name"] = opts.clusterTargetDNS + mak.Set(&annots, "tailscale.com/operator-last-set-cluster-dns-name", opts.clusterTargetDNS) } if opts.serveConfig != nil { tsContainer.Env = append(tsContainer.Env, corev1.EnvVar{ @@ -150,6 +167,29 @@ func expectedSTS(t *testing.T, cl client.Client, opts configOpts) *appsv1.Statef Name: "TS_INTERNAL_APP", Value: opts.app, }) + if opts.enableMetrics { + tsContainer.Env = append(tsContainer.Env, + corev1.EnvVar{ + Name: "TS_DEBUG_ADDR_PORT", + Value: "$(POD_IP):9001"}, + corev1.EnvVar{ + Name: "TS_TAILSCALED_EXTRA_ARGS", + Value: "--debug=$(TS_DEBUG_ADDR_PORT)", + }, + corev1.EnvVar{ + Name: "TS_LOCAL_ADDR_PORT", + Value: "$(POD_IP):9002", + }, + corev1.EnvVar{ + Name: "TS_ENABLE_METRICS", + Value: "true", + }, + ) + tsContainer.Ports = append(tsContainer.Ports, + corev1.ContainerPort{Name: "debug", ContainerPort: 9001, Protocol: "TCP"}, + corev1.ContainerPort{Name: "metrics", ContainerPort: 9002, Protocol: "TCP"}, + ) + } ss := &appsv1.StatefulSet{ TypeMeta: metav1.TypeMeta{ Kind: "StatefulSet", @@ -228,8 +268,9 @@ func expectedSTSUserspace(t *testing.T, cl client.Client, opts configOpts) *apps Env: []corev1.EnvVar{ {Name: "TS_USERSPACE", Value: "true"}, {Name: "POD_IP", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "status.podIP"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, + {Name: "POD_NAME", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "metadata.name"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, + {Name: "POD_UID", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{APIVersion: "", FieldPath: "metadata.uid"}, ResourceFieldRef: nil, ConfigMapKeyRef: nil, SecretKeyRef: nil}}, {Name: "TS_KUBE_SECRET", Value: opts.secretName}, - {Name: "EXPERIMENTAL_TS_CONFIGFILE_PATH", Value: "/etc/tsconfig/tailscaled"}, {Name: "TS_EXPERIMENTAL_VERSIONED_CONFIG_DIR", Value: "/etc/tsconfig"}, {Name: "TS_SERVE_CONFIG", Value: "/etc/tailscaled/serve-config"}, {Name: "TS_INTERNAL_APP", Value: opts.app}, @@ -240,6 +281,29 @@ func expectedSTSUserspace(t *testing.T, cl client.Client, opts configOpts) *apps {Name: "serve-config", ReadOnly: true, MountPath: "/etc/tailscaled"}, }, } + if opts.enableMetrics { + tsContainer.Env = append(tsContainer.Env, + corev1.EnvVar{ + Name: "TS_DEBUG_ADDR_PORT", + Value: "$(POD_IP):9001"}, + corev1.EnvVar{ + Name: "TS_TAILSCALED_EXTRA_ARGS", + Value: "--debug=$(TS_DEBUG_ADDR_PORT)", + }, + corev1.EnvVar{ + Name: "TS_LOCAL_ADDR_PORT", + Value: "$(POD_IP):9002", + }, + corev1.EnvVar{ + Name: "TS_ENABLE_METRICS", + Value: "true", + }, + ) + tsContainer.Ports = append(tsContainer.Ports, corev1.ContainerPort{ + Name: "debug", ContainerPort: 9001, Protocol: "TCP"}, + corev1.ContainerPort{Name: "metrics", ContainerPort: 9002, Protocol: "TCP"}, + ) + } volumes := []corev1.Volume{ { Name: "tailscaledconfig", @@ -334,6 +398,90 @@ func expectedHeadlessService(name string, parentType string) *corev1.Service { } } +func expectedMetricsService(opts configOpts) *corev1.Service { + labels := metricsLabels(opts) + selector := map[string]string{ + "tailscale.com/managed": "true", + "tailscale.com/parent-resource": "test", + "tailscale.com/parent-resource-type": opts.parentType, + } + if opts.namespaced { + selector["tailscale.com/parent-resource-ns"] = opts.namespace + } + return &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: metricsResourceName(opts.stsName), + Namespace: opts.tailscaleNamespace, + Labels: labels, + }, + Spec: corev1.ServiceSpec{ + Selector: selector, + Type: corev1.ServiceTypeClusterIP, + Ports: []corev1.ServicePort{{Protocol: "TCP", Port: 9002, Name: "metrics"}}, + }, + } +} + +func metricsLabels(opts configOpts) map[string]string { + promJob := fmt.Sprintf("ts_%s_default_test", opts.proxyType) + if !opts.namespaced { + promJob = fmt.Sprintf("ts_%s_test", opts.proxyType) + } + labels := map[string]string{ + "tailscale.com/managed": "true", + "tailscale.com/metrics-target": opts.stsName, + "ts_prom_job": promJob, + "ts_proxy_type": opts.proxyType, + "ts_proxy_parent_name": "test", + } + if opts.namespaced { + labels["ts_proxy_parent_namespace"] = "default" + } + return labels +} + +func expectedServiceMonitor(t *testing.T, opts configOpts) *unstructured.Unstructured { + t.Helper() + smLabels := metricsLabels(opts) + if len(opts.serviceMonitorLabels) != 0 { + smLabels = mergeMapKeys(smLabels, opts.serviceMonitorLabels.Parse()) + } + name := metricsResourceName(opts.stsName) + sm := &ServiceMonitor{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: opts.tailscaleNamespace, + Labels: smLabels, + ResourceVersion: opts.resourceVersion, + OwnerReferences: []metav1.OwnerReference{{APIVersion: "v1", Kind: "Service", Name: name, BlockOwnerDeletion: ptr.To(true), Controller: ptr.To(true)}}, + }, + TypeMeta: metav1.TypeMeta{ + Kind: "ServiceMonitor", + APIVersion: "monitoring.coreos.com/v1", + }, + Spec: ServiceMonitorSpec{ + Selector: metav1.LabelSelector{MatchLabels: metricsLabels(opts)}, + Endpoints: []ServiceMonitorEndpoint{{ + Port: "metrics", + }}, + NamespaceSelector: ServiceMonitorNamespaceSelector{ + MatchNames: []string{opts.tailscaleNamespace}, + }, + JobLabel: "ts_prom_job", + TargetLabels: []string{ + "ts_proxy_parent_name", + "ts_proxy_parent_namespace", + "ts_proxy_type", + }, + }, + } + u, err := serviceMonitorToUnstructured(sm) + if err != nil { + t.Fatalf("error converting ServiceMonitor to unstructured: %v", err) + } + return u +} + func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Secret { t.Helper() s := &corev1.Secret{ @@ -350,12 +498,14 @@ func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Sec mak.Set(&s.StringData, "serve-config", string(serveConfigBs)) } conf := &ipn.ConfigVAlpha{ - Version: "alpha0", - AcceptDNS: "false", - Hostname: &opts.hostname, - Locked: "false", - AuthKey: ptr.To("secret-authkey"), - AcceptRoutes: "false", + Version: "alpha0", + AcceptDNS: "false", + Hostname: &opts.hostname, + Locked: "false", + AuthKey: ptr.To("secret-authkey"), + AcceptRoutes: "false", + AppConnector: &ipn.AppConnectorPrefs{Advertise: false}, + NoStatefulFiltering: "true", } if opts.proxyClass != "" { t.Logf("applying configuration from ProxyClass %s", opts.proxyClass) @@ -370,6 +520,9 @@ func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Sec if opts.shouldRemoveAuthKey { conf.AuthKey = nil } + if opts.isAppConnector { + conf.AppConnector = &ipn.AppConnectorPrefs{Advertise: true} + } var routes []netip.Prefix if opts.subnetRoutes != "" || opts.isExitNode { r := opts.subnetRoutes @@ -385,21 +538,17 @@ func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Sec } } conf.AdvertiseRoutes = routes - b, err := json.Marshal(conf) + bnn, err := json.Marshal(conf) if err != nil { t.Fatalf("error marshalling tailscaled config") } - if opts.tailnetTargetFQDN != "" || opts.tailnetTargetIP != "" { - conf.NoStatefulFiltering = "true" - } else { - conf.NoStatefulFiltering = "false" - } + conf.AppConnector = nil bn, err := json.Marshal(conf) if err != nil { t.Fatalf("error marshalling tailscaled config") } - mak.Set(&s.StringData, "tailscaled", string(b)) mak.Set(&s.StringData, "cap-95.hujson", string(bn)) + mak.Set(&s.StringData, "cap-107.hujson", string(bnn)) labels := map[string]string{ "tailscale.com/managed": "true", "tailscale.com/parent-resource": "test", @@ -416,13 +565,30 @@ func expectedSecret(t *testing.T, cl client.Client, opts configOpts) *corev1.Sec return s } +func findNoGenName(t *testing.T, client client.Client, ns, name, typ string) { + t.Helper() + labels := map[string]string{ + kubetypes.LabelManaged: "true", + LabelParentName: name, + LabelParentNamespace: ns, + LabelParentType: typ, + } + s, err := getSingleObject[corev1.Secret](context.Background(), client, "operator-ns", labels) + if err != nil { + t.Fatalf("finding secrets for %q: %v", name, err) + } + if s != nil { + t.Fatalf("found unexpected secret with name %q", s.GetName()) + } +} + func findGenName(t *testing.T, client client.Client, ns, name, typ string) (full, noSuffix string) { t.Helper() labels := map[string]string{ - LabelManaged: "true", - LabelParentName: name, - LabelParentNamespace: ns, - LabelParentType: typ, + kubetypes.LabelManaged: "true", + LabelParentName: name, + LabelParentNamespace: ns, + LabelParentType: typ, } s, err := getSingleObject[corev1.Secret](context.Background(), client, "operator-ns", labels) if err != nil { @@ -440,6 +606,21 @@ func mustCreate(t *testing.T, client client.Client, obj client.Object) { t.Fatalf("creating %q: %v", obj.GetName(), err) } } +func mustCreateAll(t *testing.T, client client.Client, objs ...client.Object) { + t.Helper() + for _, obj := range objs { + mustCreate(t, client, obj) + } +} + +func mustDeleteAll(t *testing.T, client client.Client, objs ...client.Object) { + t.Helper() + for _, obj := range objs { + if err := client.Delete(context.Background(), obj); err != nil { + t.Fatalf("deleting %q: %v", obj.GetName(), err) + } + } +} func mustUpdate[T any, O ptrObject[T]](t *testing.T, client client.Client, ns, name string, update func(O)) { t.Helper() @@ -477,7 +658,7 @@ func mustUpdateStatus[T any, O ptrObject[T]](t *testing.T, client client.Client, // modify func to ensure that they are removed from the cluster object and the // object passed as 'want'. If no such modifications are needed, you can pass // nil in place of the modify function. -func expectEqual[T any, O ptrObject[T]](t *testing.T, client client.Client, want O, modifier func(O)) { +func expectEqual[T any, O ptrObject[T]](t *testing.T, client client.Client, want O, modifiers ...func(O)) { t.Helper() got := O(new(T)) if err := client.Get(context.Background(), types.NamespacedName{ @@ -491,7 +672,7 @@ func expectEqual[T any, O ptrObject[T]](t *testing.T, client client.Client, want // so just remove it from both got and want. got.SetResourceVersion("") want.SetResourceVersion("") - if modifier != nil { + for _, modifier := range modifiers { modifier(want) modifier(got) } @@ -500,13 +681,29 @@ func expectEqual[T any, O ptrObject[T]](t *testing.T, client client.Client, want } } +func expectEqualUnstructured(t *testing.T, client client.Client, want *unstructured.Unstructured) { + t.Helper() + got := &unstructured.Unstructured{} + got.SetGroupVersionKind(want.GroupVersionKind()) + if err := client.Get(context.Background(), types.NamespacedName{ + Name: want.GetName(), + Namespace: want.GetNamespace(), + }, got); err != nil { + t.Fatalf("getting %q: %v", want.GetName(), err) + } + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("unexpected contents of Unstructured (-got +want):\n%s", diff) + } +} + func expectMissing[T any, O ptrObject[T]](t *testing.T, client client.Client, ns, name string) { t.Helper() obj := O(new(T)) - if err := client.Get(context.Background(), types.NamespacedName{ + err := client.Get(context.Background(), types.NamespacedName{ Name: name, Namespace: ns, - }, obj); !apierrors.IsNotFound(err) { + }, obj) + if !apierrors.IsNotFound(err) { t.Fatalf("%s %s/%s unexpectedly present, wanted missing", reflect.TypeOf(obj).Elem().Name(), ns, name) } } @@ -547,6 +744,19 @@ func expectRequeue(t *testing.T, sr reconcile.Reconciler, ns, name string) { t.Fatalf("expected timed requeue, got success") } } +func expectError(t *testing.T, sr reconcile.Reconciler, ns, name string) { + t.Helper() + req := reconcile.Request{ + NamespacedName: types.NamespacedName{ + Name: name, + Namespace: ns, + }, + } + _, err := sr.Reconcile(context.Background(), req) + if err == nil { + t.Error("Reconcile: expected error but did not get one") + } +} // expectEvents accepts a test recorder and a list of events, tests that expected // events are sent down the recorder's channel. Waits for 5s for each event. @@ -580,6 +790,7 @@ type fakeTSClient struct { sync.Mutex keyRequests []tailscale.KeyCapabilities deleted []string + vipServices map[tailcfg.ServiceName]*tailscale.VIPService } type fakeTSNetServer struct { certDomains []string @@ -637,12 +848,21 @@ func (c *fakeTSClient) Deleted() []string { // change to the configfile contents). func removeHashAnnotation(sts *appsv1.StatefulSet) { delete(sts.Spec.Template.Annotations, podAnnotationLastSetConfigFileHash) + if len(sts.Spec.Template.Annotations) == 0 { + sts.Spec.Template.Annotations = nil + } +} + +func removeResourceReqs(sts *appsv1.StatefulSet) { + if sts != nil { + sts.Spec.Template.Spec.Resources = nil + } } func removeTargetPortsFromSvc(svc *corev1.Service) { newPorts := make([]corev1.ServicePort, 0) for _, p := range svc.Spec.Ports { - newPorts = append(newPorts, corev1.ServicePort{Protocol: p.Protocol, Port: p.Port}) + newPorts = append(newPorts, corev1.ServicePort{Protocol: p.Protocol, Port: p.Port, Name: p.Name}) } svc.Spec.Ports = newPorts } @@ -650,29 +870,90 @@ func removeTargetPortsFromSvc(svc *corev1.Service) { func removeAuthKeyIfExistsModifier(t *testing.T) func(s *corev1.Secret) { return func(secret *corev1.Secret) { t.Helper() - if len(secret.StringData["tailscaled"]) != 0 { + if len(secret.StringData["cap-95.hujson"]) != 0 { conf := &ipn.ConfigVAlpha{} - if err := json.Unmarshal([]byte(secret.StringData["tailscaled"]), conf); err != nil { - t.Fatalf("error unmarshalling 'tailscaled' contents: %v", err) + if err := json.Unmarshal([]byte(secret.StringData["cap-95.hujson"]), conf); err != nil { + t.Fatalf("error umarshalling 'cap-95.hujson' contents: %v", err) } conf.AuthKey = nil b, err := json.Marshal(conf) if err != nil { - t.Fatalf("error marshalling updated 'tailscaled' config: %v", err) + t.Fatalf("error marshalling 'cap-95.huson' contents: %v", err) } - mak.Set(&secret.StringData, "tailscaled", string(b)) + mak.Set(&secret.StringData, "cap-95.hujson", string(b)) } - if len(secret.StringData["cap-95.hujson"]) != 0 { + if len(secret.StringData["cap-107.hujson"]) != 0 { conf := &ipn.ConfigVAlpha{} - if err := json.Unmarshal([]byte(secret.StringData["cap-95.hujson"]), conf); err != nil { - t.Fatalf("error umarshalling 'cap-95.hujson' contents: %v", err) + if err := json.Unmarshal([]byte(secret.StringData["cap-107.hujson"]), conf); err != nil { + t.Fatalf("error umarshalling 'cap-107.hujson' contents: %v", err) } conf.AuthKey = nil b, err := json.Marshal(conf) if err != nil { - t.Fatalf("error marshalling 'cap-95.huson' contents: %v", err) + t.Fatalf("error marshalling 'cap-107.huson' contents: %v", err) } - mak.Set(&secret.StringData, "cap-95.hujson", string(b)) + mak.Set(&secret.StringData, "cap-107.hujson", string(b)) } } } + +func (c *fakeTSClient) GetVIPService(ctx context.Context, name tailcfg.ServiceName) (*tailscale.VIPService, error) { + c.Lock() + defer c.Unlock() + if c.vipServices == nil { + return nil, tailscale.ErrResponse{Status: http.StatusNotFound} + } + svc, ok := c.vipServices[name] + if !ok { + return nil, tailscale.ErrResponse{Status: http.StatusNotFound} + } + return svc, nil +} + +func (c *fakeTSClient) ListVIPServices(ctx context.Context) (map[tailcfg.ServiceName]*tailscale.VIPService, error) { + c.Lock() + defer c.Unlock() + if c.vipServices == nil { + return nil, &tailscale.ErrResponse{Status: http.StatusNotFound} + } + return c.vipServices, nil +} + +func (c *fakeTSClient) CreateOrUpdateVIPService(ctx context.Context, svc *tailscale.VIPService) error { + c.Lock() + defer c.Unlock() + if c.vipServices == nil { + c.vipServices = make(map[tailcfg.ServiceName]*tailscale.VIPService) + } + + if svc.Addrs == nil { + svc.Addrs = []string{vipTestIP} + } + + c.vipServices[svc.Name] = svc + return nil +} + +func (c *fakeTSClient) DeleteVIPService(ctx context.Context, name tailcfg.ServiceName) error { + c.Lock() + defer c.Unlock() + if c.vipServices != nil { + delete(c.vipServices, name) + } + return nil +} + +type fakeLocalClient struct { + status *ipnstate.Status +} + +func (f *fakeLocalClient) StatusWithoutPeers(ctx context.Context) (*ipnstate.Status, error) { + if f.status == nil { + return &ipnstate.Status{ + Self: &ipnstate.PeerStatus{ + DNSName: "test-node.test.ts.net.", + }, + }, nil + } + return f.status, nil +} diff --git a/cmd/k8s-operator/tsclient.go b/cmd/k8s-operator/tsclient.go new file mode 100644 index 0000000000000..f606394376589 --- /dev/null +++ b/cmd/k8s-operator/tsclient.go @@ -0,0 +1,58 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "fmt" + "os" + + "golang.org/x/oauth2/clientcredentials" + "tailscale.com/internal/client/tailscale" + "tailscale.com/tailcfg" +) + +// defaultTailnet is a value that can be used in Tailscale API calls instead of tailnet name to indicate that the API +// call should be performed on the default tailnet for the provided credentials. +const ( + defaultTailnet = "-" + defaultBaseURL = "https://api.tailscale.com" +) + +func newTSClient(ctx context.Context, clientIDPath, clientSecretPath string) (tsClient, error) { + tokenURL := defaultEnv("TOKEN_URL", "https://login.tailscale.com/api/v2/oauth/token") + controlURL := defaultEnv("CONTROL_URL", "") + clientID, err := os.ReadFile(clientIDPath) + if err != nil { + return nil, fmt.Errorf("error reading client ID %q: %w", clientIDPath, err) + } + clientSecret, err := os.ReadFile(clientSecretPath) + if err != nil { + return nil, fmt.Errorf("reading client secret %q: %w", clientSecretPath, err) + } + credentials := clientcredentials.Config{ + ClientID: string(clientID), + ClientSecret: string(clientSecret), + TokenURL: tokenURL, + } + c := tailscale.NewClient(defaultTailnet, nil) + c.UserAgent = "tailscale-k8s-operator" + c.HTTPClient = credentials.Client(ctx) + c.BaseURL = controlURL + return c, nil +} + +type tsClient interface { + CreateKey(ctx context.Context, caps tailscale.KeyCapabilities) (string, *tailscale.Key, error) + Device(ctx context.Context, deviceID string, fields *tailscale.DeviceFieldsOpts) (*tailscale.Device, error) + DeleteDevice(ctx context.Context, nodeStableID string) error + // GetVIPService is a method for getting a Tailscale Service. VIPService is the original name for Tailscale Service. + GetVIPService(ctx context.Context, name tailcfg.ServiceName) (*tailscale.VIPService, error) + // CreateOrUpdateVIPService is a method for creating or updating a Tailscale Service. + CreateOrUpdateVIPService(ctx context.Context, svc *tailscale.VIPService) error + // DeleteVIPService is a method for deleting a Tailscale Service. + DeleteVIPService(ctx context.Context, name tailcfg.ServiceName) error +} diff --git a/cmd/k8s-operator/tsrecorder.go b/cmd/k8s-operator/tsrecorder.go index cfe38c50af311..081543cd384db 100644 --- a/cmd/k8s-operator/tsrecorder.go +++ b/cmd/k8s-operator/tsrecorder.go @@ -8,12 +8,13 @@ package main import ( "context" "encoding/json" + "errors" "fmt" "net/http" "slices" + "strings" "sync" - "github.com/pkg/errors" "go.uber.org/zap" xslices "golang.org/x/exp/slices" appsv1 "k8s.io/api/apps/v1" @@ -21,8 +22,10 @@ import ( rbacv1 "k8s.io/api/rbac/v1" apiequality "k8s.io/apimachinery/pkg/api/equality" apierrors "k8s.io/apimachinery/pkg/api/errors" + apivalidation "k8s.io/apimachinery/pkg/api/validation" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" + "k8s.io/apimachinery/pkg/util/validation/field" "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" @@ -38,6 +41,7 @@ import ( const ( reasonRecorderCreationFailed = "RecorderCreationFailed" + reasonRecorderCreating = "RecorderCreating" reasonRecorderCreated = "RecorderCreated" reasonRecorderInvalid = "RecorderInvalid" @@ -102,10 +106,10 @@ func (r *RecorderReconciler) Reconcile(ctx context.Context, req reconcile.Reques oldTSRStatus := tsr.Status.DeepCopy() setStatusReady := func(tsr *tsapi.Recorder, status metav1.ConditionStatus, reason, message string) (reconcile.Result, error) { tsoperator.SetRecorderCondition(tsr, tsapi.RecorderReady, status, reason, message, tsr.Generation, r.clock, logger) - if !apiequality.Semantic.DeepEqual(oldTSRStatus, tsr.Status) { + if !apiequality.Semantic.DeepEqual(oldTSRStatus, &tsr.Status) { // An error encountered here should get returned by the Reconcile function. if updateErr := r.Client.Status().Update(ctx, tsr); updateErr != nil { - err = errors.Wrap(err, updateErr.Error()) + err = errors.Join(err, updateErr) } } return reconcile.Result{}, err @@ -119,23 +123,28 @@ func (r *RecorderReconciler) Reconcile(ctx context.Context, req reconcile.Reques logger.Infof("ensuring Recorder is set up") tsr.Finalizers = append(tsr.Finalizers, FinalizerName) if err := r.Update(ctx, tsr); err != nil { - logger.Errorf("error adding finalizer: %w", err) return setStatusReady(tsr, metav1.ConditionFalse, reasonRecorderCreationFailed, reasonRecorderCreationFailed) } } - if err := r.validate(tsr); err != nil { - logger.Errorf("error validating Recorder spec: %w", err) + if err := r.validate(ctx, tsr); err != nil { message := fmt.Sprintf("Recorder is invalid: %s", err) r.recorder.Eventf(tsr, corev1.EventTypeWarning, reasonRecorderInvalid, message) return setStatusReady(tsr, metav1.ConditionFalse, reasonRecorderInvalid, message) } if err = r.maybeProvision(ctx, tsr); err != nil { - logger.Errorf("error creating Recorder resources: %w", err) + reason := reasonRecorderCreationFailed message := fmt.Sprintf("failed creating Recorder: %s", err) - r.recorder.Eventf(tsr, corev1.EventTypeWarning, reasonRecorderCreationFailed, message) - return setStatusReady(tsr, metav1.ConditionFalse, reasonRecorderCreationFailed, message) + if strings.Contains(err.Error(), optimisticLockErrorMsg) { + reason = reasonRecorderCreating + message = fmt.Sprintf("optimistic lock error, retrying: %s", err) + err = nil + logger.Info(message) + } else { + r.recorder.Eventf(tsr, corev1.EventTypeWarning, reasonRecorderCreationFailed, message) + } + return setStatusReady(tsr, metav1.ConditionFalse, reason, message) } logger.Info("Recorder resources synced") @@ -153,20 +162,26 @@ func (r *RecorderReconciler) maybeProvision(ctx context.Context, tsr *tsapi.Reco if err := r.ensureAuthSecretCreated(ctx, tsr); err != nil { return fmt.Errorf("error creating secrets: %w", err) } - // State secret is precreated so we can use the Recorder CR as its owner ref. + // State Secret is precreated so we can use the Recorder CR as its owner ref. sec := tsrStateSecret(tsr, r.tsNamespace) if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, sec, func(s *corev1.Secret) { s.ObjectMeta.Labels = sec.ObjectMeta.Labels s.ObjectMeta.Annotations = sec.ObjectMeta.Annotations - s.ObjectMeta.OwnerReferences = sec.ObjectMeta.OwnerReferences }); err != nil { return fmt.Errorf("error creating state Secret: %w", err) } sa := tsrServiceAccount(tsr, r.tsNamespace) - if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, sa, func(s *corev1.ServiceAccount) { + if _, err := createOrMaybeUpdate(ctx, r.Client, r.tsNamespace, sa, func(s *corev1.ServiceAccount) error { + // Perform this check within the update function to make sure we don't + // have a race condition between the previous check and the update. + if err := saOwnedByRecorder(s, tsr); err != nil { + return err + } + s.ObjectMeta.Labels = sa.ObjectMeta.Labels s.ObjectMeta.Annotations = sa.ObjectMeta.Annotations - s.ObjectMeta.OwnerReferences = sa.ObjectMeta.OwnerReferences + + return nil }); err != nil { return fmt.Errorf("error creating ServiceAccount: %w", err) } @@ -174,7 +189,6 @@ func (r *RecorderReconciler) maybeProvision(ctx context.Context, tsr *tsapi.Reco if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, role, func(r *rbacv1.Role) { r.ObjectMeta.Labels = role.ObjectMeta.Labels r.ObjectMeta.Annotations = role.ObjectMeta.Annotations - r.ObjectMeta.OwnerReferences = role.ObjectMeta.OwnerReferences r.Rules = role.Rules }); err != nil { return fmt.Errorf("error creating Role: %w", err) @@ -183,7 +197,6 @@ func (r *RecorderReconciler) maybeProvision(ctx context.Context, tsr *tsapi.Reco if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, roleBinding, func(r *rbacv1.RoleBinding) { r.ObjectMeta.Labels = roleBinding.ObjectMeta.Labels r.ObjectMeta.Annotations = roleBinding.ObjectMeta.Annotations - r.ObjectMeta.OwnerReferences = roleBinding.ObjectMeta.OwnerReferences r.RoleRef = roleBinding.RoleRef r.Subjects = roleBinding.Subjects }); err != nil { @@ -193,12 +206,18 @@ func (r *RecorderReconciler) maybeProvision(ctx context.Context, tsr *tsapi.Reco if _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, ss, func(s *appsv1.StatefulSet) { s.ObjectMeta.Labels = ss.ObjectMeta.Labels s.ObjectMeta.Annotations = ss.ObjectMeta.Annotations - s.ObjectMeta.OwnerReferences = ss.ObjectMeta.OwnerReferences s.Spec = ss.Spec }); err != nil { return fmt.Errorf("error creating StatefulSet: %w", err) } + // ServiceAccount name may have changed, in which case we need to clean up + // the previous ServiceAccount. RoleBinding will already be updated to point + // to the new ServiceAccount. + if err := r.maybeCleanupServiceAccounts(ctx, tsr, sa.Name); err != nil { + return fmt.Errorf("error cleaning up ServiceAccounts: %w", err) + } + var devices []tsapi.RecorderTailnetDevice device, ok, err := r.getDeviceInfo(ctx, tsr.Name) @@ -217,13 +236,54 @@ func (r *RecorderReconciler) maybeProvision(ctx context.Context, tsr *tsapi.Reco return nil } +func saOwnedByRecorder(sa *corev1.ServiceAccount, tsr *tsapi.Recorder) error { + // If ServiceAccount name has been configured, check that we don't clobber + // a pre-existing SA not owned by this Recorder. + if sa.Name != tsr.Name && !apiequality.Semantic.DeepEqual(sa.OwnerReferences, tsrOwnerReference(tsr)) { + return fmt.Errorf("custom ServiceAccount name %q specified but conflicts with a pre-existing ServiceAccount in the %s namespace", sa.Name, sa.Namespace) + } + + return nil +} + +// maybeCleanupServiceAccounts deletes any dangling ServiceAccounts +// owned by the Recorder if the ServiceAccount name has been changed. +// They would eventually be cleaned up by owner reference deletion, but +// this avoids a long-lived Recorder with many ServiceAccount name changes +// accumulating a large amount of garbage. +// +// This is a no-op if the ServiceAccount name has not changed. +func (r *RecorderReconciler) maybeCleanupServiceAccounts(ctx context.Context, tsr *tsapi.Recorder, currentName string) error { + logger := r.logger(tsr.Name) + + // List all ServiceAccounts owned by this Recorder. + sas := &corev1.ServiceAccountList{} + if err := r.List(ctx, sas, client.InNamespace(r.tsNamespace), client.MatchingLabels(labels("recorder", tsr.Name, nil))); err != nil { + return fmt.Errorf("error listing ServiceAccounts for cleanup: %w", err) + } + for _, sa := range sas.Items { + if sa.Name == currentName { + continue + } + if err := r.Delete(ctx, &sa); err != nil { + if apierrors.IsNotFound(err) { + logger.Debugf("ServiceAccount %s not found, likely already deleted", sa.Name) + } else { + return fmt.Errorf("error deleting ServiceAccount %s: %w", sa.Name, err) + } + } + } + + return nil +} + // maybeCleanup just deletes the device from the tailnet. All the kubernetes // resources linked to a Recorder will get cleaned up via owner references // (which we can use because they are all in the same namespace). func (r *RecorderReconciler) maybeCleanup(ctx context.Context, tsr *tsapi.Recorder) (bool, error) { logger := r.logger(tsr.Name) - id, _, ok, err := r.getNodeMetadata(ctx, tsr.Name) + prefs, ok, err := r.getDevicePrefs(ctx, tsr.Name) if err != nil { return false, err } @@ -236,6 +296,7 @@ func (r *RecorderReconciler) maybeCleanup(ctx context.Context, tsr *tsapi.Record return true, nil } + id := string(prefs.Config.NodeID) logger.Debugf("deleting device %s from control", string(id)) if err := r.tsClient.DeleteDevice(ctx, string(id)); err != nil { errResp := &tailscale.ErrResponse{} @@ -294,11 +355,41 @@ func (r *RecorderReconciler) ensureAuthSecretCreated(ctx context.Context, tsr *t return nil } -func (r *RecorderReconciler) validate(tsr *tsapi.Recorder) error { +func (r *RecorderReconciler) validate(ctx context.Context, tsr *tsapi.Recorder) error { if !tsr.Spec.EnableUI && tsr.Spec.Storage.S3 == nil { return errors.New("must either enable UI or use S3 storage to ensure recordings are accessible") } + // Check any custom ServiceAccount config doesn't conflict with pre-existing + // ServiceAccounts. This check is performed once during validation to ensure + // errors are raised early, but also again during any Updates to prevent a race. + specSA := tsr.Spec.StatefulSet.Pod.ServiceAccount + if specSA.Name != "" && specSA.Name != tsr.Name { + sa := &corev1.ServiceAccount{} + key := client.ObjectKey{ + Name: specSA.Name, + Namespace: r.tsNamespace, + } + + err := r.Get(ctx, key, sa) + switch { + case apierrors.IsNotFound(err): + // ServiceAccount doesn't exist, so no conflict. + case err != nil: + return fmt.Errorf("error getting ServiceAccount %q for validation: %w", tsr.Spec.StatefulSet.Pod.ServiceAccount.Name, err) + default: + // ServiceAccount exists, check if it's owned by the Recorder. + if err := saOwnedByRecorder(sa, tsr); err != nil { + return err + } + } + } + if len(specSA.Annotations) > 0 { + if violations := apivalidation.ValidateAnnotations(specSA.Annotations, field.NewPath(".spec.statefulSet.pod.serviceAccount.annotations")); len(violations) > 0 { + return violations.ToAggregate() + } + } + return nil } @@ -320,34 +411,33 @@ func (r *RecorderReconciler) getStateSecret(ctx context.Context, tsrName string) return secret, nil } -func (r *RecorderReconciler) getNodeMetadata(ctx context.Context, tsrName string) (id tailcfg.StableNodeID, dnsName string, ok bool, err error) { +func (r *RecorderReconciler) getDevicePrefs(ctx context.Context, tsrName string) (prefs prefs, ok bool, err error) { secret, err := r.getStateSecret(ctx, tsrName) if err != nil || secret == nil { - return "", "", false, err + return prefs, false, err } - return getNodeMetadata(ctx, secret) + return getDevicePrefs(secret) } -// getNodeMetadata returns 'ok == true' iff the node ID is found. The dnsName +// getDevicePrefs returns 'ok == true' iff the node ID is found. The dnsName // is expected to always be non-empty if the node ID is, but not required. -func getNodeMetadata(ctx context.Context, secret *corev1.Secret) (id tailcfg.StableNodeID, dnsName string, ok bool, err error) { +func getDevicePrefs(secret *corev1.Secret) (prefs prefs, ok bool, err error) { // TODO(tomhjp): Should maybe use ipn to parse the following info instead. currentProfile, ok := secret.Data[currentProfileKey] if !ok { - return "", "", false, nil + return prefs, false, nil } profileBytes, ok := secret.Data[string(currentProfile)] if !ok { - return "", "", false, nil + return prefs, false, nil } - var profile profile - if err := json.Unmarshal(profileBytes, &profile); err != nil { - return "", "", false, fmt.Errorf("failed to extract node profile info from state Secret %s: %w", secret.Name, err) + if err := json.Unmarshal(profileBytes, &prefs); err != nil { + return prefs, false, fmt.Errorf("failed to extract node profile info from state Secret %s: %w", secret.Name, err) } - ok = profile.Config.NodeID != "" - return tailcfg.StableNodeID(profile.Config.NodeID), profile.Config.UserProfile.LoginName, ok, nil + ok = prefs.Config.NodeID != "" + return prefs, ok, nil } func (r *RecorderReconciler) getDeviceInfo(ctx context.Context, tsrName string) (d tsapi.RecorderTailnetDevice, ok bool, err error) { @@ -360,14 +450,14 @@ func (r *RecorderReconciler) getDeviceInfo(ctx context.Context, tsrName string) } func getDeviceInfo(ctx context.Context, tsClient tsClient, secret *corev1.Secret) (d tsapi.RecorderTailnetDevice, ok bool, err error) { - nodeID, dnsName, ok, err := getNodeMetadata(ctx, secret) + prefs, ok, err := getDevicePrefs(secret) if !ok || err != nil { return tsapi.RecorderTailnetDevice{}, false, err } // TODO(tomhjp): The profile info doesn't include addresses, which is why we // need the API. Should we instead update the profile to include addresses? - device, err := tsClient.Device(ctx, string(nodeID), nil) + device, err := tsClient.Device(ctx, string(prefs.Config.NodeID), nil) if err != nil { return tsapi.RecorderTailnetDevice{}, false, fmt.Errorf("failed to get device info from API: %w", err) } @@ -376,20 +466,25 @@ func getDeviceInfo(ctx context.Context, tsClient tsClient, secret *corev1.Secret Hostname: device.Hostname, TailnetIPs: device.Addresses, } - if dnsName != "" { + if dnsName := prefs.Config.UserProfile.LoginName; dnsName != "" { d.URL = fmt.Sprintf("https://%s", dnsName) } return d, true, nil } -type profile struct { +// [prefs] is a subset of the ipn.Prefs struct used for extracting information +// from the state Secret of Tailscale devices. +type prefs struct { Config struct { - NodeID string `json:"NodeID"` + NodeID tailcfg.StableNodeID `json:"NodeID"` UserProfile struct { + // LoginName is the MagicDNS name of the device, e.g. foo.tail-scale.ts.net. LoginName string `json:"LoginName"` } `json:"UserProfile"` } `json:"Config"` + + AdvertiseServices []string `json:"AdvertiseServices"` } func markedForDeletion(obj metav1.Object) bool { diff --git a/cmd/k8s-operator/tsrecorder_specs.go b/cmd/k8s-operator/tsrecorder_specs.go index 4a74fb7e03442..7c6e80aed56fd 100644 --- a/cmd/k8s-operator/tsrecorder_specs.go +++ b/cmd/k8s-operator/tsrecorder_specs.go @@ -39,7 +39,7 @@ func tsrStatefulSet(tsr *tsapi.Recorder, namespace string) *appsv1.StatefulSet { Annotations: tsr.Spec.StatefulSet.Pod.Annotations, }, Spec: corev1.PodSpec{ - ServiceAccountName: tsr.Name, + ServiceAccountName: tsrServiceAccountName(tsr), Affinity: tsr.Spec.StatefulSet.Pod.Affinity, SecurityContext: tsr.Spec.StatefulSet.Pod.SecurityContext, ImagePullSecrets: tsr.Spec.StatefulSet.Pod.ImagePullSecrets, @@ -100,14 +100,25 @@ func tsrStatefulSet(tsr *tsapi.Recorder, namespace string) *appsv1.StatefulSet { func tsrServiceAccount(tsr *tsapi.Recorder, namespace string) *corev1.ServiceAccount { return &corev1.ServiceAccount{ ObjectMeta: metav1.ObjectMeta{ - Name: tsr.Name, + Name: tsrServiceAccountName(tsr), Namespace: namespace, Labels: labels("recorder", tsr.Name, nil), OwnerReferences: tsrOwnerReference(tsr), + Annotations: tsr.Spec.StatefulSet.Pod.ServiceAccount.Annotations, }, } } +func tsrServiceAccountName(tsr *tsapi.Recorder) string { + sa := tsr.Spec.StatefulSet.Pod.ServiceAccount + name := tsr.Name + if sa.Name != "" { + name = sa.Name + } + + return name +} + func tsrRole(tsr *tsapi.Recorder, namespace string) *rbacv1.Role { return &rbacv1.Role{ ObjectMeta: metav1.ObjectMeta{ @@ -130,6 +141,15 @@ func tsrRole(tsr *tsapi.Recorder, namespace string) *rbacv1.Role { fmt.Sprintf("%s-0", tsr.Name), // Contains the node state. }, }, + { + APIGroups: []string{""}, + Resources: []string{"events"}, + Verbs: []string{ + "get", + "create", + "patch", + }, + }, }, } } @@ -145,7 +165,7 @@ func tsrRoleBinding(tsr *tsapi.Recorder, namespace string) *rbacv1.RoleBinding { Subjects: []rbacv1.Subject{ { Kind: "ServiceAccount", - Name: tsr.Name, + Name: tsrServiceAccountName(tsr), Namespace: namespace, }, }, @@ -203,6 +223,14 @@ func env(tsr *tsapi.Recorder) []corev1.EnvVar { }, }, }, + { + Name: "POD_UID", + ValueFrom: &corev1.EnvVarSource{ + FieldRef: &corev1.ObjectFieldSelector{ + FieldPath: "metadata.uid", + }, + }, + }, { Name: "TS_STATE", Value: "kube:$(POD_NAME)", diff --git a/cmd/k8s-operator/tsrecorder_test.go b/cmd/k8s-operator/tsrecorder_test.go index bd73e8fb9ec26..e6d56ef2f04c6 100644 --- a/cmd/k8s-operator/tsrecorder_test.go +++ b/cmd/k8s-operator/tsrecorder_test.go @@ -8,6 +8,7 @@ package main import ( "context" "encoding/json" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -41,7 +42,7 @@ func TestRecorder(t *testing.T) { Build() tsClient := &fakeTSClient{} zl, _ := zap.NewDevelopment() - fr := record.NewFakeRecorder(1) + fr := record.NewFakeRecorder(2) cl := tstest.NewClock(tstest.ClockOpts{}) reconciler := &RecorderReconciler{ tsNamespace: tsNamespace, @@ -52,12 +53,12 @@ func TestRecorder(t *testing.T) { clock: cl, } - t.Run("invalid spec gives an error condition", func(t *testing.T) { + t.Run("invalid_spec_gives_an_error_condition", func(t *testing.T) { expectReconciled(t, reconciler, "", tsr.Name) msg := "Recorder is invalid: must either enable UI or use S3 storage to ensure recordings are accessible" tsoperator.SetRecorderCondition(tsr, tsapi.RecorderReady, metav1.ConditionFalse, reasonRecorderInvalid, msg, 0, cl, zl.Sugar()) - expectEqual(t, fc, tsr, nil) + expectEqual(t, fc, tsr) if expected := 0; reconciler.recorders.Len() != expected { t.Fatalf("expected %d recorders, got %d", expected, reconciler.recorders.Len()) } @@ -65,10 +66,66 @@ func TestRecorder(t *testing.T) { expectedEvent := "Warning RecorderInvalid Recorder is invalid: must either enable UI or use S3 storage to ensure recordings are accessible" expectEvents(t, fr, []string{expectedEvent}) - }) - t.Run("observe Ready=true status condition for a valid spec", func(t *testing.T) { tsr.Spec.EnableUI = true + tsr.Spec.StatefulSet.Pod.ServiceAccount.Annotations = map[string]string{ + "invalid space characters": "test", + } + mustUpdate(t, fc, "", "test", func(t *tsapi.Recorder) { + t.Spec = tsr.Spec + }) + expectReconciled(t, reconciler, "", tsr.Name) + + // Only check part of this error message, because it's defined in an + // external package and may change. + if err := fc.Get(context.Background(), client.ObjectKey{ + Name: tsr.Name, + }, tsr); err != nil { + t.Fatal(err) + } + if len(tsr.Status.Conditions) != 1 { + t.Fatalf("expected 1 condition, got %d", len(tsr.Status.Conditions)) + } + cond := tsr.Status.Conditions[0] + if cond.Type != string(tsapi.RecorderReady) || cond.Status != metav1.ConditionFalse || cond.Reason != reasonRecorderInvalid { + t.Fatalf("expected condition RecorderReady false due to RecorderInvalid, got %v", cond) + } + for _, msg := range []string{cond.Message, <-fr.Events} { + if !strings.Contains(msg, `"invalid space characters"`) { + t.Fatalf("expected invalid annotation key in error message, got %q", cond.Message) + } + } + }) + + t.Run("conflicting_service_account_config_marked_as_invalid", func(t *testing.T) { + mustCreate(t, fc, &corev1.ServiceAccount{ + ObjectMeta: metav1.ObjectMeta{ + Name: "pre-existing-sa", + Namespace: tsNamespace, + }, + }) + + tsr.Spec.StatefulSet.Pod.ServiceAccount.Annotations = nil + tsr.Spec.StatefulSet.Pod.ServiceAccount.Name = "pre-existing-sa" + mustUpdate(t, fc, "", "test", func(t *tsapi.Recorder) { + t.Spec = tsr.Spec + }) + + expectReconciled(t, reconciler, "", tsr.Name) + + msg := `Recorder is invalid: custom ServiceAccount name "pre-existing-sa" specified but conflicts with a pre-existing ServiceAccount in the tailscale namespace` + tsoperator.SetRecorderCondition(tsr, tsapi.RecorderReady, metav1.ConditionFalse, reasonRecorderInvalid, msg, 0, cl, zl.Sugar()) + expectEqual(t, fc, tsr) + if expected := 0; reconciler.recorders.Len() != expected { + t.Fatalf("expected %d recorders, got %d", expected, reconciler.recorders.Len()) + } + + expectedEvent := "Warning RecorderInvalid " + msg + expectEvents(t, fr, []string{expectedEvent}) + }) + + t.Run("observe_Ready_true_status_condition_for_a_valid_spec", func(t *testing.T) { + tsr.Spec.StatefulSet.Pod.ServiceAccount.Name = "" mustUpdate(t, fc, "", "test", func(t *tsapi.Recorder) { t.Spec = tsr.Spec }) @@ -76,14 +133,49 @@ func TestRecorder(t *testing.T) { expectReconciled(t, reconciler, "", tsr.Name) tsoperator.SetRecorderCondition(tsr, tsapi.RecorderReady, metav1.ConditionTrue, reasonRecorderCreated, reasonRecorderCreated, 0, cl, zl.Sugar()) - expectEqual(t, fc, tsr, nil) + expectEqual(t, fc, tsr) if expected := 1; reconciler.recorders.Len() != expected { t.Fatalf("expected %d recorders, got %d", expected, reconciler.recorders.Len()) } expectRecorderResources(t, fc, tsr, true) }) - t.Run("populate node info in state secret, and see it appear in status", func(t *testing.T) { + t.Run("valid_service_account_config", func(t *testing.T) { + tsr.Spec.StatefulSet.Pod.ServiceAccount.Name = "test-sa" + tsr.Spec.StatefulSet.Pod.ServiceAccount.Annotations = map[string]string{ + "test": "test", + } + mustUpdate(t, fc, "", "test", func(t *tsapi.Recorder) { + t.Spec = tsr.Spec + }) + + expectReconciled(t, reconciler, "", tsr.Name) + + expectEqual(t, fc, tsr) + if expected := 1; reconciler.recorders.Len() != expected { + t.Fatalf("expected %d recorders, got %d", expected, reconciler.recorders.Len()) + } + expectRecorderResources(t, fc, tsr, true) + + // Get the service account and check the annotations. + sa := &corev1.ServiceAccount{} + if err := fc.Get(context.Background(), client.ObjectKey{ + Name: tsr.Spec.StatefulSet.Pod.ServiceAccount.Name, + Namespace: tsNamespace, + }, sa); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(sa.Annotations, tsr.Spec.StatefulSet.Pod.ServiceAccount.Annotations); diff != "" { + t.Fatalf("unexpected service account annotations (-got +want):\n%s", diff) + } + if sa.Name != tsr.Spec.StatefulSet.Pod.ServiceAccount.Name { + t.Fatalf("unexpected service account name: got %q, want %q", sa.Name, tsr.Spec.StatefulSet.Pod.ServiceAccount.Name) + } + + expectMissing[corev1.ServiceAccount](t, fc, tsNamespace, tsr.Name) + }) + + t.Run("populate_node_info_in_state_secret_and_see_it_appear_in_status", func(t *testing.T) { bytes, err := json.Marshal(map[string]any{ "Config": map[string]any{ "NodeID": "nodeid-123", @@ -112,10 +204,10 @@ func TestRecorder(t *testing.T) { URL: "https://test-0.example.ts.net", }, } - expectEqual(t, fc, tsr, nil) + expectEqual(t, fc, tsr) }) - t.Run("delete the Recorder and observe cleanup", func(t *testing.T) { + t.Run("delete_the_Recorder_and_observe_cleanup", func(t *testing.T) { if err := fc.Delete(context.Background(), tsr); err != nil { t.Fatal(err) } @@ -145,12 +237,12 @@ func expectRecorderResources(t *testing.T, fc client.WithWatch, tsr *tsapi.Recor statefulSet := tsrStatefulSet(tsr, tsNamespace) if shouldExist { - expectEqual(t, fc, auth, nil) - expectEqual(t, fc, state, nil) - expectEqual(t, fc, role, nil) - expectEqual(t, fc, roleBinding, nil) - expectEqual(t, fc, serviceAccount, nil) - expectEqual(t, fc, statefulSet, nil) + expectEqual(t, fc, auth) + expectEqual(t, fc, state) + expectEqual(t, fc, role) + expectEqual(t, fc, roleBinding) + expectEqual(t, fc, serviceAccount) + expectEqual(t, fc, statefulSet, removeResourceReqs) } else { expectMissing[corev1.Secret](t, fc, auth.Namespace, auth.Name) expectMissing[corev1.Secret](t, fc, state.Namespace, state.Name) diff --git a/cmd/nardump/nardump.go b/cmd/nardump/nardump.go index 05be7b65a7e37..f8947b02b852c 100644 --- a/cmd/nardump/nardump.go +++ b/cmd/nardump/nardump.go @@ -100,14 +100,13 @@ func (nw *narWriter) writeDir(dirPath string) error { sub := path.Join(dirPath, ent.Name()) var err error switch { - case mode.IsRegular(): - err = nw.writeRegular(sub) case mode.IsDir(): err = nw.writeDir(sub) + case mode.IsRegular(): + err = nw.writeRegular(sub) + case mode&os.ModeSymlink != 0: + err = nw.writeSymlink(sub) default: - // TODO(bradfitz): symlink, but requires fighting io/fs a bit - // to get at Readlink or the osFS via fs. But for now - // we don't need symlinks because they're not in Go's archive. return fmt.Errorf("unsupported file type %v at %q", sub, mode) } if err != nil { @@ -143,6 +142,23 @@ func (nw *narWriter) writeRegular(path string) error { return nil } +func (nw *narWriter) writeSymlink(path string) error { + nw.str("(") + nw.str("type") + nw.str("symlink") + nw.str("target") + // broken symlinks are valid in a nar + // given we do os.chdir(dir) and os.dirfs(".") above + // readlink now resolves relative links even if they are broken + link, err := os.Readlink(path) + if err != nil { + return err + } + nw.str(link) + nw.str(")") + return nil +} + func (nw *narWriter) str(s string) { if err := writeString(nw.w, s); err != nil { panic(writeNARError{err}) diff --git a/cmd/nardump/nardump_test.go b/cmd/nardump/nardump_test.go new file mode 100644 index 0000000000000..3b87e7962d638 --- /dev/null +++ b/cmd/nardump/nardump_test.go @@ -0,0 +1,52 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "crypto/sha256" + "fmt" + "os" + "runtime" + "testing" +) + +// setupTmpdir sets up a known golden layout, covering all allowed file/folder types in a nar +func setupTmpdir(t *testing.T) string { + tmpdir := t.TempDir() + pwd, _ := os.Getwd() + os.Chdir(tmpdir) + defer os.Chdir(pwd) + os.MkdirAll("sub/dir", 0755) + os.Symlink("brokenfile", "brokenlink") + os.Symlink("sub/dir", "dirl") + os.Symlink("/abs/nonexistentdir", "dirb") + os.Create("sub/dir/file1") + f, _ := os.Create("file2m") + _ = f.Truncate(2 * 1024 * 1024) + f.Close() + os.Symlink("../file2m", "sub/goodlink") + return tmpdir +} + +func TestWriteNar(t *testing.T) { + if runtime.GOOS == "windows" { + // Skip test on Windows as the Nix package manager is not supported on this platform + t.Skip("nix package manager is not available on Windows") + } + dir := setupTmpdir(t) + t.Run("nar", func(t *testing.T) { + // obtained via `nix-store --dump /tmp/... | sha256sum` of the above test dir + expected := "727613a36f41030e93a4abf2649c3ec64a2757ccff364e3f6f7d544eb976e442" + h := sha256.New() + os.Chdir(dir) + err := writeNAR(h, os.DirFS(".")) + if err != nil { + t.Fatal(err) + } + hash := fmt.Sprintf("%x", h.Sum(nil)) + if expected != hash { + t.Fatal("sha256sum of nar not matched", hash, expected) + } + }) +} diff --git a/cmd/natc/ippool/ippool.go b/cmd/natc/ippool/ippool.go new file mode 100644 index 0000000000000..3a46a6e7ad186 --- /dev/null +++ b/cmd/natc/ippool/ippool.go @@ -0,0 +1,118 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// ippool implements IP address storage, creation, and retrieval for cmd/natc +package ippool + +import ( + "errors" + "log" + "math/big" + "net/netip" + "sync" + + "github.com/gaissmai/bart" + "go4.org/netipx" + "tailscale.com/syncs" + "tailscale.com/tailcfg" + "tailscale.com/util/dnsname" + "tailscale.com/util/mak" +) + +var ErrNoIPsAvailable = errors.New("no IPs available") + +type IPPool struct { + perPeerMap syncs.Map[tailcfg.NodeID, *perPeerState] + IPSet *netipx.IPSet +} + +func (ipp *IPPool) DomainForIP(from tailcfg.NodeID, addr netip.Addr) (string, bool) { + ps, ok := ipp.perPeerMap.Load(from) + if !ok { + log.Printf("handleTCPFlow: no perPeerState for %v", from) + return "", false + } + domain, ok := ps.domainForIP(addr) + if !ok { + log.Printf("handleTCPFlow: no domain for IP %v\n", addr) + return "", false + } + return domain, ok +} + +func (ipp *IPPool) IPForDomain(from tailcfg.NodeID, domain string) (netip.Addr, error) { + npps := &perPeerState{ + ipset: ipp.IPSet, + } + ps, _ := ipp.perPeerMap.LoadOrStore(from, npps) + return ps.ipForDomain(domain) +} + +// perPeerState holds the state for a single peer. +type perPeerState struct { + ipset *netipx.IPSet + + mu sync.Mutex + addrInUse *big.Int + domainToAddr map[string]netip.Addr + addrToDomain *bart.Table[string] +} + +// domainForIP returns the domain name assigned to the given IP address and +// whether it was found. +func (ps *perPeerState) domainForIP(ip netip.Addr) (_ string, ok bool) { + ps.mu.Lock() + defer ps.mu.Unlock() + if ps.addrToDomain == nil { + return "", false + } + return ps.addrToDomain.Lookup(ip) +} + +// ipForDomain assigns a pair of unique IP addresses for the given domain and +// returns them. The first address is an IPv4 address and the second is an IPv6 +// address. If the domain already has assigned addresses, it returns them. +func (ps *perPeerState) ipForDomain(domain string) (netip.Addr, error) { + fqdn, err := dnsname.ToFQDN(domain) + if err != nil { + return netip.Addr{}, err + } + domain = fqdn.WithoutTrailingDot() + + ps.mu.Lock() + defer ps.mu.Unlock() + if addr, ok := ps.domainToAddr[domain]; ok { + return addr, nil + } + addr := ps.assignAddrsLocked(domain) + if !addr.IsValid() { + return netip.Addr{}, ErrNoIPsAvailable + } + return addr, nil +} + +// unusedIPv4Locked returns an unused IPv4 address from the available ranges. +func (ps *perPeerState) unusedIPv4Locked() netip.Addr { + if ps.addrInUse == nil { + ps.addrInUse = big.NewInt(0) + } + return allocAddr(ps.ipset, ps.addrInUse) +} + +// assignAddrsLocked assigns a pair of unique IP addresses for the given domain +// and returns them. The first address is an IPv4 address and the second is an +// IPv6 address. It does not check if the domain already has assigned addresses. +// ps.mu must be held. +func (ps *perPeerState) assignAddrsLocked(domain string) netip.Addr { + if ps.addrToDomain == nil { + ps.addrToDomain = &bart.Table[string]{} + } + v4 := ps.unusedIPv4Locked() + if !v4.IsValid() { + return netip.Addr{} + } + addr := v4 + mak.Set(&ps.domainToAddr, domain, addr) + ps.addrToDomain.Insert(netip.PrefixFrom(addr, addr.BitLen()), domain) + return addr +} diff --git a/cmd/natc/ippool/ippool_test.go b/cmd/natc/ippool/ippool_test.go new file mode 100644 index 0000000000000..2919d7757af8c --- /dev/null +++ b/cmd/natc/ippool/ippool_test.go @@ -0,0 +1,107 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ippool + +import ( + "errors" + "fmt" + "net/netip" + "testing" + + "go4.org/netipx" + "tailscale.com/tailcfg" + "tailscale.com/util/must" +) + +func TestIPPoolExhaustion(t *testing.T) { + smallPrefix := netip.MustParsePrefix("100.64.1.0/30") // Only 4 IPs: .0, .1, .2, .3 + var ipsb netipx.IPSetBuilder + ipsb.AddPrefix(smallPrefix) + addrPool := must.Get(ipsb.IPSet()) + pool := IPPool{IPSet: addrPool} + + assignedIPs := make(map[netip.Addr]string) + + domains := []string{"a.example.com", "b.example.com", "c.example.com", "d.example.com", "e.example.com"} + + var errs []error + + from := tailcfg.NodeID(12345) + + for i := 0; i < 5; i++ { + for _, domain := range domains { + addr, err := pool.IPForDomain(from, domain) + if err != nil { + errs = append(errs, fmt.Errorf("failed to get IP for domain %q: %w", domain, err)) + continue + } + + if d, ok := assignedIPs[addr]; ok { + if d != domain { + t.Errorf("IP %s reused for domain %q, previously assigned to %q", addr, domain, d) + } + } else { + assignedIPs[addr] = domain + } + } + } + + for addr, domain := range assignedIPs { + if addr.Is4() && !smallPrefix.Contains(addr) { + t.Errorf("IP %s for domain %q not in expected range %s", addr, domain, smallPrefix) + } + } + + // expect one error for each iteration with the 5th domain + if len(errs) != 5 { + t.Errorf("Expected 5 errors, got %d: %v", len(errs), errs) + } + for _, err := range errs { + if !errors.Is(err, ErrNoIPsAvailable) { + t.Errorf("generateDNSResponse() error = %v, want ErrNoIPsAvailable", err) + } + } +} + +func TestIPPool(t *testing.T) { + var ipsb netipx.IPSetBuilder + ipsb.AddPrefix(netip.MustParsePrefix("100.64.1.0/24")) + addrPool := must.Get(ipsb.IPSet()) + pool := IPPool{ + IPSet: addrPool, + } + from := tailcfg.NodeID(12345) + addr, err := pool.IPForDomain(from, "example.com") + if err != nil { + t.Fatalf("ipForDomain() error = %v", err) + } + + if !addr.IsValid() { + t.Fatal("ipForDomain() returned an invalid address") + } + + if !addr.Is4() { + t.Errorf("Address is not IPv4: %s", addr) + } + + if !addrPool.Contains(addr) { + t.Errorf("IPv4 address %s not in range %s", addr, addrPool) + } + + domain, ok := pool.DomainForIP(from, addr) + if !ok { + t.Errorf("domainForIP(%s) not found", addr) + } else if domain != "example.com" { + t.Errorf("domainForIP(%s) = %s, want %s", addr, domain, "example.com") + } + + addr2, err := pool.IPForDomain(from, "example.com") + if err != nil { + t.Fatalf("ipForDomain() second call error = %v", err) + } + + if addr.Compare(addr2) != 0 { + t.Errorf("ipForDomain() second call = %v, want %v", addr2, addr) + } +} diff --git a/cmd/natc/ippool/ipx.go b/cmd/natc/ippool/ipx.go new file mode 100644 index 0000000000000..8259a56dbf30e --- /dev/null +++ b/cmd/natc/ippool/ipx.go @@ -0,0 +1,130 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ippool + +import ( + "math/big" + "math/bits" + "math/rand/v2" + "net/netip" + + "go4.org/netipx" +) + +func addrLessOrEqual(a, b netip.Addr) bool { + if a.Less(b) { + return true + } + if a == b { + return true + } + return false +} + +// indexOfAddr returns the index of addr in ipset, or -1 if not found. +func indexOfAddr(addr netip.Addr, ipset *netipx.IPSet) int { + var base int // offset of the current range + for _, r := range ipset.Ranges() { + if addr.Less(r.From()) { + return -1 + } + numFrom := v4ToNum(r.From()) + if addrLessOrEqual(addr, r.To()) { + numInRange := int(v4ToNum(addr) - numFrom) + return base + numInRange + } + numTo := v4ToNum(r.To()) + base += int(numTo-numFrom) + 1 + } + return -1 +} + +// addrAtIndex returns the address at the given index in ipset, or an empty +// address if index is out of range. +func addrAtIndex(index int, ipset *netipx.IPSet) netip.Addr { + if index < 0 { + return netip.Addr{} + } + var base int // offset of the current range + for _, r := range ipset.Ranges() { + numFrom := v4ToNum(r.From()) + numTo := v4ToNum(r.To()) + if index <= base+int(numTo-numFrom) { + return numToV4(uint32(int(numFrom) + index - base)) + } + base += int(numTo-numFrom) + 1 + } + return netip.Addr{} +} + +// TODO(golang/go#9455): once we have uint128 we can easily implement for all addrs. + +// v4ToNum returns a uint32 representation of the IPv4 address. If addr is not +// an IPv4 address, this function will panic. +func v4ToNum(addr netip.Addr) uint32 { + addr = addr.Unmap() + if !addr.Is4() { + panic("only IPv4 addresses are supported by v4ToNum") + } + b := addr.As4() + var o uint32 + o = o<<8 | uint32(b[0]) + o = o<<8 | uint32(b[1]) + o = o<<8 | uint32(b[2]) + o = o<<8 | uint32(b[3]) + return o +} + +func numToV4(i uint32) netip.Addr { + var addr [4]byte + addr[0] = byte((i >> 24) & 0xff) + addr[1] = byte((i >> 16) & 0xff) + addr[2] = byte((i >> 8) & 0xff) + addr[3] = byte(i & 0xff) + return netip.AddrFrom4(addr) +} + +// allocAddr returns an address in ipset that is not already marked allocated in allocated. +func allocAddr(ipset *netipx.IPSet, allocated *big.Int) netip.Addr { + // first try to allocate a random IP from each range, if we land on one. + var base uint32 // index offset of the current range + for _, r := range ipset.Ranges() { + numFrom := v4ToNum(r.From()) + numTo := v4ToNum(r.To()) + randInRange := rand.N(numTo - numFrom) + randIndex := base + randInRange + if allocated.Bit(int(randIndex)) == 0 { + allocated.SetBit(allocated, int(randIndex), 1) + return numToV4(numFrom + randInRange) + } + base += numTo - numFrom + 1 + } + + // fall back to seeking a free bit in the allocated set + index := -1 + for i, word := range allocated.Bits() { + zbi := leastZeroBit(uint(word)) + if zbi == -1 { + continue + } + index = i*bits.UintSize + zbi + allocated.SetBit(allocated, index, 1) + break + } + if index == -1 { + return netip.Addr{} + } + return addrAtIndex(index, ipset) +} + +// leastZeroBit returns the index of the least significant zero bit in the given uint, or -1 +// if all bits are set. +func leastZeroBit(n uint) int { + notN := ^n + rightmostBit := notN & -notN + if rightmostBit == 0 { + return -1 + } + return bits.TrailingZeros(rightmostBit) +} diff --git a/cmd/natc/ippool/ipx_test.go b/cmd/natc/ippool/ipx_test.go new file mode 100644 index 0000000000000..2e2b9d3d45baf --- /dev/null +++ b/cmd/natc/ippool/ipx_test.go @@ -0,0 +1,150 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ippool + +import ( + "math" + "math/big" + "net/netip" + "testing" + + "go4.org/netipx" + "tailscale.com/util/must" +) + +func TestV4ToNum(t *testing.T) { + cases := []struct { + addr netip.Addr + num uint32 + }{ + {netip.MustParseAddr("0.0.0.0"), 0}, + {netip.MustParseAddr("255.255.255.255"), 0xffffffff}, + {netip.MustParseAddr("8.8.8.8"), 0x08080808}, + {netip.MustParseAddr("192.168.0.1"), 0xc0a80001}, + {netip.MustParseAddr("10.0.0.1"), 0x0a000001}, + {netip.MustParseAddr("172.16.0.1"), 0xac100001}, + {netip.MustParseAddr("100.64.0.1"), 0x64400001}, + } + + for _, tc := range cases { + num := v4ToNum(tc.addr) + if num != tc.num { + t.Errorf("addrNum(%v) = %d, want %d", tc.addr, num, tc.num) + } + if numToV4(num) != tc.addr { + t.Errorf("numToV4(%d) = %v, want %v", num, numToV4(num), tc.addr) + } + } + + func() { + defer func() { + if r := recover(); r == nil { + t.Fatal("expected panic") + } + }() + + v4ToNum(netip.MustParseAddr("::1")) + }() +} + +func TestAddrIndex(t *testing.T) { + builder := netipx.IPSetBuilder{} + builder.AddRange(netipx.MustParseIPRange("10.0.0.1-10.0.0.5")) + builder.AddRange(netipx.MustParseIPRange("192.168.0.1-192.168.0.10")) + ipset := must.Get(builder.IPSet()) + + indexCases := []struct { + addr netip.Addr + index int + }{ + {netip.MustParseAddr("10.0.0.1"), 0}, + {netip.MustParseAddr("10.0.0.2"), 1}, + {netip.MustParseAddr("10.0.0.3"), 2}, + {netip.MustParseAddr("10.0.0.4"), 3}, + {netip.MustParseAddr("10.0.0.5"), 4}, + {netip.MustParseAddr("192.168.0.1"), 5}, + {netip.MustParseAddr("192.168.0.5"), 9}, + {netip.MustParseAddr("192.168.0.10"), 14}, + {netip.MustParseAddr("172.16.0.1"), -1}, // Not in set + } + + for _, tc := range indexCases { + index := indexOfAddr(tc.addr, ipset) + if index != tc.index { + t.Errorf("indexOfAddr(%v) = %d, want %d", tc.addr, index, tc.index) + } + if tc.index == -1 { + continue + } + addr := addrAtIndex(tc.index, ipset) + if addr != tc.addr { + t.Errorf("addrAtIndex(%d) = %v, want %v", tc.index, addr, tc.addr) + } + } +} + +func TestAllocAddr(t *testing.T) { + builder := netipx.IPSetBuilder{} + builder.AddRange(netipx.MustParseIPRange("10.0.0.1-10.0.0.5")) + builder.AddRange(netipx.MustParseIPRange("192.168.0.1-192.168.0.10")) + ipset := must.Get(builder.IPSet()) + + allocated := new(big.Int) + for range 15 { + addr := allocAddr(ipset, allocated) + if !addr.IsValid() { + t.Errorf("allocAddr() = invalid, want valid") + } + if !ipset.Contains(addr) { + t.Errorf("allocAddr() = %v, not in set", addr) + } + } + addr := allocAddr(ipset, allocated) + if addr.IsValid() { + t.Errorf("allocAddr() = %v, want invalid", addr) + } + wantAddr := netip.MustParseAddr("10.0.0.2") + allocated.SetBit(allocated, indexOfAddr(wantAddr, ipset), 0) + addr = allocAddr(ipset, allocated) + if addr != wantAddr { + t.Errorf("allocAddr() = %v, want %v", addr, wantAddr) + } +} + +func TestLeastZeroBit(t *testing.T) { + cases := []struct { + num uint + want int + }{ + {math.MaxUint, -1}, + {0, 0}, + {0b01, 1}, + {0b11, 2}, + {0b111, 3}, + {math.MaxUint, -1}, + {math.MaxUint - 1, 0}, + } + if math.MaxUint == math.MaxUint64 { + cases = append(cases, []struct { + num uint + want int + }{ + {math.MaxUint >> 1, 63}, + }...) + } else { + cases = append(cases, []struct { + num uint + want int + }{ + {math.MaxUint >> 1, 31}, + }...) + } + + for _, tc := range cases { + got := leastZeroBit(tc.num) + if got != tc.want { + t.Errorf("leastZeroBit(%b) = %d, want %d", tc.num, got, tc.want) + } + } +} diff --git a/cmd/natc/natc.go b/cmd/natc/natc.go index d94523c6e4161..b327f55bdc3ea 100644 --- a/cmd/natc/natc.go +++ b/cmd/natc/natc.go @@ -8,8 +8,8 @@ package main import ( "context" - "encoding/binary" "errors" + "expvar" "flag" "fmt" "log" @@ -19,24 +19,25 @@ import ( "net/netip" "os" "strings" - "sync" "time" "github.com/gaissmai/bart" "github.com/inetaf/tcpproxy" "github.com/peterbourgon/ff/v3" + "go4.org/netipx" "golang.org/x/net/dns/dnsmessage" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/cmd/natc/ippool" "tailscale.com/envknob" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/net/netutil" - "tailscale.com/syncs" - "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/tsweb" - "tailscale.com/util/dnsname" "tailscale.com/util/mak" + "tailscale.com/util/must" + "tailscale.com/wgengine/netstack" ) func main() { @@ -90,18 +91,6 @@ func main() { } ignoreDstTable.Insert(pfx, true) } - var v4Prefixes []netip.Prefix - for _, s := range strings.Split(*v4PfxStr, ",") { - p := netip.MustParsePrefix(strings.TrimSpace(s)) - if p.Masked() != p { - log.Fatalf("v4 prefix %v is not a masked prefix", p) - } - v4Prefixes = append(v4Prefixes, p) - } - if len(v4Prefixes) == 0 { - log.Fatalf("no v4 prefixes specified") - } - dnsAddr := v4Prefixes[0].Addr() ts := &tsnet.Server{ Hostname: *hostname, } @@ -112,6 +101,7 @@ func main() { ts.Port = uint16(*wgPort) } defer ts.Close() + if *verboseTSNet { ts.Logf = log.Printf } @@ -129,6 +119,16 @@ func main() { log.Fatalf("debug serve: %v", http.Serve(dln, mux)) }() } + + if err := ts.Start(); err != nil { + log.Fatalf("ts.Start: %v", err) + } + // TODO(raggi): this is not a public interface or guarantee. + ns := ts.Sys().Netstack.Get().(*netstack.Impl) + if *debugPort != 0 { + expvar.Publish("netstack", ns.ExpVar()) + } + lc, err := ts.LocalClient() if err != nil { log.Fatalf("LocalClient() failed: %v", err) @@ -137,36 +137,68 @@ func main() { log.Fatalf("ts.Up: %v", err) } + var prefixes []netip.Prefix + for _, s := range strings.Split(*v4PfxStr, ",") { + p := netip.MustParsePrefix(strings.TrimSpace(s)) + if p.Masked() != p { + log.Fatalf("v4 prefix %v is not a masked prefix", p) + } + prefixes = append(prefixes, p) + } + routes, dnsAddr, addrPool := calculateAddresses(prefixes) + + v6ULA := ula(uint16(*siteID)) c := &connector{ ts: ts, - lc: lc, - dnsAddr: dnsAddr, - v4Ranges: v4Prefixes, - v6ULA: ula(uint16(*siteID)), + whois: lc, + v6ULA: v6ULA, ignoreDsts: ignoreDstTable, + ipPool: &ippool.IPPool{IPSet: addrPool}, + routes: routes, + dnsAddr: dnsAddr, + resolver: net.DefaultResolver, } - c.run(ctx) + c.run(ctx, lc) +} + +func calculateAddresses(prefixes []netip.Prefix) (*netipx.IPSet, netip.Addr, *netipx.IPSet) { + var ipsb netipx.IPSetBuilder + for _, p := range prefixes { + ipsb.AddPrefix(p) + } + routesToAdvertise := must.Get(ipsb.IPSet()) + dnsAddr := routesToAdvertise.Ranges()[0].From() + ipsb.Remove(dnsAddr) + addrPool := must.Get(ipsb.IPSet()) + return routesToAdvertise, dnsAddr, addrPool +} + +type lookupNetIPer interface { + LookupNetIP(ctx context.Context, net, host string) ([]netip.Addr, error) +} + +type whoiser interface { + WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) } type connector struct { // ts is the tsnet.Server used to host the connector. ts *tsnet.Server - // lc is the LocalClient used to interact with the tsnet.Server hosting this + // whois is the local.Client used to interact with the tsnet.Server hosting this // connector. - lc *tailscale.LocalClient + whois whoiser // dnsAddr is the IPv4 address to listen on for DNS requests. It is used to // prevent the app connector from assigning it to a domain. dnsAddr netip.Addr - // v4Ranges is the list of IPv4 ranges to advertise and assign addresses from. - // These are masked prefixes. - v4Ranges []netip.Prefix + // routes is the set of IPv4 ranges advertised to the tailnet, or ipset with + // the dnsAddr removed. + routes *netipx.IPSet + // v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses. v6ULA netip.Prefix - perPeerMap syncs.Map[tailcfg.NodeID, *perPeerState] - // ignoreDsts is initialized at start up with the contents of --ignore-destinations (if none it is nil) // It is never mutated, only used for lookups. // Users who want to natc a DNS wildcard but not every address record in that domain can supply the @@ -175,6 +207,12 @@ type connector struct { // return a dns response that contains the ip addresses we discovered with the lookup (ie not the // natc behavior, which would return a dummy ip address pointing at natc). ignoreDsts *bart.Table[bool] + + // ipPool contains the per-peer IPv4 address assignments. + ipPool *ippool.IPPool + + // resolver is used to lookup IP addresses for DNS queries. + resolver lookupNetIPer } // v6ULA is the ULA prefix used by the app connector to assign IPv6 addresses. @@ -194,11 +232,11 @@ func ula(siteID uint16) netip.Prefix { // // The passed in context is only used for the initial setup. The connector runs // forever. -func (c *connector) run(ctx context.Context) { - if _, err := c.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ +func (c *connector) run(ctx context.Context, lc *local.Client) { + if _, err := lc.EditPrefs(ctx, &ipn.MaskedPrefs{ AdvertiseRoutesSet: true, Prefs: ipn.Prefs{ - AdvertiseRoutes: append(c.v4Ranges, c.v6ULA), + AdvertiseRoutes: append(c.routes.Prefixes(), c.v6ULA), }, }); err != nil { log.Fatalf("failed to advertise routes: %v", err) @@ -228,26 +266,6 @@ func (c *connector) serveDNS() { } } -func lookupDestinationIP(domain string) ([]netip.Addr, error) { - netIPs, err := net.LookupIP(domain) - if err != nil { - var dnsError *net.DNSError - if errors.As(err, &dnsError) && dnsError.IsNotFound { - return nil, nil - } else { - return nil, err - } - } - var addrs []netip.Addr - for _, ip := range netIPs { - a, ok := netip.AddrFromSlice(ip) - if ok { - addrs = append(addrs, a) - } - } - return addrs, nil -} - // handleDNS handles a DNS request to the app connector. // It generates a response based on the request and the node that sent it. // @@ -262,157 +280,161 @@ func lookupDestinationIP(domain string) ([]netip.Addr, error) { func (c *connector) handleDNS(pc net.PacketConn, buf []byte, remoteAddr *net.UDPAddr) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - who, err := c.lc.WhoIs(ctx, remoteAddr.String()) + who, err := c.whois.WhoIs(ctx, remoteAddr.String()) if err != nil { - log.Printf("HandleDNS: WhoIs failed: %v\n", err) + log.Printf("HandleDNS(remote=%s): WhoIs failed: %v\n", remoteAddr.String(), err) return } var msg dnsmessage.Message err = msg.Unpack(buf) if err != nil { - log.Printf("HandleDNS: dnsmessage unpack failed: %v\n ", err) + log.Printf("HandleDNS(remote=%s): dnsmessage unpack failed: %v\n", remoteAddr.String(), err) return } - // If there are destination ips that we don't want to route, we - // have to do a dns lookup here to find the destination ip. - if c.ignoreDsts != nil { - if len(msg.Questions) > 0 { - q := msg.Questions[0] - switch q.Type { - case dnsmessage.TypeAAAA, dnsmessage.TypeA: - dstAddrs, err := lookupDestinationIP(q.Name.String()) + var resolves map[string][]netip.Addr + var addrQCount int + for _, q := range msg.Questions { + if q.Type != dnsmessage.TypeA && q.Type != dnsmessage.TypeAAAA { + continue + } + addrQCount++ + if _, ok := resolves[q.Name.String()]; !ok { + addrs, err := c.resolver.LookupNetIP(ctx, "ip", q.Name.String()) + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) && dnsErr.IsNotFound { + continue + } + if err != nil { + log.Printf("HandleDNS(remote=%s): lookup destination failed: %v\n", remoteAddr.String(), err) + return + } + // Note: If _any_ destination is ignored, pass through all of the resolved + // addresses as-is. + // + // This could result in some odd split-routing if there was a mix of + // ignored and non-ignored addresses, but it's currently the user + // preferred behavior. + if !c.ignoreDestination(addrs) { + addr, err := c.ipPool.IPForDomain(who.Node.ID, q.Name.String()) if err != nil { - log.Printf("HandleDNS: lookup destination failed: %v\n ", err) - return - } - if c.ignoreDestination(dstAddrs) { - bs, err := dnsResponse(&msg, dstAddrs) - // TODO (fran): treat as SERVFAIL - if err != nil { - log.Printf("HandleDNS: generate ignore response failed: %v\n", err) - return - } - _, err = pc.WriteTo(bs, remoteAddr) - if err != nil { - log.Printf("HandleDNS: write failed: %v\n", err) - } + log.Printf("HandleDNS(remote=%s): lookup destination failed: %v\n", remoteAddr.String(), err) return } + addrs = []netip.Addr{addr, v6ForV4(c.v6ULA.Addr(), addr)} } + mak.Set(&resolves, q.Name.String(), addrs) } } - // None of the destination IP addresses match an ignore destination prefix, do - // the natc thing. - - resp, err := c.generateDNSResponse(&msg, who.Node.ID) - // TODO (fran): treat as SERVFAIL - if err != nil { - log.Printf("HandleDNS: connector handling failed: %v\n", err) - return - } - // TODO (fran): treat as NXDOMAIN - if len(resp) == 0 { - return - } - // This connector handled the DNS request - _, err = pc.WriteTo(resp, remoteAddr) - if err != nil { - log.Printf("HandleDNS: write failed: %v\n", err) - } -} - -// tsMBox is the mailbox used in SOA records. -// The convention is to replace the @ symbol with a dot. -// So in this case, the mailbox is support.tailscale.com. with the trailing dot -// to indicate that it is a fully qualified domain name. -var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") -// generateDNSResponse generates a DNS response for the given request. The from -// argument is the NodeID of the node that sent the request. -func (c *connector) generateDNSResponse(req *dnsmessage.Message, from tailcfg.NodeID) ([]byte, error) { - pm, _ := c.perPeerMap.LoadOrStore(from, &perPeerState{c: c}) - var addrs []netip.Addr - if len(req.Questions) > 0 { - switch req.Questions[0].Type { - case dnsmessage.TypeAAAA, dnsmessage.TypeA: - var err error - addrs, err = pm.ipForDomain(req.Questions[0].Name.String()) - if err != nil { - return nil, err - } - } + rcode := dnsmessage.RCodeSuccess + if addrQCount > 0 && len(resolves) == 0 { + rcode = dnsmessage.RCodeNameError } - return dnsResponse(req, addrs) -} -// dnsResponse makes a DNS response for the natc. If the dnsmessage is requesting TypeAAAA -// or TypeA the provided addrs of the requested type will be used. -func dnsResponse(req *dnsmessage.Message, addrs []netip.Addr) ([]byte, error) { b := dnsmessage.NewBuilder(nil, dnsmessage.Header{ - ID: req.Header.ID, + ID: msg.Header.ID, Response: true, Authoritative: true, + RCode: rcode, }) b.EnableCompression() - if len(req.Questions) == 0 { - return b.Finish() - } - q := req.Questions[0] if err := b.StartQuestions(); err != nil { - return nil, err + log.Printf("HandleDNS(remote=%s): dnsmessage start questions failed: %v\n", remoteAddr.String(), err) + return } - if err := b.Question(q); err != nil { - return nil, err + + for _, q := range msg.Questions { + b.Question(q) } + if err := b.StartAnswers(); err != nil { - return nil, err + log.Printf("HandleDNS(remote=%s): dnsmessage start answers failed: %v\n", remoteAddr.String(), err) + return } - switch q.Type { - case dnsmessage.TypeAAAA, dnsmessage.TypeA: - want6 := q.Type == dnsmessage.TypeAAAA - for _, ip := range addrs { - if want6 != ip.Is6() { - continue + + for _, q := range msg.Questions { + switch q.Type { + case dnsmessage.TypeSOA: + if err := b.SOAResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, + Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, + ); err != nil { + log.Printf("HandleDNS(remote=%s): dnsmessage SOA resource failed: %v\n", remoteAddr.String(), err) + return + } + case dnsmessage.TypeNS: + if err := b.NSResource( + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.NSResource{NS: tsMBox}, + ); err != nil { + log.Printf("HandleDNS(remote=%s): dnsmessage NS resource failed: %v\n", remoteAddr.String(), err) + return } - if want6 { + case dnsmessage.TypeAAAA: + for _, addr := range resolves[q.Name.String()] { + if !addr.Is6() { + continue + } if err := b.AAAAResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 5}, - dnsmessage.AAAAResource{AAAA: ip.As16()}, + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AAAAResource{AAAA: addr.As16()}, ); err != nil { - return nil, err + log.Printf("HandleDNS(remote=%s): dnsmessage AAAA resource failed: %v\n", remoteAddr.String(), err) + return + } + } + case dnsmessage.TypeA: + for _, addr := range resolves[q.Name.String()] { + if !addr.Is4() { + continue } - } else { if err := b.AResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 5}, - dnsmessage.AResource{A: ip.As4()}, + dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, + dnsmessage.AResource{A: addr.As4()}, ); err != nil { - return nil, err + log.Printf("HandleDNS(remote=%s): dnsmessage A resource failed: %v\n", remoteAddr.String(), err) + return } } } - case dnsmessage.TypeSOA: - if err := b.SOAResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.SOAResource{NS: q.Name, MBox: tsMBox, Serial: 2023030600, - Refresh: 120, Retry: 120, Expire: 120, MinTTL: 60}, - ); err != nil { - return nil, err - } - case dnsmessage.TypeNS: - if err := b.NSResource( - dnsmessage.ResourceHeader{Name: q.Name, Class: q.Class, TTL: 120}, - dnsmessage.NSResource{NS: tsMBox}, - ); err != nil { - return nil, err - } } - return b.Finish() + + out, err := b.Finish() + if err != nil { + log.Printf("HandleDNS(remote=%s): dnsmessage finish failed: %v\n", remoteAddr.String(), err) + return + } + _, err = pc.WriteTo(out, remoteAddr) + if err != nil { + log.Printf("HandleDNS(remote=%s): write failed: %v\n", remoteAddr.String(), err) + } } +func v6ForV4(ula netip.Addr, v4 netip.Addr) netip.Addr { + as16 := ula.As16() + as4 := v4.As4() + copy(as16[12:], as4[:]) + return netip.AddrFrom16(as16) +} + +func v4ForV6(v6 netip.Addr) netip.Addr { + as16 := v6.As16() + var as4 [4]byte + copy(as4[:], as16[12:]) + return netip.AddrFrom4(as4) +} + +// tsMBox is the mailbox used in SOA records. +// The convention is to replace the @ symbol with a dot. +// So in this case, the mailbox is support.tailscale.com. with the trailing dot +// to indicate that it is a fully qualified domain name. +var tsMBox = dnsmessage.MustNewName("support.tailscale.com.") + // handleTCPFlow handles a TCP flow from the given source to the given // destination. It uses the source address to determine the node that sent the // request and the destination address to determine the domain that the request @@ -421,32 +443,31 @@ func dnsResponse(req *dnsmessage.Message, addrs []netip.Addr) ([]byte, error) { func (c *connector) handleTCPFlow(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - who, err := c.lc.WhoIs(ctx, src.Addr().String()) + who, err := c.whois.WhoIs(ctx, src.Addr().String()) cancel() if err != nil { log.Printf("HandleTCPFlow: WhoIs failed: %v\n", err) return nil, false } - - from := who.Node.ID - ps, ok := c.perPeerMap.Load(from) - if !ok { - log.Printf("handleTCPFlow: no perPeerState for %v", from) - return nil, false + dstAddr := dst.Addr() + if dstAddr.Is6() { + dstAddr = v4ForV6(dstAddr) } - domain, ok := ps.domainForIP(dst.Addr()) + domain, ok := c.ipPool.DomainForIP(who.Node.ID, dstAddr) if !ok { - log.Printf("handleTCPFlow: no domain for IP %v\n", dst.Addr()) return nil, false } return func(conn net.Conn) { - proxyTCPConn(conn, domain) + proxyTCPConn(conn, domain, c) }, true } // ignoreDestination reports whether any of the provided dstAddrs match the prefixes configured // in --ignore-destinations func (c *connector) ignoreDestination(dstAddrs []netip.Addr) bool { + if c.ignoreDsts == nil { + return false + } for _, a := range dstAddrs { if _, ok := c.ignoreDsts.Lookup(a); ok { return true @@ -455,121 +476,74 @@ func (c *connector) ignoreDestination(dstAddrs []netip.Addr) bool { return false } -func proxyTCPConn(c net.Conn, dest string) { +func proxyTCPConn(c net.Conn, dest string, ctor *connector) { if c.RemoteAddr() == nil { log.Printf("proxyTCPConn: nil RemoteAddr") c.Close() return } - addrPortStr := c.LocalAddr().String() - _, port, err := net.SplitHostPort(addrPortStr) + laddr, err := netip.ParseAddrPort(c.LocalAddr().String()) if err != nil { - log.Printf("tcpRoundRobinHandler.Handle: bogus addrPort %q", addrPortStr) + log.Printf("proxyTCPConn: ParseAddrPort failed: %v", err) c.Close() return } - p := &tcpproxy.Proxy{ - ListenFunc: func(net, laddr string) (net.Listener, error) { - return netutil.NewOneConnListener(c, nil), nil - }, + daddrs, err := ctor.resolver.LookupNetIP(context.TODO(), "ip", dest) + if err != nil { + log.Printf("proxyTCPConn: LookupNetIP failed: %v", err) + c.Close() + return } - p.AddRoute(addrPortStr, &tcpproxy.DialProxy{ - Addr: fmt.Sprintf("%s:%s", dest, port), - }) - p.Start() -} - -// perPeerState holds the state for a single peer. -type perPeerState struct { - c *connector - mu sync.Mutex - domainToAddr map[string][]netip.Addr - addrToDomain *bart.Table[string] -} - -// domainForIP returns the domain name assigned to the given IP address and -// whether it was found. -func (ps *perPeerState) domainForIP(ip netip.Addr) (_ string, ok bool) { - ps.mu.Lock() - defer ps.mu.Unlock() - if ps.addrToDomain == nil { - return "", false + if len(daddrs) == 0 { + log.Printf("proxyTCPConn: no IP addresses found for %s", dest) + c.Close() + return } - return ps.addrToDomain.Lookup(ip) -} -// ipForDomain assigns a pair of unique IP addresses for the given domain and -// returns them. The first address is an IPv4 address and the second is an IPv6 -// address. If the domain already has assigned addresses, it returns them. -func (ps *perPeerState) ipForDomain(domain string) ([]netip.Addr, error) { - fqdn, err := dnsname.ToFQDN(domain) - if err != nil { - return nil, err + if ctor.ignoreDestination(daddrs) { + log.Printf("proxyTCPConn: closing connection to ignored destination %s (%v)", dest, daddrs) + c.Close() + return } - domain = fqdn.WithoutTrailingDot() - ps.mu.Lock() - defer ps.mu.Unlock() - if addrs, ok := ps.domainToAddr[domain]; ok { - return addrs, nil + p := &tcpproxy.Proxy{ + ListenFunc: func(net, laddr string) (net.Listener, error) { + return netutil.NewOneConnListener(c, nil), nil + }, } - addrs := ps.assignAddrsLocked(domain) - return addrs, nil -} -// isIPUsedLocked reports whether the given IP address is already assigned to a -// domain. -// ps.mu must be held. -func (ps *perPeerState) isIPUsedLocked(ip netip.Addr) bool { - _, ok := ps.addrToDomain.Lookup(ip) - return ok -} - -// unusedIPv4Locked returns an unused IPv4 address from the available ranges. -func (ps *perPeerState) unusedIPv4Locked() netip.Addr { - // TODO: skip ranges that have been exhausted - for _, r := range ps.c.v4Ranges { - ip := randV4(r) - for r.Contains(ip) { - if !ps.isIPUsedLocked(ip) && ip != ps.c.dnsAddr { - return ip + // TODO(raggi): more code could avoid this shuffle, but avoiding allocations + // for now most of the time daddrs will be short. + rand.Shuffle(len(daddrs), func(i, j int) { + daddrs[i], daddrs[j] = daddrs[j], daddrs[i] + }) + daddr := daddrs[0] + + // Try to match the upstream and downstream protocols (v4/v6) + if laddr.Addr().Is6() { + for _, addr := range daddrs { + if addr.Is6() { + daddr = addr + break + } + } + } else { + for _, addr := range daddrs { + if addr.Is4() { + daddr = addr + break } - ip = ip.Next() } } - return netip.Addr{} -} - -// randV4 returns a random IPv4 address within the given prefix. -func randV4(maskedPfx netip.Prefix) netip.Addr { - bits := 32 - maskedPfx.Bits() - randBits := rand.Uint32N(1 << uint(bits)) - ip4 := maskedPfx.Addr().As4() - pn := binary.BigEndian.Uint32(ip4[:]) - binary.BigEndian.PutUint32(ip4[:], randBits|pn) - return netip.AddrFrom4(ip4) -} + // TODO(raggi): drop this library, it ends up being allocation and + // indirection heavy and really doesn't help us here. + dsockaddrs := netip.AddrPortFrom(daddr, laddr.Port()).String() + p.AddRoute(dsockaddrs, &tcpproxy.DialProxy{ + Addr: dsockaddrs, + }) -// assignAddrsLocked assigns a pair of unique IP addresses for the given domain -// and returns them. The first address is an IPv4 address and the second is an -// IPv6 address. It does not check if the domain already has assigned addresses. -// ps.mu must be held. -func (ps *perPeerState) assignAddrsLocked(domain string) []netip.Addr { - if ps.addrToDomain == nil { - ps.addrToDomain = &bart.Table[string]{} - } - v4 := ps.unusedIPv4Locked() - as16 := ps.c.v6ULA.Addr().As16() - as4 := v4.As4() - copy(as16[12:], as4[:]) - v6 := netip.AddrFrom16(as16) - addrs := []netip.Addr{v4, v6} - mak.Set(&ps.domainToAddr, domain, addrs) - for _, a := range addrs { - ps.addrToDomain.Insert(netip.PrefixFrom(a, a.BitLen()), domain) - } - return addrs + p.Start() } diff --git a/cmd/natc/natc_test.go b/cmd/natc/natc_test.go new file mode 100644 index 0000000000000..0320db8a4ea59 --- /dev/null +++ b/cmd/natc/natc_test.go @@ -0,0 +1,482 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "context" + "fmt" + "io" + "net" + "net/netip" + "testing" + "time" + + "github.com/gaissmai/bart" + "golang.org/x/net/dns/dnsmessage" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/cmd/natc/ippool" + "tailscale.com/tailcfg" + "tailscale.com/util/must" +) + +func prefixEqual(a, b netip.Prefix) bool { + return a.Bits() == b.Bits() && a.Addr() == b.Addr() +} + +func TestULA(t *testing.T) { + tests := []struct { + name string + siteID uint16 + expected string + }{ + {"zero", 0, "fd7a:115c:a1e0:a99c:0000::/80"}, + {"one", 1, "fd7a:115c:a1e0:a99c:0001::/80"}, + {"max", 65535, "fd7a:115c:a1e0:a99c:ffff::/80"}, + {"random", 12345, "fd7a:115c:a1e0:a99c:3039::/80"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := ula(tc.siteID) + expected := netip.MustParsePrefix(tc.expected) + if !prefixEqual(got, expected) { + t.Errorf("ula(%d) = %s; want %s", tc.siteID, got, expected) + } + }) + } +} + +type recordingPacketConn struct { + writes [][]byte +} + +func (w *recordingPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { + w.writes = append(w.writes, b) + return len(b), nil +} + +func (w *recordingPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { + return 0, nil, io.EOF +} + +func (w *recordingPacketConn) Close() error { + return nil +} + +func (w *recordingPacketConn) LocalAddr() net.Addr { + return nil +} + +func (w *recordingPacketConn) RemoteAddr() net.Addr { + return nil +} + +func (w *recordingPacketConn) SetDeadline(t time.Time) error { + return nil +} + +func (w *recordingPacketConn) SetReadDeadline(t time.Time) error { + return nil +} + +func (w *recordingPacketConn) SetWriteDeadline(t time.Time) error { + return nil +} + +type resolver struct { + resolves map[string][]netip.Addr + fails map[string]bool +} + +func (r *resolver) LookupNetIP(ctx context.Context, _net, host string) ([]netip.Addr, error) { + if addrs, ok := r.resolves[host]; ok { + return addrs, nil + } + if _, ok := r.fails[host]; ok { + return nil, &net.DNSError{IsTimeout: false, IsNotFound: false, Name: host, IsTemporary: true} + } + return nil, &net.DNSError{IsNotFound: true, Name: host} +} + +type whois struct { + peers map[string]*apitype.WhoIsResponse +} + +func (w *whois) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { + addr := netip.MustParseAddrPort(remoteAddr).Addr().String() + if peer, ok := w.peers[addr]; ok { + return peer, nil + } + return nil, fmt.Errorf("peer not found") +} + +func TestDNSResponse(t *testing.T) { + tests := []struct { + name string + questions []dnsmessage.Question + wantEmpty bool + wantAnswers []struct { + name string + qType dnsmessage.Type + addr netip.Addr + } + wantNXDOMAIN bool + wantIgnored bool + }{ + { + name: "empty_request", + questions: []dnsmessage.Question{}, + wantEmpty: false, + wantAnswers: nil, + }, + { + name: "a_record", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }, + wantAnswers: []struct { + name string + qType dnsmessage.Type + addr netip.Addr + }{ + { + name: "example.com.", + qType: dnsmessage.TypeA, + addr: netip.MustParseAddr("100.64.0.0"), + }, + }, + }, + { + name: "aaaa_record", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }, + }, + wantAnswers: []struct { + name string + qType dnsmessage.Type + addr netip.Addr + }{ + { + name: "example.com.", + qType: dnsmessage.TypeAAAA, + addr: netip.MustParseAddr("fd7a:115c:a1e0::"), + }, + }, + }, + { + name: "soa_record", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeSOA, + Class: dnsmessage.ClassINET, + }, + }, + wantAnswers: nil, + }, + { + name: "ns_record", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("example.com."), + Type: dnsmessage.TypeNS, + Class: dnsmessage.ClassINET, + }, + }, + wantAnswers: nil, + }, + { + name: "nxdomain", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("noexist.example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }, + wantNXDOMAIN: true, + }, + { + name: "servfail", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("fail.example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }, + wantEmpty: true, // TODO: pass through instead? + }, + { + name: "ignored", + questions: []dnsmessage.Question{ + { + Name: dnsmessage.MustNewName("ignore.example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + }, + wantAnswers: []struct { + name string + qType dnsmessage.Type + addr netip.Addr + }{ + { + name: "ignore.example.com.", + qType: dnsmessage.TypeA, + addr: netip.MustParseAddr("8.8.4.4"), + }, + }, + wantIgnored: true, + }, + } + + var rpc recordingPacketConn + remoteAddr := must.Get(net.ResolveUDPAddr("udp", "100.64.254.1:12345")) + + routes, dnsAddr, addrPool := calculateAddresses([]netip.Prefix{netip.MustParsePrefix("10.64.0.0/24")}) + v6ULA := ula(1) + c := connector{ + resolver: &resolver{ + resolves: map[string][]netip.Addr{ + "example.com.": { + netip.MustParseAddr("8.8.8.8"), + netip.MustParseAddr("2001:4860:4860::8888"), + }, + "ignore.example.com.": { + netip.MustParseAddr("8.8.4.4"), + }, + }, + fails: map[string]bool{ + "fail.example.com.": true, + }, + }, + whois: &whois{ + peers: map[string]*apitype.WhoIsResponse{ + "100.64.254.1": { + Node: &tailcfg.Node{ID: 123}, + }, + }, + }, + ignoreDsts: &bart.Table[bool]{}, + routes: routes, + v6ULA: v6ULA, + ipPool: &ippool.IPPool{IPSet: addrPool}, + dnsAddr: dnsAddr, + } + c.ignoreDsts.Insert(netip.MustParsePrefix("8.8.4.4/32"), true) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + rb := dnsmessage.NewBuilder(nil, + dnsmessage.Header{ + ID: 1234, + }, + ) + must.Do(rb.StartQuestions()) + for _, q := range tc.questions { + rb.Question(q) + } + + c.handleDNS(&rpc, must.Get(rb.Finish()), remoteAddr) + + writes := rpc.writes + rpc.writes = rpc.writes[:0] + + if tc.wantEmpty { + if len(writes) != 0 { + t.Errorf("handleDNS() returned non-empty response when expected empty") + } + return + } + + if !tc.wantEmpty && len(writes) != 1 { + t.Fatalf("handleDNS() returned an unexpected number of responses: %d, want 1", len(writes)) + } + + resp := writes[0] + var msg dnsmessage.Message + err := msg.Unpack(resp) + if err != nil { + t.Fatalf("Failed to unpack response: %v", err) + } + + if !msg.Header.Response { + t.Errorf("Response header is not set") + } + + if msg.Header.ID != 1234 { + t.Errorf("Response ID = %d, want %d", msg.Header.ID, 1234) + } + + if len(tc.wantAnswers) > 0 { + if len(msg.Answers) != len(tc.wantAnswers) { + t.Errorf("got %d answers, want %d:\n%s", len(msg.Answers), len(tc.wantAnswers), msg.GoString()) + } else { + for i, want := range tc.wantAnswers { + ans := msg.Answers[i] + + gotName := ans.Header.Name.String() + if gotName != want.name { + t.Errorf("answer[%d] name = %s, want %s", i, gotName, want.name) + } + + if ans.Header.Type != want.qType { + t.Errorf("answer[%d] type = %v, want %v", i, ans.Header.Type, want.qType) + } + + switch want.qType { + case dnsmessage.TypeA: + if ans.Body.(*dnsmessage.AResource) == nil { + t.Errorf("answer[%d] not an A record", i) + continue + } + case dnsmessage.TypeAAAA: + if ans.Body.(*dnsmessage.AAAAResource) == nil { + t.Errorf("answer[%d] not an AAAA record", i) + continue + } + } + + var gotIP netip.Addr + switch want.qType { + case dnsmessage.TypeA: + resource := ans.Body.(*dnsmessage.AResource) + gotIP = netip.AddrFrom4([4]byte(resource.A)) + case dnsmessage.TypeAAAA: + resource := ans.Body.(*dnsmessage.AAAAResource) + gotIP = netip.AddrFrom16([16]byte(resource.AAAA)) + } + + var wantIP netip.Addr + if tc.wantIgnored { + var net string + var fxSelectIP func(netip.Addr) bool + switch want.qType { + case dnsmessage.TypeA: + net = "ip4" + fxSelectIP = func(a netip.Addr) bool { + return a.Is4() + } + case dnsmessage.TypeAAAA: + //TODO(fran) is this branch exercised? + net = "ip6" + fxSelectIP = func(a netip.Addr) bool { + return a.Is6() + } + } + ips := must.Get(c.resolver.LookupNetIP(t.Context(), net, want.name)) + for _, ip := range ips { + if fxSelectIP(ip) { + wantIP = ip + break + } + } + } else { + addr := must.Get(c.ipPool.IPForDomain(tailcfg.NodeID(123), want.name)) + switch want.qType { + case dnsmessage.TypeA: + wantIP = addr + case dnsmessage.TypeAAAA: + wantIP = v6ForV4(v6ULA.Addr(), addr) + } + } + if gotIP != wantIP { + t.Errorf("answer[%d] IP = %s, want %s", i, gotIP, wantIP) + } + } + } + } + + if tc.wantNXDOMAIN { + if msg.RCode != dnsmessage.RCodeNameError { + t.Errorf("expected NXDOMAIN, got %v", msg.RCode) + } + if len(msg.Answers) != 0 { + t.Errorf("expected no answers, got %d", len(msg.Answers)) + } + } + }) + } +} + +func TestIgnoreDestination(t *testing.T) { + ignoreDstTable := &bart.Table[bool]{} + ignoreDstTable.Insert(netip.MustParsePrefix("192.168.1.0/24"), true) + ignoreDstTable.Insert(netip.MustParsePrefix("10.0.0.0/8"), true) + + c := &connector{ + ignoreDsts: ignoreDstTable, + } + + tests := []struct { + name string + addrs []netip.Addr + expected bool + }{ + { + name: "no_match", + addrs: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1")}, + expected: false, + }, + { + name: "one_match", + addrs: []netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("192.168.1.5")}, + expected: true, + }, + { + name: "all_match", + addrs: []netip.Addr{netip.MustParseAddr("10.0.0.1"), netip.MustParseAddr("192.168.1.5")}, + expected: true, + }, + { + name: "empty_addrs", + addrs: []netip.Addr{}, + expected: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := c.ignoreDestination(tc.addrs) + if got != tc.expected { + t.Errorf("ignoreDestination(%v) = %v, want %v", tc.addrs, got, tc.expected) + } + }) + } +} + +func TestV6V4(t *testing.T) { + v6ULA := ula(1) + + tests := [][]string{ + {"100.64.0.0", "fd7a:115c:a1e0:a99c:1:0:6440:0"}, + {"0.0.0.0", "fd7a:115c:a1e0:a99c:1::"}, + {"255.255.255.255", "fd7a:115c:a1e0:a99c:1:0:ffff:ffff"}, + } + + for i, test := range tests { + // to v6 + v6 := v6ForV4(v6ULA.Addr(), netip.MustParseAddr(test[0])) + want := netip.MustParseAddr(test[1]) + if v6 != want { + t.Fatalf("test %d: want: %v, got: %v", i, want, v6) + } + + // to v4 + v4 := v4ForV6(netip.MustParseAddr(test[1])) + want = netip.MustParseAddr(test[0]) + if v4 != want { + t.Fatalf("test %d: want: %v, got: %v", i, want, v4) + } + } +} diff --git a/cmd/pgproxy/pgproxy.go b/cmd/pgproxy/pgproxy.go index 468649ee2bc5f..e102c8ae47411 100644 --- a/cmd/pgproxy/pgproxy.go +++ b/cmd/pgproxy/pgproxy.go @@ -24,7 +24,7 @@ import ( "strings" "time" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/metrics" "tailscale.com/tsnet" "tailscale.com/tsweb" @@ -105,7 +105,7 @@ type proxy struct { upstreamHost string // "my.database.com" upstreamCertPool *x509.CertPool downstreamCert []tls.Certificate - client *tailscale.LocalClient + client *local.Client activeSessions expvar.Int startedSessions expvar.Int @@ -115,7 +115,7 @@ type proxy struct { // newProxy returns a proxy that forwards connections to // upstreamAddr. The upstream's TLS session is verified using the CA // cert(s) in upstreamCAPath. -func newProxy(upstreamAddr, upstreamCAPath string, client *tailscale.LocalClient) (*proxy, error) { +func newProxy(upstreamAddr, upstreamCAPath string, client *local.Client) (*proxy, error) { bs, err := os.ReadFile(upstreamCAPath) if err != nil { return nil, err diff --git a/cmd/proxy-to-grafana/proxy-to-grafana.go b/cmd/proxy-to-grafana/proxy-to-grafana.go index f1c67bad5e28e..27f5e338c8d65 100644 --- a/cmd/proxy-to-grafana/proxy-to-grafana.go +++ b/cmd/proxy-to-grafana/proxy-to-grafana.go @@ -19,8 +19,25 @@ // header_property = username // auto_sign_up = true // whitelist = 127.0.0.1 -// headers = Name:X-WEBAUTH-NAME +// headers = Email:X-Webauth-User, Name:X-Webauth-Name, Role:X-Webauth-Role // enable_login_token = true +// +// You can use grants in Tailscale ACL to give users different roles in Grafana. +// For example, to give group:eng the Editor role, add the following to your ACLs: +// +// "grants": [ +// { +// "src": ["group:eng"], +// "dst": ["tag:grafana"], +// "app": { +// "tailscale.com/cap/proxy-to-grafana": [{ +// "role": "editor", +// }], +// }, +// }, +// ], +// +// If multiple roles are specified, the most permissive role is used. package main import ( @@ -36,7 +53,7 @@ import ( "strings" "time" - "tailscale.com/client/tailscale" + "tailscale.com/client/tailscale/apitype" "tailscale.com/tailcfg" "tailscale.com/tsnet" ) @@ -49,6 +66,57 @@ var ( loginServer = flag.String("login-server", "", "URL to alternative control server. If empty, the default Tailscale control is used.") ) +// aclCap is the Tailscale ACL capability used to configure proxy-to-grafana. +const aclCap tailcfg.PeerCapability = "tailscale.com/cap/proxy-to-grafana" + +// aclGrant is an access control rule that assigns Grafana permissions +// while provisioning a user. +type aclGrant struct { + // Role is one of: "viewer", "editor", "admin". + Role string `json:"role"` +} + +// grafanaRole defines possible Grafana roles. +type grafanaRole int + +const ( + // Roles are ordered by their permissions, with the least permissive role first. + // If a user has multiple roles, the most permissive role is used. + ViewerRole grafanaRole = iota + EditorRole + AdminRole +) + +// String returns the string representation of a grafanaRole. +// It is used as a header value in the HTTP request to Grafana. +func (r grafanaRole) String() string { + switch r { + case ViewerRole: + return "Viewer" + case EditorRole: + return "Editor" + case AdminRole: + return "Admin" + default: + // A safe default. + return "Viewer" + } +} + +// roleFromString converts a string to a grafanaRole. +// It is used to parse the role from the ACL grant. +func roleFromString(s string) (grafanaRole, error) { + switch strings.ToLower(s) { + case "viewer": + return ViewerRole, nil + case "editor": + return EditorRole, nil + case "admin": + return AdminRole, nil + } + return ViewerRole, fmt.Errorf("unknown role: %q", s) +} + func main() { flag.Parse() if *hostname == "" || strings.Contains(*hostname, ".") { @@ -127,14 +195,23 @@ func main() { log.Fatal(http.Serve(ln, proxy)) } -func modifyRequest(req *http.Request, localClient *tailscale.LocalClient) { - // with enable_login_token set to true, we get a cookie that handles +func modifyRequest(req *http.Request, localClient whoisIdentitySource) { + // Delete any existing X-Webauth-* headers to prevent possible spoofing + // if getting Tailnet identity fails. + for h := range req.Header { + if strings.HasPrefix(h, "X-Webauth-") { + req.Header.Del(h) + } + } + + // Set the X-Webauth-* headers only for the /login path + // With enable_login_token set to true, we get a cookie that handles // auth for paths that are not /login if req.URL.Path != "/login" { return } - user, err := getTailscaleUser(req.Context(), localClient, req.RemoteAddr) + user, role, err := getTailscaleIdentity(req.Context(), localClient, req.RemoteAddr) if err != nil { log.Printf("error getting Tailscale user: %v", err) return @@ -142,19 +219,37 @@ func modifyRequest(req *http.Request, localClient *tailscale.LocalClient) { req.Header.Set("X-Webauth-User", user.LoginName) req.Header.Set("X-Webauth-Name", user.DisplayName) + req.Header.Set("X-Webauth-Role", role.String()) } -func getTailscaleUser(ctx context.Context, localClient *tailscale.LocalClient, ipPort string) (*tailcfg.UserProfile, error) { +func getTailscaleIdentity(ctx context.Context, localClient whoisIdentitySource, ipPort string) (*tailcfg.UserProfile, grafanaRole, error) { whois, err := localClient.WhoIs(ctx, ipPort) if err != nil { - return nil, fmt.Errorf("failed to identify remote host: %w", err) + return nil, ViewerRole, fmt.Errorf("failed to identify remote host: %w", err) } if whois.Node.IsTagged() { - return nil, fmt.Errorf("tagged nodes are not users") + return nil, ViewerRole, fmt.Errorf("tagged nodes are not users") } if whois.UserProfile == nil || whois.UserProfile.LoginName == "" { - return nil, fmt.Errorf("failed to identify remote user") + return nil, ViewerRole, fmt.Errorf("failed to identify remote user") } - return whois.UserProfile, nil + role := ViewerRole + grants, err := tailcfg.UnmarshalCapJSON[aclGrant](whois.CapMap, aclCap) + if err != nil { + return nil, ViewerRole, fmt.Errorf("failed to unmarshal ACL grants: %w", err) + } + for _, g := range grants { + r, err := roleFromString(g.Role) + if err != nil { + return nil, ViewerRole, fmt.Errorf("failed to parse role: %w", err) + } + role = max(role, r) + } + + return whois.UserProfile, role, nil +} + +type whoisIdentitySource interface { + WhoIs(ctx context.Context, ipPort string) (*apitype.WhoIsResponse, error) } diff --git a/cmd/proxy-to-grafana/proxy-to-grafana_test.go b/cmd/proxy-to-grafana/proxy-to-grafana_test.go new file mode 100644 index 0000000000000..083c4bc494ad6 --- /dev/null +++ b/cmd/proxy-to-grafana/proxy-to-grafana_test.go @@ -0,0 +1,77 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package main + +import ( + "context" + "fmt" + "net/http/httptest" + "testing" + + "tailscale.com/client/tailscale/apitype" + "tailscale.com/tailcfg" +) + +type mockWhoisSource struct { + id *apitype.WhoIsResponse +} + +func (m *mockWhoisSource) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsResponse, error) { + if m.id == nil { + return nil, fmt.Errorf("missing mock identity") + } + return m.id, nil +} + +var whois = &apitype.WhoIsResponse{ + UserProfile: &tailcfg.UserProfile{ + LoginName: "foobar@example.com", + DisplayName: "Foobar", + }, + Node: &tailcfg.Node{ + ID: 1, + }, +} + +func TestModifyRequest_Login(t *testing.T) { + req := httptest.NewRequest("GET", "/login", nil) + modifyRequest(req, &mockWhoisSource{id: whois}) + + if got := req.Header.Get("X-Webauth-User"); got != "foobar@example.com" { + t.Errorf("X-Webauth-User = %q; want %q", got, "foobar@example.com") + } + + if got := req.Header.Get("X-Webauth-Role"); got != "Viewer" { + t.Errorf("X-Webauth-Role = %q; want %q", got, "Viewer") + } +} + +func TestModifyRequest_RemoveHeaders_Login(t *testing.T) { + req := httptest.NewRequest("GET", "/login", nil) + req.Header.Set("X-Webauth-User", "malicious@example.com") + req.Header.Set("X-Webauth-Role", "Admin") + + modifyRequest(req, &mockWhoisSource{id: whois}) + + if got := req.Header.Get("X-Webauth-User"); got != "foobar@example.com" { + t.Errorf("X-Webauth-User = %q; want %q", got, "foobar@example.com") + } + if got := req.Header.Get("X-Webauth-Role"); got != "Viewer" { + t.Errorf("X-Webauth-Role = %q; want %q", got, "Viewer") + } +} + +func TestModifyRequest_RemoveHeaders_API(t *testing.T) { + req := httptest.NewRequest("DELETE", "/api/org/users/1", nil) + req.Header.Set("X-Webauth-User", "malicious@example.com") + req.Header.Set("X-Webauth-Role", "Admin") + + modifyRequest(req, &mockWhoisSource{id: whois}) + + if got := req.Header.Get("X-Webauth-User"); got != "" { + t.Errorf("X-Webauth-User = %q; want %q", got, "") + } + if got := req.Header.Get("X-Webauth-Role"); got != "" { + t.Errorf("X-Webauth-Role = %q; want %q", got, "") + } +} diff --git a/cmd/sniproxy/handlers.go b/cmd/sniproxy/handlers.go index 102110fe36dc7..1973eecc017a3 100644 --- a/cmd/sniproxy/handlers.go +++ b/cmd/sniproxy/handlers.go @@ -14,6 +14,7 @@ import ( "github.com/inetaf/tcpproxy" "tailscale.com/net/netutil" + "tailscale.com/net/netx" ) type tcpRoundRobinHandler struct { @@ -22,7 +23,7 @@ type tcpRoundRobinHandler struct { To []string // DialContext is used to make the outgoing TCP connection. - DialContext func(ctx context.Context, network, address string) (net.Conn, error) + DialContext netx.DialFunc // ReachableIPs enumerates the IP addresses this handler is reachable on. ReachableIPs []netip.Addr diff --git a/cmd/sniproxy/sniproxy.go b/cmd/sniproxy/sniproxy.go index fa83aaf4ab44e..c020b4a1f1605 100644 --- a/cmd/sniproxy/sniproxy.go +++ b/cmd/sniproxy/sniproxy.go @@ -22,7 +22,7 @@ import ( "strings" "github.com/peterbourgon/ff/v3" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/tailcfg" @@ -157,10 +157,8 @@ func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, pro // NetMap contains app-connector configuration if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() { - sn := nm.SelfNode.AsStruct() - var c appctype.AppConnectorConfig - nmConf, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorConfig](sn.CapMap, configCapKey) + nmConf, err := tailcfg.UnmarshalNodeCapViewJSON[appctype.AppConnectorConfig](nm.SelfNode.CapMap(), configCapKey) if err != nil { log.Printf("failed to read app connector configuration from coordination server: %v", err) } else if len(nmConf) > 0 { @@ -185,7 +183,7 @@ func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, pro type sniproxy struct { srv Server ts *tsnet.Server - lc *tailscale.LocalClient + lc *local.Client } func (s *sniproxy) advertiseRoutesFromConfig(ctx context.Context, c *appctype.AppConnectorConfig) error { diff --git a/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go b/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go index ee929299a4273..39af584ecd481 100644 --- a/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go +++ b/cmd/ssh-auth-none-demo/ssh-auth-none-demo.go @@ -6,6 +6,9 @@ // highlight the unique parts of the Tailscale SSH server so SSH // client authors can hit it easily and fix their SSH clients without // needing to set up Tailscale and Tailscale SSH. +// +// Connections are allowed using any username except for "denyme". Connecting as +// "denyme" will result in an authentication failure with error message. package main import ( @@ -16,6 +19,7 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "errors" "flag" "fmt" "io" @@ -24,7 +28,7 @@ import ( "path/filepath" "time" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" "tailscale.com/tempfork/gliderlabs/ssh" ) @@ -62,13 +66,21 @@ func main() { Handler: handleSessionPostSSHAuth, ServerConfigCallback: func(ctx ssh.Context) *gossh.ServerConfig { start := time.Now() + var spac gossh.ServerPreAuthConn return &gossh.ServerConfig{ - NextAuthMethodCallback: func(conn gossh.ConnMetadata, prevErrors []error) []string { - return []string{"tailscale"} + PreAuthConnCallback: func(conn gossh.ServerPreAuthConn) { + spac = conn }, NoClientAuth: true, // required for the NoClientAuthCallback to run NoClientAuthCallback: func(cm gossh.ConnMetadata) (*gossh.Permissions, error) { - cm.SendAuthBanner(fmt.Sprintf("# Banner: doing none auth at %v\r\n", time.Since(start))) + spac.SendAuthBanner(fmt.Sprintf("# Banner: doing none auth at %v\r\n", time.Since(start))) + + if cm.User() == "denyme" { + return nil, &gossh.BannerError{ + Err: errors.New("denying access"), + Message: "denyme is not allowed to access this machine\n", + } + } totalBanners := 2 if cm.User() == "banners" { @@ -77,9 +89,9 @@ func main() { for banner := 2; banner <= totalBanners; banner++ { time.Sleep(time.Second) if banner == totalBanners { - cm.SendAuthBanner(fmt.Sprintf("# Banner%d: access granted at %v\r\n", banner, time.Since(start))) + spac.SendAuthBanner(fmt.Sprintf("# Banner%d: access granted at %v\r\n", banner, time.Since(start))) } else { - cm.SendAuthBanner(fmt.Sprintf("# Banner%d at %v\r\n", banner, time.Since(start))) + spac.SendAuthBanner(fmt.Sprintf("# Banner%d at %v\r\n", banner, time.Since(start))) } } return nil, nil diff --git a/cmd/stunc/stunc.go b/cmd/stunc/stunc.go index 9743a33007265..c4b2eedd39f90 100644 --- a/cmd/stunc/stunc.go +++ b/cmd/stunc/stunc.go @@ -5,24 +5,40 @@ package main import ( + "flag" "log" "net" "os" "strconv" + "time" "tailscale.com/net/stun" ) func main() { log.SetFlags(0) - - if len(os.Args) < 2 || len(os.Args) > 3 { - log.Fatalf("usage: %s [port]", os.Args[0]) - } - host := os.Args[1] + var host string port := "3478" - if len(os.Args) == 3 { - port = os.Args[2] + + var readTimeout time.Duration + flag.DurationVar(&readTimeout, "timeout", 3*time.Second, "response wait timeout") + + flag.Parse() + + values := flag.Args() + if len(values) < 1 || len(values) > 2 { + log.Printf("usage: %s [port]", os.Args[0]) + flag.PrintDefaults() + os.Exit(1) + } else { + for i, value := range values { + switch i { + case 0: + host = value + case 1: + port = value + } + } } _, err := strconv.ParseUint(port, 10, 16) if err != nil { @@ -46,6 +62,10 @@ func main() { log.Fatal(err) } + err = c.SetReadDeadline(time.Now().Add(readTimeout)) + if err != nil { + log.Fatal(err) + } var buf [1024]byte n, raddr, err := c.ReadFromUDPAddrPort(buf[:]) if err != nil { diff --git a/cmd/stund/depaware.txt b/cmd/stund/depaware.txt index a35f59516ee32..da768039431fe 100644 --- a/cmd/stund/depaware.txt +++ b/cmd/stund/depaware.txt @@ -8,12 +8,11 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/internal/jsonwire from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/jsontext from github.com/go-json-experiment/json+ - github.com/google/uuid from tailscale.com/util/fastuuid + github.com/munnerz/goautoneg from github.com/prometheus/common/expfmt đŸ’Ŗ github.com/prometheus/client_golang/prometheus from tailscale.com/tsweb/promvarz github.com/prometheus/client_golang/prometheus/internal from github.com/prometheus/client_golang/prometheus github.com/prometheus/client_model/go from github.com/prometheus/client_golang/prometheus+ github.com/prometheus/common/expfmt from github.com/prometheus/client_golang/prometheus+ - github.com/prometheus/common/internal/bitbucket.org/ww/goautoneg from github.com/prometheus/common/expfmt github.com/prometheus/common/model from github.com/prometheus/client_golang/prometheus+ LD github.com/prometheus/procfs from github.com/prometheus/client_golang/prometheus LD github.com/prometheus/procfs/internal/fs from github.com/prometheus/procfs @@ -50,33 +49,38 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar google.golang.org/protobuf/types/known/timestamppb from github.com/prometheus/client_golang/prometheus+ tailscale.com from tailscale.com/version tailscale.com/envknob from tailscale.com/tsweb+ + tailscale.com/feature from tailscale.com/tsweb tailscale.com/kube/kubetypes from tailscale.com/envknob tailscale.com/metrics from tailscale.com/net/stunserver+ tailscale.com/net/netaddr from tailscale.com/net/tsaddr tailscale.com/net/stun from tailscale.com/net/stunserver tailscale.com/net/stunserver from tailscale.com/cmd/stund tailscale.com/net/tsaddr from tailscale.com/tsweb + tailscale.com/syncs from tailscale.com/metrics tailscale.com/tailcfg from tailscale.com/version - tailscale.com/tsweb from tailscale.com/cmd/stund - tailscale.com/tsweb/promvarz from tailscale.com/tsweb + tailscale.com/tsweb from tailscale.com/cmd/stund+ + tailscale.com/tsweb/promvarz from tailscale.com/cmd/stund tailscale.com/tsweb/varz from tailscale.com/tsweb+ tailscale.com/types/dnstype from tailscale.com/tailcfg tailscale.com/types/ipproto from tailscale.com/tailcfg tailscale.com/types/key from tailscale.com/tailcfg tailscale.com/types/lazy from tailscale.com/version+ - tailscale.com/types/logger from tailscale.com/tsweb + tailscale.com/types/logger from tailscale.com/tsweb+ tailscale.com/types/opt from tailscale.com/envknob+ tailscale.com/types/ptr from tailscale.com/tailcfg+ + tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/tailcfg+ tailscale.com/types/tkatype from tailscale.com/tailcfg+ tailscale.com/types/views from tailscale.com/net/tsaddr+ tailscale.com/util/ctxkey from tailscale.com/tsweb+ L đŸ’Ŗ tailscale.com/util/dirwalk from tailscale.com/metrics tailscale.com/util/dnsname from tailscale.com/tailcfg - tailscale.com/util/fastuuid from tailscale.com/tsweb - tailscale.com/util/lineread from tailscale.com/version/distro + tailscale.com/util/lineiter from tailscale.com/version/distro + tailscale.com/util/mak from tailscale.com/syncs tailscale.com/util/nocasemaps from tailscale.com/types/ipproto + tailscale.com/util/rands from tailscale.com/tsweb tailscale.com/util/slicesx from tailscale.com/tailcfg + tailscale.com/util/testenv from tailscale.com/types/logger tailscale.com/util/vizerror from tailscale.com/tailcfg+ tailscale.com/version from tailscale.com/envknob+ tailscale.com/version/distro from tailscale.com/envknob @@ -86,11 +90,12 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar golang.org/x/crypto/cryptobyte from crypto/ecdsa+ golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ golang.org/x/crypto/curve25519 from golang.org/x/crypto/nacl/box+ - golang.org/x/crypto/hkdf from crypto/tls+ + golang.org/x/crypto/internal/alias from golang.org/x/crypto/chacha20+ + golang.org/x/crypto/internal/poly1305 from golang.org/x/crypto/chacha20poly1305+ golang.org/x/crypto/nacl/box from tailscale.com/types/key golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - golang.org/x/crypto/sha3 from crypto/internal/mlkem768+ + golang.org/x/exp/constraints from tailscale.com/tsweb/varz golang.org/x/net/dns/dnsmessage from net+ golang.org/x/net/http/httpguts from net/http golang.org/x/net/http/httpproxy from net/http @@ -112,7 +117,7 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar container/list from crypto/tls+ context from crypto/tls+ crypto from crypto/ecdh+ - crypto/aes from crypto/ecdsa+ + crypto/aes from crypto/internal/hpke+ crypto/cipher from crypto/aes+ crypto/des from crypto/tls+ crypto/dsa from crypto/x509 @@ -120,20 +125,59 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar crypto/ecdsa from crypto/tls+ crypto/ed25519 from crypto/tls+ crypto/elliptic from crypto/ecdsa+ - crypto/hmac from crypto/tls+ + crypto/hmac from crypto/tls + crypto/internal/boring from crypto/aes+ + crypto/internal/boring/bbig from crypto/ecdsa+ + crypto/internal/boring/sig from crypto/internal/boring + crypto/internal/entropy from crypto/internal/fips140/drbg + crypto/internal/fips140 from crypto/internal/fips140/aes+ + crypto/internal/fips140/aes from crypto/aes+ + crypto/internal/fips140/aes/gcm from crypto/cipher+ + crypto/internal/fips140/alias from crypto/cipher+ + crypto/internal/fips140/bigmod from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/check from crypto/internal/fips140/aes+ + crypto/internal/fips140/drbg from crypto/internal/fips140/aes/gcm+ + crypto/internal/fips140/ecdh from crypto/ecdh + crypto/internal/fips140/ecdsa from crypto/ecdsa + crypto/internal/fips140/ed25519 from crypto/ed25519 + crypto/internal/fips140/edwards25519 from crypto/internal/fips140/ed25519 + crypto/internal/fips140/edwards25519/field from crypto/ecdh+ + crypto/internal/fips140/hkdf from crypto/internal/fips140/tls13+ + crypto/internal/fips140/hmac from crypto/hmac+ + crypto/internal/fips140/mlkem from crypto/tls + crypto/internal/fips140/nistec from crypto/elliptic+ + crypto/internal/fips140/nistec/fiat from crypto/internal/fips140/nistec + crypto/internal/fips140/rsa from crypto/rsa + crypto/internal/fips140/sha256 from crypto/internal/fips140/check+ + crypto/internal/fips140/sha3 from crypto/internal/fips140/hmac+ + crypto/internal/fips140/sha512 from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/subtle from crypto/internal/fips140/aes+ + crypto/internal/fips140/tls12 from crypto/tls + crypto/internal/fips140/tls13 from crypto/tls + crypto/internal/fips140deps/byteorder from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/cpu from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/godebug from crypto/internal/fips140+ + crypto/internal/fips140hash from crypto/ecdsa+ + crypto/internal/fips140only from crypto/cipher+ + crypto/internal/hpke from crypto/tls + crypto/internal/impl from crypto/internal/fips140/aes+ + crypto/internal/randutil from crypto/dsa+ + crypto/internal/sysrand from crypto/internal/entropy+ crypto/md5 from crypto/tls+ crypto/rand from crypto/ed25519+ crypto/rc4 from crypto/tls crypto/rsa from crypto/tls+ crypto/sha1 from crypto/tls+ crypto/sha256 from crypto/tls+ + crypto/sha3 from crypto/internal/fips140hash crypto/sha512 from crypto/ecdsa+ - crypto/subtle from crypto/aes+ + crypto/subtle from crypto/cipher+ crypto/tls from net/http+ + crypto/tls/internal/fips140tls from crypto/tls crypto/x509 from crypto/tls + D crypto/x509/internal/macos from crypto/x509 crypto/x509/pkix from crypto/x509 - database/sql/driver from github.com/google/uuid - embed from crypto/internal/nistec+ + embed from google.golang.org/protobuf/internal/editiondefaults+ encoding from encoding/json+ encoding/asn1 from crypto/x509+ encoding/base32 from github.com/go-json-experiment/json @@ -144,7 +188,7 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar encoding/pem from crypto/tls+ errors from bufio+ expvar from github.com/prometheus/client_golang/prometheus+ - flag from tailscale.com/cmd/stund + flag from tailscale.com/cmd/stund+ fmt from compress/flate+ go/token from google.golang.org/protobuf/internal/strs hash from crypto+ @@ -152,9 +196,48 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar hash/fnv from google.golang.org/protobuf/internal/detrand hash/maphash from go4.org/mem html from net/http/pprof+ + internal/abi from crypto/x509/internal/macos+ + internal/asan from internal/runtime/maps+ + internal/bisect from internal/godebug + internal/bytealg from bytes+ + internal/byteorder from crypto/cipher+ + internal/chacha8rand from math/rand/v2+ + internal/coverage/rtcov from runtime + internal/cpu from crypto/internal/fips140deps/cpu+ + internal/filepathlite from os+ + internal/fmtsort from fmt + internal/goarch from crypto/internal/fips140deps/cpu+ + internal/godebug from crypto/internal/fips140deps/godebug+ + internal/godebugs from internal/godebug+ + internal/goexperiment from hash/maphash+ + internal/goos from crypto/x509+ + internal/itoa from internal/poll+ + internal/msan from internal/runtime/maps+ + internal/nettrace from net+ + internal/oserror from io/fs+ + internal/poll from net+ + internal/profile from net/http/pprof + internal/profilerecord from runtime+ + internal/race from internal/poll+ + internal/reflectlite from context+ + internal/runtime/atomic from internal/runtime/exithook+ + internal/runtime/exithook from runtime + internal/runtime/maps from reflect+ + internal/runtime/math from internal/runtime/maps+ + internal/runtime/sys from crypto/subtle+ + L internal/runtime/syscall from runtime+ + internal/singleflight from net + internal/stringslite from embed+ + internal/sync from sync+ + internal/syscall/execenv from os + LD internal/syscall/unix from crypto/internal/sysrand+ + W internal/syscall/windows from crypto/internal/sysrand+ + W internal/syscall/windows/registry from mime+ + W internal/syscall/windows/sysdll from internal/syscall/windows+ + internal/testlog from os + internal/unsafeheader from internal/reflectlite+ io from bufio+ io/fs from crypto/x509+ - io/ioutil from google.golang.org/protobuf/internal/impl iter from maps+ log from expvar+ log/internal from log @@ -163,7 +246,7 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar math/big from crypto/dsa+ math/bits from compress/flate+ math/rand from math/big+ - math/rand/v2 from tailscale.com/util/fastuuid+ + math/rand/v2 from crypto/ecdsa+ mime from github.com/prometheus/common/expfmt+ mime/multipart from net/http mime/quotedprintable from mime/multipart @@ -171,17 +254,19 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar net/http from expvar+ net/http/httptrace from net/http net/http/internal from net/http + net/http/internal/ascii from net/http net/http/pprof from tailscale.com/tsweb net/netip from go4.org/netipx+ net/textproto from golang.org/x/net/http/httpguts+ net/url from crypto/x509+ - os from crypto/rand+ + os from crypto/internal/sysrand+ os/signal from tailscale.com/cmd/stund path from github.com/prometheus/client_golang/prometheus/internal+ path/filepath from crypto/x509+ reflect from crypto/x509+ regexp from github.com/prometheus/client_golang/prometheus/internal+ regexp/syntax from regexp + runtime from crypto/internal/fips140+ runtime/debug from github.com/prometheus/client_golang/prometheus+ runtime/metrics from github.com/prometheus/client_golang/prometheus+ runtime/pprof from net/http/pprof @@ -192,10 +277,12 @@ tailscale.com/cmd/stund dependencies: (generated by github.com/tailscale/depawar strings from bufio+ sync from compress/flate+ sync/atomic from context+ - syscall from crypto/rand+ + syscall from crypto/internal/sysrand+ text/tabwriter from runtime/pprof time from compress/gzip+ unicode from bytes+ unicode/utf16 from crypto/x509+ unicode/utf8 from bufio+ unique from net/netip + unsafe from bytes+ + weak from unique diff --git a/cmd/stund/stund.go b/cmd/stund/stund.go index c38429169b066..1055d966f42c5 100644 --- a/cmd/stund/stund.go +++ b/cmd/stund/stund.go @@ -15,6 +15,9 @@ import ( "tailscale.com/net/stunserver" "tailscale.com/tsweb" + + // Support for prometheus varz in tsweb + _ "tailscale.com/tsweb/promvarz" ) var ( diff --git a/cmd/systray/logo.go b/cmd/systray/logo.go deleted file mode 100644 index cd79c94a02ea4..0000000000000 --- a/cmd/systray/logo.go +++ /dev/null @@ -1,220 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build cgo || !darwin - -package main - -import ( - "bytes" - "context" - "image/color" - "image/png" - "sync" - "time" - - "fyne.io/systray" - "github.com/fogleman/gg" -) - -// tsLogo represents the state of the 3x3 dot grid in the Tailscale logo. -// A 0 represents a gray dot, any other value is a white dot. -type tsLogo [9]byte - -var ( - // disconnected is all gray dots - disconnected = tsLogo{ - 0, 0, 0, - 0, 0, 0, - 0, 0, 0, - } - - // connected is the normal Tailscale logo - connected = tsLogo{ - 0, 0, 0, - 1, 1, 1, - 0, 1, 0, - } - - // loading is a special tsLogo value that is not meant to be rendered directly, - // but indicates that the loading animation should be shown. - loading = tsLogo{'l', 'o', 'a', 'd', 'i', 'n', 'g'} - - // loadingIcons are shown in sequence as an animated loading icon. - loadingLogos = []tsLogo{ - { - 0, 1, 1, - 1, 0, 1, - 0, 0, 1, - }, - { - 0, 1, 1, - 0, 0, 1, - 0, 1, 0, - }, - { - 0, 1, 1, - 0, 0, 0, - 0, 0, 1, - }, - { - 0, 0, 1, - 0, 1, 0, - 0, 0, 0, - }, - { - 0, 1, 0, - 0, 0, 0, - 0, 0, 0, - }, - { - 0, 0, 0, - 0, 0, 1, - 0, 0, 0, - }, - { - 0, 0, 0, - 0, 0, 0, - 0, 0, 0, - }, - { - 0, 0, 1, - 0, 0, 0, - 0, 0, 0, - }, - { - 0, 0, 0, - 0, 0, 0, - 1, 0, 0, - }, - { - 0, 0, 0, - 0, 0, 0, - 1, 1, 0, - }, - { - 0, 0, 0, - 1, 0, 0, - 1, 1, 0, - }, - { - 0, 0, 0, - 1, 1, 0, - 0, 1, 0, - }, - { - 0, 0, 0, - 1, 1, 0, - 0, 1, 1, - }, - { - 0, 0, 0, - 1, 1, 1, - 0, 0, 1, - }, - { - 0, 1, 0, - 0, 1, 1, - 1, 0, 1, - }, - } -) - -var ( - black = color.NRGBA{0, 0, 0, 255} - white = color.NRGBA{255, 255, 255, 255} - gray = color.NRGBA{255, 255, 255, 102} -) - -// render returns a PNG image of the logo. -func (logo tsLogo) render() *bytes.Buffer { - const radius = 25 - const borderUnits = 1 - dim := radius * (8 + borderUnits*2) - - dc := gg.NewContext(dim, dim) - dc.DrawRectangle(0, 0, float64(dim), float64(dim)) - dc.SetColor(black) - dc.Fill() - - for y := 0; y < 3; y++ { - for x := 0; x < 3; x++ { - px := (borderUnits + 1 + 3*x) * radius - py := (borderUnits + 1 + 3*y) * radius - col := white - if logo[y*3+x] == 0 { - col = gray - } - dc.DrawCircle(float64(px), float64(py), radius) - dc.SetColor(col) - dc.Fill() - } - } - - b := bytes.NewBuffer(nil) - png.Encode(b, dc.Image()) - return b -} - -// setAppIcon renders logo and sets it as the systray icon. -func setAppIcon(icon tsLogo) { - if icon == loading { - startLoadingAnimation() - } else { - stopLoadingAnimation() - systray.SetIcon(icon.render().Bytes()) - } -} - -var ( - loadingMu sync.Mutex // protects loadingCancel - - // loadingCancel stops the loading animation in the systray icon. - // This is nil if the animation is not currently active. - loadingCancel func() -) - -// startLoadingAnimation starts the animated loading icon in the system tray. -// The animation continues until [stopLoadingAnimation] is called. -// If the loading animation is already active, this func does nothing. -func startLoadingAnimation() { - loadingMu.Lock() - defer loadingMu.Unlock() - - if loadingCancel != nil { - // loading icon already displayed - return - } - - ctx := context.Background() - ctx, loadingCancel = context.WithCancel(ctx) - - go func() { - t := time.NewTicker(500 * time.Millisecond) - var i int - for { - select { - case <-ctx.Done(): - return - case <-t.C: - systray.SetIcon(loadingLogos[i].render().Bytes()) - i++ - if i >= len(loadingLogos) { - i = 0 - } - } - } - }() -} - -// stopLoadingAnimation stops the animated loading icon in the system tray. -// If the loading animation is not currently active, this func does nothing. -func stopLoadingAnimation() { - loadingMu.Lock() - defer loadingMu.Unlock() - - if loadingCancel != nil { - loadingCancel() - loadingCancel = nil - } -} diff --git a/cmd/systray/systray.go b/cmd/systray/systray.go index aca38f627c65a..0185a1bc2dc5e 100644 --- a/cmd/systray/systray.go +++ b/cmd/systray/systray.go @@ -3,256 +3,13 @@ //go:build cgo || !darwin -// The systray command is a minimal Tailscale systray application for Linux. +// systray is a minimal Tailscale systray application. package main import ( - "context" - "errors" - "fmt" - "io" - "log" - "os" - "strings" - "sync" - "time" - - "fyne.io/systray" - "github.com/atotto/clipboard" - dbus "github.com/godbus/dbus/v5" - "github.com/toqueteos/webbrowser" - "tailscale.com/client/tailscale" - "tailscale.com/ipn" - "tailscale.com/ipn/ipnstate" -) - -var ( - localClient tailscale.LocalClient - chState chan ipn.State // tailscale state changes - - appIcon *os.File + "tailscale.com/client/systray" ) func main() { - systray.Run(onReady, onExit) -} - -// Menu represents the systray menu, its items, and the current Tailscale state. -type Menu struct { - mu sync.Mutex // protects the entire Menu - status *ipnstate.Status - - connect *systray.MenuItem - disconnect *systray.MenuItem - - self *systray.MenuItem - more *systray.MenuItem - quit *systray.MenuItem - - eventCancel func() // cancel eventLoop -} - -func onReady() { - log.Printf("starting") - ctx := context.Background() - - setAppIcon(disconnected) - - // dbus wants a file path for notification icons, so copy to a temp file. - appIcon, _ = os.CreateTemp("", "tailscale-systray.png") - io.Copy(appIcon, connected.render()) - - chState = make(chan ipn.State, 1) - - status, err := localClient.Status(ctx) - if err != nil { - log.Print(err) - } - - menu := new(Menu) - menu.rebuild(status) - - go watchIPNBus(ctx) -} - -// rebuild the systray menu based on the current Tailscale state. -// -// We currently rebuild the entire menu because it is not easy to update the existing menu. -// You cannot iterate over the items in a menu, nor can you remove some items like separators. -// So for now we rebuild the whole thing, and can optimize this later if needed. -func (menu *Menu) rebuild(status *ipnstate.Status) { - menu.mu.Lock() - defer menu.mu.Unlock() - - if menu.eventCancel != nil { - menu.eventCancel() - } - menu.status = status - systray.ResetMenu() - - menu.connect = systray.AddMenuItem("Connect", "") - menu.disconnect = systray.AddMenuItem("Disconnect", "") - menu.disconnect.Hide() - systray.AddSeparator() - - if status != nil && status.Self != nil { - title := fmt.Sprintf("This Device: %s (%s)", status.Self.HostName, status.Self.TailscaleIPs[0]) - menu.self = systray.AddMenuItem(title, "") - } - systray.AddSeparator() - - menu.more = systray.AddMenuItem("More settings", "") - menu.more.Enable() - - menu.quit = systray.AddMenuItem("Quit", "Quit the app") - menu.quit.Enable() - - ctx := context.Background() - ctx, menu.eventCancel = context.WithCancel(ctx) - go menu.eventLoop(ctx) -} - -// eventLoop is the main event loop for handling click events on menu items -// and responding to Tailscale state changes. -// This method does not return until ctx.Done is closed. -func (menu *Menu) eventLoop(ctx context.Context) { - for { - select { - case <-ctx.Done(): - return - case state := <-chState: - switch state { - case ipn.Running: - setAppIcon(loading) - status, err := localClient.Status(ctx) - if err != nil { - log.Printf("error getting tailscale status: %v", err) - } - menu.rebuild(status) - setAppIcon(connected) - menu.connect.SetTitle("Connected") - menu.connect.Disable() - menu.disconnect.Show() - menu.disconnect.Enable() - case ipn.NoState, ipn.Stopped: - menu.connect.SetTitle("Connect") - menu.connect.Enable() - menu.disconnect.Hide() - setAppIcon(disconnected) - case ipn.Starting: - setAppIcon(loading) - } - case <-menu.connect.ClickedCh: - _, err := localClient.EditPrefs(ctx, &ipn.MaskedPrefs{ - Prefs: ipn.Prefs{ - WantRunning: true, - }, - WantRunningSet: true, - }) - if err != nil { - log.Print(err) - continue - } - - case <-menu.disconnect.ClickedCh: - _, err := localClient.EditPrefs(ctx, &ipn.MaskedPrefs{ - Prefs: ipn.Prefs{ - WantRunning: false, - }, - WantRunningSet: true, - }) - if err != nil { - log.Printf("disconnecting: %v", err) - continue - } - - case <-menu.self.ClickedCh: - copyTailscaleIP(menu.status.Self) - - case <-menu.more.ClickedCh: - webbrowser.Open("http://100.100.100.100/") - - case <-menu.quit.ClickedCh: - systray.Quit() - } - } -} - -// watchIPNBus subscribes to the tailscale event bus and sends state updates to chState. -// This method does not return. -func watchIPNBus(ctx context.Context) { - for { - if err := watchIPNBusInner(ctx); err != nil { - log.Println(err) - if errors.Is(err, context.Canceled) { - // If the context got canceled, we will never be able to - // reconnect to IPN bus, so exit the process. - log.Fatalf("watchIPNBus: %v", err) - } - } - // If our watch connection breaks, wait a bit before reconnecting. No - // reason to spam the logs if e.g. tailscaled is restarting or goes - // down. - time.Sleep(3 * time.Second) - } -} - -func watchIPNBusInner(ctx context.Context) error { - watcher, err := localClient.WatchIPNBus(ctx, ipn.NotifyInitialState|ipn.NotifyNoPrivateKeys) - if err != nil { - return fmt.Errorf("watching ipn bus: %w", err) - } - defer watcher.Close() - for { - select { - case <-ctx.Done(): - return nil - default: - n, err := watcher.Next() - if err != nil { - return fmt.Errorf("ipnbus error: %w", err) - } - if n.State != nil { - chState <- *n.State - log.Printf("new state: %v", n.State) - } - } - } -} - -// copyTailscaleIP copies the first Tailscale IP of the given device to the clipboard -// and sends a notification with the copied value. -func copyTailscaleIP(device *ipnstate.PeerStatus) { - if device == nil || len(device.TailscaleIPs) == 0 { - return - } - name := strings.Split(device.DNSName, ".")[0] - ip := device.TailscaleIPs[0].String() - err := clipboard.WriteAll(ip) - if err != nil { - log.Printf("clipboard error: %v", err) - } - - sendNotification(fmt.Sprintf("Copied Address for %v", name), ip) -} - -// sendNotification sends a desktop notification with the given title and content. -func sendNotification(title, content string) { - conn, err := dbus.SessionBus() - if err != nil { - log.Printf("dbus: %v", err) - return - } - timeout := 3 * time.Second - obj := conn.Object("org.freedesktop.Notifications", "/org/freedesktop/Notifications") - call := obj.Call("org.freedesktop.Notifications.Notify", 0, "Tailscale", uint32(0), - appIcon.Name(), title, content, []string{}, map[string]dbus.Variant{}, int32(timeout.Milliseconds())) - if call.Err != nil { - log.Printf("dbus: %v", call.Err) - } -} - -func onExit() { - log.Printf("exiting") - os.Remove(appIcon.Name()) + new(systray.Menu).Run() } diff --git a/cmd/tailscale/cli/advertise.go b/cmd/tailscale/cli/advertise.go new file mode 100644 index 0000000000000..83d1a35aa8a14 --- /dev/null +++ b/cmd/tailscale/cli/advertise.go @@ -0,0 +1,76 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "context" + "flag" + "fmt" + "strings" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/envknob" + "tailscale.com/ipn" + "tailscale.com/tailcfg" +) + +var advertiseArgs struct { + services string // comma-separated list of services to advertise +} + +// TODO(naman): This flag may move to set.go or serve_v2.go after the WIPCode +// envknob is not needed. +func advertiseCmd() *ffcli.Command { + if !envknob.UseWIPCode() { + return nil + } + return &ffcli.Command{ + Name: "advertise", + ShortUsage: "tailscale advertise --services=", + ShortHelp: "Advertise this node as a destination for a service", + Exec: runAdvertise, + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("advertise") + fs.StringVar(&advertiseArgs.services, "services", "", "comma-separated services to advertise; each must start with \"svc:\" (e.g. \"svc:idp,svc:nas,svc:database\")") + return fs + })(), + } +} + +func runAdvertise(ctx context.Context, args []string) error { + if len(args) > 0 { + return flag.ErrHelp + } + + services, err := parseServiceNames(advertiseArgs.services) + if err != nil { + return err + } + + _, err = localClient.EditPrefs(ctx, &ipn.MaskedPrefs{ + AdvertiseServicesSet: true, + Prefs: ipn.Prefs{ + AdvertiseServices: services, + }, + }) + return err +} + +// parseServiceNames takes a comma-separated list of service names +// (eg. "svc:hello,svc:webserver,svc:catphotos"), splits them into +// a list and validates each service name. If valid, it returns +// the service names in a slice of strings. +func parseServiceNames(servicesArg string) ([]string, error) { + var services []string + if servicesArg != "" { + services = strings.Split(servicesArg, ",") + for _, svc := range services { + err := tailcfg.ServiceName(svc).Validate() + if err != nil { + return nil, fmt.Errorf("service %q: %s", svc, err) + } + } + } + return services, nil +} diff --git a/cmd/tailscale/cli/cli.go b/cmd/tailscale/cli/cli.go index 864cf6903a6d0..d7e8e5ca22dce 100644 --- a/cmd/tailscale/cli/cli.go +++ b/cmd/tailscale/cli/cli.go @@ -21,10 +21,12 @@ import ( "github.com/mattn/go-colorable" "github.com/mattn/go-isatty" "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/client/local" "tailscale.com/client/tailscale" "tailscale.com/cmd/tailscale/cli/ffcomplete" "tailscale.com/envknob" "tailscale.com/paths" + "tailscale.com/util/slicesx" "tailscale.com/version/distro" ) @@ -63,38 +65,50 @@ func newFlagSet(name string) *flag.FlagSet { func CleanUpArgs(args []string) []string { out := make([]string, 0, len(args)) for _, arg := range args { + switch { // Rewrite --authkey to --auth-key, and --authkey=x to --auth-key=x, // and the same for the -authkey variant. - switch { case arg == "--authkey", arg == "-authkey": arg = "--auth-key" case strings.HasPrefix(arg, "--authkey="), strings.HasPrefix(arg, "-authkey="): - arg = strings.TrimLeft(arg, "-") - arg = strings.TrimPrefix(arg, "authkey=") - arg = "--auth-key=" + arg + _, val, _ := strings.Cut(arg, "=") + arg = "--auth-key=" + val + + // And the same, for posture-checking => report-posture + case arg == "--posture-checking", arg == "-posture-checking": + arg = "--report-posture" + case strings.HasPrefix(arg, "--posture-checking="), strings.HasPrefix(arg, "-posture-checking="): + _, val, _ := strings.Cut(arg, "=") + arg = "--report-posture=" + val + } out = append(out, arg) } return out } -var localClient = tailscale.LocalClient{ +var localClient = local.Client{ Socket: paths.DefaultTailscaledSocket(), } // Run runs the CLI. The args do not include the binary name. func Run(args []string) (err error) { - if runtime.GOOS == "linux" && os.Getenv("GOKRAZY_FIRST_START") == "1" && distro.Get() == distro.Gokrazy && os.Getppid() == 1 { - // We're running on gokrazy and it's the first start. - // Don't run the tailscale CLI as a service; just exit. + if runtime.GOOS == "linux" && os.Getenv("GOKRAZY_FIRST_START") == "1" && distro.Get() == distro.Gokrazy && os.Getppid() == 1 && len(args) == 0 { + // We're running on gokrazy and the user did not specify 'up'. + // Don't run the tailscale CLI and spam logs with usage; just exit. // See https://gokrazy.org/development/process-interface/ os.Exit(0) } args = CleanUpArgs(args) - if len(args) == 1 && (args[0] == "-V" || args[0] == "--version") { - args = []string{"version"} + if len(args) == 1 { + switch args[0] { + case "-V", "--version": + args = []string{"version"} + case "help": + args = []string{"--help"} + } } var warnOnce sync.Once @@ -150,7 +164,7 @@ func Run(args []string) (err error) { err = rootCmd.Run(context.Background()) if tailscale.IsAccessDeniedError(err) && os.Getuid() != 0 && runtime.GOOS != "windows" { - return fmt.Errorf("%v\n\nUse 'sudo tailscale %s' or 'tailscale up --operator=$USER' to not require root.", err, strings.Join(args, " ")) + return fmt.Errorf("%v\n\nUse 'sudo tailscale %s'.\nTo not require root, use 'sudo tailscale set --operator=$USER' once.", err, strings.Join(args, " ")) } if errors.Is(err, flag.ErrHelp) { return nil @@ -158,6 +172,43 @@ func Run(args []string) (err error) { return err } +type onceFlagValue struct { + flag.Value + set bool +} + +func (v *onceFlagValue) Set(s string) error { + if v.set { + return fmt.Errorf("flag provided multiple times") + } + v.set = true + return v.Value.Set(s) +} + +func (v *onceFlagValue) IsBoolFlag() bool { + type boolFlag interface { + IsBoolFlag() bool + } + bf, ok := v.Value.(boolFlag) + return ok && bf.IsBoolFlag() +} + +// noDupFlagify modifies c recursively to make all the +// flag values be wrappers that permit setting the value +// at most once. +func noDupFlagify(c *ffcli.Command) { + if c.FlagSet != nil { + c.FlagSet.VisitAll(func(f *flag.Flag) { + f.Value = &onceFlagValue{Value: f.Value} + }) + } + for _, sub := range c.Subcommands { + noDupFlagify(sub) + } +} + +var fileCmd func() *ffcli.Command + func newRootCmd() *ffcli.Command { rootfs := newFlagSet("tailscale") rootfs.Func("socket", "path to tailscaled socket", func(s string) error { @@ -177,18 +228,20 @@ For help on subcommands, add --help after: "tailscale status --help". This CLI is still under active development. Commands and flags will change in the future. `), - Subcommands: []*ffcli.Command{ + Subcommands: nonNilCmds( upCmd, downCmd, setCmd, loginCmd, logoutCmd, switchCmd, - configureCmd, + configureCmd(), + syspolicyCmd, netcheckCmd, ipCmd, dnsCmd, statusCmd, + metricsCmd, pingCmd, ncCmd, sshCmd, @@ -196,7 +249,7 @@ change in the future. serveCmd(), versionCmd, webCmd, - fileCmd, + nilOrCall(fileCmd), bugReportCmd, certCmd, netlockCmd, @@ -204,10 +257,12 @@ change in the future. exitNodeCmd(), updateCmd, whoisCmd, - debugCmd, + debugCmd(), driveCmd, idTokenCmd, - }, + advertiseCmd(), + configureHostCmd(), + ), FlagSet: rootfs, Exec: func(ctx context.Context, args []string) error { if len(args) > 0 { @@ -217,10 +272,6 @@ change in the future. }, } - if runtime.GOOS == "linux" && distro.Get() == distro.Synology { - rootCmd.Subcommands = append(rootCmd.Subcommands, configureHostCmd) - } - walkCommands(rootCmd, func(w cmdWalk) bool { if w.UsageFunc == nil { w.UsageFunc = usageFunc @@ -229,9 +280,21 @@ change in the future. }) ffcomplete.Inject(rootCmd, func(c *ffcli.Command) { c.LongHelp = hidden + c.LongHelp }, usageFunc) + noDupFlagify(rootCmd) return rootCmd } +func nonNilCmds(cmds ...*ffcli.Command) []*ffcli.Command { + return slicesx.AppendNonzero(cmds[:0], cmds) +} + +func nilOrCall(f func() *ffcli.Command) *ffcli.Command { + if f == nil { + return nil + } + return f() +} + func fatalf(format string, a ...any) { if Fatalf != nil { Fatalf(format, a...) diff --git a/cmd/tailscale/cli/cli_test.go b/cmd/tailscale/cli/cli_test.go index d103c8f7e9f5c..9aa3693fd92c5 100644 --- a/cmd/tailscale/cli/cli_test.go +++ b/cmd/tailscale/cli/cli_test.go @@ -9,6 +9,7 @@ import ( "encoding/json" "flag" "fmt" + "io" "net/netip" "reflect" "strings" @@ -16,6 +17,7 @@ import ( qt "github.com/frankban/quicktest" "github.com/google/go-cmp/cmp" + "github.com/peterbourgon/ff/v3/ffcli" "tailscale.com/envknob" "tailscale.com/health/healthmsg" "tailscale.com/ipn" @@ -23,10 +25,12 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tka" "tailscale.com/tstest" + "tailscale.com/tstest/deptest" "tailscale.com/types/logger" "tailscale.com/types/opt" "tailscale.com/types/persist" "tailscale.com/types/preftype" + "tailscale.com/util/set" "tailscale.com/version/distro" ) @@ -600,6 +604,19 @@ func TestCheckForAccidentalSettingReverts(t *testing.T) { goos: "linux", want: "", }, + { + name: "losing_report_posture", + flags: []string{"--accept-dns"}, + curPrefs: &ipn.Prefs{ + ControlURL: ipn.DefaultControlURL, + WantRunning: false, + CorpDNS: true, + PostureChecking: true, + NetfilterMode: preftype.NetfilterOn, + NoStatefulFiltering: opt.NewBool(true), + }, + want: accidentalUpPrefix + " --accept-dns --report-posture", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -946,6 +963,10 @@ func TestPrefFlagMapping(t *testing.T) { // Handled by the tailscale share subcommand, we don't want a CLI // flag for this. continue + case "AdvertiseServices": + // Handled by the tailscale advertise subcommand, we don't want a + // CLI flag for this. + continue case "InternalExitNodePrior": // Used internally by LocalBackend as part of exit node usage toggling. // No CLI flag for this. @@ -1040,6 +1061,7 @@ func TestUpdatePrefs(t *testing.T) { NoSNATSet: true, NoStatefulFilteringSet: true, OperatorUserSet: true, + PostureCheckingSet: true, RouteAllSet: true, RunSSHSet: true, ShieldsUpSet: true, @@ -1372,23 +1394,28 @@ var cmpIP = cmp.Comparer(func(a, b netip.Addr) bool { }) func TestCleanUpArgs(t *testing.T) { + type S = []string c := qt.New(t) tests := []struct { in []string want []string }{ - {in: []string{"something"}, want: []string{"something"}}, - {in: []string{}, want: []string{}}, - {in: []string{"--authkey=0"}, want: []string{"--auth-key=0"}}, - {in: []string{"a", "--authkey=1", "b"}, want: []string{"a", "--auth-key=1", "b"}}, - {in: []string{"a", "--auth-key=2", "b"}, want: []string{"a", "--auth-key=2", "b"}}, - {in: []string{"a", "-authkey=3", "b"}, want: []string{"a", "--auth-key=3", "b"}}, - {in: []string{"a", "-auth-key=4", "b"}, want: []string{"a", "-auth-key=4", "b"}}, - {in: []string{"a", "--authkey", "5", "b"}, want: []string{"a", "--auth-key", "5", "b"}}, - {in: []string{"a", "-authkey", "6", "b"}, want: []string{"a", "--auth-key", "6", "b"}}, - {in: []string{"a", "authkey", "7", "b"}, want: []string{"a", "authkey", "7", "b"}}, - {in: []string{"--authkeyexpiry", "8"}, want: []string{"--authkeyexpiry", "8"}}, - {in: []string{"--auth-key-expiry", "9"}, want: []string{"--auth-key-expiry", "9"}}, + {in: S{"something"}, want: S{"something"}}, + {in: S{}, want: S{}}, + {in: S{"--authkey=0"}, want: S{"--auth-key=0"}}, + {in: S{"a", "--authkey=1", "b"}, want: S{"a", "--auth-key=1", "b"}}, + {in: S{"a", "--auth-key=2", "b"}, want: S{"a", "--auth-key=2", "b"}}, + {in: S{"a", "-authkey=3", "b"}, want: S{"a", "--auth-key=3", "b"}}, + {in: S{"a", "-auth-key=4", "b"}, want: S{"a", "-auth-key=4", "b"}}, + {in: S{"a", "--authkey", "5", "b"}, want: S{"a", "--auth-key", "5", "b"}}, + {in: S{"a", "-authkey", "6", "b"}, want: S{"a", "--auth-key", "6", "b"}}, + {in: S{"a", "authkey", "7", "b"}, want: S{"a", "authkey", "7", "b"}}, + {in: S{"--authkeyexpiry", "8"}, want: S{"--authkeyexpiry", "8"}}, + {in: S{"--auth-key-expiry", "9"}, want: S{"--auth-key-expiry", "9"}}, + + {in: S{"--posture-checking"}, want: S{"--report-posture"}}, + {in: S{"-posture-checking"}, want: S{"--report-posture"}}, + {in: S{"--posture-checking=nein"}, want: S{"--report-posture=nein"}}, } for _, tt := range tests { @@ -1476,3 +1503,148 @@ func TestParseNLArgs(t *testing.T) { }) } } + +// makeQuietContinueOnError modifies c recursively to make all the +// flagsets have error mode flag.ContinueOnError and not +// spew all over stderr. +func makeQuietContinueOnError(c *ffcli.Command) { + if c.FlagSet != nil { + c.FlagSet.Init(c.Name, flag.ContinueOnError) + c.FlagSet.Usage = func() {} + c.FlagSet.SetOutput(io.Discard) + } + c.UsageFunc = func(*ffcli.Command) string { return "" } + for _, sub := range c.Subcommands { + makeQuietContinueOnError(sub) + } +} + +// see tailscale/tailscale#6813 +func TestNoDups(t *testing.T) { + tests := []struct { + name string + args []string + want string + }{ + { + name: "dup-boolean", + args: []string{"up", "--json", "--json"}, + want: "error parsing commandline arguments: invalid boolean flag json: flag provided multiple times", + }, + { + name: "dup-string", + args: []string{"up", "--hostname=foo", "--hostname=bar"}, + want: "error parsing commandline arguments: invalid value \"bar\" for flag -hostname: flag provided multiple times", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd := newRootCmd() + makeQuietContinueOnError(cmd) + err := cmd.Parse(tt.args) + if got := fmt.Sprint(err); got != tt.want { + t.Errorf("got %q, want %q", got, tt.want) + } + }) + } +} + +func TestHelpAlias(t *testing.T) { + var stdout, stderr bytes.Buffer + tstest.Replace[io.Writer](t, &Stdout, &stdout) + tstest.Replace[io.Writer](t, &Stderr, &stderr) + + gotExit0 := false + defer func() { + if !gotExit0 { + t.Error("expected os.Exit(0) to be called") + return + } + if !strings.Contains(stderr.String(), "SUBCOMMANDS") { + t.Errorf("expected help output to contain SUBCOMMANDS; got stderr=%q; stdout=%q", stderr.String(), stdout.String()) + } + }() + defer func() { + if e := recover(); e != nil { + if strings.Contains(fmt.Sprint(e), "unexpected call to os.Exit(0)") { + gotExit0 = true + } else { + t.Errorf("unexpected panic: %v", e) + } + } + }() + err := Run([]string{"help"}) + if err != nil { + t.Fatalf("Run: %v", err) + } +} + +func TestDocs(t *testing.T) { + root := newRootCmd() + check := func(t *testing.T, c *ffcli.Command) { + shortVerb, _, ok := strings.Cut(c.ShortHelp, " ") + if !ok || shortVerb == "" { + t.Errorf("couldn't find verb+space in ShortHelp") + } else { + if strings.HasSuffix(shortVerb, ".") { + t.Errorf("ShortHelp shouldn't end in period; got %q", c.ShortHelp) + } + if b := shortVerb[0]; b >= 'a' && b <= 'z' { + t.Errorf("ShortHelp should start with upper-case letter; got %q", c.ShortHelp) + } + if strings.HasSuffix(shortVerb, "s") && shortVerb != "Does" { + t.Errorf("verb %q ending in 's' is unexpected, from %q", shortVerb, c.ShortHelp) + } + } + + name := t.Name() + wantPfx := strings.ReplaceAll(strings.TrimPrefix(name, "TestDocs/"), "/", " ") + switch name { + case "TestDocs/tailscale/completion/bash", + "TestDocs/tailscale/completion/zsh": + wantPfx = "" // special-case exceptions + } + if !strings.HasPrefix(c.ShortUsage, wantPfx) { + t.Errorf("ShortUsage should start with %q; got %q", wantPfx, c.ShortUsage) + } + } + + var walk func(t *testing.T, c *ffcli.Command) + walk = func(t *testing.T, c *ffcli.Command) { + t.Run(c.Name, func(t *testing.T) { + check(t, c) + for _, sub := range c.Subcommands { + walk(t, sub) + } + }) + } + walk(t, root) +} + +func TestDeps(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "arm64", + WantDeps: set.Of( + "tailscale.com/feature/capture/dissector", // want the Lua by default + ), + BadDeps: map[string]string{ + "tailscale.com/feature/capture": "don't link capture code", + "tailscale.com/net/packet": "why we passing packets in the CLI?", + "tailscale.com/net/flowtrack": "why we tracking flows in the CLI?", + }, + }.Check(t) +} + +func TestDepsNoCapture(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "arm64", + Tags: "ts_omit_capture", + BadDeps: map[string]string{ + "tailscale.com/feature/capture": "don't link capture code", + "tailscale.com/feature/capture/dissector": "don't like the Lua", + }, + }.Check(t) + +} diff --git a/cmd/tailscale/cli/configure-kube.go b/cmd/tailscale/cli/configure-kube.go index 6af15e3d9ae7b..6bc4e202efd4e 100644 --- a/cmd/tailscale/cli/configure-kube.go +++ b/cmd/tailscale/cli/configure-kube.go @@ -20,33 +20,31 @@ import ( "tailscale.com/version" ) -func init() { - configureCmd.Subcommands = append(configureCmd.Subcommands, configureKubeconfigCmd) -} - -var configureKubeconfigCmd = &ffcli.Command{ - Name: "kubeconfig", - ShortHelp: "[ALPHA] Connect to a Kubernetes cluster using a Tailscale Auth Proxy", - ShortUsage: "tailscale configure kubeconfig ", - LongHelp: strings.TrimSpace(` +func configureKubeconfigCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "kubeconfig", + ShortHelp: "[ALPHA] Connect to a Kubernetes cluster using a Tailscale Auth Proxy", + ShortUsage: "tailscale configure kubeconfig ", + LongHelp: strings.TrimSpace(` Run this command to configure kubectl to connect to a Kubernetes cluster over Tailscale. The hostname argument should be set to the Tailscale hostname of the peer running as an auth proxy in the cluster. See: https://tailscale.com/s/k8s-auth-proxy `), - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("kubeconfig") - return fs - })(), - Exec: runConfigureKubeconfig, + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("kubeconfig") + return fs + })(), + Exec: runConfigureKubeconfig, + } } // kubeconfigPath returns the path to the kubeconfig file for the current user. func kubeconfigPath() (string, error) { if kubeconfig := os.Getenv("KUBECONFIG"); kubeconfig != "" { if version.IsSandboxedMacOS() { - return "", errors.New("$KUBECONFIG is incompatible with the App Store version") + return "", errors.New("cannot read $KUBECONFIG on GUI builds of the macOS client: this requires the open-source tailscaled distribution") } var out string for _, out = range filepath.SplitList(kubeconfig) { diff --git a/cmd/tailscale/cli/configure-kube_omit.go b/cmd/tailscale/cli/configure-kube_omit.go new file mode 100644 index 0000000000000..130f2870fab44 --- /dev/null +++ b/cmd/tailscale/cli/configure-kube_omit.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ts_omit_kube + +package cli + +import "github.com/peterbourgon/ff/v3/ffcli" + +func configureKubeconfigCmd() *ffcli.Command { + // omitted from the build when the ts_omit_kube build tag is set + return nil +} diff --git a/cmd/tailscale/cli/configure-synology-cert.go b/cmd/tailscale/cli/configure-synology-cert.go index aabcb8dfad866..663d0c8790456 100644 --- a/cmd/tailscale/cli/configure-synology-cert.go +++ b/cmd/tailscale/cli/configure-synology-cert.go @@ -22,22 +22,27 @@ import ( "tailscale.com/version/distro" ) -var synologyConfigureCertCmd = &ffcli.Command{ - Name: "synology-cert", - Exec: runConfigureSynologyCert, - ShortHelp: "Configure Synology with a TLS certificate for your tailnet", - ShortUsage: "synology-cert [--domain ]", - LongHelp: strings.TrimSpace(` +func synologyConfigureCertCmd() *ffcli.Command { + if runtime.GOOS != "linux" || distro.Get() != distro.Synology { + return nil + } + return &ffcli.Command{ + Name: "synology-cert", + Exec: runConfigureSynologyCert, + ShortHelp: "Configure Synology with a TLS certificate for your tailnet", + ShortUsage: "synology-cert [--domain ]", + LongHelp: strings.TrimSpace(` This command is intended to run periodically as root on a Synology device to create or refresh the TLS certificate for the tailnet domain. See: https://tailscale.com/kb/1153/enabling-https `), - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("synology-cert") - fs.StringVar(&synologyConfigureCertArgs.domain, "domain", "", "Tailnet domain to create or refresh certificates for. Ignored if only one domain exists.") - return fs - })(), + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("synology-cert") + fs.StringVar(&synologyConfigureCertArgs.domain, "domain", "", "Tailnet domain to create or refresh certificates for. Ignored if only one domain exists.") + return fs + })(), + } } var synologyConfigureCertArgs struct { diff --git a/cmd/tailscale/cli/configure-synology.go b/cmd/tailscale/cli/configure-synology.go index 9d674e56dd79a..f0f05f75765b9 100644 --- a/cmd/tailscale/cli/configure-synology.go +++ b/cmd/tailscale/cli/configure-synology.go @@ -21,34 +21,49 @@ import ( // configureHostCmd is the "tailscale configure-host" command which was once // used to configure Synology devices, but is now a compatibility alias to // "tailscale configure synology". -var configureHostCmd = &ffcli.Command{ - Name: "configure-host", - Exec: runConfigureSynology, - ShortUsage: "tailscale configure-host\n" + synologyConfigureCmd.ShortUsage, - ShortHelp: synologyConfigureCmd.ShortHelp, - LongHelp: hidden + synologyConfigureCmd.LongHelp, - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("configure-host") - return fs - })(), +// +// It returns nil if the actual "tailscale configure synology" command is not +// available. +func configureHostCmd() *ffcli.Command { + synologyConfigureCmd := synologyConfigureCmd() + if synologyConfigureCmd == nil { + // No need to offer this compatibility alias if the actual command is not available. + return nil + } + return &ffcli.Command{ + Name: "configure-host", + Exec: runConfigureSynology, + ShortUsage: "tailscale configure-host\n" + synologyConfigureCmd.ShortUsage, + ShortHelp: synologyConfigureCmd.ShortHelp, + LongHelp: hidden + synologyConfigureCmd.LongHelp, + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("configure-host") + return fs + })(), + } } -var synologyConfigureCmd = &ffcli.Command{ - Name: "synology", - Exec: runConfigureSynology, - ShortUsage: "tailscale configure synology", - ShortHelp: "Configure Synology to enable outbound connections", - LongHelp: strings.TrimSpace(` +func synologyConfigureCmd() *ffcli.Command { + if runtime.GOOS != "linux" || distro.Get() != distro.Synology { + return nil + } + return &ffcli.Command{ + Name: "synology", + Exec: runConfigureSynology, + ShortUsage: "tailscale configure synology", + ShortHelp: "Configure Synology to enable outbound connections", + LongHelp: strings.TrimSpace(` This command is intended to run at boot as root on a Synology device to create the /dev/net/tun device and give the tailscaled binary permission to use it. See: https://tailscale.com/s/synology-outbound `), - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("synology") - return fs - })(), + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("synology") + return fs + })(), + } } func runConfigureSynology(ctx context.Context, args []string) error { diff --git a/cmd/tailscale/cli/configure.go b/cmd/tailscale/cli/configure.go index fd136d766360d..acb416755a586 100644 --- a/cmd/tailscale/cli/configure.go +++ b/cmd/tailscale/cli/configure.go @@ -5,32 +5,41 @@ package cli import ( "flag" - "runtime" "strings" "github.com/peterbourgon/ff/v3/ffcli" - "tailscale.com/version/distro" ) -var configureCmd = &ffcli.Command{ - Name: "configure", - ShortUsage: "tailscale configure ", - ShortHelp: "[ALPHA] Configure the host to enable more Tailscale features", - LongHelp: strings.TrimSpace(` +func configureCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "configure", + ShortUsage: "tailscale configure ", + ShortHelp: "Configure the host to enable more Tailscale features", + LongHelp: strings.TrimSpace(` The 'configure' set of commands are intended to provide a way to enable different services on the host to use Tailscale in more ways. `), - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("configure") - return fs - })(), - Subcommands: configureSubcommands(), + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("configure") + return fs + })(), + Subcommands: nonNilCmds( + configureKubeconfigCmd(), + synologyConfigureCmd(), + synologyConfigureCertCmd(), + ccall(maybeSysExtCmd), + ccall(maybeVPNConfigCmd), + ), + } } -func configureSubcommands() (out []*ffcli.Command) { - if runtime.GOOS == "linux" && distro.Get() == distro.Synology { - out = append(out, synologyConfigureCmd) - out = append(out, synologyConfigureCertCmd) +// ccall calls the function f if it is non-nil, and returns its result. +// +// It returns the zero value of the type T if f is nil. +func ccall[T any](f func() T) T { + var zero T + if f == nil { + return zero } - return out + return f() } diff --git a/cmd/tailscale/cli/configure_apple-all.go b/cmd/tailscale/cli/configure_apple-all.go new file mode 100644 index 0000000000000..5f0da9b95420e --- /dev/null +++ b/cmd/tailscale/cli/configure_apple-all.go @@ -0,0 +1,11 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import "github.com/peterbourgon/ff/v3/ffcli" + +var ( + maybeSysExtCmd func() *ffcli.Command // non-nil only on macOS, see configure_apple.go + maybeVPNConfigCmd func() *ffcli.Command // non-nil only on macOS, see configure_apple.go +) diff --git a/cmd/tailscale/cli/configure_apple.go b/cmd/tailscale/cli/configure_apple.go new file mode 100644 index 0000000000000..c0d99b90aa2c4 --- /dev/null +++ b/cmd/tailscale/cli/configure_apple.go @@ -0,0 +1,97 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin + +package cli + +import ( + "context" + "errors" + + "github.com/peterbourgon/ff/v3/ffcli" +) + +func init() { + maybeSysExtCmd = sysExtCmd + maybeVPNConfigCmd = vpnConfigCmd +} + +// Functions in this file provide a dummy Exec function that only prints an error message for users of the open-source +// tailscaled distribution. On GUI builds, the Swift code in the macOS client handles these commands by not passing the +// flow of execution to the CLI. + +// sysExtCmd returns a command for managing the Tailscale system extension on macOS +// (for the Standalone variant of the client only). +func sysExtCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "sysext", + ShortUsage: "tailscale configure sysext [activate|deactivate|status]", + ShortHelp: "Manage the system extension for macOS (Standalone variant)", + LongHelp: "The sysext set of commands provides a way to activate, deactivate, or manage the state of the Tailscale system extension on macOS. " + + "This is only relevant if you are running the Standalone variant of the Tailscale client for macOS. " + + "To access more detailed information about system extensions installed on this Mac, run 'systemextensionsctl list'.", + Subcommands: []*ffcli.Command{ + { + Name: "activate", + ShortUsage: "tailscale configure sysext activate", + ShortHelp: "Register the Tailscale system extension with macOS.", + LongHelp: "This command registers the Tailscale system extension with macOS. To run Tailscale, you'll also need to install the VPN configuration separately (run `tailscale configure vpn-config install`). After running this command, you need to approve the extension in System Settings > Login Items and Extensions > Network Extensions.", + Exec: requiresStandalone, + }, + { + Name: "deactivate", + ShortUsage: "tailscale configure sysext deactivate", + ShortHelp: "Deactivate the Tailscale system extension on macOS", + LongHelp: "This command deactivates the Tailscale system extension on macOS. To completely remove Tailscale, you'll also need to delete the VPN configuration separately (use `tailscale configure vpn-config uninstall`).", + Exec: requiresStandalone, + }, + { + Name: "status", + ShortUsage: "tailscale configure sysext status", + ShortHelp: "Print the enablement status of the Tailscale system extension", + LongHelp: "This command prints the enablement status of the Tailscale system extension. If the extension is not enabled, run `tailscale sysext activate` to enable it.", + Exec: requiresStandalone, + }, + }, + Exec: requiresStandalone, + } +} + +// vpnConfigCmd returns a command for managing the Tailscale VPN configuration on macOS +// (the entry that appears in System Settings > VPN). +func vpnConfigCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "mac-vpn", + ShortUsage: "tailscale configure mac-vpn [install|uninstall]", + ShortHelp: "Manage the VPN configuration on macOS (App Store and Standalone variants)", + LongHelp: "The vpn-config set of commands provides a way to add or remove the Tailscale VPN configuration from the macOS settings. This is the entry that appears in System Settings > VPN.", + Subcommands: []*ffcli.Command{ + { + Name: "install", + ShortUsage: "tailscale configure mac-vpn install", + ShortHelp: "Write the Tailscale VPN configuration to the macOS settings", + LongHelp: "This command writes the Tailscale VPN configuration to the macOS settings. This is the entry that appears in System Settings > VPN. If you are running the Standalone variant of the client, you'll also need to install the system extension separately (run `tailscale configure sysext activate`).", + Exec: requiresGUI, + }, + { + Name: "uninstall", + ShortUsage: "tailscale configure mac-vpn uninstall", + ShortHelp: "Delete the Tailscale VPN configuration from the macOS settings", + LongHelp: "This command removes the Tailscale VPN configuration from the macOS settings. This is the entry that appears in System Settings > VPN. If you are running the Standalone variant of the client, you'll also need to deactivate the system extension separately (run `tailscale configure sysext deactivate`).", + Exec: requiresGUI, + }, + }, + Exec: func(ctx context.Context, args []string) error { + return errors.New("unsupported command: requires a GUI build of the macOS client") + }, + } +} + +func requiresStandalone(ctx context.Context, args []string) error { + return errors.New("unsupported command: requires the Standalone (.pkg installer) GUI build of the client") +} + +func requiresGUI(ctx context.Context, args []string) error { + return errors.New("unsupported command: requires a GUI build of the macOS client") +} diff --git a/cmd/tailscale/cli/debug-capture.go b/cmd/tailscale/cli/debug-capture.go new file mode 100644 index 0000000000000..a54066fa614cb --- /dev/null +++ b/cmd/tailscale/cli/debug-capture.go @@ -0,0 +1,80 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !ts_omit_capture + +package cli + +import ( + "context" + "flag" + "fmt" + "io" + "os" + "os/exec" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/feature/capture/dissector" +) + +func init() { + debugCaptureCmd = mkDebugCaptureCmd +} + +func mkDebugCaptureCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "capture", + ShortUsage: "tailscale debug capture", + Exec: runCapture, + ShortHelp: "Stream pcaps for debugging", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("capture") + fs.StringVar(&captureArgs.outFile, "o", "", "path to stream the pcap (or - for stdout), leave empty to start wireshark") + return fs + })(), + } +} + +var captureArgs struct { + outFile string +} + +func runCapture(ctx context.Context, args []string) error { + stream, err := localClient.StreamDebugCapture(ctx) + if err != nil { + return err + } + defer stream.Close() + + switch captureArgs.outFile { + case "-": + fmt.Fprintln(Stderr, "Press Ctrl-C to stop the capture.") + _, err = io.Copy(os.Stdout, stream) + return err + case "": + lua, err := os.CreateTemp("", "ts-dissector") + if err != nil { + return err + } + defer os.Remove(lua.Name()) + io.WriteString(lua, dissector.Lua) + if err := lua.Close(); err != nil { + return err + } + + wireshark := exec.CommandContext(ctx, "wireshark", "-X", "lua_script:"+lua.Name(), "-k", "-i", "-") + wireshark.Stdin = stream + wireshark.Stdout = os.Stdout + wireshark.Stderr = os.Stderr + return wireshark.Run() + } + + f, err := os.OpenFile(captureArgs.outFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) + if err != nil { + return err + } + defer f.Close() + fmt.Fprintln(Stderr, "Press Ctrl-C to stop the capture.") + _, err = io.Copy(f, stream) + return err +} diff --git a/cmd/tailscale/cli/debug.go b/cmd/tailscale/cli/debug.go index fdde9ef096ae3..213a0166e2aa5 100644 --- a/cmd/tailscale/cli/debug.go +++ b/cmd/tailscale/cli/debug.go @@ -20,7 +20,6 @@ import ( "net/netip" "net/url" "os" - "os/exec" "runtime" "runtime/debug" "strconv" @@ -36,6 +35,7 @@ import ( "tailscale.com/hostinfo" "tailscale.com/internal/noiseconn" "tailscale.com/ipn" + "tailscale.com/net/netmon" "tailscale.com/net/tsaddr" "tailscale.com/net/tshttpproxy" "tailscale.com/paths" @@ -43,301 +43,315 @@ import ( "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" "tailscale.com/util/must" - "tailscale.com/wgengine/capture" ) -var debugCmd = &ffcli.Command{ - Name: "debug", - Exec: runDebug, - ShortUsage: "tailscale debug ", - ShortHelp: "Debug commands", - LongHelp: hidden + `"tailscale debug" contains misc debug facilities; it is not a stable interface.`, - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("debug") - fs.StringVar(&debugArgs.file, "file", "", "get, delete:NAME, or NAME") - fs.StringVar(&debugArgs.cpuFile, "cpu-profile", "", "if non-empty, grab a CPU profile for --profile-seconds seconds and write it to this file; - for stdout") - fs.StringVar(&debugArgs.memFile, "mem-profile", "", "if non-empty, grab a memory profile and write it to this file; - for stdout") - fs.IntVar(&debugArgs.cpuSec, "profile-seconds", 15, "number of seconds to run a CPU profile for, when --cpu-profile is non-empty") - return fs - })(), - Subcommands: []*ffcli.Command{ - { - Name: "derp-map", - ShortUsage: "tailscale debug derp-map", - Exec: runDERPMap, - ShortHelp: "Print DERP map", - }, - { - Name: "component-logs", - ShortUsage: "tailscale debug component-logs [" + strings.Join(ipn.DebuggableComponents, "|") + "]", - Exec: runDebugComponentLogs, - ShortHelp: "Enable/disable debug logs for a component", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("component-logs") - fs.DurationVar(&debugComponentLogsArgs.forDur, "for", time.Hour, "how long to enable debug logs for; zero or negative means to disable") - return fs - })(), - }, - { - Name: "daemon-goroutines", - ShortUsage: "tailscale debug daemon-goroutines", - Exec: runDaemonGoroutines, - ShortHelp: "Print tailscaled's goroutines", - }, - { - Name: "daemon-logs", - ShortUsage: "tailscale debug daemon-logs", - Exec: runDaemonLogs, - ShortHelp: "Watch tailscaled's server logs", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("daemon-logs") - fs.IntVar(&daemonLogsArgs.verbose, "verbose", 0, "verbosity level") - fs.BoolVar(&daemonLogsArgs.time, "time", false, "include client time") - return fs - })(), - }, - { - Name: "metrics", - ShortUsage: "tailscale debug metrics", - Exec: runDaemonMetrics, - ShortHelp: "Print tailscaled's metrics", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("metrics") - fs.BoolVar(&metricsArgs.watch, "watch", false, "print JSON dump of delta values") - return fs - })(), - }, - { - Name: "env", - ShortUsage: "tailscale debug env", - Exec: runEnv, - ShortHelp: "Print cmd/tailscale environment", - }, - { - Name: "stat", - ShortUsage: "tailscale debug stat ", - Exec: runStat, - ShortHelp: "Stat a file", - }, - { - Name: "hostinfo", - ShortUsage: "tailscale debug hostinfo", - Exec: runHostinfo, - ShortHelp: "Print hostinfo", - }, - { - Name: "local-creds", - ShortUsage: "tailscale debug local-creds", - Exec: runLocalCreds, - ShortHelp: "Print how to access Tailscale LocalAPI", - }, - { - Name: "restun", - ShortUsage: "tailscale debug restun", - Exec: localAPIAction("restun"), - ShortHelp: "Force a magicsock restun", - }, - { - Name: "rebind", - ShortUsage: "tailscale debug rebind", - Exec: localAPIAction("rebind"), - ShortHelp: "Force a magicsock rebind", - }, - { - Name: "derp-set-on-demand", - ShortUsage: "tailscale debug derp-set-on-demand", - Exec: localAPIAction("derp-set-homeless"), - ShortHelp: "Enable DERP on-demand mode (breaks reachability)", - }, - { - Name: "derp-unset-on-demand", - ShortUsage: "tailscale debug derp-unset-on-demand", - Exec: localAPIAction("derp-unset-homeless"), - ShortHelp: "Disable DERP on-demand mode", - }, - { - Name: "break-tcp-conns", - ShortUsage: "tailscale debug break-tcp-conns", - Exec: localAPIAction("break-tcp-conns"), - ShortHelp: "Break any open TCP connections from the daemon", - }, - { - Name: "break-derp-conns", - ShortUsage: "tailscale debug break-derp-conns", - Exec: localAPIAction("break-derp-conns"), - ShortHelp: "Break any open DERP connections from the daemon", - }, - { - Name: "pick-new-derp", - ShortUsage: "tailscale debug pick-new-derp", - Exec: localAPIAction("pick-new-derp"), - ShortHelp: "Switch to some other random DERP home region for a short time", - }, - { - Name: "force-netmap-update", - ShortUsage: "tailscale debug force-netmap-update", - Exec: localAPIAction("force-netmap-update"), - ShortHelp: "Force a full no-op netmap update (for load testing)", - }, - { - // TODO(bradfitz,maisem): eventually promote this out of debug - Name: "reload-config", - ShortUsage: "tailscale debug reload-config", - Exec: reloadConfig, - ShortHelp: "Reload config", - }, - { - Name: "control-knobs", - ShortUsage: "tailscale debug control-knobs", - Exec: debugControlKnobs, - ShortHelp: "See current control knobs", - }, - { - Name: "prefs", - ShortUsage: "tailscale debug prefs", - Exec: runPrefs, - ShortHelp: "Print prefs", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("prefs") - fs.BoolVar(&prefsArgs.pretty, "pretty", false, "If true, pretty-print output") - return fs - })(), - }, - { - Name: "watch-ipn", - ShortUsage: "tailscale debug watch-ipn", - Exec: runWatchIPN, - ShortHelp: "Subscribe to IPN message bus", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("watch-ipn") - fs.BoolVar(&watchIPNArgs.netmap, "netmap", true, "include netmap in messages") - fs.BoolVar(&watchIPNArgs.initial, "initial", false, "include initial status") - fs.BoolVar(&watchIPNArgs.showPrivateKey, "show-private-key", false, "include node private key in printed netmap") - fs.IntVar(&watchIPNArgs.count, "count", 0, "exit after printing this many statuses, or 0 to keep going forever") - return fs - })(), - }, - { - Name: "netmap", - ShortUsage: "tailscale debug netmap", - Exec: runNetmap, - ShortHelp: "Print the current network map", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("netmap") - fs.BoolVar(&netmapArgs.showPrivateKey, "show-private-key", false, "include node private key in printed netmap") - return fs - })(), - }, - { - Name: "via", - ShortUsage: "tailscale debug via \n" + - "tailscale debug via ", - Exec: runVia, - ShortHelp: "Convert between site-specific IPv4 CIDRs and IPv6 'via' routes", - }, - { - Name: "ts2021", - ShortUsage: "tailscale debug ts2021", - Exec: runTS2021, - ShortHelp: "Debug ts2021 protocol connectivity", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("ts2021") - fs.StringVar(&ts2021Args.host, "host", "controlplane.tailscale.com", "hostname of control plane") - fs.IntVar(&ts2021Args.version, "version", int(tailcfg.CurrentCapabilityVersion), "protocol version") - fs.BoolVar(&ts2021Args.verbose, "verbose", false, "be extra verbose") - return fs - })(), - }, - { - Name: "set-expire", - ShortUsage: "tailscale debug set-expire --in=1m", - Exec: runSetExpire, - ShortHelp: "Manipulate node key expiry for testing", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("set-expire") - fs.DurationVar(&setExpireArgs.in, "in", 0, "if non-zero, set node key to expire this duration from now") - return fs - })(), - }, - { - Name: "dev-store-set", - ShortUsage: "tailscale debug dev-store-set", - Exec: runDevStoreSet, - ShortHelp: "Set a key/value pair during development", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("store-set") - fs.BoolVar(&devStoreSetArgs.danger, "danger", false, "accept danger") - return fs - })(), - }, - { - Name: "derp", - ShortUsage: "tailscale debug derp", - Exec: runDebugDERP, - ShortHelp: "Test a DERP configuration", - }, - { - Name: "capture", - ShortUsage: "tailscale debug capture", - Exec: runCapture, - ShortHelp: "Streams pcaps for debugging", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("capture") - fs.StringVar(&captureArgs.outFile, "o", "", "path to stream the pcap (or - for stdout), leave empty to start wireshark") - return fs - })(), - }, - { - Name: "portmap", - ShortUsage: "tailscale debug portmap", - Exec: debugPortmap, - ShortHelp: "Run portmap debugging", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("portmap") - fs.DurationVar(&debugPortmapArgs.duration, "duration", 5*time.Second, "timeout for port mapping") - fs.StringVar(&debugPortmapArgs.ty, "type", "", `portmap debug type (one of "", "pmp", "pcp", or "upnp")`) - fs.StringVar(&debugPortmapArgs.gatewayAddr, "gateway-addr", "", `override gateway IP (must also pass --self-addr)`) - fs.StringVar(&debugPortmapArgs.selfAddr, "self-addr", "", `override self IP (must also pass --gateway-addr)`) - fs.BoolVar(&debugPortmapArgs.logHTTP, "log-http", false, `print all HTTP requests and responses to the log`) - return fs - })(), - }, - { - Name: "peer-endpoint-changes", - ShortUsage: "tailscale debug peer-endpoint-changes ", - Exec: runPeerEndpointChanges, - ShortHelp: "Prints debug information about a peer's endpoint changes", - }, - { - Name: "dial-types", - ShortUsage: "tailscale debug dial-types ", - Exec: runDebugDialTypes, - ShortHelp: "Prints debug information about connecting to a given host or IP", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("dial-types") - fs.StringVar(&debugDialTypesArgs.network, "network", "tcp", `network type to dial ("tcp", "udp", etc.)`) - return fs - })(), - }, - { - Name: "resolve", - ShortUsage: "tailscale debug resolve ", - Exec: runDebugResolve, - ShortHelp: "Does a DNS lookup", - FlagSet: (func() *flag.FlagSet { - fs := newFlagSet("resolve") - fs.StringVar(&resolveArgs.net, "net", "ip", "network type to resolve (ip, ip4, ip6)") - return fs - })(), - }, - { - Name: "go-buildinfo", - ShortUsage: "tailscale debug go-buildinfo", - ShortHelp: "Prints Go's runtime/debug.BuildInfo", - Exec: runGoBuildInfo, - }, - }, +var ( + debugCaptureCmd func() *ffcli.Command // or nil +) + +func debugCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "debug", + Exec: runDebug, + ShortUsage: "tailscale debug ", + ShortHelp: "Debug commands", + LongHelp: hidden + `"tailscale debug" contains misc debug facilities; it is not a stable interface.`, + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("debug") + fs.StringVar(&debugArgs.file, "file", "", "get, delete:NAME, or NAME") + fs.StringVar(&debugArgs.cpuFile, "cpu-profile", "", "if non-empty, grab a CPU profile for --profile-seconds seconds and write it to this file; - for stdout") + fs.StringVar(&debugArgs.memFile, "mem-profile", "", "if non-empty, grab a memory profile and write it to this file; - for stdout") + fs.IntVar(&debugArgs.cpuSec, "profile-seconds", 15, "number of seconds to run a CPU profile for, when --cpu-profile is non-empty") + return fs + })(), + Subcommands: nonNilCmds([]*ffcli.Command{ + { + Name: "derp-map", + ShortUsage: "tailscale debug derp-map", + Exec: runDERPMap, + ShortHelp: "Print DERP map", + }, + { + Name: "component-logs", + ShortUsage: "tailscale debug component-logs [" + strings.Join(ipn.DebuggableComponents, "|") + "]", + Exec: runDebugComponentLogs, + ShortHelp: "Enable/disable debug logs for a component", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("component-logs") + fs.DurationVar(&debugComponentLogsArgs.forDur, "for", time.Hour, "how long to enable debug logs for; zero or negative means to disable") + return fs + })(), + }, + { + Name: "daemon-goroutines", + ShortUsage: "tailscale debug daemon-goroutines", + Exec: runDaemonGoroutines, + ShortHelp: "Print tailscaled's goroutines", + }, + { + Name: "daemon-logs", + ShortUsage: "tailscale debug daemon-logs", + Exec: runDaemonLogs, + ShortHelp: "Watch tailscaled's server logs", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("daemon-logs") + fs.IntVar(&daemonLogsArgs.verbose, "verbose", 0, "verbosity level") + fs.BoolVar(&daemonLogsArgs.time, "time", false, "include client time") + return fs + })(), + }, + { + Name: "metrics", + ShortUsage: "tailscale debug metrics", + Exec: runDaemonMetrics, + ShortHelp: "Print tailscaled's metrics", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("metrics") + fs.BoolVar(&metricsArgs.watch, "watch", false, "print JSON dump of delta values") + return fs + })(), + }, + { + Name: "env", + ShortUsage: "tailscale debug env", + Exec: runEnv, + ShortHelp: "Print cmd/tailscale environment", + }, + { + Name: "stat", + ShortUsage: "tailscale debug stat ", + Exec: runStat, + ShortHelp: "Stat a file", + }, + { + Name: "hostinfo", + ShortUsage: "tailscale debug hostinfo", + Exec: runHostinfo, + ShortHelp: "Print hostinfo", + }, + { + Name: "local-creds", + ShortUsage: "tailscale debug local-creds", + Exec: runLocalCreds, + ShortHelp: "Print how to access Tailscale LocalAPI", + }, + { + Name: "localapi", + ShortUsage: "tailscale debug localapi [] []", + Exec: runLocalAPI, + ShortHelp: "Call a LocalAPI method directly", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("localapi") + fs.BoolVar(&localAPIFlags.verbose, "v", false, "verbose; dump HTTP headers") + return fs + })(), + }, + { + Name: "restun", + ShortUsage: "tailscale debug restun", + Exec: localAPIAction("restun"), + ShortHelp: "Force a magicsock restun", + }, + { + Name: "rebind", + ShortUsage: "tailscale debug rebind", + Exec: localAPIAction("rebind"), + ShortHelp: "Force a magicsock rebind", + }, + { + Name: "derp-set-on-demand", + ShortUsage: "tailscale debug derp-set-on-demand", + Exec: localAPIAction("derp-set-homeless"), + ShortHelp: "Enable DERP on-demand mode (breaks reachability)", + }, + { + Name: "derp-unset-on-demand", + ShortUsage: "tailscale debug derp-unset-on-demand", + Exec: localAPIAction("derp-unset-homeless"), + ShortHelp: "Disable DERP on-demand mode", + }, + { + Name: "break-tcp-conns", + ShortUsage: "tailscale debug break-tcp-conns", + Exec: localAPIAction("break-tcp-conns"), + ShortHelp: "Break any open TCP connections from the daemon", + }, + { + Name: "break-derp-conns", + ShortUsage: "tailscale debug break-derp-conns", + Exec: localAPIAction("break-derp-conns"), + ShortHelp: "Break any open DERP connections from the daemon", + }, + { + Name: "pick-new-derp", + ShortUsage: "tailscale debug pick-new-derp", + Exec: localAPIAction("pick-new-derp"), + ShortHelp: "Switch to some other random DERP home region for a short time", + }, + { + Name: "force-prefer-derp", + ShortUsage: "tailscale debug force-prefer-derp", + Exec: forcePreferDERP, + ShortHelp: "Prefer the given region ID if reachable (until restart, or 0 to clear)", + }, + { + Name: "force-netmap-update", + ShortUsage: "tailscale debug force-netmap-update", + Exec: localAPIAction("force-netmap-update"), + ShortHelp: "Force a full no-op netmap update (for load testing)", + }, + { + // TODO(bradfitz,maisem): eventually promote this out of debug + Name: "reload-config", + ShortUsage: "tailscale debug reload-config", + Exec: reloadConfig, + ShortHelp: "Reload config", + }, + { + Name: "control-knobs", + ShortUsage: "tailscale debug control-knobs", + Exec: debugControlKnobs, + ShortHelp: "See current control knobs", + }, + { + Name: "prefs", + ShortUsage: "tailscale debug prefs", + Exec: runPrefs, + ShortHelp: "Print prefs", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("prefs") + fs.BoolVar(&prefsArgs.pretty, "pretty", false, "If true, pretty-print output") + return fs + })(), + }, + { + Name: "watch-ipn", + ShortUsage: "tailscale debug watch-ipn", + Exec: runWatchIPN, + ShortHelp: "Subscribe to IPN message bus", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("watch-ipn") + fs.BoolVar(&watchIPNArgs.netmap, "netmap", true, "include netmap in messages") + fs.BoolVar(&watchIPNArgs.initial, "initial", false, "include initial status") + fs.BoolVar(&watchIPNArgs.rateLimit, "rate-limit", true, "rate limit messags") + fs.BoolVar(&watchIPNArgs.showPrivateKey, "show-private-key", false, "include node private key in printed netmap") + fs.IntVar(&watchIPNArgs.count, "count", 0, "exit after printing this many statuses, or 0 to keep going forever") + return fs + })(), + }, + { + Name: "netmap", + ShortUsage: "tailscale debug netmap", + Exec: runNetmap, + ShortHelp: "Print the current network map", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("netmap") + fs.BoolVar(&netmapArgs.showPrivateKey, "show-private-key", false, "include node private key in printed netmap") + return fs + })(), + }, + { + Name: "via", + ShortUsage: "tailscale debug via \n" + + "tailscale debug via ", + Exec: runVia, + ShortHelp: "Convert between site-specific IPv4 CIDRs and IPv6 'via' routes", + }, + { + Name: "ts2021", + ShortUsage: "tailscale debug ts2021", + Exec: runTS2021, + ShortHelp: "Debug ts2021 protocol connectivity", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("ts2021") + fs.StringVar(&ts2021Args.host, "host", "controlplane.tailscale.com", "hostname of control plane") + fs.IntVar(&ts2021Args.version, "version", int(tailcfg.CurrentCapabilityVersion), "protocol version") + fs.BoolVar(&ts2021Args.verbose, "verbose", false, "be extra verbose") + return fs + })(), + }, + { + Name: "set-expire", + ShortUsage: "tailscale debug set-expire --in=1m", + Exec: runSetExpire, + ShortHelp: "Manipulate node key expiry for testing", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("set-expire") + fs.DurationVar(&setExpireArgs.in, "in", 0, "if non-zero, set node key to expire this duration from now") + return fs + })(), + }, + { + Name: "dev-store-set", + ShortUsage: "tailscale debug dev-store-set", + Exec: runDevStoreSet, + ShortHelp: "Set a key/value pair during development", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("store-set") + fs.BoolVar(&devStoreSetArgs.danger, "danger", false, "accept danger") + return fs + })(), + }, + { + Name: "derp", + ShortUsage: "tailscale debug derp", + Exec: runDebugDERP, + ShortHelp: "Test a DERP configuration", + }, + ccall(debugCaptureCmd), + { + Name: "portmap", + ShortUsage: "tailscale debug portmap", + Exec: debugPortmap, + ShortHelp: "Run portmap debugging", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("portmap") + fs.DurationVar(&debugPortmapArgs.duration, "duration", 5*time.Second, "timeout for port mapping") + fs.StringVar(&debugPortmapArgs.ty, "type", "", `portmap debug type (one of "", "pmp", "pcp", or "upnp")`) + fs.StringVar(&debugPortmapArgs.gatewayAddr, "gateway-addr", "", `override gateway IP (must also pass --self-addr)`) + fs.StringVar(&debugPortmapArgs.selfAddr, "self-addr", "", `override self IP (must also pass --gateway-addr)`) + fs.BoolVar(&debugPortmapArgs.logHTTP, "log-http", false, `print all HTTP requests and responses to the log`) + return fs + })(), + }, + { + Name: "peer-endpoint-changes", + ShortUsage: "tailscale debug peer-endpoint-changes ", + Exec: runPeerEndpointChanges, + ShortHelp: "Print debug information about a peer's endpoint changes", + }, + { + Name: "dial-types", + ShortUsage: "tailscale debug dial-types ", + Exec: runDebugDialTypes, + ShortHelp: "Print debug information about connecting to a given host or IP", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("dial-types") + fs.StringVar(&debugDialTypesArgs.network, "network", "tcp", `network type to dial ("tcp", "udp", etc.)`) + return fs + })(), + }, + { + Name: "resolve", + ShortUsage: "tailscale debug resolve ", + Exec: runDebugResolve, + ShortHelp: "Does a DNS lookup", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("resolve") + fs.StringVar(&resolveArgs.net, "net", "ip", "network type to resolve (ip, ip4, ip6)") + return fs + })(), + }, + { + Name: "go-buildinfo", + ShortUsage: "tailscale debug go-buildinfo", + ShortHelp: "Print Go's runtime/debug.BuildInfo", + Exec: runGoBuildInfo, + }, + }...), + } } func runGoBuildInfo(ctx context.Context, args []string) error { @@ -449,6 +463,81 @@ func runLocalCreds(ctx context.Context, args []string) error { return nil } +func looksLikeHTTPMethod(s string) bool { + if len(s) > len("OPTIONS") { + return false + } + for _, r := range s { + if r < 'A' || r > 'Z' { + return false + } + } + return true +} + +var localAPIFlags struct { + verbose bool +} + +func runLocalAPI(ctx context.Context, args []string) error { + if len(args) == 0 { + return errors.New("expected at least one argument") + } + method := "GET" + if looksLikeHTTPMethod(args[0]) { + method = args[0] + args = args[1:] + if len(args) == 0 { + return errors.New("expected at least one argument after method") + } + } + path := args[0] + if !strings.HasPrefix(path, "/localapi/") { + if !strings.Contains(path, "/") { + path = "/localapi/v0/" + path + } else { + path = "/localapi/" + path + } + } + + var body io.Reader + if len(args) > 1 { + if args[1] == "-" { + fmt.Fprintf(Stderr, "# reading request body from stdin...\n") + all, err := io.ReadAll(os.Stdin) + if err != nil { + return fmt.Errorf("reading Stdin: %q", err) + } + body = bytes.NewReader(all) + } else { + body = strings.NewReader(args[1]) + } + } + req, err := http.NewRequest(method, "http://local-tailscaled.sock"+path, body) + if err != nil { + return err + } + fmt.Fprintf(Stderr, "# doing request %s %s\n", method, path) + + res, err := localClient.DoLocalRequest(req) + if err != nil { + return err + } + is2xx := res.StatusCode >= 200 && res.StatusCode <= 299 + if localAPIFlags.verbose { + res.Write(Stdout) + } else { + if !is2xx { + fmt.Fprintf(Stderr, "# Response status %s\n", res.Status) + } + io.Copy(Stdout, res.Body) + } + if is2xx { + return nil + } + return errors.New(res.Status) +} + type localClientRoundTripper struct{} func (localClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { @@ -500,6 +589,7 @@ var watchIPNArgs struct { netmap bool initial bool showPrivateKey bool + rateLimit bool count int } @@ -511,6 +601,9 @@ func runWatchIPN(ctx context.Context, args []string) error { if !watchIPNArgs.showPrivateKey { mask |= ipn.NotifyNoPrivateKeys } + if watchIPNArgs.rateLimit { + mask |= ipn.NotifyRateLimit + } watcher, err := localClient.WatchIPNBus(ctx, mask) if err != nil { return err @@ -571,6 +664,25 @@ func runDERPMap(ctx context.Context, args []string) error { return nil } +func forcePreferDERP(ctx context.Context, args []string) error { + var n int + if len(args) != 1 { + return errors.New("expected exactly one integer argument") + } + n, err := strconv.Atoi(args[0]) + if err != nil { + return fmt.Errorf("expected exactly one integer argument: %w", err) + } + b, err := json.Marshal(n) + if err != nil { + return fmt.Errorf("failed to marshal DERP region: %w", err) + } + if err := localClient.DebugActionBody(ctx, "force-prefer-derp", bytes.NewReader(b)); err != nil { + return fmt.Errorf("failed to force preferred DERP: %w", err) + } + return nil +} + func localAPIAction(action string) func(context.Context, []string) error { return func(ctx context.Context, args []string) error { if len(args) > 0 { @@ -845,6 +957,14 @@ func runTS2021(ctx context.Context, args []string) error { logf = log.Printf } + bus := eventbus.New() + defer bus.Close() + + netMon, err := netmon.New(bus, logger.WithPrefix(logf, "netmon: ")) + if err != nil { + return fmt.Errorf("creating netmon: %w", err) + } + noiseDialer := &controlhttp.Dialer{ Hostname: ts2021Args.host, HTTPPort: "80", @@ -854,6 +974,7 @@ func runTS2021(ctx context.Context, args []string) error { ProtocolVersion: uint16(ts2021Args.version), Dialer: dialFunc, Logf: logf, + NetMon: netMon, } const tries = 2 for i := range tries { @@ -999,50 +1120,6 @@ func runSetExpire(ctx context.Context, args []string) error { return localClient.DebugSetExpireIn(ctx, setExpireArgs.in) } -var captureArgs struct { - outFile string -} - -func runCapture(ctx context.Context, args []string) error { - stream, err := localClient.StreamDebugCapture(ctx) - if err != nil { - return err - } - defer stream.Close() - - switch captureArgs.outFile { - case "-": - fmt.Fprintln(Stderr, "Press Ctrl-C to stop the capture.") - _, err = io.Copy(os.Stdout, stream) - return err - case "": - lua, err := os.CreateTemp("", "ts-dissector") - if err != nil { - return err - } - defer os.Remove(lua.Name()) - lua.Write([]byte(capture.DissectorLua)) - if err := lua.Close(); err != nil { - return err - } - - wireshark := exec.CommandContext(ctx, "wireshark", "-X", "lua_script:"+lua.Name(), "-k", "-i", "-") - wireshark.Stdin = stream - wireshark.Stdout = os.Stdout - wireshark.Stderr = os.Stderr - return wireshark.Run() - } - - f, err := os.OpenFile(captureArgs.outFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644) - if err != nil { - return err - } - defer f.Close() - fmt.Fprintln(Stderr, "Press Ctrl-C to stop the capture.") - _, err = io.Copy(f, stream) - return err -} - var debugPortmapArgs struct { duration time.Duration gatewayAddr string diff --git a/cmd/tailscale/cli/dns.go b/cmd/tailscale/cli/dns.go index 042ce1a94161a..402f0cedf0a1e 100644 --- a/cmd/tailscale/cli/dns.go +++ b/cmd/tailscale/cli/dns.go @@ -20,7 +20,7 @@ var dnsCmd = &ffcli.Command{ Name: "status", ShortUsage: "tailscale dns status [--all]", Exec: runDNSStatus, - ShortHelp: "Prints the current DNS status and configuration", + ShortHelp: "Print the current DNS status and configuration", LongHelp: dnsStatusLongHelp(), FlagSet: (func() *flag.FlagSet { fs := newFlagSet("status") diff --git a/cmd/tailscale/cli/down.go b/cmd/tailscale/cli/down.go index 1eb85a13e6c78..224198a98deb5 100644 --- a/cmd/tailscale/cli/down.go +++ b/cmd/tailscale/cli/down.go @@ -9,6 +9,7 @@ import ( "fmt" "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/client/tailscale/apitype" "tailscale.com/ipn" ) @@ -23,10 +24,12 @@ var downCmd = &ffcli.Command{ var downArgs struct { acceptedRisks string + reason string } func newDownFlagSet() *flag.FlagSet { downf := newFlagSet("down") + downf.StringVar(&downArgs.reason, "reason", "", "reason for the disconnect, if required by a policy") registerAcceptRiskFlag(downf, &downArgs.acceptedRisks) return downf } @@ -50,6 +53,7 @@ func runDown(ctx context.Context, args []string) error { fmt.Fprintf(Stderr, "Tailscale was already stopped.\n") return nil } + ctx = apitype.RequestReasonKey.WithValue(ctx, downArgs.reason) _, err = localClient.EditPrefs(ctx, &ipn.MaskedPrefs{ Prefs: ipn.Prefs{ WantRunning: false, diff --git a/cmd/tailscale/cli/exitnode.go b/cmd/tailscale/cli/exitnode.go index 6b9247a7bc303..ad7a8ccee5b42 100644 --- a/cmd/tailscale/cli/exitnode.go +++ b/cmd/tailscale/cli/exitnode.go @@ -15,10 +15,10 @@ import ( "github.com/kballard/go-shellquote" "github.com/peterbourgon/ff/v3/ffcli" - xmaps "golang.org/x/exp/maps" "tailscale.com/envknob" "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" + "tailscale.com/util/slicesx" ) func exitNodeCmd() *ffcli.Command { @@ -41,7 +41,7 @@ func exitNodeCmd() *ffcli.Command { { Name: "suggest", ShortUsage: "tailscale exit-node suggest", - ShortHelp: "Suggests the best available exit node", + ShortHelp: "Suggest the best available exit node", Exec: runExitNodeSuggest, }}, (func() []*ffcli.Command { @@ -255,7 +255,7 @@ func filterFormatAndSortExitNodes(peers []*ipnstate.PeerStatus, filterBy string) } filteredExitNodes := filteredExitNodes{ - Countries: xmaps.Values(countries), + Countries: slicesx.MapValues(countries), } for _, country := range filteredExitNodes.Countries { diff --git a/cmd/tailscale/cli/file.go b/cmd/tailscale/cli/file.go index cd776244679c8..6f3aa40b5a806 100644 --- a/cmd/tailscale/cli/file.go +++ b/cmd/tailscale/cli/file.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !ts_omit_taildrop + package cli import ( @@ -28,6 +30,7 @@ import ( "tailscale.com/client/tailscale/apitype" "tailscale.com/cmd/tailscale/cli/ffcomplete" "tailscale.com/envknob" + "tailscale.com/ipn/ipnstate" "tailscale.com/net/tsaddr" "tailscale.com/syncs" "tailscale.com/tailcfg" @@ -37,14 +40,20 @@ import ( "tailscale.com/version" ) -var fileCmd = &ffcli.Command{ - Name: "file", - ShortUsage: "tailscale file ...", - ShortHelp: "Send or receive files", - Subcommands: []*ffcli.Command{ - fileCpCmd, - fileGetCmd, - }, +func init() { + fileCmd = getFileCmd +} + +func getFileCmd() *ffcli.Command { + return &ffcli.Command{ + Name: "file", + ShortUsage: "tailscale file ...", + ShortHelp: "Send or receive files", + Subcommands: []*ffcli.Command{ + fileCpCmd, + fileGetCmd, + }, + } } type countingReader struct { @@ -268,46 +277,77 @@ func getTargetStableID(ctx context.Context, ipStr string) (id tailcfg.StableNode if err != nil { return "", false, err } - fts, err := localClient.FileTargets(ctx) + + st, err := localClient.Status(ctx) if err != nil { - return "", false, err - } - for _, ft := range fts { - n := ft.Node - for _, a := range n.Addresses { - if a.Addr() != ip { - continue + // This likely means tailscaled is unreachable or returned an error on /localapi/v0/status. + return "", false, fmt.Errorf("failed to get local status: %w", err) + } + if st == nil { + // Handle the case if the daemon returns nil with no error. + return "", false, errors.New("no status available") + } + if st.Self == nil { + // We have a status structure, but it doesn’t include Self info. Probably not connected. + return "", false, errors.New("local node is not configured or missing Self information") + } + + // Find the PeerStatus that corresponds to ip. + var foundPeer *ipnstate.PeerStatus +peerLoop: + for _, ps := range st.Peer { + for _, pip := range ps.TailscaleIPs { + if pip == ip { + foundPeer = ps + break peerLoop } - isOffline = n.Online != nil && !*n.Online - return n.StableID, isOffline, nil } } - return "", false, fileTargetErrorDetail(ctx, ip) -} -// fileTargetErrorDetail returns a non-nil error saying why ip is an -// invalid file sharing target. -func fileTargetErrorDetail(ctx context.Context, ip netip.Addr) error { - found := false - if st, err := localClient.Status(ctx); err == nil && st.Self != nil { - for _, peer := range st.Peer { - for _, pip := range peer.TailscaleIPs { - if pip == ip { - found = true - if peer.UserID != st.Self.UserID { - return errors.New("owned by different user; can only send files to your own devices") - } - } - } + // If we didn’t find a matching peer at all: + if foundPeer == nil { + if !tsaddr.IsTailscaleIP(ip) { + return "", false, fmt.Errorf("unknown target; %v is not a Tailscale IP address", ip) } + return "", false, errors.New("unknown target; not in your Tailnet") } - if found { - return errors.New("target seems to be running an old Tailscale version") - } - if !tsaddr.IsTailscaleIP(ip) { - return fmt.Errorf("unknown target; %v is not a Tailscale IP address", ip) + + // We found a peer. Decide whether we can send files to it: + isOffline = !foundPeer.Online + + switch foundPeer.TaildropTarget { + case ipnstate.TaildropTargetAvailable: + return foundPeer.ID, isOffline, nil + + case ipnstate.TaildropTargetNoNetmapAvailable: + return "", isOffline, errors.New("cannot send files: no netmap available on this node") + + case ipnstate.TaildropTargetIpnStateNotRunning: + return "", isOffline, errors.New("cannot send files: local Tailscale is not connected to the tailnet") + + case ipnstate.TaildropTargetMissingCap: + return "", isOffline, errors.New("cannot send files: missing required Taildrop capability") + + case ipnstate.TaildropTargetOffline: + return "", isOffline, errors.New("cannot send files: peer is offline") + + case ipnstate.TaildropTargetNoPeerInfo: + return "", isOffline, errors.New("cannot send files: invalid or unrecognized peer") + + case ipnstate.TaildropTargetUnsupportedOS: + return "", isOffline, errors.New("cannot send files: target's OS does not support Taildrop") + + case ipnstate.TaildropTargetNoPeerAPI: + return "", isOffline, errors.New("cannot send files: target is not advertising a file sharing API") + + case ipnstate.TaildropTargetOwnedByOtherUser: + return "", isOffline, errors.New("cannot send files: peer is owned by a different user") + + case ipnstate.TaildropTargetUnknown: + fallthrough + default: + return "", isOffline, fmt.Errorf("cannot send files: unknown or indeterminate reason") } - return errors.New("unknown target; not in your Tailnet") } const maxSniff = 4 << 20 diff --git a/cmd/tailscale/cli/funnel.go b/cmd/tailscale/cli/funnel.go index a95f9e27083b6..f4a1c6bfdb3b8 100644 --- a/cmd/tailscale/cli/funnel.go +++ b/cmd/tailscale/cli/funnel.go @@ -19,7 +19,7 @@ import ( var funnelCmd = func() *ffcli.Command { se := &serveEnv{lc: &localClient} // previously used to serve legacy newFunnelCommand unless useWIPCode is true - // change is limited to make a revert easier and full cleanup to come after the relase. + // change is limited to make a revert easier and full cleanup to come after the release. // TODO(tylersmalley): cleanup and removal of newFunnelCommand as of 2023-10-16 return newServeV2Command(se, funnel) } diff --git a/cmd/tailscale/cli/metrics.go b/cmd/tailscale/cli/metrics.go new file mode 100644 index 0000000000000..dbdedd5a61037 --- /dev/null +++ b/cmd/tailscale/cli/metrics.go @@ -0,0 +1,88 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/atomicfile" +) + +var metricsCmd = &ffcli.Command{ + Name: "metrics", + ShortHelp: "Show Tailscale metrics", + LongHelp: strings.TrimSpace(` + +The 'tailscale metrics' command shows Tailscale user-facing metrics (as opposed +to internal metrics printed by 'tailscale debug metrics'). + +For more information about Tailscale metrics, refer to +https://tailscale.com/s/client-metrics + +`), + ShortUsage: "tailscale metrics [flags]", + UsageFunc: usageFuncNoDefaultValues, + Exec: runMetricsNoSubcommand, + Subcommands: []*ffcli.Command{ + { + Name: "print", + ShortUsage: "tailscale metrics print", + Exec: runMetricsPrint, + ShortHelp: "Print current metric values in Prometheus text format", + }, + { + Name: "write", + ShortUsage: "tailscale metrics write ", + Exec: runMetricsWrite, + ShortHelp: "Write metric values to a file", + LongHelp: strings.TrimSpace(` + +The 'tailscale metrics write' command writes metric values to a text file provided as its +only argument. It's meant to be used alongside Prometheus node exporter, allowing Tailscale +metrics to be consumed and exported by the textfile collector. + +As an example, to export Tailscale metrics on an Ubuntu system running node exporter, you +can regularly run 'tailscale metrics write /var/lib/prometheus/node-exporter/tailscaled.prom' +using cron or a systemd timer. + + `), + }, + }, +} + +// runMetricsNoSubcommand prints metric values if no subcommand is specified. +func runMetricsNoSubcommand(ctx context.Context, args []string) error { + if len(args) > 0 { + return fmt.Errorf("tailscale metrics: unknown subcommand: %s", args[0]) + } + + return runMetricsPrint(ctx, args) +} + +// runMetricsPrint prints metric values to stdout. +func runMetricsPrint(ctx context.Context, args []string) error { + out, err := localClient.UserMetrics(ctx) + if err != nil { + return err + } + Stdout.Write(out) + return nil +} + +// runMetricsWrite writes metric values to a file. +func runMetricsWrite(ctx context.Context, args []string) error { + if len(args) != 1 { + return errors.New("usage: tailscale metrics write ") + } + path := args[0] + out, err := localClient.UserMetrics(ctx) + if err != nil { + return err + } + return atomicfile.WriteFile(path, out, 0644) +} diff --git a/cmd/tailscale/cli/netcheck.go b/cmd/tailscale/cli/netcheck.go index 682cd99a3c6e4..3cf05a3b7987f 100644 --- a/cmd/tailscale/cli/netcheck.go +++ b/cmd/tailscale/cli/netcheck.go @@ -24,6 +24,7 @@ import ( "tailscale.com/net/tlsdial" "tailscale.com/tailcfg" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" ) var netcheckCmd = &ffcli.Command{ @@ -48,14 +49,19 @@ var netcheckArgs struct { func runNetcheck(ctx context.Context, args []string) error { logf := logger.WithPrefix(log.Printf, "portmap: ") - netMon, err := netmon.New(logf) + bus := eventbus.New() + defer bus.Close() + netMon, err := netmon.New(bus, logf) if err != nil { return err } // Ensure that we close the portmapper after running a netcheck; this // will release any port mappings created. - pm := portmapper.NewClient(logf, netMon, nil, nil, nil) + pm := portmapper.NewClient(portmapper.Config{ + Logf: logf, + NetMon: netMon, + }) defer pm.Close() c := &netcheck.Client{ @@ -136,6 +142,7 @@ func printReport(dm *tailcfg.DERPMap, report *netcheck.Report) error { } printf("\nReport:\n") + printf("\t* Time: %v\n", report.Now.Format(time.RFC3339Nano)) printf("\t* UDP: %v\n", report.UDP) if report.GlobalV4.IsValid() { printf("\t* IPv4: yes, %s\n", report.GlobalV4) diff --git a/cmd/tailscale/cli/network-lock.go b/cmd/tailscale/cli/network-lock.go index 45f989f1057a7..c7776707422ec 100644 --- a/cmd/tailscale/cli/network-lock.go +++ b/cmd/tailscale/cli/network-lock.go @@ -191,8 +191,7 @@ var nlStatusArgs struct { var nlStatusCmd = &ffcli.Command{ Name: "status", ShortUsage: "tailscale lock status", - ShortHelp: "Outputs the state of tailnet lock", - LongHelp: "Outputs the state of tailnet lock", + ShortHelp: "Output the state of tailnet lock", Exec: runNetworkLockStatus, FlagSet: (func() *flag.FlagSet { fs := newFlagSet("lock status") @@ -293,8 +292,7 @@ func runNetworkLockStatus(ctx context.Context, args []string) error { var nlAddCmd = &ffcli.Command{ Name: "add", ShortUsage: "tailscale lock add ...", - ShortHelp: "Adds one or more trusted signing keys to tailnet lock", - LongHelp: "Adds one or more trusted signing keys to tailnet lock", + ShortHelp: "Add one or more trusted signing keys to tailnet lock", Exec: func(ctx context.Context, args []string) error { return runNetworkLockModify(ctx, args, nil) }, @@ -307,8 +305,7 @@ var nlRemoveArgs struct { var nlRemoveCmd = &ffcli.Command{ Name: "remove", ShortUsage: "tailscale lock remove [--re-sign=false] ...", - ShortHelp: "Removes one or more trusted signing keys from tailnet lock", - LongHelp: "Removes one or more trusted signing keys from tailnet lock", + ShortHelp: "Remove one or more trusted signing keys from tailnet lock", Exec: runNetworkLockRemove, FlagSet: (func() *flag.FlagSet { fs := newFlagSet("lock remove") @@ -448,7 +445,7 @@ func runNetworkLockModify(ctx context.Context, addArgs, removeArgs []string) err var nlSignCmd = &ffcli.Command{ Name: "sign", ShortUsage: "tailscale lock sign []\ntailscale lock sign ", - ShortHelp: "Signs a node or pre-approved auth key", + ShortHelp: "Sign a node or pre-approved auth key", LongHelp: `Either: - signs a node key and transmits the signature to the coordination server, or @@ -510,7 +507,7 @@ func runNetworkLockSign(ctx context.Context, args []string) error { var nlDisableCmd = &ffcli.Command{ Name: "disable", ShortUsage: "tailscale lock disable ", - ShortHelp: "Consumes a disablement secret to shut down tailnet lock for the tailnet", + ShortHelp: "Consume a disablement secret to shut down tailnet lock for the tailnet", LongHelp: strings.TrimSpace(` The 'tailscale lock disable' command uses the specified disablement @@ -539,7 +536,7 @@ func runNetworkLockDisable(ctx context.Context, args []string) error { var nlLocalDisableCmd = &ffcli.Command{ Name: "local-disable", ShortUsage: "tailscale lock local-disable", - ShortHelp: "Disables tailnet lock for this node only", + ShortHelp: "Disable tailnet lock for this node only", LongHelp: strings.TrimSpace(` The 'tailscale lock local-disable' command disables tailnet lock for only @@ -561,8 +558,8 @@ func runNetworkLockLocalDisable(ctx context.Context, args []string) error { var nlDisablementKDFCmd = &ffcli.Command{ Name: "disablement-kdf", ShortUsage: "tailscale lock disablement-kdf ", - ShortHelp: "Computes a disablement value from a disablement secret (advanced users only)", - LongHelp: "Computes a disablement value from a disablement secret (advanced users only)", + ShortHelp: "Compute a disablement value from a disablement secret (advanced users only)", + LongHelp: "Compute a disablement value from a disablement secret (advanced users only)", Exec: runNetworkLockDisablementKDF, } diff --git a/cmd/tailscale/cli/risks.go b/cmd/tailscale/cli/risks.go index 4cfa50d581ed4..c36ffafaeb11a 100644 --- a/cmd/tailscale/cli/risks.go +++ b/cmd/tailscale/cli/risks.go @@ -17,11 +17,18 @@ import ( ) var ( - riskTypes []string - riskLoseSSH = registerRiskType("lose-ssh") - riskAll = registerRiskType("all") + riskTypes []string + riskLoseSSH = registerRiskType("lose-ssh") + riskMacAppConnector = registerRiskType("mac-app-connector") + riskAll = registerRiskType("all") ) +const riskMacAppConnectorMessage = ` +You are trying to configure an app connector on macOS, which is not officially supported due to system limitations. This may result in performance and reliability issues. + +Do not use a macOS app connector for any mission-critical purposes. For the best experience, Linux is the only recommended platform for app connectors. +` + func registerRiskType(riskType string) string { riskTypes = append(riskTypes, riskType) return riskType @@ -70,7 +77,7 @@ func presentRiskToUser(riskType, riskMessage, acceptedRisks string) error { for left := riskAbortTimeSeconds; left > 0; left-- { msg := fmt.Sprintf("\rContinuing in %d seconds...", left) msgLen = len(msg) - printf(msg) + printf("%s", msg) select { case <-interrupt: printf("\r%s\r", strings.Repeat("x", msgLen+1)) diff --git a/cmd/tailscale/cli/serve_legacy.go b/cmd/tailscale/cli/serve_legacy.go index 443a404abcbf7..96629b5ad45ef 100644 --- a/cmd/tailscale/cli/serve_legacy.go +++ b/cmd/tailscale/cli/serve_legacy.go @@ -27,6 +27,7 @@ import ( "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" + "tailscale.com/util/slicesx" "tailscale.com/version" ) @@ -129,7 +130,7 @@ func (e *serveEnv) newFlags(name string, setup func(fs *flag.FlagSet)) *flag.Fla } // localServeClient is an interface conforming to the subset of -// tailscale.LocalClient. It includes only the methods used by the +// local.Client. It includes only the methods used by the // serve command. // // The purpose of this interface is to allow tests to provide a mock. @@ -707,10 +708,7 @@ func (e *serveEnv) printWebStatusTree(sc *ipn.ServeConfig, hp ipn.HostPort) erro return "", "" } - var mounts []string - for k := range sc.Web[hp].Handlers { - mounts = append(mounts, k) - } + mounts := slicesx.MapKeys(sc.Web[hp].Handlers) sort.Slice(mounts, func(i, j int) bool { return len(mounts[i]) < len(mounts[j]) }) diff --git a/cmd/tailscale/cli/serve_legacy_test.go b/cmd/tailscale/cli/serve_legacy_test.go index 2eb982ca0a1f7..df68b5edd32a1 100644 --- a/cmd/tailscale/cli/serve_legacy_test.go +++ b/cmd/tailscale/cli/serve_legacy_test.go @@ -850,7 +850,7 @@ func TestVerifyFunnelEnabled(t *testing.T) { } } -// fakeLocalServeClient is a fake tailscale.LocalClient for tests. +// fakeLocalServeClient is a fake local.Client for tests. // It's not a full implementation, just enough to test the serve command. // // The fake client is stateful, and is used to test manipulating diff --git a/cmd/tailscale/cli/serve_v2.go b/cmd/tailscale/cli/serve_v2.go index 009a61198dad8..3e173ce28d8c1 100644 --- a/cmd/tailscale/cli/serve_v2.go +++ b/cmd/tailscale/cli/serve_v2.go @@ -28,6 +28,7 @@ import ( "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" "tailscale.com/util/mak" + "tailscale.com/util/slicesx" "tailscale.com/version" ) @@ -439,11 +440,7 @@ func (e *serveEnv) messageForPort(sc *ipn.ServeConfig, st *ipnstate.Status, dnsN } if sc.Web[hp] != nil { - var mounts []string - - for k := range sc.Web[hp].Handlers { - mounts = append(mounts, k) - } + mounts := slicesx.MapKeys(sc.Web[hp].Handlers) sort.Slice(mounts, func(i, j int) bool { return len(mounts[i]) < len(mounts[j]) }) diff --git a/cmd/tailscale/cli/set.go b/cmd/tailscale/cli/set.go index 2e1251f04a4b9..aa59666987532 100644 --- a/cmd/tailscale/cli/set.go +++ b/cmd/tailscale/cli/set.go @@ -10,6 +10,8 @@ import ( "fmt" "net/netip" "os/exec" + "runtime" + "strconv" "strings" "github.com/peterbourgon/ff/v3/ffcli" @@ -21,6 +23,7 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/safesocket" "tailscale.com/types/opt" + "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/version" ) @@ -57,18 +60,19 @@ type setArgsT struct { forceDaemon bool updateCheck bool updateApply bool - postureChecking bool + reportPosture bool snat bool statefulFiltering bool netfilterMode string + relayServerPort string } func newSetFlagSet(goos string, setArgs *setArgsT) *flag.FlagSet { setf := newFlagSet("set") setf.StringVar(&setArgs.profileName, "nickname", "", "nickname for the current account") - setf.BoolVar(&setArgs.acceptRoutes, "accept-routes", false, "accept routes advertised by other Tailscale nodes") - setf.BoolVar(&setArgs.acceptDNS, "accept-dns", false, "accept DNS configuration from the admin panel") + setf.BoolVar(&setArgs.acceptRoutes, "accept-routes", acceptRouteDefault(goos), "accept routes advertised by other Tailscale nodes") + setf.BoolVar(&setArgs.acceptDNS, "accept-dns", true, "accept DNS configuration from the admin panel") setf.StringVar(&setArgs.exitNodeIP, "exit-node", "", "Tailscale exit node (IP or base name) for internet traffic, or empty string to not use an exit node") setf.BoolVar(&setArgs.exitNodeAllowLANAccess, "exit-node-allow-lan-access", false, "Allow direct access to the local network when routing traffic via an exit node") setf.BoolVar(&setArgs.shieldsUp, "shields-up", false, "don't allow incoming connections") @@ -79,8 +83,9 @@ func newSetFlagSet(goos string, setArgs *setArgsT) *flag.FlagSet { setf.BoolVar(&setArgs.advertiseConnector, "advertise-connector", false, "offer to be an app connector for domain specific internet traffic for the tailnet") setf.BoolVar(&setArgs.updateCheck, "update-check", true, "notify about available Tailscale updates") setf.BoolVar(&setArgs.updateApply, "auto-update", false, "automatically update to the latest available version") - setf.BoolVar(&setArgs.postureChecking, "posture-checking", false, hidden+"allow management plane to gather device posture information") + setf.BoolVar(&setArgs.reportPosture, "report-posture", false, "allow management plane to gather device posture information") setf.BoolVar(&setArgs.runWebClient, "webclient", false, "expose the web interface for managing this node over Tailscale at port 5252") + setf.StringVar(&setArgs.relayServerPort, "relay-server-port", "", hidden+"UDP port number (0 will pick a random unused port) for the relay server to bind to, on all interfaces, or empty string to disable relay server functionality") ffcomplete.Flag(setf, "exit-node", func(args []string) ([]string, ffcomplete.ShellCompDirective, error) { st, err := localClient.Status(context.Background()) @@ -151,7 +156,7 @@ func runSet(ctx context.Context, args []string) (retErr error) { AppConnector: ipn.AppConnectorPrefs{ Advertise: setArgs.advertiseConnector, }, - PostureChecking: setArgs.postureChecking, + PostureChecking: setArgs.reportPosture, NoStatefulFiltering: opt.NewBool(!setArgs.statefulFiltering), }, } @@ -203,6 +208,12 @@ func runSet(ctx context.Context, args []string) (retErr error) { } } + if runtime.GOOS == "darwin" && maskedPrefs.AppConnector.Advertise { + if err := presentRiskToUser(riskMacAppConnector, riskMacAppConnectorMessage, setArgs.acceptedRisks); err != nil { + return err + } + } + if maskedPrefs.RunSSHSet { wantSSH, haveSSH := maskedPrefs.RunSSH, curPrefs.RunSSH if err := presentSSHToggleRisk(wantSSH, haveSSH, setArgs.acceptedRisks); err != nil { @@ -226,6 +237,15 @@ func runSet(ctx context.Context, args []string) (retErr error) { } } } + + if setArgs.relayServerPort != "" { + uport, err := strconv.ParseUint(setArgs.relayServerPort, 10, 16) + if err != nil { + return fmt.Errorf("failed to set relay server port: %v", err) + } + maskedPrefs.Prefs.RelayServerPort = ptr.To(int(uport)) + } + checkPrefs := curPrefs.Clone() checkPrefs.ApplyEdits(maskedPrefs) if err := localClient.CheckPrefs(ctx, checkPrefs); err != nil { diff --git a/cmd/tailscale/cli/set_test.go b/cmd/tailscale/cli/set_test.go index 15305c3ce3ed3..a2f211f8cdc36 100644 --- a/cmd/tailscale/cli/set_test.go +++ b/cmd/tailscale/cli/set_test.go @@ -4,6 +4,7 @@ package cli import ( + "flag" "net/netip" "reflect" "testing" @@ -129,3 +130,24 @@ func TestCalcAdvertiseRoutesForSet(t *testing.T) { }) } } + +// TestSetDefaultsMatchUpDefaults is meant to ensure that the default values +// for `tailscale set` and `tailscale up` are the same. +// Since `tailscale set` only sets preferences that are explicitly mentioned, +// the default values for its flags are only used for `--help` documentation. +func TestSetDefaultsMatchUpDefaults(t *testing.T) { + upFlagSet.VisitAll(func(up *flag.Flag) { + if preflessFlag(up.Name) { + return + } + + set := setFlagSet.Lookup(up.Name) + if set == nil { + return + } + + if set.DefValue != up.DefValue { + t.Errorf("--%s: set defaults to %q, but up defaults to %q", up.Name, set.DefValue, up.DefValue) + } + }) +} diff --git a/cmd/tailscale/cli/ssh.go b/cmd/tailscale/cli/ssh.go index 68a6193af9d1e..ba70e97e9f925 100644 --- a/cmd/tailscale/cli/ssh.go +++ b/cmd/tailscale/cli/ssh.go @@ -84,10 +84,6 @@ func runSSH(ctx context.Context, args []string) error { // of failing. But for now: return fmt.Errorf("no system 'ssh' command found: %w", err) } - tailscaleBin, err := os.Executable() - if err != nil { - return err - } knownHostsFile, err := writeKnownHosts(st) if err != nil { return err @@ -116,7 +112,9 @@ func runSSH(ctx context.Context, args []string) error { argv = append(argv, "-o", fmt.Sprintf("ProxyCommand %q %s nc %%h %%p", - tailscaleBin, + // os.Executable() would return the real running binary but in case tailscale is built with the ts_include_cli tag, + // we need to return the started symlink instead + os.Args[0], socketArg, )) } diff --git a/cmd/tailscale/cli/switch.go b/cmd/tailscale/cli/switch.go index 731492daaa976..af8b513263d37 100644 --- a/cmd/tailscale/cli/switch.go +++ b/cmd/tailscale/cli/switch.go @@ -20,7 +20,7 @@ import ( var switchCmd = &ffcli.Command{ Name: "switch", ShortUsage: "tailscale switch ", - ShortHelp: "Switches to a different Tailscale account", + ShortHelp: "Switch to a different Tailscale account", LongHelp: `"tailscale switch" switches between logged in accounts. You can use the ID that's returned from 'tailnet switch -list' to pick which profile you want to switch to. Alternatively, you diff --git a/cmd/tailscale/cli/syspolicy.go b/cmd/tailscale/cli/syspolicy.go new file mode 100644 index 0000000000000..a71952a9f7f62 --- /dev/null +++ b/cmd/tailscale/cli/syspolicy.go @@ -0,0 +1,110 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "os" + "slices" + "text/tabwriter" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/util/syspolicy/setting" +) + +var syspolicyArgs struct { + json bool // JSON output mode +} + +var syspolicyCmd = &ffcli.Command{ + Name: "syspolicy", + ShortHelp: "Diagnose the MDM and system policy configuration", + LongHelp: "The 'tailscale syspolicy' command provides tools for diagnosing the MDM and system policy configuration.", + ShortUsage: "tailscale syspolicy ", + UsageFunc: usageFuncNoDefaultValues, + Subcommands: []*ffcli.Command{ + { + Name: "list", + ShortUsage: "tailscale syspolicy list", + Exec: runSysPolicyList, + ShortHelp: "Print effective policy settings", + LongHelp: "The 'tailscale syspolicy list' subcommand displays the effective policy settings and their sources (e.g., MDM or environment variables).", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("syspolicy list") + fs.BoolVar(&syspolicyArgs.json, "json", false, "output in JSON format") + return fs + })(), + }, + { + Name: "reload", + ShortUsage: "tailscale syspolicy reload", + Exec: runSysPolicyReload, + ShortHelp: "Force a reload of policy settings, even if no changes are detected, and prints the result", + LongHelp: "The 'tailscale syspolicy reload' subcommand forces a reload of policy settings, even if no changes are detected, and prints the result.", + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("syspolicy reload") + fs.BoolVar(&syspolicyArgs.json, "json", false, "output in JSON format") + return fs + })(), + }, + }, +} + +func runSysPolicyList(ctx context.Context, args []string) error { + policy, err := localClient.GetEffectivePolicy(ctx, setting.DefaultScope()) + if err != nil { + return err + } + printPolicySettings(policy) + return nil + +} + +func runSysPolicyReload(ctx context.Context, args []string) error { + policy, err := localClient.ReloadEffectivePolicy(ctx, setting.DefaultScope()) + if err != nil { + return err + } + printPolicySettings(policy) + return nil +} + +func printPolicySettings(policy *setting.Snapshot) { + if syspolicyArgs.json { + json, err := json.MarshalIndent(policy, "", "\t") + if err != nil { + errf("syspolicy marshalling error: %v", err) + } else { + outln(string(json)) + } + return + } + if policy.Len() == 0 { + outln("No policy settings") + return + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "Name\tOrigin\tValue\tError") + fmt.Fprintln(w, "----\t------\t-----\t-----") + for _, k := range slices.Sorted(policy.Keys()) { + setting, _ := policy.GetSetting(k) + var origin string + if o := setting.Origin(); o != nil { + origin = o.String() + } + if err := setting.Error(); err != nil { + fmt.Fprintf(w, "%s\t%s\t\t{%v}\n", k, origin, err) + } else { + fmt.Fprintf(w, "%s\t%s\t%v\t\n", k, origin, setting.Value()) + } + } + w.Flush() + + fmt.Println() + return +} diff --git a/cmd/tailscale/cli/up.go b/cmd/tailscale/cli/up.go index bf6a9af773f60..e4bb6f576759e 100644 --- a/cmd/tailscale/cli/up.go +++ b/cmd/tailscale/cli/up.go @@ -27,8 +27,8 @@ import ( "github.com/peterbourgon/ff/v3/ffcli" qrcode "github.com/skip2/go-qrcode" "golang.org/x/oauth2/clientcredentials" - "tailscale.com/client/tailscale" "tailscale.com/health/healthmsg" + "tailscale.com/internal/client/tailscale" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/net/netutil" @@ -39,7 +39,6 @@ import ( "tailscale.com/types/preftype" "tailscale.com/types/views" "tailscale.com/util/dnsname" - "tailscale.com/version" "tailscale.com/version/distro" ) @@ -79,14 +78,8 @@ func effectiveGOOS() string { // acceptRouteDefault returns the CLI's default value of --accept-routes as // a function of the platform it's running on. func acceptRouteDefault(goos string) bool { - switch goos { - case "windows": - return true - case "darwin": - return version.IsSandboxedMacOS() - default: - return false - } + var p *ipn.Prefs + return p.DefaultRouteAll(goos) } var upFlagSet = newUpFlagSet(effectiveGOOS(), &upArgsGlobal, "up") @@ -116,6 +109,7 @@ func newUpFlagSet(goos string, upArgs *upArgsT, cmd string) *flag.FlagSet { upf.StringVar(&upArgs.advertiseRoutes, "advertise-routes", "", "routes to advertise to other nodes (comma-separated, e.g. \"10.0.0.0/8,192.168.0.0/24\") or empty string to not advertise routes") upf.BoolVar(&upArgs.advertiseConnector, "advertise-connector", false, "advertise this node as an app connector") upf.BoolVar(&upArgs.advertiseDefaultRoute, "advertise-exit-node", false, "offer to be an exit node for internet traffic for the tailnet") + upf.BoolVar(&upArgs.postureChecking, "report-posture", false, hidden+"allow management plane to gather device posture information") if safesocket.GOOSUsesPeerCreds(goos) { upf.StringVar(&upArgs.opUser, "operator", "", "Unix username to allow to operate on tailscaled without sudo") @@ -138,7 +132,7 @@ func newUpFlagSet(goos string, upArgs *upArgsT, cmd string) *flag.FlagSet { // Some flags are only for "up", not "login". upf.BoolVar(&upArgs.json, "json", false, "output in JSON format (WARNING: format subject to change)") upf.BoolVar(&upArgs.reset, "reset", false, "reset unspecified settings to their default values") - upf.BoolVar(&upArgs.forceReauth, "force-reauth", false, "force reauthentication") + upf.BoolVar(&upArgs.forceReauth, "force-reauth", false, "force reauthentication (WARNING: this will bring down the Tailscale connection and thus should not be done remotely over SSH or RDP)") registerAcceptRiskFlag(upf, &upArgs.acceptedRisks) } @@ -164,6 +158,9 @@ func defaultNetfilterMode() string { return "on" } +// upArgsT is the type of upArgs, the argument struct for `tailscale up`. +// As of 2024-10-08, upArgsT is frozen and no new arguments should be +// added to it. Add new arguments to setArgsT instead. type upArgsT struct { qr bool reset bool @@ -191,6 +188,7 @@ type upArgsT struct { timeout time.Duration acceptedRisks string profileName string + postureChecking bool } func (a upArgsT) getAuthKey() (string, error) { @@ -301,6 +299,7 @@ func prefsFromUpArgs(upArgs upArgsT, warnf logger.Logf, st *ipnstate.Status, goo prefs.OperatorUser = upArgs.opUser prefs.ProfileName = upArgs.profileName prefs.AppConnector.Advertise = upArgs.advertiseConnector + prefs.PostureChecking = upArgs.postureChecking if goos == "linux" { prefs.NoSNAT = !upArgs.snat @@ -376,6 +375,12 @@ func updatePrefs(prefs, curPrefs *ipn.Prefs, env upCheckEnv) (simpleUp bool, jus return false, nil, err } + if env.goos == "darwin" && env.upArgs.advertiseConnector { + if err := presentRiskToUser(riskMacAppConnector, riskMacAppConnectorMessage, env.upArgs.acceptedRisks); err != nil { + return false, nil, err + } + } + if env.upArgs.forceReauth && isSSHOverTailscale() { if err := presentRiskToUser(riskLoseSSH, `You are connected over Tailscale; this action will result in your SSH session disconnecting.`, env.upArgs.acceptedRisks); err != nil { return false, nil, err @@ -767,7 +772,8 @@ func init() { addPrefFlagMapping("update-check", "AutoUpdate.Check") addPrefFlagMapping("auto-update", "AutoUpdate.Apply") addPrefFlagMapping("advertise-connector", "AppConnector") - addPrefFlagMapping("posture-checking", "PostureChecking") + addPrefFlagMapping("report-posture", "PostureChecking") + addPrefFlagMapping("relay-server-port", "RelayServerPort") } func addPrefFlagMapping(flagName string, prefNames ...string) { @@ -1044,6 +1050,8 @@ func prefsToFlags(env upCheckEnv, prefs *ipn.Prefs) (flagVal map[string]any) { set(prefs.NetfilterMode.String()) case "unattended": set(prefs.ForceDaemon) + case "report-posture": + set(prefs.PostureChecking) } }) return ret @@ -1083,12 +1091,6 @@ func exitNodeIP(p *ipn.Prefs, st *ipnstate.Status) (ip netip.Addr) { return } -func init() { - // Required to use our client API. We're fine with the instability since the - // client lives in the same repo as this code. - tailscale.I_Acknowledge_This_API_Is_Unstable = true -} - // resolveAuthKey either returns v unchanged (in the common case) or, if it // starts with "tskey-client-" (as Tailscale OAuth secrets do) parses it like // @@ -1148,7 +1150,6 @@ func resolveAuthKey(ctx context.Context, v, tags string) (string, error) { ClientID: "some-client-id", // ignored ClientSecret: clientSecret, TokenURL: baseURL + "/api/v2/oauth/token", - Scopes: []string{"device"}, } tsClient := tailscale.NewClient("-", nil) diff --git a/cmd/tailscale/cli/up_test.go b/cmd/tailscale/cli/up_test.go new file mode 100644 index 0000000000000..eb06f84dce2ea --- /dev/null +++ b/cmd/tailscale/cli/up_test.go @@ -0,0 +1,55 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cli + +import ( + "flag" + "testing" + + "tailscale.com/util/set" +) + +// validUpFlags are the only flags that are valid for tailscale up. The up +// command is frozen: no new preferences can be added. Instead, add them to +// tailscale set. +// See tailscale/tailscale#15460. +var validUpFlags = set.Of( + "accept-dns", + "accept-risk", + "accept-routes", + "advertise-connector", + "advertise-exit-node", + "advertise-routes", + "advertise-tags", + "auth-key", + "exit-node", + "exit-node-allow-lan-access", + "force-reauth", + "host-routes", + "hostname", + "json", + "login-server", + "netfilter-mode", + "nickname", + "operator", + "report-posture", + "qr", + "reset", + "shields-up", + "snat-subnet-routes", + "ssh", + "stateful-filtering", + "timeout", + "unattended", +) + +// TestUpFlagSetIsFrozen complains when new flags are added to tailscale up. +func TestUpFlagSetIsFrozen(t *testing.T) { + upFlagSet.VisitAll(func(f *flag.Flag) { + name := f.Name + if !validUpFlags.Contains(name) { + t.Errorf("--%s flag added to tailscale up, new prefs go in tailscale set: see tailscale/tailscale#15460", name) + } + }) +} diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 2c644d1be7d79..03bf2f94ca4df 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -5,7 +5,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep W đŸ’Ŗ github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/internal/common+ W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate W đŸ’Ŗ github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy - github.com/coder/websocket from tailscale.com/control/controlhttp+ + github.com/coder/websocket from tailscale.com/util/eventbus github.com/coder/websocket/internal/errd from github.com/coder/websocket github.com/coder/websocket/internal/util from github.com/coder/websocket github.com/coder/websocket/internal/xsync from github.com/coder/websocket @@ -30,7 +30,6 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep github.com/gorilla/csrf from tailscale.com/client/web github.com/gorilla/securecookie from github.com/gorilla/csrf github.com/hdevalence/ed25519consensus from tailscale.com/clientupdate/distsign+ - L github.com/josharian/native from github.com/mdlayher/netlink+ L đŸ’Ŗ github.com/jsimonetti/rtnetlink from tailscale.com/net/netmon L github.com/jsimonetti/rtnetlink/internal/unix from github.com/jsimonetti/rtnetlink github.com/kballard/go-shellquote from tailscale.com/cmd/tailscale/cli @@ -62,11 +61,10 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep L đŸ’Ŗ github.com/tailscale/netlink from tailscale.com/util/linuxfw L đŸ’Ŗ github.com/tailscale/netlink/nl from github.com/tailscale/netlink github.com/tailscale/web-client-prebuilt from tailscale.com/client/web - github.com/tcnksm/go-httpstat from tailscale.com/net/netcheck github.com/toqueteos/webbrowser from tailscale.com/cmd/tailscale/cli L github.com/vishvananda/netns from github.com/tailscale/netlink+ github.com/x448/float16 from github.com/fxamacker/cbor/v2 - đŸ’Ŗ go4.org/mem from tailscale.com/client/tailscale+ + đŸ’Ŗ go4.org/mem from tailscale.com/client/local+ go4.org/netipx from tailscale.com/net/tsaddr W đŸ’Ŗ golang.zx2c4.com/wireguard/windows/tunnel/winipcfg from tailscale.com/net/netmon+ k8s.io/client-go/util/homedir from tailscale.com/cmd/tailscale/cli @@ -75,8 +73,9 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep software.sslmate.com/src/go-pkcs12 from tailscale.com/cmd/tailscale/cli software.sslmate.com/src/go-pkcs12/internal/rc2 from software.sslmate.com/src/go-pkcs12 tailscale.com from tailscale.com/version - tailscale.com/atomicfile from tailscale.com/cmd/tailscale/cli+ - tailscale.com/client/tailscale from tailscale.com/client/web+ + đŸ’Ŗ tailscale.com/atomicfile from tailscale.com/cmd/tailscale/cli+ + tailscale.com/client/local from tailscale.com/client/tailscale+ + tailscale.com/client/tailscale from tailscale.com/cmd/tailscale/cli+ tailscale.com/client/tailscale/apitype from tailscale.com/client/tailscale+ tailscale.com/client/web from tailscale.com/cmd/tailscale/cli tailscale.com/clientupdate from tailscale.com/client/web+ @@ -86,58 +85,65 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/cmd/tailscale/cli/ffcomplete/internal from tailscale.com/cmd/tailscale/cli/ffcomplete tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ tailscale.com/control/controlhttp from tailscale.com/cmd/tailscale/cli + tailscale.com/control/controlhttp/controlhttpcommon from tailscale.com/control/controlhttp tailscale.com/control/controlknobs from tailscale.com/net/portmapper tailscale.com/derp from tailscale.com/derp/derphttp + tailscale.com/derp/derpconst from tailscale.com/derp+ tailscale.com/derp/derphttp from tailscale.com/net/netcheck tailscale.com/disco from tailscale.com/derp - tailscale.com/drive from tailscale.com/client/tailscale+ - tailscale.com/envknob from tailscale.com/client/tailscale+ + tailscale.com/drive from tailscale.com/client/local+ + tailscale.com/envknob from tailscale.com/client/local+ + tailscale.com/envknob/featureknob from tailscale.com/client/web + tailscale.com/feature from tailscale.com/tsweb + tailscale.com/feature/capture/dissector from tailscale.com/cmd/tailscale/cli tailscale.com/health from tailscale.com/net/tlsdial+ tailscale.com/health/healthmsg from tailscale.com/cmd/tailscale/cli tailscale.com/hostinfo from tailscale.com/client/web+ + tailscale.com/internal/client/tailscale from tailscale.com/cmd/tailscale/cli tailscale.com/internal/noiseconn from tailscale.com/cmd/tailscale/cli - tailscale.com/ipn from tailscale.com/client/tailscale+ - tailscale.com/ipn/ipnstate from tailscale.com/client/tailscale+ + tailscale.com/ipn from tailscale.com/client/local+ + tailscale.com/ipn/ipnstate from tailscale.com/client/local+ tailscale.com/kube/kubetypes from tailscale.com/envknob tailscale.com/licenses from tailscale.com/client/web+ tailscale.com/metrics from tailscale.com/derp+ + tailscale.com/net/bakedroots from tailscale.com/net/tlsdial tailscale.com/net/captivedetection from tailscale.com/net/netcheck tailscale.com/net/dns/recursive from tailscale.com/net/dnsfallback tailscale.com/net/dnscache from tailscale.com/control/controlhttp+ tailscale.com/net/dnsfallback from tailscale.com/control/controlhttp+ - tailscale.com/net/flowtrack from tailscale.com/net/packet tailscale.com/net/netaddr from tailscale.com/ipn+ tailscale.com/net/netcheck from tailscale.com/cmd/tailscale/cli tailscale.com/net/neterror from tailscale.com/net/netcheck+ tailscale.com/net/netknob from tailscale.com/net/netns đŸ’Ŗ tailscale.com/net/netmon from tailscale.com/cmd/tailscale/cli+ đŸ’Ŗ tailscale.com/net/netns from tailscale.com/derp/derphttp+ - tailscale.com/net/netutil from tailscale.com/client/tailscale+ - tailscale.com/net/packet from tailscale.com/wgengine/capture + tailscale.com/net/netutil from tailscale.com/client/local+ + tailscale.com/net/netx from tailscale.com/control/controlhttp+ tailscale.com/net/ping from tailscale.com/net/netcheck tailscale.com/net/portmapper from tailscale.com/cmd/tailscale/cli+ tailscale.com/net/sockstats from tailscale.com/control/controlhttp+ tailscale.com/net/stun from tailscale.com/net/netcheck L tailscale.com/net/tcpinfo from tailscale.com/derp tailscale.com/net/tlsdial from tailscale.com/cmd/tailscale/cli+ + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial tailscale.com/net/tsaddr from tailscale.com/client/web+ đŸ’Ŗ tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ - tailscale.com/net/wsconn from tailscale.com/control/controlhttp+ - tailscale.com/paths from tailscale.com/client/tailscale+ - đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/tailscale+ + tailscale.com/paths from tailscale.com/client/local+ + đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/local+ tailscale.com/syncs from tailscale.com/cmd/tailscale/cli+ - tailscale.com/tailcfg from tailscale.com/client/tailscale+ + tailscale.com/tailcfg from tailscale.com/client/local+ tailscale.com/tempfork/spf13/cobra from tailscale.com/cmd/tailscale/cli/ffcomplete+ - tailscale.com/tka from tailscale.com/client/tailscale+ + tailscale.com/tka from tailscale.com/client/local+ tailscale.com/tsconst from tailscale.com/net/netmon+ tailscale.com/tstime from tailscale.com/control/controlhttp+ tailscale.com/tstime/mono from tailscale.com/tstime/rate tailscale.com/tstime/rate from tailscale.com/cmd/tailscale/cli+ - tailscale.com/tsweb/varz from tailscale.com/util/usermetric + tailscale.com/tsweb from tailscale.com/util/eventbus + tailscale.com/tsweb/varz from tailscale.com/util/usermetric+ tailscale.com/types/dnstype from tailscale.com/tailcfg+ tailscale.com/types/empty from tailscale.com/ipn - tailscale.com/types/ipproto from tailscale.com/net/flowtrack+ - tailscale.com/types/key from tailscale.com/client/tailscale+ + tailscale.com/types/ipproto from tailscale.com/ipn+ + tailscale.com/types/key from tailscale.com/client/local+ tailscale.com/types/lazy from tailscale.com/util/testenv+ tailscale.com/types/logger from tailscale.com/client/web+ tailscale.com/types/netmap from tailscale.com/ipn+ @@ -146,6 +152,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/types/persist from tailscale.com/ipn tailscale.com/types/preftype from tailscale.com/cmd/tailscale/cli+ tailscale.com/types/ptr from tailscale.com/hostinfo+ + tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/ipn+ tailscale.com/types/tkatype from tailscale.com/types/key+ tailscale.com/types/views from tailscale.com/tailcfg+ @@ -153,37 +160,42 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/util/clientmetric from tailscale.com/net/netcheck+ tailscale.com/util/cloudenv from tailscale.com/net/dnscache+ tailscale.com/util/cmpver from tailscale.com/net/tshttpproxy+ - tailscale.com/util/ctxkey from tailscale.com/types/logger + tailscale.com/util/ctxkey from tailscale.com/types/logger+ đŸ’Ŗ tailscale.com/util/deephash from tailscale.com/util/syspolicy/setting L đŸ’Ŗ tailscale.com/util/dirwalk from tailscale.com/metrics tailscale.com/util/dnsname from tailscale.com/cmd/tailscale/cli+ + tailscale.com/util/eventbus from tailscale.com/net/portmapper+ tailscale.com/util/groupmember from tailscale.com/client/web đŸ’Ŗ tailscale.com/util/hashx from tailscale.com/util/deephash tailscale.com/util/httpm from tailscale.com/client/tailscale+ - tailscale.com/util/lineread from tailscale.com/hostinfo+ + tailscale.com/util/lineiter from tailscale.com/hostinfo+ L tailscale.com/util/linuxfw from tailscale.com/net/netns tailscale.com/util/mak from tailscale.com/cmd/tailscale/cli+ tailscale.com/util/multierr from tailscale.com/control/controlhttp+ tailscale.com/util/must from tailscale.com/clientupdate/distsign+ tailscale.com/util/nocasemaps from tailscale.com/types/ipproto tailscale.com/util/quarantine from tailscale.com/cmd/tailscale/cli + tailscale.com/util/rands from tailscale.com/tsweb tailscale.com/util/set from tailscale.com/derp+ tailscale.com/util/singleflight from tailscale.com/net/dnscache+ tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+ tailscale.com/util/syspolicy from tailscale.com/ipn tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+ - tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy - tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy - tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ + tailscale.com/util/testenv from tailscale.com/cmd/tailscale/cli+ tailscale.com/util/truncate from tailscale.com/cmd/tailscale/cli tailscale.com/util/usermetric from tailscale.com/health tailscale.com/util/vizerror from tailscale.com/tailcfg+ W đŸ’Ŗ tailscale.com/util/winutil from tailscale.com/clientupdate+ W đŸ’Ŗ tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate + W đŸ’Ŗ tailscale.com/util/winutil/gp from tailscale.com/util/syspolicy/source W đŸ’Ŗ tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ tailscale.com/version from tailscale.com/client/web+ tailscale.com/version/distro from tailscale.com/client/web+ - tailscale.com/wgengine/capture from tailscale.com/cmd/tailscale/cli tailscale.com/wgengine/filter/filtertype from tailscale.com/types/netmap golang.org/x/crypto/argon2 from tailscale.com/tka golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ @@ -193,14 +205,15 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep golang.org/x/crypto/cryptobyte from crypto/ecdsa+ golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ golang.org/x/crypto/curve25519 from golang.org/x/crypto/nacl/box+ - golang.org/x/crypto/hkdf from crypto/tls+ + golang.org/x/crypto/hkdf from tailscale.com/control/controlbase + golang.org/x/crypto/internal/alias from golang.org/x/crypto/chacha20+ + golang.org/x/crypto/internal/poly1305 from golang.org/x/crypto/chacha20poly1305+ golang.org/x/crypto/nacl/box from tailscale.com/types/key golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/pbkdf2 from software.sslmate.com/src/go-pkcs12 golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - golang.org/x/crypto/sha3 from crypto/internal/mlkem768+ - W golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+ - golang.org/x/exp/maps from tailscale.com/cmd/tailscale/cli+ + golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+ + golang.org/x/exp/maps from tailscale.com/util/syspolicy/internal/metrics+ golang.org/x/net/bpf from github.com/mdlayher/netlink+ golang.org/x/net/dns/dnsmessage from net+ golang.org/x/net/http/httpguts from net/http+ @@ -209,6 +222,10 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep golang.org/x/net/http2/hpack from net/http+ golang.org/x/net/icmp from tailscale.com/net/ping golang.org/x/net/idna from golang.org/x/net/http/httpguts+ + golang.org/x/net/internal/httpcommon from golang.org/x/net/http2 + golang.org/x/net/internal/iana from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socks from golang.org/x/net/proxy golang.org/x/net/ipv4 from github.com/miekg/dns+ golang.org/x/net/ipv6 from github.com/miekg/dns+ golang.org/x/net/proxy from tailscale.com/net/netns @@ -217,7 +234,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep golang.org/x/oauth2/clientcredentials from tailscale.com/cmd/tailscale/cli golang.org/x/oauth2/internal from golang.org/x/oauth2+ golang.org/x/sync/errgroup from github.com/mdlayher/socket+ - golang.org/x/sys/cpu from github.com/josharian/native+ + golang.org/x/sys/cpu from golang.org/x/crypto/argon2+ LD golang.org/x/sys/unix from github.com/google/nftables+ W golang.org/x/sys/windows from github.com/dblohm7/wingoes+ W golang.org/x/sys/windows/registry from github.com/dblohm7/wingoes+ @@ -238,7 +255,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep container/list from crypto/tls+ context from crypto/tls+ crypto from crypto/ecdh+ - crypto/aes from crypto/ecdsa+ + crypto/aes from crypto/internal/hpke+ crypto/cipher from crypto/aes+ crypto/des from crypto/tls+ crypto/dsa from crypto/x509 @@ -247,21 +264,61 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep crypto/ed25519 from crypto/tls+ crypto/elliptic from crypto/ecdsa+ crypto/hmac from crypto/tls+ + crypto/internal/boring from crypto/aes+ + crypto/internal/boring/bbig from crypto/ecdsa+ + crypto/internal/boring/sig from crypto/internal/boring + crypto/internal/entropy from crypto/internal/fips140/drbg + crypto/internal/fips140 from crypto/internal/fips140/aes+ + crypto/internal/fips140/aes from crypto/aes+ + crypto/internal/fips140/aes/gcm from crypto/cipher+ + crypto/internal/fips140/alias from crypto/cipher+ + crypto/internal/fips140/bigmod from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/check from crypto/internal/fips140/aes+ + crypto/internal/fips140/drbg from crypto/internal/fips140/aes/gcm+ + crypto/internal/fips140/ecdh from crypto/ecdh + crypto/internal/fips140/ecdsa from crypto/ecdsa + crypto/internal/fips140/ed25519 from crypto/ed25519 + crypto/internal/fips140/edwards25519 from crypto/internal/fips140/ed25519 + crypto/internal/fips140/edwards25519/field from crypto/ecdh+ + crypto/internal/fips140/hkdf from crypto/internal/fips140/tls13+ + crypto/internal/fips140/hmac from crypto/hmac+ + crypto/internal/fips140/mlkem from crypto/tls + crypto/internal/fips140/nistec from crypto/elliptic+ + crypto/internal/fips140/nistec/fiat from crypto/internal/fips140/nistec + crypto/internal/fips140/rsa from crypto/rsa + crypto/internal/fips140/sha256 from crypto/internal/fips140/check+ + crypto/internal/fips140/sha3 from crypto/internal/fips140/hmac+ + crypto/internal/fips140/sha512 from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/subtle from crypto/internal/fips140/aes+ + crypto/internal/fips140/tls12 from crypto/tls + crypto/internal/fips140/tls13 from crypto/tls + crypto/internal/fips140deps/byteorder from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/cpu from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/godebug from crypto/internal/fips140+ + crypto/internal/fips140hash from crypto/ecdsa+ + crypto/internal/fips140only from crypto/cipher+ + crypto/internal/hpke from crypto/tls + crypto/internal/impl from crypto/internal/fips140/aes+ + crypto/internal/randutil from crypto/dsa+ + crypto/internal/sysrand from crypto/internal/entropy+ crypto/md5 from crypto/tls+ crypto/rand from crypto/ed25519+ crypto/rc4 from crypto/tls crypto/rsa from crypto/tls+ crypto/sha1 from crypto/tls+ crypto/sha256 from crypto/tls+ + crypto/sha3 from crypto/internal/fips140hash crypto/sha512 from crypto/ecdsa+ - crypto/subtle from crypto/aes+ + crypto/subtle from crypto/cipher+ crypto/tls from github.com/miekg/dns+ + crypto/tls/internal/fips140tls from crypto/tls crypto/x509 from crypto/tls+ + D crypto/x509/internal/macos from crypto/x509 crypto/x509/pkix from crypto/x509+ DW database/sql/driver from github.com/google/uuid W debug/dwarf from debug/pe W debug/pe from github.com/dblohm7/wingoes/pe - embed from crypto/internal/nistec+ + embed from github.com/peterbourgon/ff/v3+ encoding from encoding/gob+ encoding/asn1 from crypto/x509+ encoding/base32 from github.com/fxamacker/cbor/v2+ @@ -281,10 +338,51 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep hash/crc32 from compress/gzip+ hash/maphash from go4.org/mem html from html/template+ - html/template from github.com/gorilla/csrf + html/template from github.com/gorilla/csrf+ image from github.com/skip2/go-qrcode+ image/color from github.com/skip2/go-qrcode+ image/png from github.com/skip2/go-qrcode + internal/abi from crypto/x509/internal/macos+ + internal/asan from internal/runtime/maps+ + internal/bisect from internal/godebug + internal/bytealg from bytes+ + internal/byteorder from crypto/cipher+ + internal/chacha8rand from math/rand/v2+ + internal/coverage/rtcov from runtime + internal/cpu from crypto/internal/fips140deps/cpu+ + internal/filepathlite from os+ + internal/fmtsort from fmt+ + internal/goarch from crypto/internal/fips140deps/cpu+ + internal/godebug from archive/tar+ + internal/godebugs from internal/godebug+ + internal/goexperiment from hash/maphash+ + internal/goos from crypto/x509+ + internal/itoa from internal/poll+ + internal/msan from internal/runtime/maps+ + internal/nettrace from net+ + internal/oserror from io/fs+ + internal/poll from net+ + internal/profile from net/http/pprof + internal/profilerecord from runtime+ + internal/race from internal/poll+ + internal/reflectlite from context+ + internal/runtime/atomic from internal/runtime/exithook+ + internal/runtime/exithook from runtime + internal/runtime/maps from reflect+ + internal/runtime/math from internal/runtime/maps+ + internal/runtime/sys from crypto/subtle+ + L internal/runtime/syscall from runtime+ + internal/saferio from debug/pe+ + internal/singleflight from net + internal/stringslite from embed+ + internal/sync from sync+ + internal/syscall/execenv from os+ + LD internal/syscall/unix from crypto/internal/sysrand+ + W internal/syscall/windows from crypto/internal/sysrand+ + W internal/syscall/windows/registry from mime+ + W internal/syscall/windows/sysdll from internal/syscall/windows+ + internal/testlog from os + internal/unsafeheader from internal/reflectlite+ io from archive/tar+ io/fs from archive/tar+ io/ioutil from github.com/mitchellh/go-ps+ @@ -303,13 +401,15 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep net from crypto/tls+ net/http from expvar+ net/http/cgi from tailscale.com/cmd/tailscale/cli - net/http/httptrace from github.com/tcnksm/go-httpstat+ + net/http/httptrace from golang.org/x/net/http2+ net/http/httputil from tailscale.com/client/web+ net/http/internal from net/http+ + net/http/internal/ascii from net/http+ + net/http/pprof from tailscale.com/tsweb net/netip from go4.org/netipx+ net/textproto from golang.org/x/net/http/httpguts+ net/url from crypto/x509+ - os from crypto/rand+ + os from crypto/internal/sysrand+ os/exec from github.com/coreos/go-iptables/iptables+ os/signal from tailscale.com/cmd/tailscale/cli os/user from archive/tar+ @@ -318,7 +418,10 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep reflect from archive/tar+ regexp from github.com/coreos/go-iptables/iptables+ regexp/syntax from regexp - runtime/debug from github.com/coder/websocket/internal/xsync+ + runtime from archive/tar+ + runtime/debug from tailscale.com+ + runtime/pprof from net/http/pprof + runtime/trace from net/http/pprof slices from tailscale.com/client/web+ sort from compress/flate+ strconv from archive/tar+ @@ -334,3 +437,5 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep unicode/utf16 from crypto/x509+ unicode/utf8 from bufio+ unique from net/netip + unsafe from bytes+ + weak from unique diff --git a/cmd/tailscale/tailscale.rc b/cmd/tailscale/tailscale.rc new file mode 100755 index 0000000000000..2cac53efbb243 --- /dev/null +++ b/cmd/tailscale/tailscale.rc @@ -0,0 +1,3 @@ +#!/bin/rc +# Plan 9 cmd/tailscale wrapper script to run cmd/tailscaled's embedded CLI. +TS_BE_CLI=1 tailscaled $* diff --git a/cmd/tailscaled/debug.go b/cmd/tailscaled/debug.go index b41604d29516e..2f469a0d189f7 100644 --- a/cmd/tailscaled/debug.go +++ b/cmd/tailscaled/debug.go @@ -27,6 +27,7 @@ import ( "tailscale.com/net/tshttpproxy" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/util/eventbus" ) var debugArgs struct { @@ -72,11 +73,14 @@ func debugMode(args []string) error { } func runMonitor(ctx context.Context, loop bool) error { + b := eventbus.New() + defer b.Close() + dump := func(st *netmon.State) { j, _ := json.MarshalIndent(st, "", " ") os.Stderr.Write(j) } - mon, err := netmon.New(log.Printf) + mon, err := netmon.New(b, log.Printf) if err != nil { return err } diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 6f71a88a93217..6de0ddc391794 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -10,7 +10,6 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L github.com/aws/aws-sdk-go-v2/aws/arn from tailscale.com/ipn/store/awsstore L github.com/aws/aws-sdk-go-v2/aws/defaults from github.com/aws/aws-sdk-go-v2/service/ssm+ L github.com/aws/aws-sdk-go-v2/aws/middleware from github.com/aws/aws-sdk-go-v2/aws/retry+ - L github.com/aws/aws-sdk-go-v2/aws/middleware/private/metrics from github.com/aws/aws-sdk-go-v2/aws/retry+ L github.com/aws/aws-sdk-go-v2/aws/protocol/query from github.com/aws/aws-sdk-go-v2/service/sts L github.com/aws/aws-sdk-go-v2/aws/protocol/restjson from github.com/aws/aws-sdk-go-v2/service/ssm+ L github.com/aws/aws-sdk-go-v2/aws/protocol/xml from github.com/aws/aws-sdk-go-v2/service/sts @@ -32,10 +31,12 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L github.com/aws/aws-sdk-go-v2/internal/auth from github.com/aws/aws-sdk-go-v2/aws/signer/v4+ L github.com/aws/aws-sdk-go-v2/internal/auth/smithy from github.com/aws/aws-sdk-go-v2/service/ssm+ L github.com/aws/aws-sdk-go-v2/internal/configsources from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/aws-sdk-go-v2/internal/context from github.com/aws/aws-sdk-go-v2/aws/retry+ L github.com/aws/aws-sdk-go-v2/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/ssm+ L github.com/aws/aws-sdk-go-v2/internal/endpoints/awsrulesfn from github.com/aws/aws-sdk-go-v2/service/ssm+ L github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 from github.com/aws/aws-sdk-go-v2/service/ssm/internal/endpoints+ L github.com/aws/aws-sdk-go-v2/internal/ini from github.com/aws/aws-sdk-go-v2/config + L github.com/aws/aws-sdk-go-v2/internal/middleware from github.com/aws/aws-sdk-go-v2/service/sso+ L github.com/aws/aws-sdk-go-v2/internal/rand from github.com/aws/aws-sdk-go-v2/aws+ L github.com/aws/aws-sdk-go-v2/internal/sdk from github.com/aws/aws-sdk-go-v2/aws+ L github.com/aws/aws-sdk-go-v2/internal/sdkio from github.com/aws/aws-sdk-go-v2/credentials/processcreds @@ -70,16 +71,17 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L github.com/aws/smithy-go/internal/sync/singleflight from github.com/aws/smithy-go/auth/bearer L github.com/aws/smithy-go/io from github.com/aws/aws-sdk-go-v2/feature/ec2/imds+ L github.com/aws/smithy-go/logging from github.com/aws/aws-sdk-go-v2/aws+ + L github.com/aws/smithy-go/metrics from github.com/aws/aws-sdk-go-v2/aws/retry+ L github.com/aws/smithy-go/middleware from github.com/aws/aws-sdk-go-v2/aws+ L github.com/aws/smithy-go/private/requestcompression from github.com/aws/aws-sdk-go-v2/config L github.com/aws/smithy-go/ptr from github.com/aws/aws-sdk-go-v2/aws+ L github.com/aws/smithy-go/rand from github.com/aws/aws-sdk-go-v2/aws/middleware+ L github.com/aws/smithy-go/time from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/smithy-go/tracing from github.com/aws/aws-sdk-go-v2/aws/middleware+ L github.com/aws/smithy-go/transport/http from github.com/aws/aws-sdk-go-v2/aws/middleware+ L github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http L github.com/aws/smithy-go/waiter from github.com/aws/aws-sdk-go-v2/service/ssm - github.com/bits-and-blooms/bitset from github.com/gaissmai/bart - github.com/coder/websocket from tailscale.com/control/controlhttp+ + github.com/coder/websocket from tailscale.com/util/eventbus github.com/coder/websocket/internal/errd from github.com/coder/websocket github.com/coder/websocket/internal/util from github.com/coder/websocket github.com/coder/websocket/internal/xsync from github.com/coder/websocket @@ -94,6 +96,8 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de đŸ’Ŗ github.com/djherbis/times from tailscale.com/drive/driveimpl github.com/fxamacker/cbor/v2 from tailscale.com/tka github.com/gaissmai/bart from tailscale.com/net/tstun+ + github.com/gaissmai/bart/internal/bitset from github.com/gaissmai/bart+ + github.com/gaissmai/bart/internal/sparse from github.com/gaissmai/bart github.com/go-json-experiment/json from tailscale.com/types/opt+ github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json/internal/jsonflags+ github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json/internal/jsonopts+ @@ -105,6 +109,13 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L đŸ’Ŗ github.com/godbus/dbus/v5 from tailscale.com/net/dns+ github.com/golang/groupcache/lru from tailscale.com/net/dnscache github.com/google/btree from gvisor.dev/gvisor/pkg/tcpip/header+ + github.com/google/go-tpm/legacy/tpm2 from github.com/google/go-tpm/tpm2/transport+ + github.com/google/go-tpm/tpm2 from tailscale.com/feature/tpm + github.com/google/go-tpm/tpm2/transport from github.com/google/go-tpm/tpm2/transport/linuxtpm+ + L github.com/google/go-tpm/tpm2/transport/linuxtpm from tailscale.com/feature/tpm + W github.com/google/go-tpm/tpm2/transport/windowstpm from tailscale.com/feature/tpm + github.com/google/go-tpm/tpmutil from github.com/google/go-tpm/legacy/tpm2+ + W đŸ’Ŗ github.com/google/go-tpm/tpmutil/tbs from github.com/google/go-tpm/legacy/tpm2+ L github.com/google/nftables from tailscale.com/util/linuxfw L đŸ’Ŗ github.com/google/nftables/alignedbuff from github.com/google/nftables/xt L đŸ’Ŗ github.com/google/nftables/binaryutil from github.com/google/nftables+ @@ -115,14 +126,14 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de github.com/gorilla/csrf from tailscale.com/client/web github.com/gorilla/securecookie from github.com/gorilla/csrf github.com/hdevalence/ed25519consensus from tailscale.com/clientupdate/distsign+ - L đŸ’Ŗ github.com/illarion/gonotify/v2 from tailscale.com/net/dns - L github.com/insomniacslk/dhcp/dhcpv4 from tailscale.com/net/tstun + L đŸ’Ŗ github.com/illarion/gonotify/v3 from tailscale.com/net/dns + L github.com/illarion/gonotify/v3/syscallf from github.com/illarion/gonotify/v3 + L github.com/insomniacslk/dhcp/dhcpv4 from tailscale.com/feature/tap L github.com/insomniacslk/dhcp/iana from github.com/insomniacslk/dhcp/dhcpv4 L github.com/insomniacslk/dhcp/interfaces from github.com/insomniacslk/dhcp/dhcpv4 L github.com/insomniacslk/dhcp/rfc1035label from github.com/insomniacslk/dhcp/dhcpv4 github.com/jellydator/ttlcache/v3 from tailscale.com/drive/driveimpl/compositedav L github.com/jmespath/go-jmespath from github.com/aws/aws-sdk-go-v2/service/ssm - L github.com/josharian/native from github.com/mdlayher/netlink+ L đŸ’Ŗ github.com/jsimonetti/rtnetlink from tailscale.com/net/netmon L github.com/jsimonetti/rtnetlink/internal/unix from github.com/jsimonetti/rtnetlink github.com/klauspost/compress from github.com/klauspost/compress/zstd @@ -132,7 +143,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de github.com/klauspost/compress/internal/snapref from github.com/klauspost/compress/zstd github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe github.com/klauspost/compress/zstd/internal/xxhash from github.com/klauspost/compress/zstd - github.com/kortschak/wol from tailscale.com/ipn/ipnlocal + github.com/kortschak/wol from tailscale.com/feature/wakeonlan LD github.com/kr/fs from github.com/pkg/sftp L github.com/mdlayher/genetlink from tailscale.com/net/tstun L đŸ’Ŗ github.com/mdlayher/netlink from github.com/google/nftables+ @@ -157,10 +168,6 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de W đŸ’Ŗ github.com/tailscale/go-winio/internal/socket from github.com/tailscale/go-winio W github.com/tailscale/go-winio/internal/stringbuffer from github.com/tailscale/go-winio/internal/fs W github.com/tailscale/go-winio/pkg/guid from github.com/tailscale/go-winio+ - github.com/tailscale/golang-x-crypto/acme from tailscale.com/ipn/ipnlocal - LD github.com/tailscale/golang-x-crypto/internal/poly1305 from github.com/tailscale/golang-x-crypto/ssh - LD github.com/tailscale/golang-x-crypto/ssh from tailscale.com/ipn/ipnlocal+ - LD github.com/tailscale/golang-x-crypto/ssh/internal/bcrypt_pbkdf from github.com/tailscale/golang-x-crypto/ssh github.com/tailscale/goupnp from github.com/tailscale/goupnp/dcps/internetgateway2+ github.com/tailscale/goupnp/dcps/internetgateway2 from tailscale.com/net/portmapper github.com/tailscale/goupnp/httpu from github.com/tailscale/goupnp+ @@ -185,13 +192,12 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de đŸ’Ŗ github.com/tailscale/wireguard-go/tun from github.com/tailscale/wireguard-go/device+ github.com/tailscale/xnet/webdav from tailscale.com/drive/driveimpl+ github.com/tailscale/xnet/webdav/internal/xml from github.com/tailscale/xnet/webdav - github.com/tcnksm/go-httpstat from tailscale.com/net/netcheck LD github.com/u-root/u-root/pkg/termios from tailscale.com/ssh/tailssh L github.com/u-root/uio/rand from github.com/insomniacslk/dhcp/dhcpv4 L github.com/u-root/uio/uio from github.com/insomniacslk/dhcp/dhcpv4+ L github.com/vishvananda/netns from github.com/tailscale/netlink+ github.com/x448/float16 from github.com/fxamacker/cbor/v2 - đŸ’Ŗ go4.org/mem from tailscale.com/client/tailscale+ + đŸ’Ŗ go4.org/mem from tailscale.com/client/local+ go4.org/netipx from github.com/tailscale/wf+ W đŸ’Ŗ golang.zx2c4.com/wintun from github.com/tailscale/wireguard-go/tun+ W đŸ’Ŗ golang.zx2c4.com/wireguard/windows/tunnel/winipcfg from tailscale.com/cmd/tailscaled+ @@ -215,13 +221,13 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de gvisor.dev/gvisor/pkg/tcpip/hash/jenkins from gvisor.dev/gvisor/pkg/tcpip/stack+ gvisor.dev/gvisor/pkg/tcpip/header from gvisor.dev/gvisor/pkg/tcpip/header/parse+ gvisor.dev/gvisor/pkg/tcpip/header/parse from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ - gvisor.dev/gvisor/pkg/tcpip/internal/tcp from gvisor.dev/gvisor/pkg/tcpip/stack+ + gvisor.dev/gvisor/pkg/tcpip/internal/tcp from gvisor.dev/gvisor/pkg/tcpip/transport/tcp gvisor.dev/gvisor/pkg/tcpip/network/hash from gvisor.dev/gvisor/pkg/tcpip/network/ipv4 gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ gvisor.dev/gvisor/pkg/tcpip/network/internal/ip from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ gvisor.dev/gvisor/pkg/tcpip/network/internal/multicast from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ - gvisor.dev/gvisor/pkg/tcpip/network/ipv4 from tailscale.com/net/tstun+ - gvisor.dev/gvisor/pkg/tcpip/network/ipv6 from tailscale.com/wgengine/netstack + gvisor.dev/gvisor/pkg/tcpip/network/ipv4 from tailscale.com/feature/tap+ + gvisor.dev/gvisor/pkg/tcpip/network/ipv6 from tailscale.com/wgengine/netstack+ gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+ gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+ đŸ’Ŗ gvisor.dev/gvisor/pkg/tcpip/stack from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ @@ -238,42 +244,58 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de gvisor.dev/gvisor/pkg/waiter from gvisor.dev/gvisor/pkg/context+ tailscale.com from tailscale.com/version tailscale.com/appc from tailscale.com/ipn/ipnlocal - tailscale.com/atomicfile from tailscale.com/ipn+ + đŸ’Ŗ tailscale.com/atomicfile from tailscale.com/ipn+ LD tailscale.com/chirp from tailscale.com/cmd/tailscaled - tailscale.com/client/tailscale from tailscale.com/client/web+ + tailscale.com/client/local from tailscale.com/client/tailscale+ + tailscale.com/client/tailscale from tailscale.com/derp tailscale.com/client/tailscale/apitype from tailscale.com/client/tailscale+ tailscale.com/client/web from tailscale.com/ipn/ipnlocal tailscale.com/clientupdate from tailscale.com/client/web+ LW tailscale.com/clientupdate/distsign from tailscale.com/clientupdate tailscale.com/cmd/tailscaled/childproc from tailscale.com/cmd/tailscaled+ + tailscale.com/cmd/tailscaled/tailscaledhooks from tailscale.com/cmd/tailscaled+ tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ tailscale.com/control/controlclient from tailscale.com/cmd/tailscaled+ tailscale.com/control/controlhttp from tailscale.com/control/controlclient + tailscale.com/control/controlhttp/controlhttpcommon from tailscale.com/control/controlhttp tailscale.com/control/controlknobs from tailscale.com/control/controlclient+ tailscale.com/derp from tailscale.com/derp/derphttp+ + tailscale.com/derp/derpconst from tailscale.com/derp+ tailscale.com/derp/derphttp from tailscale.com/cmd/tailscaled+ tailscale.com/disco from tailscale.com/derp+ tailscale.com/doctor from tailscale.com/ipn/ipnlocal tailscale.com/doctor/ethtool from tailscale.com/ipn/ipnlocal đŸ’Ŗ tailscale.com/doctor/permissions from tailscale.com/ipn/ipnlocal tailscale.com/doctor/routetable from tailscale.com/ipn/ipnlocal - tailscale.com/drive from tailscale.com/client/tailscale+ + tailscale.com/drive from tailscale.com/client/local+ tailscale.com/drive/driveimpl from tailscale.com/cmd/tailscaled tailscale.com/drive/driveimpl/compositedav from tailscale.com/drive/driveimpl tailscale.com/drive/driveimpl/dirfs from tailscale.com/drive/driveimpl+ tailscale.com/drive/driveimpl/shared from tailscale.com/drive/driveimpl+ - tailscale.com/envknob from tailscale.com/client/tailscale+ + tailscale.com/envknob from tailscale.com/client/local+ + tailscale.com/envknob/featureknob from tailscale.com/client/web+ + tailscale.com/feature from tailscale.com/feature/wakeonlan+ + tailscale.com/feature/capture from tailscale.com/feature/condregister + tailscale.com/feature/condregister from tailscale.com/cmd/tailscaled + tailscale.com/feature/relayserver from tailscale.com/feature/condregister + tailscale.com/feature/taildrop from tailscale.com/feature/condregister + L tailscale.com/feature/tap from tailscale.com/feature/condregister + tailscale.com/feature/tpm from tailscale.com/feature/condregister + tailscale.com/feature/wakeonlan from tailscale.com/feature/condregister tailscale.com/health from tailscale.com/control/controlclient+ tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal tailscale.com/hostinfo from tailscale.com/client/web+ tailscale.com/internal/noiseconn from tailscale.com/control/controlclient - tailscale.com/ipn from tailscale.com/client/tailscale+ + tailscale.com/ipn from tailscale.com/client/local+ + W tailscale.com/ipn/auditlog from tailscale.com/cmd/tailscaled tailscale.com/ipn/conffile from tailscale.com/cmd/tailscaled+ + W đŸ’Ŗ tailscale.com/ipn/desktop from tailscale.com/cmd/tailscaled đŸ’Ŗ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnlocal+ + tailscale.com/ipn/ipnext from tailscale.com/ipn/auditlog+ tailscale.com/ipn/ipnlocal from tailscale.com/cmd/tailscaled+ tailscale.com/ipn/ipnserver from tailscale.com/cmd/tailscaled - tailscale.com/ipn/ipnstate from tailscale.com/client/tailscale+ - tailscale.com/ipn/localapi from tailscale.com/ipn/ipnserver + tailscale.com/ipn/ipnstate from tailscale.com/client/local+ + tailscale.com/ipn/localapi from tailscale.com/ipn/ipnserver+ tailscale.com/ipn/policy from tailscale.com/ipn/ipnlocal tailscale.com/ipn/store from tailscale.com/cmd/tailscaled+ L tailscale.com/ipn/store/awsstore from tailscale.com/ipn/store @@ -281,7 +303,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/ipn/store/mem from tailscale.com/ipn/ipnlocal+ L tailscale.com/kube/kubeapi from tailscale.com/ipn/store/kubestore+ L tailscale.com/kube/kubeclient from tailscale.com/ipn/store/kubestore - tailscale.com/kube/kubetypes from tailscale.com/envknob + tailscale.com/kube/kubetypes from tailscale.com/envknob+ tailscale.com/licenses from tailscale.com/client/web tailscale.com/log/filelogger from tailscale.com/logpolicy tailscale.com/log/sockstatlog from tailscale.com/ipn/ipnlocal @@ -290,13 +312,14 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/logtail/backoff from tailscale.com/cmd/tailscaled+ tailscale.com/logtail/filch from tailscale.com/log/sockstatlog+ tailscale.com/metrics from tailscale.com/derp+ + tailscale.com/net/bakedroots from tailscale.com/net/tlsdial+ tailscale.com/net/captivedetection from tailscale.com/ipn/ipnlocal+ tailscale.com/net/connstats from tailscale.com/net/tstun+ tailscale.com/net/dns from tailscale.com/cmd/tailscaled+ tailscale.com/net/dns/publicdns from tailscale.com/net/dns+ tailscale.com/net/dns/recursive from tailscale.com/net/dnsfallback tailscale.com/net/dns/resolvconffile from tailscale.com/net/dns+ - tailscale.com/net/dns/resolver from tailscale.com/net/dns + tailscale.com/net/dns/resolver from tailscale.com/net/dns+ tailscale.com/net/dnscache from tailscale.com/control/controlclient+ tailscale.com/net/dnsfallback from tailscale.com/cmd/tailscaled+ tailscale.com/net/flowtrack from tailscale.com/net/packet+ @@ -309,7 +332,8 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de đŸ’Ŗ tailscale.com/net/netmon from tailscale.com/cmd/tailscaled+ đŸ’Ŗ tailscale.com/net/netns from tailscale.com/cmd/tailscaled+ W đŸ’Ŗ tailscale.com/net/netstat from tailscale.com/portlist - tailscale.com/net/netutil from tailscale.com/client/tailscale+ + tailscale.com/net/netutil from tailscale.com/client/local+ + tailscale.com/net/netx from tailscale.com/control/controlclient+ tailscale.com/net/packet from tailscale.com/net/connstats+ tailscale.com/net/packet/checksum from tailscale.com/net/tstun tailscale.com/net/ping from tailscale.com/net/netcheck+ @@ -321,40 +345,45 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/net/stun from tailscale.com/ipn/localapi+ L tailscale.com/net/tcpinfo from tailscale.com/derp tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial tailscale.com/net/tsaddr from tailscale.com/client/web+ tailscale.com/net/tsdial from tailscale.com/cmd/tailscaled+ đŸ’Ŗ tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ tailscale.com/net/tstun from tailscale.com/cmd/tailscaled+ - tailscale.com/net/wsconn from tailscale.com/control/controlhttp+ + tailscale.com/net/udprelay from tailscale.com/feature/relayserver + tailscale.com/net/udprelay/endpoint from tailscale.com/feature/relayserver+ tailscale.com/omit from tailscale.com/ipn/conffile - tailscale.com/paths from tailscale.com/client/tailscale+ + tailscale.com/paths from tailscale.com/client/local+ đŸ’Ŗ tailscale.com/portlist from tailscale.com/ipn/ipnlocal tailscale.com/posture from tailscale.com/ipn/ipnlocal tailscale.com/proxymap from tailscale.com/tsd+ - đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/tailscale+ + đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/local+ LD tailscale.com/sessionrecording from tailscale.com/ssh/tailssh LD đŸ’Ŗ tailscale.com/ssh/tailssh from tailscale.com/cmd/tailscaled tailscale.com/syncs from tailscale.com/cmd/tailscaled+ - tailscale.com/tailcfg from tailscale.com/client/tailscale+ - tailscale.com/taildrop from tailscale.com/ipn/ipnlocal+ + tailscale.com/tailcfg from tailscale.com/client/local+ + tailscale.com/tempfork/acme from tailscale.com/ipn/ipnlocal LD tailscale.com/tempfork/gliderlabs/ssh from tailscale.com/ssh/tailssh tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock - tailscale.com/tka from tailscale.com/client/tailscale+ + tailscale.com/tempfork/httprec from tailscale.com/control/controlclient + tailscale.com/tka from tailscale.com/client/local+ tailscale.com/tsconst from tailscale.com/net/netmon+ tailscale.com/tsd from tailscale.com/cmd/tailscaled+ tailscale.com/tstime from tailscale.com/control/controlclient+ tailscale.com/tstime/mono from tailscale.com/net/tstun+ tailscale.com/tstime/rate from tailscale.com/derp+ + tailscale.com/tsweb from tailscale.com/util/eventbus tailscale.com/tsweb/varz from tailscale.com/cmd/tailscaled+ tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal tailscale.com/types/dnstype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/empty from tailscale.com/ipn+ tailscale.com/types/flagtype from tailscale.com/cmd/tailscaled tailscale.com/types/ipproto from tailscale.com/net/flowtrack+ - tailscale.com/types/key from tailscale.com/client/tailscale+ + tailscale.com/types/key from tailscale.com/client/local+ tailscale.com/types/lazy from tailscale.com/ipn/ipnlocal+ tailscale.com/types/logger from tailscale.com/appc+ tailscale.com/types/logid from tailscale.com/cmd/tailscaled+ + tailscale.com/types/mapx from tailscale.com/ipn/ipnext tailscale.com/types/netlogtype from tailscale.com/net/connstats+ tailscale.com/types/netmap from tailscale.com/control/controlclient+ tailscale.com/types/nettype from tailscale.com/ipn/localapi+ @@ -362,6 +391,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/types/persist from tailscale.com/control/controlclient+ tailscale.com/types/preftype from tailscale.com/ipn+ tailscale.com/types/ptr from tailscale.com/control/controlclient+ + tailscale.com/types/result from tailscale.com/util/lineiter tailscale.com/types/structs from tailscale.com/control/controlclient+ tailscale.com/types/tkatype from tailscale.com/tka+ tailscale.com/types/views from tailscale.com/ipn/ipnlocal+ @@ -373,13 +403,14 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de đŸ’Ŗ tailscale.com/util/deephash from tailscale.com/ipn/ipnlocal+ L đŸ’Ŗ tailscale.com/util/dirwalk from tailscale.com/metrics+ tailscale.com/util/dnsname from tailscale.com/appc+ + tailscale.com/util/eventbus from tailscale.com/tsd+ tailscale.com/util/execqueue from tailscale.com/control/controlclient+ tailscale.com/util/goroutines from tailscale.com/ipn/ipnlocal tailscale.com/util/groupmember from tailscale.com/client/web+ đŸ’Ŗ tailscale.com/util/hashx from tailscale.com/util/deephash - tailscale.com/util/httphdr from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/httphdr from tailscale.com/feature/taildrop tailscale.com/util/httpm from tailscale.com/client/tailscale+ - tailscale.com/util/lineread from tailscale.com/hostinfo+ + tailscale.com/util/lineiter from tailscale.com/hostinfo+ L tailscale.com/util/linuxfw from tailscale.com/net/netns+ tailscale.com/util/mak from tailscale.com/control/controlclient+ tailscale.com/util/multierr from tailscale.com/cmd/tailscaled+ @@ -389,7 +420,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de W đŸ’Ŗ tailscale.com/util/osdiag/internal/wsc from tailscale.com/util/osdiag tailscale.com/util/osshare from tailscale.com/cmd/tailscaled+ tailscale.com/util/osuser from tailscale.com/ipn/ipnlocal+ - tailscale.com/util/progresstracking from tailscale.com/ipn/localapi + tailscale.com/util/progresstracking from tailscale.com/feature/taildrop tailscale.com/util/race from tailscale.com/net/dns/resolver tailscale.com/util/racebuild from tailscale.com/logpolicy tailscale.com/util/rands from tailscale.com/ipn/ipnlocal+ @@ -399,18 +430,20 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/util/slicesx from tailscale.com/net/dns/recursive+ tailscale.com/util/syspolicy from tailscale.com/cmd/tailscaled+ tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy/setting+ - tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy - tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy/internal/metrics+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/rsop from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/setting from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock tailscale.com/util/systemd from tailscale.com/control/controlclient+ tailscale.com/util/testenv from tailscale.com/ipn/ipnlocal+ tailscale.com/util/truncate from tailscale.com/logtail - tailscale.com/util/uniq from tailscale.com/ipn/ipnlocal+ tailscale.com/util/usermetric from tailscale.com/health+ tailscale.com/util/vizerror from tailscale.com/tailcfg+ đŸ’Ŗ tailscale.com/util/winutil from tailscale.com/clientupdate+ W đŸ’Ŗ tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+ - W đŸ’Ŗ tailscale.com/util/winutil/gp from tailscale.com/net/dns + W đŸ’Ŗ tailscale.com/util/winutil/gp from tailscale.com/net/dns+ W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal W đŸ’Ŗ tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ tailscale.com/util/zstdframe from tailscale.com/control/controlclient+ @@ -418,7 +451,6 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/version/distro from tailscale.com/client/web+ W tailscale.com/wf from tailscale.com/cmd/tailscaled tailscale.com/wgengine from tailscale.com/cmd/tailscaled+ - tailscale.com/wgengine/capture from tailscale.com/ipn/ipnlocal+ tailscale.com/wgengine/filter from tailscale.com/control/controlclient+ tailscale.com/wgengine/filter/filtertype from tailscale.com/types/netmap+ đŸ’Ŗ tailscale.com/wgengine/magicsock from tailscale.com/ipn/ipnlocal+ @@ -434,21 +466,23 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de golang.org/x/crypto/argon2 from tailscale.com/tka golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ golang.org/x/crypto/blake2s from github.com/tailscale/wireguard-go/device+ - LD golang.org/x/crypto/blowfish from github.com/tailscale/golang-x-crypto/ssh/internal/bcrypt_pbkdf+ + LD golang.org/x/crypto/blowfish from golang.org/x/crypto/ssh/internal/bcrypt_pbkdf golang.org/x/crypto/chacha20 from golang.org/x/crypto/chacha20poly1305+ golang.org/x/crypto/chacha20poly1305 from crypto/tls+ golang.org/x/crypto/cryptobyte from crypto/ecdsa+ golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ - golang.org/x/crypto/curve25519 from github.com/tailscale/golang-x-crypto/ssh+ - golang.org/x/crypto/hkdf from crypto/tls+ + golang.org/x/crypto/curve25519 from golang.org/x/crypto/ssh+ + golang.org/x/crypto/hkdf from tailscale.com/control/controlbase + golang.org/x/crypto/internal/alias from golang.org/x/crypto/chacha20+ + golang.org/x/crypto/internal/poly1305 from golang.org/x/crypto/chacha20poly1305+ golang.org/x/crypto/nacl/box from tailscale.com/types/key golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box golang.org/x/crypto/poly1305 from github.com/tailscale/wireguard-go/device golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ - golang.org/x/crypto/sha3 from crypto/internal/mlkem768+ LD golang.org/x/crypto/ssh from github.com/pkg/sftp+ + LD golang.org/x/crypto/ssh/internal/bcrypt_pbkdf from golang.org/x/crypto/ssh golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+ - golang.org/x/exp/maps from tailscale.com/appc+ + golang.org/x/exp/maps from tailscale.com/ipn/store/mem+ golang.org/x/net/bpf from github.com/mdlayher/genetlink+ golang.org/x/net/dns/dnsmessage from net+ golang.org/x/net/http/httpguts from golang.org/x/net/http2+ @@ -458,13 +492,17 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de golang.org/x/net/http2/hpack from golang.org/x/net/http2+ golang.org/x/net/icmp from tailscale.com/net/ping+ golang.org/x/net/idna from golang.org/x/net/http/httpguts+ + golang.org/x/net/internal/httpcommon from golang.org/x/net/http2 + golang.org/x/net/internal/iana from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socks from golang.org/x/net/proxy golang.org/x/net/ipv4 from github.com/miekg/dns+ golang.org/x/net/ipv6 from github.com/miekg/dns+ golang.org/x/net/proxy from tailscale.com/net/netns D golang.org/x/net/route from net+ golang.org/x/sync/errgroup from github.com/mdlayher/socket+ golang.org/x/sync/singleflight from github.com/jellydator/ttlcache/v3 - golang.org/x/sys/cpu from github.com/josharian/native+ + golang.org/x/sys/cpu from github.com/tailscale/certstore+ LD golang.org/x/sys/unix from github.com/google/nftables+ W golang.org/x/sys/windows from github.com/dblohm7/wingoes+ W golang.org/x/sys/windows/registry from github.com/dblohm7/wingoes+ @@ -488,7 +526,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de container/list from crypto/tls+ context from crypto/tls+ crypto from crypto/ecdh+ - crypto/aes from crypto/ecdsa+ + crypto/aes from crypto/internal/hpke+ crypto/cipher from crypto/aes+ crypto/des from crypto/tls+ crypto/dsa from crypto/x509+ @@ -497,21 +535,61 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de crypto/ed25519 from crypto/tls+ crypto/elliptic from crypto/ecdsa+ crypto/hmac from crypto/tls+ + crypto/internal/boring from crypto/aes+ + crypto/internal/boring/bbig from crypto/ecdsa+ + crypto/internal/boring/sig from crypto/internal/boring + crypto/internal/entropy from crypto/internal/fips140/drbg + crypto/internal/fips140 from crypto/internal/fips140/aes+ + crypto/internal/fips140/aes from crypto/aes+ + crypto/internal/fips140/aes/gcm from crypto/cipher+ + crypto/internal/fips140/alias from crypto/cipher+ + crypto/internal/fips140/bigmod from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/check from crypto/internal/fips140/aes+ + crypto/internal/fips140/drbg from crypto/internal/fips140/aes/gcm+ + crypto/internal/fips140/ecdh from crypto/ecdh + crypto/internal/fips140/ecdsa from crypto/ecdsa + crypto/internal/fips140/ed25519 from crypto/ed25519 + crypto/internal/fips140/edwards25519 from crypto/internal/fips140/ed25519 + crypto/internal/fips140/edwards25519/field from crypto/ecdh+ + crypto/internal/fips140/hkdf from crypto/internal/fips140/tls13+ + crypto/internal/fips140/hmac from crypto/hmac+ + crypto/internal/fips140/mlkem from crypto/tls + crypto/internal/fips140/nistec from crypto/elliptic+ + crypto/internal/fips140/nistec/fiat from crypto/internal/fips140/nistec + crypto/internal/fips140/rsa from crypto/rsa + crypto/internal/fips140/sha256 from crypto/internal/fips140/check+ + crypto/internal/fips140/sha3 from crypto/internal/fips140/hmac+ + crypto/internal/fips140/sha512 from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/subtle from crypto/internal/fips140/aes+ + crypto/internal/fips140/tls12 from crypto/tls + crypto/internal/fips140/tls13 from crypto/tls + crypto/internal/fips140deps/byteorder from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/cpu from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/godebug from crypto/internal/fips140+ + crypto/internal/fips140hash from crypto/ecdsa+ + crypto/internal/fips140only from crypto/cipher+ + crypto/internal/hpke from crypto/tls + crypto/internal/impl from crypto/internal/fips140/aes+ + crypto/internal/randutil from crypto/dsa+ + crypto/internal/sysrand from crypto/internal/entropy+ crypto/md5 from crypto/tls+ crypto/rand from crypto/ed25519+ crypto/rc4 from crypto/tls+ crypto/rsa from crypto/tls+ crypto/sha1 from crypto/tls+ crypto/sha256 from crypto/tls+ + crypto/sha3 from crypto/internal/fips140hash crypto/sha512 from crypto/ecdsa+ - crypto/subtle from crypto/aes+ + crypto/subtle from crypto/cipher+ crypto/tls from github.com/aws/aws-sdk-go-v2/aws/transport/http+ + crypto/tls/internal/fips140tls from crypto/tls crypto/x509 from crypto/tls+ + D crypto/x509/internal/macos from crypto/x509 crypto/x509/pkix from crypto/x509+ DW database/sql/driver from github.com/google/uuid W debug/dwarf from debug/pe W debug/pe from github.com/dblohm7/wingoes/pe - embed from crypto/internal/nistec+ + embed from github.com/tailscale/web-client-prebuilt+ encoding from encoding/gob+ encoding/asn1 from crypto/x509+ encoding/base32 from github.com/fxamacker/cbor/v2+ @@ -524,14 +602,55 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de encoding/xml from github.com/aws/aws-sdk-go-v2/aws/protocol/xml+ errors from archive/tar+ expvar from tailscale.com/derp+ - flag from net/http/httptest+ + flag from tailscale.com/cmd/tailscaled+ fmt from archive/tar+ hash from compress/zlib+ hash/adler32 from compress/zlib+ hash/crc32 from compress/gzip+ hash/maphash from go4.org/mem html from html/template+ - html/template from github.com/gorilla/csrf + html/template from github.com/gorilla/csrf+ + internal/abi from crypto/x509/internal/macos+ + internal/asan from internal/runtime/maps+ + internal/bisect from internal/godebug + internal/bytealg from bytes+ + internal/byteorder from crypto/cipher+ + internal/chacha8rand from math/rand/v2+ + internal/coverage/rtcov from runtime + internal/cpu from crypto/internal/fips140deps/cpu+ + internal/filepathlite from os+ + internal/fmtsort from fmt+ + internal/goarch from crypto/internal/fips140deps/cpu+ + internal/godebug from archive/tar+ + internal/godebugs from internal/godebug+ + internal/goexperiment from hash/maphash+ + internal/goos from crypto/x509+ + internal/itoa from internal/poll+ + internal/msan from internal/runtime/maps+ + internal/nettrace from net+ + internal/oserror from io/fs+ + internal/poll from net+ + internal/profile from net/http/pprof + internal/profilerecord from runtime+ + internal/race from internal/poll+ + internal/reflectlite from context+ + internal/runtime/atomic from internal/runtime/exithook+ + internal/runtime/exithook from runtime + internal/runtime/maps from reflect+ + internal/runtime/math from internal/runtime/maps+ + internal/runtime/sys from crypto/subtle+ + L internal/runtime/syscall from runtime+ + internal/saferio from debug/pe+ + internal/singleflight from net + internal/stringslite from embed+ + internal/sync from sync+ + internal/syscall/execenv from os+ + LD internal/syscall/unix from crypto/internal/sysrand+ + W internal/syscall/windows from crypto/internal/sysrand+ + W internal/syscall/windows/registry from mime+ + W internal/syscall/windows/sysdll from internal/syscall/windows+ + internal/testlog from os + internal/unsafeheader from internal/reflectlite+ io from archive/tar+ io/fs from archive/tar+ io/ioutil from github.com/aws/aws-sdk-go-v2/aws/protocol/query+ @@ -550,15 +669,15 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de mime/quotedprintable from mime/multipart net from crypto/tls+ net/http from expvar+ - net/http/httptest from tailscale.com/control/controlclient - net/http/httptrace from github.com/tcnksm/go-httpstat+ + net/http/httptrace from github.com/prometheus-community/pro-bing+ net/http/httputil from github.com/aws/smithy-go/transport/http+ net/http/internal from net/http+ + net/http/internal/ascii from net/http+ net/http/pprof from tailscale.com/cmd/tailscaled+ net/netip from github.com/tailscale/wireguard-go/conn+ net/textproto from github.com/aws/aws-sdk-go-v2/aws/signer/v4+ net/url from crypto/x509+ - os from crypto/rand+ + os from crypto/internal/sysrand+ os/exec from github.com/aws/aws-sdk-go-v2/credentials/processcreds+ os/signal from tailscale.com/cmd/tailscaled os/user from archive/tar+ @@ -567,6 +686,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de reflect from archive/tar+ regexp from github.com/aws/aws-sdk-go-v2/internal/endpoints/awsrulesfn+ regexp/syntax from regexp + runtime from archive/tar+ runtime/debug from github.com/aws/aws-sdk-go-v2/internal/sync/singleflight+ runtime/pprof from net/http/pprof+ runtime/trace from net/http/pprof @@ -585,3 +705,5 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de unicode/utf16 from crypto/x509+ unicode/utf8 from bufio+ unique from net/netip + unsafe from bytes+ + weak from unique diff --git a/cmd/tailscaled/deps_test.go b/cmd/tailscaled/deps_test.go new file mode 100644 index 0000000000000..7f06abc6c5ba1 --- /dev/null +++ b/cmd/tailscaled/deps_test.go @@ -0,0 +1,29 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "testing" + + "tailscale.com/tstest/deptest" +) + +func TestOmitSSH(t *testing.T) { + const msg = "unexpected with ts_omit_ssh" + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_ssh", + BadDeps: map[string]string{ + "tailscale.com/ssh/tailssh": msg, + "tailscale.com/sessionrecording": msg, + "github.com/anmitsu/go-shlex": msg, + "github.com/creack/pty": msg, + "github.com/kr/fs": msg, + "github.com/pkg/sftp": msg, + "github.com/u-root/u-root/pkg/termios": msg, + "tempfork/gliderlabs/ssh": msg, + }, + }.Check(t) +} diff --git a/cmd/tailscaled/install_windows.go b/cmd/tailscaled/install_windows.go index c36418642d2b4..c667539b04d4f 100644 --- a/cmd/tailscaled/install_windows.go +++ b/cmd/tailscaled/install_windows.go @@ -15,9 +15,9 @@ import ( "golang.org/x/sys/windows" "golang.org/x/sys/windows/svc" "golang.org/x/sys/windows/svc/mgr" + "tailscale.com/cmd/tailscaled/tailscaledhooks" "tailscale.com/logtail/backoff" "tailscale.com/types/logger" - "tailscale.com/util/osshare" ) func init() { @@ -81,8 +81,9 @@ func installSystemDaemonWindows(args []string) (err error) { } func uninstallSystemDaemonWindows(args []string) (ret error) { - // Remove file sharing from Windows shell (noop in non-windows) - osshare.SetFileSharingEnabled(false, logger.Discard) + for _, f := range tailscaledhooks.UninstallSystemDaemonWindows { + f() + } m, err := mgr.Connect() if err != nil { diff --git a/cmd/tailscaled/ssh.go b/cmd/tailscaled/ssh.go index f7b0b367ead57..59a1ddd0df461 100644 --- a/cmd/tailscaled/ssh.go +++ b/cmd/tailscaled/ssh.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || darwin || freebsd || openbsd +//go:build (linux || darwin || freebsd || openbsd || plan9) && !ts_omit_ssh package main diff --git a/cmd/tailscaled/tailscaled.go b/cmd/tailscaled/tailscaled.go index 2831b4061973d..520939aa5a7bf 100644 --- a/cmd/tailscaled/tailscaled.go +++ b/cmd/tailscaled/tailscaled.go @@ -30,11 +30,12 @@ import ( "syscall" "time" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/cmd/tailscaled/childproc" "tailscale.com/control/controlclient" "tailscale.com/drive/driveimpl" "tailscale.com/envknob" + _ "tailscale.com/feature/condregister" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/conffile" @@ -81,7 +82,9 @@ func defaultTunName() string { // "utun" is recognized by wireguard-go/tun/tun_darwin.go // as a magic value that uses/creates any free number. return "utun" - case "plan9", "aix": + case "plan9": + return "auto" + case "aix", "solaris", "illumos": return "userspace-networking" case "linux": switch distro.Get() { @@ -150,10 +153,34 @@ var subCommands = map[string]*func([]string) error{ "serve-taildrive": &serveDriveFunc, } -var beCLI func() // non-nil if CLI is linked in +var beCLI func() // non-nil if CLI is linked in with the "ts_include_cli" build tag + +// shouldRunCLI reports whether we should run the Tailscale CLI (cmd/tailscale) +// instead of the daemon (cmd/tailscaled) in the case when the two are linked +// together into one binary for space savings reasons. +func shouldRunCLI() bool { + if beCLI == nil { + // Not linked in with the "ts_include_cli" build tag. + return false + } + if len(os.Args) > 0 && filepath.Base(os.Args[0]) == "tailscale" { + // nolint:misspell + // The binary was named (or hardlinked) as "tailscale". + return true + } + if envknob.Bool("TS_BE_CLI") { + // The environment variable was set to force it. + return true + } + return false +} func main() { envknob.PanicIfAnyEnvCheckedInInit() + if shouldRunCLI() { + beCLI() + return + } envknob.ApplyDiskConfig() applyIntegrationTestEnvKnob() @@ -174,9 +201,8 @@ func main() { flag.BoolVar(&args.disableLogs, "no-logs-no-support", false, "disable log uploads; this also disables any technical support") flag.StringVar(&args.confFile, "config", "", "path to config file, or 'vm:user-data' to use the VM's user-data (EC2)") - if len(os.Args) > 0 && filepath.Base(os.Args[0]) == "tailscale" && beCLI != nil { - beCLI() - return + if runtime.GOOS == "plan9" && os.Getenv("_NETSHELL_CHILD_") != "" { + os.Args = []string{"tailscaled", "be-child", "plan9-netshell"} } if len(os.Args) > 1 { @@ -229,7 +255,18 @@ func main() { // Only apply a default statepath when neither have been provided, so that a // user may specify only --statedir if they wish. if args.statepath == "" && args.statedir == "" { - args.statepath = paths.DefaultTailscaledStateFile() + if runtime.GOOS == "plan9" { + home, err := os.UserHomeDir() + if err != nil { + log.Fatalf("failed to get home directory: %v", err) + } + args.statedir = filepath.Join(home, "tailscale-state") + if err := os.MkdirAll(args.statedir, 0700); err != nil { + log.Fatalf("failed to create state directory: %v", err) + } + } else { + args.statepath = paths.DefaultTailscaledStateFile() + } } if args.disableLogs { @@ -338,7 +375,9 @@ var debugMux *http.ServeMux func run() (err error) { var logf logger.Logf = log.Printf - sys := new(tsd.System) + // Install an event bus as early as possible, so that it's + // available universally when setting up everything else. + sys := tsd.NewSystem() // Parse config, if specified, to fail early if it's invalid. var conf *conffile.Config @@ -353,9 +392,7 @@ func run() (err error) { var netMon *netmon.Monitor isWinSvc := isWindowsService() if !isWinSvc { - netMon, err = netmon.New(func(format string, args ...any) { - logf(format, args...) - }) + netMon, err = netmon.New(sys.Bus.Get(), logf) if err != nil { return fmt.Errorf("netmon.New: %w", err) } @@ -537,7 +574,7 @@ func getLocalBackend(ctx context.Context, logf logger.Logf, logID logid.PublicID if ms, ok := sys.MagicSock.GetOK(); ok { debugMux.HandleFunc("/debug/magicsock", ms.ServeHTTPDebug) } - go runDebugServer(debugMux, args.debug) + go runDebugServer(logf, debugMux, args.debug) } ns, err := newNetstack(logf, sys) @@ -620,11 +657,10 @@ func getLocalBackend(ctx context.Context, logf logger.Logf, logID logid.PublicID if root := lb.TailscaleVarRoot(); root != "" { dnsfallback.SetCachePath(filepath.Join(root, "derpmap.cached.json"), logf) } - lb.ConfigureWebClient(&tailscale.LocalClient{ + lb.ConfigureWebClient(&local.Client{ Socket: args.socketpath, UseSocketOnly: args.socketpath != paths.DefaultTailscaledSocket(), }) - configureTaildrop(logf, lb) if err := ns.Start(lb); err != nil { log.Fatalf("failed to start netstack: %v", err) } @@ -665,7 +701,7 @@ func handleSubnetsInNetstack() bool { return true } switch runtime.GOOS { - case "windows", "darwin", "freebsd", "openbsd": + case "windows", "darwin", "freebsd", "openbsd", "solaris", "illumos": // Enable on Windows and tailscaled-on-macOS (this doesn't // affect the GUI clients), and on FreeBSD. return true @@ -730,6 +766,12 @@ func tryEngine(logf logger.Logf, sys *tsd.System, name string) (onlyNetstack boo return false, err } + if runtime.GOOS == "plan9" { + // TODO(bradfitz): why don't we do this on all platforms? + // We should. Doing it just on plan9 for now conservatively. + sys.NetMon.Get().SetTailscaleInterfaceName(devName) + } + r, err := router.New(logf, dev, sys.NetMon.Get(), sys.HealthTracker()) if err != nil { dev.Close() @@ -777,18 +819,25 @@ func servePrometheusMetrics(w http.ResponseWriter, r *http.Request) { clientmetric.WritePrometheusExpositionFormat(w) } -func runDebugServer(mux *http.ServeMux, addr string) { +func runDebugServer(logf logger.Logf, mux *http.ServeMux, addr string) { + ln, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalf("debug server: %v", err) + } + if strings.HasSuffix(addr, ":0") { + // Log kernel-selected port number so integration tests + // can find it portably. + logf("DEBUG-ADDR=%v", ln.Addr()) + } srv := &http.Server{ - Addr: addr, Handler: mux, } - if err := srv.ListenAndServe(); err != nil { + if err := srv.Serve(ln); err != nil { log.Fatal(err) } } func newNetstack(logf logger.Logf, sys *tsd.System) (*netstack.Impl, error) { - tfs, _ := sys.DriveForLocal.GetOK() ret, err := netstack.Create(logf, sys.Tun.Get(), sys.Engine.Get(), @@ -796,7 +845,6 @@ func newNetstack(logf logger.Logf, sys *tsd.System) (*netstack.Impl, error) { sys.Dialer.Get(), sys.DNSManager.Get(), sys.ProxyMapper(), - tfs, ) if err != nil { return nil, err diff --git a/cmd/tailscaled/tailscaled_test.go b/cmd/tailscaled/tailscaled_test.go index 5045468d6543a..c50c237591170 100644 --- a/cmd/tailscaled/tailscaled_test.go +++ b/cmd/tailscaled/tailscaled_test.go @@ -22,6 +22,8 @@ func TestDeps(t *testing.T) { BadDeps: map[string]string{ "testing": "do not use testing package in production code", "gvisor.dev/gvisor/pkg/hostarch": "will crash on non-4K page sizes; see https://github.com/tailscale/tailscale/issues/8658", + "net/http/httptest": "do not use httptest in production code", + "net/http/internal/testcert": "do not use httptest in production code", }, }.Check(t) @@ -29,8 +31,10 @@ func TestDeps(t *testing.T) { GOOS: "linux", GOARCH: "arm64", BadDeps: map[string]string{ - "testing": "do not use testing package in production code", - "gvisor.dev/gvisor/pkg/hostarch": "will crash on non-4K page sizes; see https://github.com/tailscale/tailscale/issues/8658", + "testing": "do not use testing package in production code", + "gvisor.dev/gvisor/pkg/hostarch": "will crash on non-4K page sizes; see https://github.com/tailscale/tailscale/issues/8658", + "google.golang.org/protobuf/proto": "unexpected", + "github.com/prometheus/client_golang/prometheus": "use tailscale.com/metrics in tailscaled", }, }.Check(t) } diff --git a/cmd/tailscaled/tailscaled_windows.go b/cmd/tailscaled/tailscaled_windows.go index 35c878f38ece3..1b50688922968 100644 --- a/cmd/tailscaled/tailscaled_windows.go +++ b/cmd/tailscaled/tailscaled_windows.go @@ -44,6 +44,8 @@ import ( "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "tailscale.com/drive/driveimpl" "tailscale.com/envknob" + _ "tailscale.com/ipn/auditlog" + _ "tailscale.com/ipn/desktop" "tailscale.com/logpolicy" "tailscale.com/logtail/backoff" "tailscale.com/net/dns" @@ -55,6 +57,7 @@ import ( "tailscale.com/util/osdiag" "tailscale.com/util/syspolicy" "tailscale.com/util/winutil" + "tailscale.com/util/winutil/gp" "tailscale.com/version" "tailscale.com/wf" ) @@ -70,6 +73,22 @@ func init() { } } +// permitPolicyLocks is a function to be called to lift the restriction on acquiring +// [gp.PolicyLock]s once the service is running. +// It is safe to be called multiple times. +var permitPolicyLocks = func() {} + +func init() { + if isWindowsService() { + // We prevent [gp.PolicyLock]s from being acquired until the service enters the running state. + // Otherwise, if tailscaled starts due to a GPSI policy installing Tailscale, it may deadlock + // while waiting for the write counterpart of the GP lock to be released by Group Policy, + // which is itself waiting for the installation to complete and tailscaled to start. + // See tailscale/tailscale#14416 for more information. + permitPolicyLocks = gp.RestrictPolicyLocks() + } +} + const serviceName = "Tailscale" // Application-defined command codes between 128 and 255 @@ -109,13 +128,13 @@ func tstunNewWithWindowsRetries(logf logger.Logf, tunName string) (_ tun.Device, } } -func isWindowsService() bool { +var isWindowsService = sync.OnceValue(func() bool { v, err := svc.IsWindowsService() if err != nil { log.Fatalf("svc.IsWindowsService failed: %v", err) } return v -} +}) // syslogf is a logger function that writes to the Windows event log (ie, the // one that you see in the Windows Event Viewer). tailscaled may optionally @@ -134,14 +153,13 @@ func runWindowsService(pol *logpolicy.Policy) error { logger.Logf(log.Printf).JSON(1, "SupportInfo", osdiag.SupportInfo(osdiag.LogSupportInfoReasonStartup)) }() - if logSCMInteractions, _ := syspolicy.GetBoolean(syspolicy.LogSCMInteractions, false); logSCMInteractions { - syslog, err := eventlog.Open(serviceName) - if err == nil { - syslogf = func(format string, args ...any) { + if syslog, err := eventlog.Open(serviceName); err == nil { + syslogf = func(format string, args ...any) { + if logSCMInteractions, _ := syspolicy.GetBoolean(syspolicy.LogSCMInteractions, false); logSCMInteractions { syslog.Info(0, fmt.Sprintf(format, args...)) } - defer syslog.Close() } + defer syslog.Close() } syslogf("Service entering svc.Run") @@ -160,10 +178,7 @@ func (service *ipnService) Execute(args []string, r <-chan svc.ChangeRequest, ch changes <- svc.Status{State: svc.StartPending} syslogf("Service start pending") - svcAccepts := svc.AcceptStop - if flushDNSOnSessionUnlock, _ := syspolicy.GetBoolean(syspolicy.FlushDNSOnSessionUnlock, false); flushDNSOnSessionUnlock { - svcAccepts |= svc.AcceptSessionChange - } + svcAccepts := svc.AcceptStop | svc.AcceptSessionChange ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -184,6 +199,10 @@ func (service *ipnService) Execute(args []string, r <-chan svc.ChangeRequest, ch changes <- svc.Status{State: svc.Running, Accepts: svcAccepts} syslogf("Service running") + // It is safe to allow GP locks to be acquired now that the service + // is running. + permitPolicyLocks() + for { select { case <-doneCh: @@ -309,8 +328,8 @@ func beWindowsSubprocess() bool { log.Printf("Error pre-loading \"%s\": %v", fqWintunPath, err) } - sys := new(tsd.System) - netMon, err := netmon.New(log.Printf) + sys := tsd.NewSystem() + netMon, err := netmon.New(sys.Bus.Get(), log.Printf) if err != nil { log.Fatalf("Could not create netMon: %v", err) } @@ -371,13 +390,15 @@ func handleSessionChange(chgRequest svc.ChangeRequest) { return } - log.Printf("Received WTS_SESSION_UNLOCK event, initiating DNS flush.") - go func() { - err := dns.Flush() - if err != nil { - log.Printf("Error flushing DNS on session unlock: %v", err) - } - }() + if flushDNSOnSessionUnlock, _ := syspolicy.GetBoolean(syspolicy.FlushDNSOnSessionUnlock, false); flushDNSOnSessionUnlock { + log.Printf("Received WTS_SESSION_UNLOCK event, initiating DNS flush.") + go func() { + err := dns.Flush() + if err != nil { + log.Printf("Error flushing DNS on session unlock: %v", err) + } + }() + } } var ( diff --git a/cmd/tailscaled/tailscaledhooks/tailscaledhooks.go b/cmd/tailscaled/tailscaledhooks/tailscaledhooks.go new file mode 100644 index 0000000000000..6ea662d39230c --- /dev/null +++ b/cmd/tailscaled/tailscaledhooks/tailscaledhooks.go @@ -0,0 +1,12 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tailscaledhooks provides hooks for optional features +// to add to during init that tailscaled calls at runtime. +package tailscaledhooks + +import "tailscale.com/feature" + +// UninstallSystemDaemonWindows is called when the Windows +// system daemon is uninstalled. +var UninstallSystemDaemonWindows feature.Hooks[func()] diff --git a/cmd/testwrapper/flakytest/flakytest.go b/cmd/testwrapper/flakytest/flakytest.go index 494ed080b26a1..6302900cbd3ab 100644 --- a/cmd/testwrapper/flakytest/flakytest.go +++ b/cmd/testwrapper/flakytest/flakytest.go @@ -9,8 +9,12 @@ package flakytest import ( "fmt" "os" + "path" "regexp" + "sync" "testing" + + "tailscale.com/util/mak" ) // FlakyTestLogMessage is a sentinel value that is printed to stderr when a @@ -25,6 +29,11 @@ const FlakeAttemptEnv = "TS_TESTWRAPPER_ATTEMPT" var issueRegexp = regexp.MustCompile(`\Ahttps://github\.com/tailscale/[a-zA-Z0-9_.-]+/issues/\d+\z`) +var ( + rootFlakesMu sync.Mutex + rootFlakes map[string]bool +) + // Mark sets the current test as a flaky test, such that if it fails, it will // be retried a few times on failure. issue must be a GitHub issue that tracks // the status of the flaky test being marked, of the format: @@ -41,4 +50,24 @@ func Mark(t testing.TB, issue string) { fmt.Fprintf(os.Stderr, "%s: %s\n", FlakyTestLogMessage, issue) } t.Logf("flakytest: issue tracking this flaky test: %s", issue) + + // Record the root test name as flakey. + rootFlakesMu.Lock() + defer rootFlakesMu.Unlock() + mak.Set(&rootFlakes, t.Name(), true) +} + +// Marked reports whether the current test or one of its parents was marked flaky. +func Marked(t testing.TB) bool { + n := t.Name() + for { + if rootFlakes[n] { + return true + } + n = path.Dir(n) + if n == "." || n == "/" { + break + } + } + return false } diff --git a/cmd/testwrapper/flakytest/flakytest_test.go b/cmd/testwrapper/flakytest/flakytest_test.go index 85e77a939c75d..64cbfd9a3cd1f 100644 --- a/cmd/testwrapper/flakytest/flakytest_test.go +++ b/cmd/testwrapper/flakytest/flakytest_test.go @@ -41,3 +41,49 @@ func TestFlakeRun(t *testing.T) { t.Fatal("First run in testwrapper, failing so that test is retried. This is expected.") } } + +func TestMarked_Root(t *testing.T) { + Mark(t, "https://github.com/tailscale/tailscale/issues/0") + + t.Run("child", func(t *testing.T) { + t.Run("grandchild", func(t *testing.T) { + if got, want := Marked(t), true; got != want { + t.Fatalf("Marked(t) = %t, want %t", got, want) + } + }) + + if got, want := Marked(t), true; got != want { + t.Fatalf("Marked(t) = %t, want %t", got, want) + } + }) + + if got, want := Marked(t), true; got != want { + t.Fatalf("Marked(t) = %t, want %t", got, want) + } +} + +func TestMarked_Subtest(t *testing.T) { + t.Run("flaky", func(t *testing.T) { + Mark(t, "https://github.com/tailscale/tailscale/issues/0") + + t.Run("child", func(t *testing.T) { + t.Run("grandchild", func(t *testing.T) { + if got, want := Marked(t), true; got != want { + t.Fatalf("Marked(t) = %t, want %t", got, want) + } + }) + + if got, want := Marked(t), true; got != want { + t.Fatalf("Marked(t) = %t, want %t", got, want) + } + }) + + if got, want := Marked(t), true; got != want { + t.Fatalf("Marked(t) = %t, want %t", got, want) + } + }) + + if got, want := Marked(t), false; got != want { + t.Fatalf("Marked(t) = %t, want %t", got, want) + } +} diff --git a/cmd/testwrapper/testwrapper.go b/cmd/testwrapper/testwrapper.go index 9b8d7a7c17ba5..53c1b1d05f7ca 100644 --- a/cmd/testwrapper/testwrapper.go +++ b/cmd/testwrapper/testwrapper.go @@ -10,6 +10,7 @@ package main import ( "bufio" "bytes" + "cmp" "context" "encoding/json" "errors" @@ -22,15 +23,9 @@ import ( "sort" "strings" "time" - "unicode" - "github.com/dave/courtney/scanner" - "github.com/dave/courtney/shared" - "github.com/dave/courtney/tester" - "github.com/dave/patsy" - "github.com/dave/patsy/vos" - xmaps "golang.org/x/exp/maps" "tailscale.com/cmd/testwrapper/flakytest" + "tailscale.com/util/slicesx" ) const ( @@ -42,6 +37,7 @@ type testAttempt struct { testName string // "TestFoo" outcome string // "pass", "fail", "skip" logs bytes.Buffer + start, end time.Time isMarkedFlaky bool // set if the test is marked as flaky issueURL string // set if the test is marked as flaky @@ -64,11 +60,12 @@ type packageTests struct { } type goTestOutput struct { - Time time.Time - Action string - Package string - Test string - Output string + Time time.Time + Action string + ImportPath string + Package string + Test string + Output string } var debug = os.Getenv("TS_TESTWRAPPER_DEBUG") != "" @@ -116,43 +113,56 @@ func runTests(ctx context.Context, attempt int, pt *packageTests, goTestArgs, te for s.Scan() { var goOutput goTestOutput if err := json.Unmarshal(s.Bytes(), &goOutput); err != nil { - if errors.Is(err, io.EOF) || errors.Is(err, os.ErrClosed) { - break - } - - // `go test -json` outputs invalid JSON when a build fails. - // In that case, discard the the output and start reading again. - // The build error will be printed to stderr. - // See: https://github.com/golang/go/issues/35169 - if _, ok := err.(*json.SyntaxError); ok { - fmt.Println(s.Text()) - continue - } - panic(err) + return fmt.Errorf("failed to parse go test output %q: %w", s.Bytes(), err) } - pkg := goOutput.Package + pkg := cmp.Or( + goOutput.Package, + "build:"+goOutput.ImportPath, // can be "./cmd" while Package is "tailscale.com/cmd" so use separate namespace + ) pkgTests := resultMap[pkg] + if pkgTests == nil { + pkgTests = map[string]*testAttempt{ + "": {}, // Used for start time and build logs. + } + resultMap[pkg] = pkgTests + } if goOutput.Test == "" { switch goOutput.Action { - case "fail", "pass", "skip": + case "start": + pkgTests[""].start = goOutput.Time + case "build-output": + pkgTests[""].logs.WriteString(goOutput.Output) + case "build-fail", "fail", "pass", "skip": for _, test := range pkgTests { - if test.outcome == "" { + if test.testName != "" && test.outcome == "" { test.outcome = "fail" ch <- test } } + outcome := goOutput.Action + if outcome == "build-fail" { + outcome = "fail" + } + pkgTests[""].logs.WriteString(goOutput.Output) ch <- &testAttempt{ pkg: goOutput.Package, - outcome: goOutput.Action, + outcome: outcome, + start: pkgTests[""].start, + end: goOutput.Time, + logs: pkgTests[""].logs, pkgFinished: true, } + case "output": + // Capture all output from the package except for the final + // "FAIL tailscale.io/control 0.684s" line, as + // printPkgOutcome will output a similar line + if !strings.HasPrefix(goOutput.Output, fmt.Sprintf("FAIL\t%s\t", goOutput.Package)) { + pkgTests[""].logs.WriteString(goOutput.Output) + } } + continue } - if pkgTests == nil { - pkgTests = make(map[string]*testAttempt) - resultMap[pkg] = pkgTests - } testName := goOutput.Test if test, _, isSubtest := strings.Cut(goOutput.Test, "/"); isSubtest { testName = test @@ -168,8 +178,10 @@ func runTests(ctx context.Context, attempt int, pt *packageTests, goTestArgs, te pkgTests[testName] = &testAttempt{ pkg: pkg, testName: testName, + start: goOutput.Time, } case "skip", "pass", "fail": + pkgTests[testName].end = goOutput.Time pkgTests[testName].outcome = goOutput.Action ch <- pkgTests[testName] case "output": @@ -213,7 +225,10 @@ func main() { firstRun.tests = append(firstRun.tests, &packageTests{Pattern: pkg}) } toRun := []*nextRun{firstRun} - printPkgOutcome := func(pkg, outcome string, attempt int) { + printPkgOutcome := func(pkg, outcome string, attempt int, runtime time.Duration) { + if pkg == "" { + return // We reach this path on a build error. + } if outcome == "skip" { fmt.Printf("?\t%s [skipped/no tests] \n", pkg) return @@ -225,36 +240,12 @@ func main() { outcome = "FAIL" } if attempt > 1 { - fmt.Printf("%s\t%s [attempt=%d]\n", outcome, pkg, attempt) + fmt.Printf("%s\t%s\t%.3fs\t[attempt=%d]\n", outcome, pkg, runtime.Seconds(), attempt) return } - fmt.Printf("%s\t%s\n", outcome, pkg) + fmt.Printf("%s\t%s\t%.3fs\n", outcome, pkg, runtime.Seconds()) } - // Check for -coverprofile argument and filter it out - combinedCoverageFilename := "" - filteredGoTestArgs := make([]string, 0, len(goTestArgs)) - preceededByCoverProfile := false - for _, arg := range goTestArgs { - if arg == "-coverprofile" { - preceededByCoverProfile = true - } else if preceededByCoverProfile { - combinedCoverageFilename = strings.TrimSpace(arg) - preceededByCoverProfile = false - } else { - filteredGoTestArgs = append(filteredGoTestArgs, arg) - } - } - goTestArgs = filteredGoTestArgs - - runningWithCoverage := combinedCoverageFilename != "" - if runningWithCoverage { - fmt.Printf("Will log coverage to %v\n", combinedCoverageFilename) - } - - // Keep track of all test coverage files. With each retry, we'll end up - // with additional coverage files that will be combined when we finish. - coverageFiles := make([]string, 0) for len(toRun) > 0 { var thisRun *nextRun thisRun, toRun = toRun[0], toRun[1:] @@ -268,27 +259,14 @@ func main() { fmt.Printf("\n\nAttempt #%d: Retrying flaky tests:\n\nflakytest failures JSON: %s\n\n", thisRun.attempt, j) } - goTestArgsWithCoverage := testArgs - if runningWithCoverage { - coverageFile := fmt.Sprintf("/tmp/coverage_%d.out", thisRun.attempt) - coverageFiles = append(coverageFiles, coverageFile) - goTestArgsWithCoverage = make([]string, len(goTestArgs), len(goTestArgs)+2) - copy(goTestArgsWithCoverage, goTestArgs) - goTestArgsWithCoverage = append( - goTestArgsWithCoverage, - fmt.Sprintf("-coverprofile=%v", coverageFile), - "-covermode=set", - "-coverpkg=./...", - ) - } - + fatalFailures := make(map[string]struct{}) // pkg.Test key toRetry := make(map[string][]*testAttempt) // pkg -> tests to retry for _, pt := range thisRun.tests { ch := make(chan *testAttempt) runErr := make(chan error, 1) go func() { defer close(runErr) - runErr <- runTests(ctx, thisRun.attempt, pt, goTestArgsWithCoverage, testArgs, ch) + runErr <- runTests(ctx, thisRun.attempt, pt, goTestArgs, testArgs, ch) }() var failed bool @@ -307,7 +285,12 @@ func main() { // when a package times out. failed = true } - printPkgOutcome(tr.pkg, tr.outcome, thisRun.attempt) + if testingVerbose || tr.outcome == "fail" { + // Output package-level output which is where e.g. + // panics outside tests will be printed + io.Copy(os.Stdout, &tr.logs) + } + printPkgOutcome(tr.pkg, tr.outcome, thisRun.attempt, tr.end.Sub(tr.start)) continue } if testingVerbose || tr.outcome == "fail" { @@ -319,11 +302,24 @@ func main() { if tr.isMarkedFlaky { toRetry[tr.pkg] = append(toRetry[tr.pkg], tr) } else { + fatalFailures[tr.pkg+"."+tr.testName] = struct{}{} failed = true } } if failed { fmt.Println("\n\nNot retrying flaky tests because non-flaky tests failed.") + + // Print the list of non-flakytest failures. + // We will later analyze the retried GitHub Action runs to see + // if non-flakytest failures succeeded upon retry. This will + // highlight tests which are flaky but not yet flagged as such. + if len(fatalFailures) > 0 { + tests := slicesx.MapKeys(fatalFailures) + sort.Strings(tests) + j, _ := json.Marshal(tests) + fmt.Printf("non-flakytest failures: %s\n", j) + } + fmt.Println() os.Exit(1) } @@ -343,7 +339,7 @@ func main() { if len(toRetry) == 0 { continue } - pkgs := xmaps.Keys(toRetry) + pkgs := slicesx.MapKeys(toRetry) sort.Strings(pkgs) nextRun := &nextRun{ attempt: thisRun.attempt + 1, @@ -365,107 +361,4 @@ func main() { } toRun = append(toRun, nextRun) } - - if runningWithCoverage { - intermediateCoverageFilename := "/tmp/coverage.out_intermediate" - if err := combineCoverageFiles(intermediateCoverageFilename, coverageFiles); err != nil { - fmt.Printf("error combining coverage files: %v\n", err) - os.Exit(2) - } - - if err := processCoverageWithCourtney(intermediateCoverageFilename, combinedCoverageFilename, testArgs); err != nil { - fmt.Printf("error processing coverage with courtney: %v\n", err) - os.Exit(3) - } - - fmt.Printf("Wrote combined coverage to %v\n", combinedCoverageFilename) - } -} - -func combineCoverageFiles(intermediateCoverageFilename string, coverageFiles []string) error { - combinedCoverageFile, err := os.OpenFile(intermediateCoverageFilename, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) - if err != nil { - return fmt.Errorf("create /tmp/coverage.out: %w", err) - } - defer combinedCoverageFile.Close() - w := bufio.NewWriter(combinedCoverageFile) - defer w.Flush() - - for fileNumber, coverageFile := range coverageFiles { - f, err := os.Open(coverageFile) - if err != nil { - return fmt.Errorf("open %v: %w", coverageFile, err) - } - defer f.Close() - in := bufio.NewReader(f) - line := 0 - for { - r, _, err := in.ReadRune() - if err != nil { - if err != io.EOF { - return fmt.Errorf("read %v: %w", coverageFile, err) - } - break - } - - // On all but the first coverage file, skip the coverage file header - if fileNumber > 0 && line == 0 { - continue - } - if r == '\n' { - line++ - } - - // filter for only printable characters because coverage file sometimes includes junk on 2nd line - if unicode.IsPrint(r) || r == '\n' { - if _, err := w.WriteRune(r); err != nil { - return fmt.Errorf("write %v: %w", combinedCoverageFile.Name(), err) - } - } - } - } - - return nil -} - -// processCoverageWithCourtney post-processes code coverage to exclude less -// meaningful sections like 'if err != nil { return err}', as well as -// anything marked with a '// notest' comment. -// -// instead of running the courtney as a separate program, this embeds -// courtney for easier integration. -func processCoverageWithCourtney(intermediateCoverageFilename, combinedCoverageFilename string, testArgs []string) error { - env := vos.Os() - - setup := &shared.Setup{ - Env: vos.Os(), - Paths: patsy.NewCache(env), - TestArgs: testArgs, - Load: intermediateCoverageFilename, - Output: combinedCoverageFilename, - } - if err := setup.Parse(testArgs); err != nil { - return fmt.Errorf("parse args: %w", err) - } - - s := scanner.New(setup) - if err := s.LoadProgram(); err != nil { - return fmt.Errorf("load program: %w", err) - } - if err := s.ScanPackages(); err != nil { - return fmt.Errorf("scan packages: %w", err) - } - - t := tester.New(setup) - if err := t.Load(); err != nil { - return fmt.Errorf("load: %w", err) - } - if err := t.ProcessExcludes(s.Excludes); err != nil { - return fmt.Errorf("process excludes: %w", err) - } - if err := t.Save(); err != nil { - return fmt.Errorf("save: %w", err) - } - - return nil } diff --git a/cmd/testwrapper/testwrapper_test.go b/cmd/testwrapper/testwrapper_test.go index d7dbccd093ef8..ace53ccd0e09a 100644 --- a/cmd/testwrapper/testwrapper_test.go +++ b/cmd/testwrapper/testwrapper_test.go @@ -10,6 +10,8 @@ import ( "os" "os/exec" "path/filepath" + "regexp" + "strings" "sync" "testing" ) @@ -76,7 +78,10 @@ func TestFlakeRun(t *testing.T) { t.Fatalf("go run . %s: %s with output:\n%s", testfile, err, out) } - want := []byte("ok\t" + testfile + " [attempt=2]") + // Replace the unpredictable timestamp with "0.00s". + out = regexp.MustCompile(`\t\d+\.\d\d\ds\t`).ReplaceAll(out, []byte("\t0.00s\t")) + + want := []byte("ok\t" + testfile + "\t0.00s\t[attempt=2]") if !bytes.Contains(out, want) { t.Fatalf("wanted output containing %q but got:\n%s", want, out) } @@ -150,24 +155,24 @@ func TestBuildError(t *testing.T) { t.Fatalf("writing package: %s", err) } - buildErr := []byte("builderror_test.go:3:1: expected declaration, found derp\nFAIL command-line-arguments [setup failed]") + wantErr := "builderror_test.go:3:1: expected declaration, found derp\nFAIL" // Confirm `go test` exits with code 1. goOut, err := exec.Command("go", "test", testfile).CombinedOutput() if code, ok := errExitCode(err); !ok || code != 1 { - t.Fatalf("go test %s: expected error with exit code 0 but got: %v", testfile, err) + t.Fatalf("go test %s: got exit code %d, want 1 (err: %v)", testfile, code, err) } - if !bytes.Contains(goOut, buildErr) { - t.Fatalf("go test %s: expected build error containing %q but got:\n%s", testfile, buildErr, goOut) + if !strings.Contains(string(goOut), wantErr) { + t.Fatalf("go test %s: got output %q, want output containing %q", testfile, goOut, wantErr) } // Confirm `testwrapper` exits with code 1. twOut, err := cmdTestwrapper(t, testfile).CombinedOutput() if code, ok := errExitCode(err); !ok || code != 1 { - t.Fatalf("testwrapper %s: expected error with exit code 0 but got: %v", testfile, err) + t.Fatalf("testwrapper %s: got exit code %d, want 1 (err: %v)", testfile, code, err) } - if !bytes.Contains(twOut, buildErr) { - t.Fatalf("testwrapper %s: expected build error containing %q but got:\n%s", testfile, buildErr, twOut) + if !strings.Contains(string(twOut), wantErr) { + t.Fatalf("testwrapper %s: got output %q, want output containing %q", testfile, twOut, wantErr) } if testing.Verbose() { diff --git a/cmd/tl-longchain/tl-longchain.go b/cmd/tl-longchain/tl-longchain.go index c92714505b8be..2a4dc10ba331c 100644 --- a/cmd/tl-longchain/tl-longchain.go +++ b/cmd/tl-longchain/tl-longchain.go @@ -22,7 +22,7 @@ import ( "log" "time" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/ipn/ipnstate" "tailscale.com/tka" "tailscale.com/types/key" @@ -37,7 +37,7 @@ var ( func main() { flag.Parse() - lc := tailscale.LocalClient{Socket: *flagSocket} + lc := local.Client{Socket: *flagSocket} if lc.Socket != "" { lc.UseSocketOnly = true } diff --git a/cmd/tsconnect/common.go b/cmd/tsconnect/common.go index a387c00c9758e..ff10e4efbb5d3 100644 --- a/cmd/tsconnect/common.go +++ b/cmd/tsconnect/common.go @@ -150,6 +150,7 @@ func runEsbuildServe(buildOptions esbuild.BuildOptions) { log.Fatalf("Cannot start esbuild server: %v", err) } log.Printf("Listening on http://%s:%d\n", result.Host, result.Port) + select {} } func runEsbuild(buildOptions esbuild.BuildOptions) esbuild.BuildResult { @@ -175,6 +176,10 @@ func runEsbuild(buildOptions esbuild.BuildOptions) esbuild.BuildResult { // wasm_exec.js runtime helper library from the Go toolchain. func setupEsbuildWasmExecJS(build esbuild.PluginBuild) { wasmExecSrcPath := filepath.Join(runtime.GOROOT(), "misc", "wasm", "wasm_exec.js") + if _, err := os.Stat(wasmExecSrcPath); os.IsNotExist(err) { + // Go 1.24+ location: + wasmExecSrcPath = filepath.Join(runtime.GOROOT(), "lib", "wasm", "wasm_exec.js") + } build.OnResolve(esbuild.OnResolveOptions{ Filter: "./wasm_exec$", }, func(args esbuild.OnResolveArgs) (esbuild.OnResolveResult, error) { diff --git a/cmd/tsconnect/tsconnect.go b/cmd/tsconnect/tsconnect.go index 4c8a0a52ece34..ef55593b49268 100644 --- a/cmd/tsconnect/tsconnect.go +++ b/cmd/tsconnect/tsconnect.go @@ -53,12 +53,12 @@ func main() { } func usage() { - fmt.Fprintf(os.Stderr, ` + fmt.Fprint(os.Stderr, ` usage: tsconnect {dev|build|serve} `[1:]) flag.PrintDefaults() - fmt.Fprintf(os.Stderr, ` + fmt.Fprint(os.Stderr, ` tsconnect implements development/build/serving workflows for Tailscale Connect. It can be invoked with one of three subcommands: diff --git a/cmd/tsconnect/wasm/wasm_js.go b/cmd/tsconnect/wasm/wasm_js.go index 8291ac9b4735f..779a87e49dec9 100644 --- a/cmd/tsconnect/wasm/wasm_js.go +++ b/cmd/tsconnect/wasm/wasm_js.go @@ -100,7 +100,7 @@ func newIPN(jsConfig js.Value) map[string]any { logtail := logtail.NewLogger(c, log.Printf) logf := logtail.Logf - sys := new(tsd.System) + sys := tsd.NewSystem() sys.Set(store) dialer := &tsdial.Dialer{Logf: logf} eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ @@ -108,13 +108,14 @@ func newIPN(jsConfig js.Value) map[string]any { SetSubsystem: sys.Set, ControlKnobs: sys.ControlKnobs(), HealthTracker: sys.HealthTracker(), + Metrics: sys.UserMetricsRegistry(), }) if err != nil { log.Fatal(err) } sys.Set(eng) - ns, err := netstack.Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil) + ns, err := netstack.Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper()) if err != nil { log.Fatalf("netstack.Create: %v", err) } @@ -128,6 +129,9 @@ func newIPN(jsConfig js.Value) map[string]any { dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { return ns.DialContextTCP(ctx, dst) } + dialer.NetstackDialUDP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) { + return ns.DialContextUDP(ctx, dst) + } sys.NetstackRouter.Set(true) sys.Tun.Get().Start() @@ -268,8 +272,8 @@ func (i *jsIPN) run(jsCallbacks js.Value) { name = p.Hostinfo().Hostname() } addrs := make([]string, p.Addresses().Len()) - for i := range p.Addresses().Len() { - addrs[i] = p.Addresses().At(i).Addr().String() + for i, ap := range p.Addresses().All() { + addrs[i] = ap.Addr().String() } return jsNetMapPeerNode{ jsNetMapNode: jsNetMapNode{ @@ -278,7 +282,7 @@ func (i *jsIPN) run(jsCallbacks js.Value) { MachineKey: p.Machine().String(), NodeKey: p.Key().String(), }, - Online: p.Online(), + Online: p.Online().Clone(), TailscaleSSHEnabled: p.Hostinfo().TailscaleSSHEnabled(), } }), @@ -585,8 +589,8 @@ func mapSlice[T any, M any](a []T, f func(T) M) []M { func mapSliceView[T any, M any](a views.Slice[T], f func(T) M) []M { n := make([]M, a.Len()) - for i := range a.Len() { - n[i] = f(a.At(i)) + for i, v := range a.All() { + n[i] = f(v) } return n } diff --git a/cmd/tsconnect/yarn.lock b/cmd/tsconnect/yarn.lock index 663a1244ebf69..d9d9db32f66a0 100644 --- a/cmd/tsconnect/yarn.lock +++ b/cmd/tsconnect/yarn.lock @@ -90,11 +90,11 @@ binary-extensions@^2.0.0: integrity sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA== braces@^3.0.2, braces@~3.0.2: - version "3.0.2" - resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.2.tgz#3454e1a462ee8d599e236df336cd9ea4f8afe107" - integrity sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A== + version "3.0.3" + resolved "https://registry.yarnpkg.com/braces/-/braces-3.0.3.tgz#490332f40919452272d55a8480adc0c441358789" + integrity sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA== dependencies: - fill-range "^7.0.1" + fill-range "^7.1.1" camelcase-css@^2.0.1: version "2.0.1" @@ -231,10 +231,10 @@ fastq@^1.6.0: dependencies: reusify "^1.0.4" -fill-range@^7.0.1: - version "7.0.1" - resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.0.1.tgz#1919a6a7c75fe38b2c7c77e5198535da9acdda40" - integrity sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ== +fill-range@^7.1.1: + version "7.1.1" + resolved "https://registry.yarnpkg.com/fill-range/-/fill-range-7.1.1.tgz#44265d3cac07e3ea7dc247516380643754a05292" + integrity sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg== dependencies: to-regex-range "^5.0.1" @@ -349,9 +349,9 @@ minimist@^1.2.6: integrity sha512-Jsjnk4bw3YJqYzbdyBiNsPWHPfO++UGG749Cxs6peCu5Xg4nrena6OVxOYxrQTqww0Jmwt+Ref8rggumkTLz9Q== nanoid@^3.3.4: - version "3.3.4" - resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.4.tgz#730b67e3cd09e2deacf03c027c81c9d9dbc5e8ab" - integrity sha512-MqBkQh/OHTS2egovRtLk45wEyNXwF+cokD+1YPf9u5VfJiRdAiRwB2froX5Co9Rh20xs4siNPm8naNotSD6RBw== + version "3.3.8" + resolved "https://registry.yarnpkg.com/nanoid/-/nanoid-3.3.8.tgz#b1be3030bee36aaff18bacb375e5cce521684baf" + integrity sha512-WNLf5Sd8oZxOm+TzppcYk8gVOgP+l58xNy58D0nbUnOxOWRWvlcCV4kUF7ltmI6PsrLl/BgKEyS4mqsGChFN0w== normalize-path@^3.0.0, normalize-path@~3.0.0: version "3.0.0" diff --git a/cmd/tsidp/Dockerfile b/cmd/tsidp/Dockerfile new file mode 100644 index 0000000000000..c4f352ed01839 --- /dev/null +++ b/cmd/tsidp/Dockerfile @@ -0,0 +1,41 @@ +# Build stage +FROM golang:alpine AS builder + +# Install build dependencies +RUN apk add --no-cache git + +# Set working directory +WORKDIR /src + +# Copy only go.mod and go.sum first to leverage Docker caching +COPY go.mod go.sum ./ +RUN go mod download + +# Copy the entire repository +COPY . . + +# Build the tsidp binary +RUN go build -o /bin/tsidp ./cmd/tsidp + +# Final stage +FROM alpine:latest + +# Create necessary directories +RUN mkdir -p /var/lib/tsidp + +# Copy binary from builder stage +COPY --from=builder /bin/tsidp /app/tsidp + +# Set working directory +WORKDIR /app + +# Environment variables +ENV TAILSCALE_USE_WIP_CODE=1 \ + TS_HOSTNAME=idp \ + TS_STATE_DIR=/var/lib/tsidp + +# Expose the default port +EXPOSE 443 + +# Run the application +ENTRYPOINT ["/bin/sh", "-c", "/app/tsidp --hostname=${TS_HOSTNAME} --dir=${TS_STATE_DIR}"] diff --git a/cmd/tsidp/README.md b/cmd/tsidp/README.md new file mode 100644 index 0000000000000..61a81e8aee7ae --- /dev/null +++ b/cmd/tsidp/README.md @@ -0,0 +1,101 @@ +# `tsidp` - Tailscale OpenID Connect (OIDC) Identity Provider + +[![status: experimental](https://img.shields.io/badge/status-experimental-blue)](https://tailscale.com/kb/1167/release-stages/#experimental) + +`tsidp` is an OIDC Identity Provider (IdP) server that integrates with your Tailscale network. It allows you to use Tailscale identities for authentication in applications that support OpenID Connect, enabling single sign-on (SSO) capabilities within your tailnet. + +## Prerequisites + +- A Tailscale network (tailnet) with magicDNS and HTTPS enabled +- A Tailscale authentication key from your tailnet +- Docker installed on your system + +## Installation using Docker + +1. **Build the Docker Image** + + The Dockerfile uses a multi-stage build process to: + - Build the `tsidp` binary from source + - Create a minimal Alpine-based image with just the necessary components + + ```bash + # Clone the Tailscale repository + git clone https://github.com/tailscale/tailscale.git + cd tailscale + ``` + + ```bash + # Build the Docker image + docker build -t tsidp:latest -f cmd/tsidp/Dockerfile . + ``` + +2. **Run the Container** + + Replace `YOUR_TAILSCALE_AUTHKEY` with your Tailscale authentication key. + + ```bash + docker run -d \ + --name tsidp \ + -p 443:443 \ + -e TS_AUTHKEY=YOUR_TAILSCALE_AUTHKEY \ + -e TS_HOSTNAME=idp \ + -v tsidp-data:/var/lib/tsidp \ + tsidp:latest + ``` + +3. **Verify Installation** + ```bash + docker logs tsidp + ``` + + Visit `https://idp.tailnet.ts.net` to confirm the service is running. + +## Usage Example: Proxmox Integration + +Here's how to configure Proxmox to use `tsidp` for authentication: + +1. In Proxmox, navigate to Datacenter > Realms > Add OpenID Connect Server + +2. Configure the following settings: + - Issuer URL: `https://idp.velociraptor.ts.net` + - Realm: `tailscale` (or your preferred name) + - Client ID: `unused` + - Client Key: `unused` + - Default: `true` + - Autocreate users: `true` + - Username claim: `email` + +3. Set up user permissions: + - Go to Datacenter > Permissions > Groups + - Create a new group (e.g., "tsadmins") + - Click Permissions in the sidebar + - Add Group Permission + - Set Path to `/` for full admin access or scope as needed + - Set the group and role + - Add Tailscale-authenticated users to the group + +## Configuration Options + +The `tsidp` server supports several command-line flags: + +- `--verbose`: Enable verbose logging +- `--port`: Port to listen on (default: 443) +- `--local-port`: Allow requests from localhost +- `--use-local-tailscaled`: Use local tailscaled instead of tsnet +- `--hostname`: tsnet hostname +- `--dir`: tsnet state directory + +## Environment Variables + +- `TS_AUTHKEY`: Your Tailscale authentication key (required) +- `TS_HOSTNAME`: Hostname for the `tsidp` server (default: "idp", Docker only) +- `TS_STATE_DIR`: State directory (default: "/var/lib/tsidp", Docker only) +- `TAILSCALE_USE_WIP_CODE`: Enable work-in-progress code (default: "1") + +## Support + +This is an [experimental](https://tailscale.com/kb/1167/release-stages#experimental), work in progress feature. For issues or questions, file issues on the [GitHub repository](https://github.com/tailscale/tailscale) + +## License + +BSD-3-Clause License. See [LICENSE](../../LICENSE) for details. diff --git a/cmd/tsidp/tsidp.go b/cmd/tsidp/tsidp.go index 1bdca8919a085..e2b777fa1b68e 100644 --- a/cmd/tsidp/tsidp.go +++ b/cmd/tsidp/tsidp.go @@ -11,6 +11,7 @@ import ( "context" crand "crypto/rand" "crypto/rsa" + "crypto/subtle" "crypto/tls" "crypto/x509" "encoding/base64" @@ -35,7 +36,7 @@ import ( "gopkg.in/square/go-jose.v2" "gopkg.in/square/go-jose.v2/jwt" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/client/tailscale/apitype" "tailscale.com/envknob" "tailscale.com/ipn" @@ -64,6 +65,7 @@ var ( flagLocalPort = flag.Int("local-port", -1, "allow requests from localhost") flagUseLocalTailscaled = flag.Bool("use-local-tailscaled", false, "use local tailscaled instead of tsnet") flagFunnel = flag.Bool("funnel", false, "use Tailscale Funnel to make tsidp available on the public internet") + flagHostname = flag.String("hostname", "idp", "tsnet hostname to use instead of idp") flagDir = flag.String("dir", "", "tsnet state directory; a default one will be created if not provided") ) @@ -75,7 +77,7 @@ func main() { } var ( - lc *tailscale.LocalClient + lc *local.Client st *ipnstate.Status err error watcherChan chan error @@ -84,7 +86,7 @@ func main() { lns []net.Listener ) if *flagUseLocalTailscaled { - lc = &tailscale.LocalClient{} + lc = &local.Client{} st, err = lc.StatusWithoutPeers(ctx) if err != nil { log.Fatalf("getting status: %v", err) @@ -120,7 +122,7 @@ func main() { defer cleanup() } else { ts := &tsnet.Server{ - Hostname: "idp", + Hostname: *flagHostname, Dir: *flagDir, } if *flagVerbose { @@ -212,7 +214,7 @@ func main() { // serveOnLocalTailscaled starts a serve session using an already-running // tailscaled instead of starting a fresh tsnet server, making something // listening on clientDNSName:dstPort accessible over serve/funnel. -func serveOnLocalTailscaled(ctx context.Context, lc *tailscale.LocalClient, st *ipnstate.Status, dstPort uint16, shouldFunnel bool) (cleanup func(), watcherChan chan error, err error) { +func serveOnLocalTailscaled(ctx context.Context, lc *local.Client, st *ipnstate.Status, dstPort uint16, shouldFunnel bool) (cleanup func(), watcherChan chan error, err error) { // In order to support funneling out in local tailscaled mode, we need // to add a serve config to forward the listeners we bound above and // allow those forwarders to be funneled out. @@ -275,7 +277,7 @@ func serveOnLocalTailscaled(ctx context.Context, lc *tailscale.LocalClient, st * } type idpServer struct { - lc *tailscale.LocalClient + lc *local.Client loopbackURL string serverURL string // "https://foo.bar.ts.net" funnel bool @@ -328,7 +330,7 @@ type authRequest struct { // allowRelyingParty validates that a relying party identified either by a // known remoteAddr or a valid client ID/secret pair is allowed to proceed // with the authorization flow associated with this authRequest. -func (ar *authRequest) allowRelyingParty(r *http.Request, lc *tailscale.LocalClient) error { +func (ar *authRequest) allowRelyingParty(r *http.Request, lc *local.Client) error { if ar.localRP { ra, err := netip.ParseAddrPort(r.RemoteAddr) if err != nil { @@ -345,7 +347,9 @@ func (ar *authRequest) allowRelyingParty(r *http.Request, lc *tailscale.LocalCli clientID = r.FormValue("client_id") clientSecret = r.FormValue("client_secret") } - if ar.funnelRP.ID != clientID || ar.funnelRP.Secret != clientSecret { + clientIDcmp := subtle.ConstantTimeCompare([]byte(clientID), []byte(ar.funnelRP.ID)) + clientSecretcmp := subtle.ConstantTimeCompare([]byte(clientSecret), []byte(ar.funnelRP.Secret)) + if clientIDcmp != 1 || clientSecretcmp != 1 { return fmt.Errorf("tsidp: invalid client credentials") } return nil @@ -494,6 +498,7 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest) return } + ui.Sub = ar.remoteUser.Node.User.String() ui.Name = ar.remoteUser.UserProfile.DisplayName ui.Email = ar.remoteUser.UserProfile.LoginName @@ -502,8 +507,29 @@ func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { // TODO(maisem): not sure if this is the right thing to do ui.UserName, _, _ = strings.Cut(ar.remoteUser.UserProfile.LoginName, "@") + rules, err := tailcfg.UnmarshalCapJSON[capRule](ar.remoteUser.CapMap, tailcfg.PeerCapabilityTsIDP) + if err != nil { + http.Error(w, "tsidp: failed to unmarshal capability: %v", http.StatusBadRequest) + return + } + + // Only keep rules where IncludeInUserInfo is true + var filtered []capRule + for _, r := range rules { + if r.IncludeInUserInfo { + filtered = append(filtered, r) + } + } + + userInfo, err := withExtraClaims(ui, filtered) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // Write the final result w.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(w).Encode(ui); err != nil { + if err := json.NewEncoder(w).Encode(userInfo); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } } @@ -516,6 +542,140 @@ type userInfo struct { UserName string `json:"username"` } +type capRule struct { + IncludeInUserInfo bool `json:"includeInUserInfo"` + ExtraClaims map[string]any `json:"extraClaims,omitempty"` // list of features peer is allowed to edit +} + +// flattenExtraClaims merges all ExtraClaims from a slice of capRule into a single map. +// It deduplicates values for each claim and preserves the original input type: +// scalar values remain scalars, and slices are returned as deduplicated []any slices. +func flattenExtraClaims(rules []capRule) map[string]any { + // sets stores deduplicated stringified values for each claim key. + sets := make(map[string]map[string]struct{}) + + // isSlice tracks whether each claim was originally provided as a slice. + isSlice := make(map[string]bool) + + for _, rule := range rules { + for claim, raw := range rule.ExtraClaims { + // Track whether the claim was provided as a slice + switch raw.(type) { + case []string, []any: + isSlice[claim] = true + default: + // Only mark as scalar if this is the first time we've seen this claim + if _, seen := isSlice[claim]; !seen { + isSlice[claim] = false + } + } + + // Add the claim value(s) into the deduplication set + addClaimValue(sets, claim, raw) + } + } + + // Build final result: either scalar or slice depending on original type + result := make(map[string]any) + for claim, valSet := range sets { + if isSlice[claim] { + // Claim was provided as a slice: output as []any + var vals []any + for val := range valSet { + vals = append(vals, val) + } + result[claim] = vals + } else { + // Claim was a scalar: return a single value + for val := range valSet { + result[claim] = val + break // only one value is expected + } + } + } + + return result +} + +// addClaimValue adds a claim value to the deduplication set for a given claim key. +// It accepts scalars (string, int, float64), slices of strings or interfaces, +// and recursively handles nested slices. Unsupported types are ignored with a log message. +func addClaimValue(sets map[string]map[string]struct{}, claim string, val any) { + switch v := val.(type) { + case string, float64, int, int64: + // Ensure the claim set is initialized + if sets[claim] == nil { + sets[claim] = make(map[string]struct{}) + } + // Add the stringified scalar to the set + sets[claim][fmt.Sprintf("%v", v)] = struct{}{} + + case []string: + // Ensure the claim set is initialized + if sets[claim] == nil { + sets[claim] = make(map[string]struct{}) + } + // Add each string value to the set + for _, s := range v { + sets[claim][s] = struct{}{} + } + + case []any: + // Recursively handle each item in the slice + for _, item := range v { + addClaimValue(sets, claim, item) + } + + default: + // Log unsupported types for visibility and debugging + log.Printf("Unsupported claim type for %q: %#v (type %T)", claim, val, val) + } +} + +// withExtraClaims merges flattened extra claims from a list of capRule into the provided struct v, +// returning a map[string]any that combines both sources. +// +// v is any struct whose fields represent static claims; it is first marshaled to JSON, then unmarshalled into a generic map. +// rules is a slice of capRule objects that may define additional (extra) claims to merge. +// +// These extra claims are flattened and merged into the base map unless they conflict with protected claims. +// Claims defined in openIDSupportedClaims are considered protected and cannot be overwritten. +// If an extra claim attempts to overwrite a protected claim, an error is returned. +// +// Returns the merged claims map or an error if any protected claim is violated or JSON (un)marshaling fails. +func withExtraClaims(v any, rules []capRule) (map[string]any, error) { + // Marshal the static struct + data, err := json.Marshal(v) + if err != nil { + return nil, err + } + + // Unmarshal into a generic map + var claimMap map[string]any + if err := json.Unmarshal(data, &claimMap); err != nil { + return nil, err + } + + // Convert views.Slice to a map[string]struct{} for efficient lookup + protected := make(map[string]struct{}, len(openIDSupportedClaims.AsSlice())) + for _, claim := range openIDSupportedClaims.AsSlice() { + protected[claim] = struct{}{} + } + + // Merge extra claims + extra := flattenExtraClaims(rules) + for k, v := range extra { + if _, isProtected := protected[k]; isProtected { + log.Printf("Skip overwriting of existing claim %q", k) + return nil, fmt.Errorf("extra claim %q overwriting existing claim", k) + } + + claimMap[k] = v + } + + return claimMap, nil +} + func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed) @@ -592,8 +752,22 @@ func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { tsClaims.Issuer = s.loopbackURL } + rules, err := tailcfg.UnmarshalCapJSON[capRule](who.CapMap, tailcfg.PeerCapabilityTsIDP) + if err != nil { + log.Printf("tsidp: failed to unmarshal capability: %v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + tsClaimsWithExtra, err := withExtraClaims(tsClaims, rules) + if err != nil { + log.Printf("tsidp: failed to merge extra claims: %v", err) + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + // Create an OIDC token using this issuer's signer. - token, err := jwt.Signed(signer).Claims(tsClaims).CompactSerialize() + token, err := jwt.Signed(signer).Claims(tsClaimsWithExtra).CompactSerialize() if err != nil { log.Printf("Error getting token: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) @@ -762,6 +936,18 @@ var ( ) func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Set("Access-Control-Allow-Origin", "*") + h.Set("Access-Control-Allow-Method", "GET, OPTIONS") + // allow all to prevent errors from client sending their own bespoke headers + // and having the server reject the request. + h.Set("Access-Control-Allow-Headers", "*") + + // early return for pre-flight OPTIONS requests. + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } if r.URL.Path != oidcConfigPath { http.Error(w, "tsidp: not found", http.StatusNotFound) return diff --git a/cmd/tsidp/tsidp_test.go b/cmd/tsidp/tsidp_test.go new file mode 100644 index 0000000000000..76a11899187f7 --- /dev/null +++ b/cmd/tsidp/tsidp_test.go @@ -0,0 +1,826 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause +package main + +import ( + "crypto/rand" + "crypto/rsa" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "net/http/httptest" + "net/netip" + "net/url" + "reflect" + "sort" + "strings" + "testing" + "time" + + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/views" +) + +// normalizeMap recursively sorts []any values in a map[string]any +func normalizeMap(t *testing.T, m map[string]any) map[string]any { + t.Helper() + normalized := make(map[string]any, len(m)) + for k, v := range m { + switch val := v.(type) { + case []any: + sorted := make([]string, len(val)) + for i, item := range val { + sorted[i] = fmt.Sprintf("%v", item) // convert everything to string for sorting + } + sort.Strings(sorted) + + // convert back to []any + sortedIface := make([]any, len(sorted)) + for i, s := range sorted { + sortedIface[i] = s + } + normalized[k] = sortedIface + + default: + normalized[k] = v + } + } + return normalized +} + +func mustMarshalJSON(t *testing.T, v any) tailcfg.RawMessage { + t.Helper() + b, err := json.Marshal(v) + if err != nil { + panic(err) + } + return tailcfg.RawMessage(b) +} + +var privateKey *rsa.PrivateKey = nil + +func oidcTestingSigner(t *testing.T) jose.Signer { + t.Helper() + privKey := mustGeneratePrivateKey(t) + sig, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: privKey}, nil) + if err != nil { + t.Fatalf("failed to create signer: %v", err) + } + return sig +} + +func oidcTestingPublicKey(t *testing.T) *rsa.PublicKey { + t.Helper() + privKey := mustGeneratePrivateKey(t) + return &privKey.PublicKey +} + +func mustGeneratePrivateKey(t *testing.T) *rsa.PrivateKey { + t.Helper() + if privateKey != nil { + return privateKey + } + + var err error + privateKey, err = rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("failed to generate key: %v", err) + } + + return privateKey +} + +func TestFlattenExtraClaims(t *testing.T) { + log.SetOutput(io.Discard) // suppress log output during tests + + tests := []struct { + name string + input []capRule + expected map[string]any + }{ + { + name: "empty extra claims", + input: []capRule{ + {ExtraClaims: map[string]any{}}, + }, + expected: map[string]any{}, + }, + { + name: "string and number values", + input: []capRule{ + { + ExtraClaims: map[string]any{ + "featureA": "read", + "featureB": 42, + }, + }, + }, + expected: map[string]any{ + "featureA": "read", + "featureB": "42", + }, + }, + { + name: "slice of strings and ints", + input: []capRule{ + { + ExtraClaims: map[string]any{ + "roles": []any{"admin", "user", 1}, + }, + }, + }, + expected: map[string]any{ + "roles": []any{"admin", "user", "1"}, + }, + }, + { + name: "duplicate values deduplicated (slice input)", + input: []capRule{ + { + ExtraClaims: map[string]any{ + "foo": []string{"bar", "baz"}, + }, + }, + { + ExtraClaims: map[string]any{ + "foo": []any{"bar", "qux"}, + }, + }, + }, + expected: map[string]any{ + "foo": []any{"bar", "baz", "qux"}, + }, + }, + { + name: "ignore unsupported map type, keep valid scalar", + input: []capRule{ + { + ExtraClaims: map[string]any{ + "invalid": map[string]any{"bad": "yes"}, + "valid": "ok", + }, + }, + }, + expected: map[string]any{ + "valid": "ok", + }, + }, + { + name: "scalar first, slice second", + input: []capRule{ + {ExtraClaims: map[string]any{"foo": "bar"}}, + {ExtraClaims: map[string]any{"foo": []any{"baz"}}}, + }, + expected: map[string]any{ + "foo": []any{"bar", "baz"}, // since first was scalar, second being a slice forces slice output + }, + }, + { + name: "conflicting scalar and unsupported map", + input: []capRule{ + {ExtraClaims: map[string]any{"foo": "bar"}}, + {ExtraClaims: map[string]any{"foo": map[string]any{"bad": "entry"}}}, + }, + expected: map[string]any{ + "foo": "bar", // map should be ignored + }, + }, + { + name: "multiple slices with overlap", + input: []capRule{ + {ExtraClaims: map[string]any{"roles": []any{"admin", "user"}}}, + {ExtraClaims: map[string]any{"roles": []any{"admin", "guest"}}}, + }, + expected: map[string]any{ + "roles": []any{"admin", "user", "guest"}, + }, + }, + { + name: "slice with unsupported values", + input: []capRule{ + {ExtraClaims: map[string]any{ + "mixed": []any{"ok", 42, map[string]string{"oops": "fail"}}, + }}, + }, + expected: map[string]any{ + "mixed": []any{"ok", "42"}, // map is ignored + }, + }, + { + name: "duplicate scalar value", + input: []capRule{ + {ExtraClaims: map[string]any{"env": "prod"}}, + {ExtraClaims: map[string]any{"env": "prod"}}, + }, + expected: map[string]any{ + "env": "prod", // not converted to slice + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := flattenExtraClaims(tt.input) + + gotNormalized := normalizeMap(t, got) + expectedNormalized := normalizeMap(t, tt.expected) + + if !reflect.DeepEqual(gotNormalized, expectedNormalized) { + t.Errorf("mismatch\nGot:\n%s\nWant:\n%s", gotNormalized, expectedNormalized) + } + }) + } +} + +func TestExtraClaims(t *testing.T) { + tests := []struct { + name string + claim tailscaleClaims + extraClaims []capRule + expected map[string]any + expectError bool + }{ + { + name: "extra claim", + claim: tailscaleClaims{ + Claims: jwt.Claims{}, + Nonce: "foobar", + Key: key.NodePublic{}, + Addresses: views.Slice[netip.Prefix]{}, + NodeID: 0, + NodeName: "test-node", + Tailnet: "test.ts.net", + Email: "test@example.com", + UserID: 0, + UserName: "test", + }, + extraClaims: []capRule{ + { + ExtraClaims: map[string]any{ + "foo": []string{"bar"}, + }, + }, + }, + expected: map[string]any{ + "nonce": "foobar", + "key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000", + "addresses": nil, + "nid": float64(0), + "node": "test-node", + "tailnet": "test.ts.net", + "email": "test@example.com", + "username": "test", + "foo": []any{"bar"}, + }, + }, + { + name: "duplicate claim distinct values", + claim: tailscaleClaims{ + Claims: jwt.Claims{}, + Nonce: "foobar", + Key: key.NodePublic{}, + Addresses: views.Slice[netip.Prefix]{}, + NodeID: 0, + NodeName: "test-node", + Tailnet: "test.ts.net", + Email: "test@example.com", + UserID: 0, + UserName: "test", + }, + extraClaims: []capRule{ + { + ExtraClaims: map[string]any{ + "foo": []string{"bar"}, + }, + }, + { + ExtraClaims: map[string]any{ + "foo": []string{"foobar"}, + }, + }, + }, + expected: map[string]any{ + "nonce": "foobar", + "key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000", + "addresses": nil, + "nid": float64(0), + "node": "test-node", + "tailnet": "test.ts.net", + "email": "test@example.com", + "username": "test", + "foo": []any{"foobar", "bar"}, + }, + }, + { + name: "multiple extra claims", + claim: tailscaleClaims{ + Claims: jwt.Claims{}, + Nonce: "foobar", + Key: key.NodePublic{}, + Addresses: views.Slice[netip.Prefix]{}, + NodeID: 0, + NodeName: "test-node", + Tailnet: "test.ts.net", + Email: "test@example.com", + UserID: 0, + UserName: "test", + }, + extraClaims: []capRule{ + { + ExtraClaims: map[string]any{ + "foo": []string{"bar"}, + }, + }, + { + ExtraClaims: map[string]any{ + "bar": []string{"foo"}, + }, + }, + }, + expected: map[string]any{ + "nonce": "foobar", + "key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000", + "addresses": nil, + "nid": float64(0), + "node": "test-node", + "tailnet": "test.ts.net", + "email": "test@example.com", + "username": "test", + "foo": []any{"bar"}, + "bar": []any{"foo"}, + }, + }, + { + name: "overwrite claim", + claim: tailscaleClaims{ + Claims: jwt.Claims{}, + Nonce: "foobar", + Key: key.NodePublic{}, + Addresses: views.Slice[netip.Prefix]{}, + NodeID: 0, + NodeName: "test-node", + Tailnet: "test.ts.net", + Email: "test@example.com", + UserID: 0, + UserName: "test", + }, + extraClaims: []capRule{ + { + ExtraClaims: map[string]any{ + "username": "foobar", + }, + }, + }, + expected: map[string]any{ + "nonce": "foobar", + "key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000", + "addresses": nil, + "nid": float64(0), + "node": "test-node", + "tailnet": "test.ts.net", + "email": "test@example.com", + "username": "foobar", + }, + expectError: true, + }, + { + name: "empty extra claims", + claim: tailscaleClaims{ + Claims: jwt.Claims{}, + Nonce: "foobar", + Key: key.NodePublic{}, + Addresses: views.Slice[netip.Prefix]{}, + NodeID: 0, + NodeName: "test-node", + Tailnet: "test.ts.net", + Email: "test@example.com", + UserID: 0, + UserName: "test", + }, + extraClaims: []capRule{{ExtraClaims: map[string]any{}}}, + expected: map[string]any{ + "nonce": "foobar", + "key": "nodekey:0000000000000000000000000000000000000000000000000000000000000000", + "addresses": nil, + "nid": float64(0), + "node": "test-node", + "tailnet": "test.ts.net", + "email": "test@example.com", + "username": "test", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claims, err := withExtraClaims(tt.claim, tt.extraClaims) + if err != nil && !tt.expectError { + t.Fatalf("claim.withExtraClaims() unexpected error = %v", err) + } else if err == nil && tt.expectError { + t.Fatalf("expected error, got nil") + } else if err != nil && tt.expectError { + return // just as expected + } + + // Marshal to JSON then unmarshal back to map[string]any + gotClaims, err := json.Marshal(claims) + if err != nil { + t.Errorf("json.Marshal(claims) error = %v", err) + } + + var gotClaimsMap map[string]any + if err := json.Unmarshal(gotClaims, &gotClaimsMap); err != nil { + t.Fatalf("json.Unmarshal(gotClaims) error = %v", err) + } + + gotNormalized := normalizeMap(t, gotClaimsMap) + expectedNormalized := normalizeMap(t, tt.expected) + + if !reflect.DeepEqual(gotNormalized, expectedNormalized) { + t.Errorf("claims mismatch:\n got: %#v\nwant: %#v", gotNormalized, expectedNormalized) + } + }) + } +} + +func TestServeToken(t *testing.T) { + tests := []struct { + name string + caps tailcfg.PeerCapMap + method string + grantType string + code string + omitCode bool + redirectURI string + remoteAddr string + expectError bool + expected map[string]any + }{ + { + name: "GET not allowed", + method: "GET", + grantType: "authorization_code", + expectError: true, + }, + { + name: "unsupported grant type", + method: "POST", + grantType: "pkcs", + expectError: true, + }, + { + name: "invalid code", + method: "POST", + grantType: "authorization_code", + code: "invalid-code", + expectError: true, + }, + { + name: "omit code from form", + method: "POST", + grantType: "authorization_code", + omitCode: true, + expectError: true, + }, + { + name: "invalid redirect uri", + method: "POST", + grantType: "authorization_code", + code: "valid-code", + redirectURI: "https://invalid.example.com/callback", + remoteAddr: "127.0.0.1:12345", + expectError: true, + }, + { + name: "invalid remoteAddr", + method: "POST", + grantType: "authorization_code", + redirectURI: "https://rp.example.com/callback", + code: "valid-code", + remoteAddr: "192.168.0.1:12345", + expectError: true, + }, + { + name: "extra claim included", + method: "POST", + grantType: "authorization_code", + redirectURI: "https://rp.example.com/callback", + code: "valid-code", + remoteAddr: "127.0.0.1:12345", + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]any{ + "foo": "bar", + }, + }), + }, + }, + expected: map[string]any{ + "foo": "bar", + }, + }, + { + name: "attempt to overwrite protected claim", + method: "POST", + grantType: "authorization_code", + redirectURI: "https://rp.example.com/callback", + code: "valid-code", + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]any{ + "sub": "should-not-overwrite", + }, + }), + }, + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + now := time.Now() + + // Fake user/node + profile := &tailcfg.UserProfile{ + LoginName: "alice@example.com", + DisplayName: "Alice Example", + ProfilePicURL: "https://example.com/alice.jpg", + } + node := &tailcfg.Node{ + ID: 123, + Name: "test-node.test.ts.net.", + User: 456, + Key: key.NodePublic{}, + Cap: 1, + DiscoKey: key.DiscoPublic{}, + } + + remoteUser := &apitype.WhoIsResponse{ + Node: node, + UserProfile: profile, + CapMap: tt.caps, + } + + s := &idpServer{ + code: map[string]*authRequest{ + "valid-code": { + clientID: "client-id", + nonce: "nonce123", + redirectURI: "https://rp.example.com/callback", + validTill: now.Add(5 * time.Minute), + remoteUser: remoteUser, + localRP: true, + }, + }, + } + // Inject a working signer + s.lazySigner.Set(oidcTestingSigner(t)) + + form := url.Values{} + form.Set("grant_type", tt.grantType) + form.Set("redirect_uri", tt.redirectURI) + if !tt.omitCode { + form.Set("code", tt.code) + } + + req := httptest.NewRequest(tt.method, "/token", strings.NewReader(form.Encode())) + req.RemoteAddr = tt.remoteAddr + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rr := httptest.NewRecorder() + + s.serveToken(rr, req) + + if tt.expectError { + if rr.Code == http.StatusOK { + t.Fatalf("expected error, got 200 OK: %s", rr.Body.String()) + } + return + } + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String()) + } + + var resp struct { + IDToken string `json:"id_token"` + } + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to unmarshal response: %v", err) + } + + tok, err := jwt.ParseSigned(resp.IDToken) + if err != nil { + t.Fatalf("failed to parse ID token: %v", err) + } + + out := make(map[string]any) + if err := tok.Claims(oidcTestingPublicKey(t), &out); err != nil { + t.Fatalf("failed to extract claims: %v", err) + } + + for k, want := range tt.expected { + got, ok := out[k] + if !ok { + t.Errorf("missing expected claim %q", k) + continue + } + if !reflect.DeepEqual(got, want) { + t.Errorf("claim %q: got %v, want %v", k, got, want) + } + } + }) + } +} + +func TestExtraUserInfo(t *testing.T) { + tests := []struct { + name string + caps tailcfg.PeerCapMap + tokenValidTill time.Time + expected map[string]any + expectError bool + }{ + { + name: "extra claim", + tokenValidTill: time.Now().Add(1 * time.Minute), + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]any{ + "foo": []string{"bar"}, + }, + }), + }, + }, + expected: map[string]any{ + "foo": []any{"bar"}, + }, + }, + { + name: "duplicate claim distinct values", + tokenValidTill: time.Now().Add(1 * time.Minute), + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]any{ + "foo": []string{"bar", "foobar"}, + }, + }), + }, + }, + expected: map[string]any{ + "foo": []any{"bar", "foobar"}, + }, + }, + { + name: "multiple extra claims", + tokenValidTill: time.Now().Add(1 * time.Minute), + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]any{ + "foo": "bar", + "bar": "foo", + }, + }), + }, + }, + expected: map[string]any{ + "foo": "bar", + "bar": "foo", + }, + }, + { + name: "empty extra claims", + caps: tailcfg.PeerCapMap{}, + tokenValidTill: time.Now().Add(1 * time.Minute), + expected: map[string]any{}, + }, + { + name: "attempt to overwrite protected claim", + tokenValidTill: time.Now().Add(1 * time.Minute), + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: true, + ExtraClaims: map[string]any{ + "sub": "should-not-overwrite", + "foo": "ok", + }, + }), + }, + }, + expectError: true, + }, + { + name: "extra claim omitted", + tokenValidTill: time.Now().Add(1 * time.Minute), + caps: tailcfg.PeerCapMap{ + tailcfg.PeerCapabilityTsIDP: { + mustMarshalJSON(t, capRule{ + IncludeInUserInfo: false, + ExtraClaims: map[string]any{ + "foo": "ok", + }, + }), + }, + }, + expected: map[string]any{}, + }, + { + name: "expired token", + caps: tailcfg.PeerCapMap{}, + tokenValidTill: time.Now().Add(-1 * time.Minute), + expected: map[string]any{}, + expectError: true, + }, + } + token := "valid-token" + + // Create a fake tailscale Node + node := &tailcfg.Node{ + ID: 123, + Name: "test-node.test.ts.net.", + User: 456, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + // Construct the remote user + profile := tailcfg.UserProfile{ + LoginName: "alice@example.com", + DisplayName: "Alice Example", + ProfilePicURL: "https://example.com/alice.jpg", + } + + remoteUser := &apitype.WhoIsResponse{ + Node: node, + UserProfile: &profile, + CapMap: tt.caps, + } + + // Insert a valid token into the idpServer + s := &idpServer{ + accessToken: map[string]*authRequest{ + token: { + validTill: tt.tokenValidTill, + remoteUser: remoteUser, + }, + }, + } + + // Construct request + req := httptest.NewRequest("GET", "/userinfo", nil) + req.Header.Set("Authorization", "Bearer "+token) + rr := httptest.NewRecorder() + + // Call the method under test + s.serveUserInfo(rr, req) + + if tt.expectError { + if rr.Code == http.StatusOK { + t.Fatalf("expected error, got %d: %s", rr.Code, rr.Body.String()) + } + return + } + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200 OK, got %d: %s", rr.Code, rr.Body.String()) + } + + var resp map[string]any + if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to parse JSON response: %v", err) + } + + // Construct expected + tt.expected["sub"] = remoteUser.Node.User.String() + tt.expected["name"] = profile.DisplayName + tt.expected["email"] = profile.LoginName + tt.expected["picture"] = profile.ProfilePicURL + tt.expected["username"], _, _ = strings.Cut(profile.LoginName, "@") + + gotNormalized := normalizeMap(t, resp) + expectedNormalized := normalizeMap(t, tt.expected) + + if !reflect.DeepEqual(gotNormalized, expectedNormalized) { + t.Errorf("UserInfo mismatch:\n got: %#v\nwant: %#v", gotNormalized, expectedNormalized) + } + }) + } +} diff --git a/cmd/tta/tta.go b/cmd/tta/tta.go index 4a4c4a6beebfa..9f8f002958d61 100644 --- a/cmd/tta/tta.go +++ b/cmd/tta/tta.go @@ -30,7 +30,7 @@ import ( "time" "tailscale.com/atomicfile" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/hostinfo" "tailscale.com/util/mak" "tailscale.com/util/must" @@ -64,7 +64,7 @@ func serveCmd(w http.ResponseWriter, cmd string, args ...string) { } type localClientRoundTripper struct { - lc tailscale.LocalClient + lc local.Client } func (rt *localClientRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { diff --git a/cmd/viewer/tests/tests.go b/cmd/viewer/tests/tests.go index 14a4888615bc1..4020e5651978a 100644 --- a/cmd/viewer/tests/tests.go +++ b/cmd/viewer/tests/tests.go @@ -37,9 +37,14 @@ type Map struct { StructWithPtrKey map[StructWithPtrs]int `json:"-"` } +type StructWithNoView struct { + Value int +} + type StructWithPtrs struct { - Value *StructWithoutPtrs - Int *int + Value *StructWithoutPtrs + Int *int + NoView *StructWithNoView NoCloneValue *StructWithoutPtrs `codegen:"noclone"` } @@ -135,7 +140,7 @@ func (c *Container[T]) Clone() *Container[T] { panic(fmt.Errorf("%T contains pointers, but is not cloneable", c.Item)) } -// ContainerView is a pre-defined readonly view of a Container[T]. +// ContainerView is a pre-defined read-only view of a Container[T]. type ContainerView[T views.ViewCloner[T, V], V views.StructView[T]] struct { // Đļ is the underlying mutable value, named with a hard-to-type // character that looks pointy like a pointer. @@ -173,7 +178,7 @@ func (c *MapContainer[K, V]) Clone() *MapContainer[K, V] { return &MapContainer[K, V]{m} } -// MapContainerView is a pre-defined readonly view of a [MapContainer][K, T]. +// MapContainerView is a pre-defined read-only view of a [MapContainer][K, T]. type MapContainerView[K comparable, T views.ViewCloner[T, V], V views.StructView[T]] struct { // Đļ is the underlying mutable value, named with a hard-to-type // character that looks pointy like a pointer. diff --git a/cmd/viewer/tests/tests_clone.go b/cmd/viewer/tests/tests_clone.go index 9131f5040c45d..106a9b6843b56 100644 --- a/cmd/viewer/tests/tests_clone.go +++ b/cmd/viewer/tests/tests_clone.go @@ -28,6 +28,9 @@ func (src *StructWithPtrs) Clone() *StructWithPtrs { if dst.Int != nil { dst.Int = ptr.To(*src.Int) } + if dst.NoView != nil { + dst.NoView = ptr.To(*src.NoView) + } return dst } @@ -35,6 +38,7 @@ func (src *StructWithPtrs) Clone() *StructWithPtrs { var _StructWithPtrsCloneNeedsRegeneration = StructWithPtrs(struct { Value *StructWithoutPtrs Int *int + NoView *StructWithNoView NoCloneValue *StructWithoutPtrs }{}) diff --git a/cmd/viewer/tests/tests_view.go b/cmd/viewer/tests/tests_view.go index 9c74c94261e08..f1d8f424ff01b 100644 --- a/cmd/viewer/tests/tests_view.go +++ b/cmd/viewer/tests/tests_view.go @@ -16,7 +16,7 @@ import ( //go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=StructWithPtrs,StructWithoutPtrs,Map,StructWithSlices,OnlyGetClone,StructWithEmbedded,GenericIntStruct,GenericNoPtrsStruct,GenericCloneableStruct,StructWithContainers,StructWithTypeAliasFields,GenericTypeAliasStruct -// View returns a readonly view of StructWithPtrs. +// View returns a read-only view of StructWithPtrs. func (p *StructWithPtrs) View() StructWithPtrsView { return StructWithPtrsView{Đļ: p} } @@ -32,7 +32,7 @@ type StructWithPtrsView struct { Đļ *StructWithPtrs } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v StructWithPtrsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -61,20 +61,11 @@ func (v *StructWithPtrsView) UnmarshalJSON(b []byte) error { return nil } -func (v StructWithPtrsView) Value() *StructWithoutPtrs { - if v.Đļ.Value == nil { - return nil - } - x := *v.Đļ.Value - return &x -} +func (v StructWithPtrsView) Value() StructWithoutPtrsView { return v.Đļ.Value.View() } +func (v StructWithPtrsView) Int() views.ValuePointer[int] { return views.ValuePointerOf(v.Đļ.Int) } -func (v StructWithPtrsView) Int() *int { - if v.Đļ.Int == nil { - return nil - } - x := *v.Đļ.Int - return &x +func (v StructWithPtrsView) NoView() views.ValuePointer[StructWithNoView] { + return views.ValuePointerOf(v.Đļ.NoView) } func (v StructWithPtrsView) NoCloneValue() *StructWithoutPtrs { return v.Đļ.NoCloneValue } @@ -85,10 +76,11 @@ func (v StructWithPtrsView) Equal(v2 StructWithPtrsView) bool { return v.Đļ.Equa var _StructWithPtrsViewNeedsRegeneration = StructWithPtrs(struct { Value *StructWithoutPtrs Int *int + NoView *StructWithNoView NoCloneValue *StructWithoutPtrs }{}) -// View returns a readonly view of StructWithoutPtrs. +// View returns a read-only view of StructWithoutPtrs. func (p *StructWithoutPtrs) View() StructWithoutPtrsView { return StructWithoutPtrsView{Đļ: p} } @@ -104,7 +96,7 @@ type StructWithoutPtrsView struct { Đļ *StructWithoutPtrs } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v StructWithoutPtrsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -142,7 +134,7 @@ var _StructWithoutPtrsViewNeedsRegeneration = StructWithoutPtrs(struct { Pfx netip.Prefix }{}) -// View returns a readonly view of Map. +// View returns a read-only view of Map. func (p *Map) View() MapView { return MapView{Đļ: p} } @@ -158,7 +150,7 @@ type MapView struct { Đļ *Map } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v MapView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -248,7 +240,7 @@ var _MapViewNeedsRegeneration = Map(struct { StructWithPtrKey map[StructWithPtrs]int }{}) -// View returns a readonly view of StructWithSlices. +// View returns a read-only view of StructWithSlices. func (p *StructWithSlices) View() StructWithSlicesView { return StructWithSlicesView{Đļ: p} } @@ -264,7 +256,7 @@ type StructWithSlicesView struct { Đļ *StructWithSlices } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v StructWithSlicesView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -322,7 +314,7 @@ var _StructWithSlicesViewNeedsRegeneration = StructWithSlices(struct { Ints []*int }{}) -// View returns a readonly view of StructWithEmbedded. +// View returns a read-only view of StructWithEmbedded. func (p *StructWithEmbedded) View() StructWithEmbeddedView { return StructWithEmbeddedView{Đļ: p} } @@ -338,7 +330,7 @@ type StructWithEmbeddedView struct { Đļ *StructWithEmbedded } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v StructWithEmbeddedView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -378,7 +370,7 @@ var _StructWithEmbeddedViewNeedsRegeneration = StructWithEmbedded(struct { StructWithSlices }{}) -// View returns a readonly view of GenericIntStruct. +// View returns a read-only view of GenericIntStruct. func (p *GenericIntStruct[T]) View() GenericIntStructView[T] { return GenericIntStructView[T]{Đļ: p} } @@ -394,7 +386,7 @@ type GenericIntStructView[T constraints.Integer] struct { Đļ *GenericIntStruct[T] } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v GenericIntStructView[T]) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -424,12 +416,8 @@ func (v *GenericIntStructView[T]) UnmarshalJSON(b []byte) error { } func (v GenericIntStructView[T]) Value() T { return v.Đļ.Value } -func (v GenericIntStructView[T]) Pointer() *T { - if v.Đļ.Pointer == nil { - return nil - } - x := *v.Đļ.Pointer - return &x +func (v GenericIntStructView[T]) Pointer() views.ValuePointer[T] { + return views.ValuePointerOf(v.Đļ.Pointer) } func (v GenericIntStructView[T]) Slice() views.Slice[T] { return views.SliceOf(v.Đļ.Slice) } @@ -454,7 +442,7 @@ func _GenericIntStructViewNeedsRegeneration[T constraints.Integer](GenericIntStr }{}) } -// View returns a readonly view of GenericNoPtrsStruct. +// View returns a read-only view of GenericNoPtrsStruct. func (p *GenericNoPtrsStruct[T]) View() GenericNoPtrsStructView[T] { return GenericNoPtrsStructView[T]{Đļ: p} } @@ -470,7 +458,7 @@ type GenericNoPtrsStructView[T StructWithoutPtrs | netip.Prefix | BasicType] str Đļ *GenericNoPtrsStruct[T] } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v GenericNoPtrsStructView[T]) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -500,12 +488,8 @@ func (v *GenericNoPtrsStructView[T]) UnmarshalJSON(b []byte) error { } func (v GenericNoPtrsStructView[T]) Value() T { return v.Đļ.Value } -func (v GenericNoPtrsStructView[T]) Pointer() *T { - if v.Đļ.Pointer == nil { - return nil - } - x := *v.Đļ.Pointer - return &x +func (v GenericNoPtrsStructView[T]) Pointer() views.ValuePointer[T] { + return views.ValuePointerOf(v.Đļ.Pointer) } func (v GenericNoPtrsStructView[T]) Slice() views.Slice[T] { return views.SliceOf(v.Đļ.Slice) } @@ -530,7 +514,7 @@ func _GenericNoPtrsStructViewNeedsRegeneration[T StructWithoutPtrs | netip.Prefi }{}) } -// View returns a readonly view of GenericCloneableStruct. +// View returns a read-only view of GenericCloneableStruct. func (p *GenericCloneableStruct[T, V]) View() GenericCloneableStructView[T, V] { return GenericCloneableStructView[T, V]{Đļ: p} } @@ -546,7 +530,7 @@ type GenericCloneableStructView[T views.ViewCloner[T, V], V views.StructView[T]] Đļ *GenericCloneableStruct[T, V] } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v GenericCloneableStructView[T, V]) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -605,7 +589,7 @@ func _GenericCloneableStructViewNeedsRegeneration[T views.ViewCloner[T, V], V vi }{}) } -// View returns a readonly view of StructWithContainers. +// View returns a read-only view of StructWithContainers. func (p *StructWithContainers) View() StructWithContainersView { return StructWithContainersView{Đļ: p} } @@ -621,7 +605,7 @@ type StructWithContainersView struct { Đļ *StructWithContainers } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v StructWithContainersView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -677,7 +661,7 @@ var _StructWithContainersViewNeedsRegeneration = StructWithContainers(struct { CloneableGenericMap MapContainer[int, *GenericNoPtrsStruct[int]] }{}) -// View returns a readonly view of StructWithTypeAliasFields. +// View returns a read-only view of StructWithTypeAliasFields. func (p *StructWithTypeAliasFields) View() StructWithTypeAliasFieldsView { return StructWithTypeAliasFieldsView{Đļ: p} } @@ -693,7 +677,7 @@ type StructWithTypeAliasFieldsView struct { Đļ *StructWithTypeAliasFields } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v StructWithTypeAliasFieldsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -722,19 +706,14 @@ func (v *StructWithTypeAliasFieldsView) UnmarshalJSON(b []byte) error { return nil } -func (v StructWithTypeAliasFieldsView) WithPtr() StructWithPtrsView { return v.Đļ.WithPtr.View() } +func (v StructWithTypeAliasFieldsView) WithPtr() StructWithPtrsAliasView { return v.Đļ.WithPtr.View() } func (v StructWithTypeAliasFieldsView) WithoutPtr() StructWithoutPtrsAlias { return v.Đļ.WithoutPtr } func (v StructWithTypeAliasFieldsView) WithPtrByPtr() StructWithPtrsAliasView { return v.Đļ.WithPtrByPtr.View() } -func (v StructWithTypeAliasFieldsView) WithoutPtrByPtr() *StructWithoutPtrsAlias { - if v.Đļ.WithoutPtrByPtr == nil { - return nil - } - x := *v.Đļ.WithoutPtrByPtr - return &x +func (v StructWithTypeAliasFieldsView) WithoutPtrByPtr() StructWithoutPtrsAliasView { + return v.Đļ.WithoutPtrByPtr.View() } - func (v StructWithTypeAliasFieldsView) SliceWithPtrs() views.SliceView[*StructWithPtrsAlias, StructWithPtrsAliasView] { return views.SliceOfViews[*StructWithPtrsAlias, StructWithPtrsAliasView](v.Đļ.SliceWithPtrs) } @@ -780,7 +759,7 @@ var _StructWithTypeAliasFieldsViewNeedsRegeneration = StructWithTypeAliasFields( MapOfSlicesWithoutPtrs map[string][]*StructWithoutPtrsAlias }{}) -// View returns a readonly view of GenericTypeAliasStruct. +// View returns a read-only view of GenericTypeAliasStruct. func (p *GenericTypeAliasStruct[T, T2, V2]) View() GenericTypeAliasStructView[T, T2, V2] { return GenericTypeAliasStructView[T, T2, V2]{Đļ: p} } @@ -796,7 +775,7 @@ type GenericTypeAliasStructView[T integer, T2 views.ViewCloner[T2, V2], V2 views Đļ *GenericTypeAliasStruct[T, T2, V2] } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v GenericTypeAliasStructView[T, T2, V2]) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with diff --git a/cmd/viewer/viewer.go b/cmd/viewer/viewer.go index 96223297b46e2..2d30cc2eb1f2d 100644 --- a/cmd/viewer/viewer.go +++ b/cmd/viewer/viewer.go @@ -21,7 +21,7 @@ import ( ) const viewTemplateStr = `{{define "common"}} -// View returns a readonly view of {{.StructName}}. +// View returns a read-only view of {{.StructName}}. func (p *{{.StructName}}{{.TypeParamNames}}) View() {{.ViewName}}{{.TypeParamNames}} { return {{.ViewName}}{{.TypeParamNames}}{Đļ: p} } @@ -37,7 +37,7 @@ type {{.ViewName}}{{.TypeParams}} struct { Đļ *{{.StructName}}{{.TypeParamNames}} } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v {{.ViewName}}{{.TypeParamNames}}) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -79,13 +79,7 @@ func (v *{{.ViewName}}{{.TypeParamNames}}) UnmarshalJSON(b []byte) error { {{end}} {{define "makeViewField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldViewName}} { return {{.MakeViewFnName}}(&v.Đļ.{{.FieldName}}) } {{end}} -{{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() {{.FieldType}} { - if v.Đļ.{{.FieldName}} == nil { - return nil - } - x := *v.Đļ.{{.FieldName}} - return &x -} +{{define "valuePointerField"}}func (v {{.ViewName}}{{.TypeParamNames}}) {{.FieldName}}() views.ValuePointer[{{.FieldType}}] { return views.ValuePointerOf(v.Đļ.{{.FieldName}}) } {{end}} {{define "mapField"}} @@ -126,7 +120,7 @@ func requiresCloning(t types.Type) (shallow, deep bool, base types.Type) { return p, p, t } -func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thisPkg *types.Package) { +func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, _ *types.Package) { t, ok := typ.Underlying().(*types.Struct) if !ok || codegen.IsViewType(t) { return @@ -149,7 +143,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi MapValueView string MapFn string - // MakeViewFnName is the name of the function that accepts a value and returns a readonly view of it. + // MakeViewFnName is the name of the function that accepts a value and returns a read-only view of it. MakeViewFnName string }{ StructName: typ.Obj().Name(), @@ -258,6 +252,7 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi writeTemplate("unsupportedField") continue } + it.Import("tailscale.com/types/views") args.MapKeyType = it.QualifiedName(key) mElem := m.Elem() var template string @@ -353,10 +348,32 @@ func genView(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named, thi } else { writeTemplate("unsupportedField") } - } else { - args.FieldType = it.QualifiedName(ptr) - writeTemplate("valuePointerField") + continue + } + + // If a view type is already defined for the base type, use it as the field's view type. + if viewType := viewTypeForValueType(base); viewType != nil { + args.FieldType = it.QualifiedName(base) + args.FieldViewName = it.QualifiedName(viewType) + writeTemplate("viewField") + continue + } + + // Otherwise, if the unaliased base type is a named type whose view type will be generated by this viewer invocation, + // append the "View" suffix to the unaliased base type name and use it as the field's view type. + if base, ok := types.Unalias(base).(*types.Named); ok && slices.Contains(typeNames, it.QualifiedName(base)) { + baseTypeName := it.QualifiedName(base) + args.FieldType = baseTypeName + args.FieldViewName = appendNameSuffix(args.FieldType, "View") + writeTemplate("viewField") + continue } + + // Otherwise, if the base type does not require deep cloning, has no existing view type, + // and will not have a generated view type, use views.ValuePointer[T] as the field's view type. + // Its Get/GetOk methods return stack-allocated shallow copies of the field's value. + args.FieldType = it.QualifiedName(base) + writeTemplate("valuePointerField") continue case *types.Interface: // If fieldType is an interface with a "View() {ViewType}" method, it can be used to clone the field. @@ -404,6 +421,33 @@ func appendNameSuffix(name, suffix string) string { return name + suffix } +func typeNameOf(typ types.Type) (name *types.TypeName, ok bool) { + switch t := typ.(type) { + case *types.Alias: + return t.Obj(), true + case *types.Named: + return t.Obj(), true + default: + return nil, false + } +} + +func lookupViewType(typ types.Type) types.Type { + for { + if typeName, ok := typeNameOf(typ); ok && typeName.Pkg() != nil { + if viewTypeObj := typeName.Pkg().Scope().Lookup(typeName.Name() + "View"); viewTypeObj != nil { + return viewTypeObj.Type() + } + } + switch alias := typ.(type) { + case *types.Alias: + typ = alias.Rhs() + default: + return nil + } + } +} + func viewTypeForValueType(typ types.Type) types.Type { if ptr, ok := typ.(*types.Pointer); ok { return viewTypeForValueType(ptr.Elem()) @@ -416,7 +460,12 @@ func viewTypeForValueType(typ types.Type) types.Type { if !ok || sig.Results().Len() != 1 { return nil } - return sig.Results().At(0).Type() + viewType := sig.Results().At(0).Type() + // Check if the typ's package defines an alias for the view type, and use it if so. + if viewTypeAlias, ok := lookupViewType(typ).(*types.Alias); ok && types.AssignableTo(viewType, viewTypeAlias) { + viewType = viewTypeAlias + } + return viewType } func viewTypeForContainerType(typ types.Type) (*types.Named, *types.Func) { diff --git a/cmd/viewer/viewer_test.go b/cmd/viewer/viewer_test.go new file mode 100644 index 0000000000000..cd5f3d95f9c93 --- /dev/null +++ b/cmd/viewer/viewer_test.go @@ -0,0 +1,78 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/token" + "go/types" + "testing" + + "tailscale.com/util/codegen" +) + +func TestViewerImports(t *testing.T) { + tests := []struct { + name string + content string + typeNames []string + wantImports []string + }{ + { + name: "Map", + content: `type Test struct { Map map[string]int }`, + typeNames: []string{"Test"}, + wantImports: []string{"tailscale.com/types/views"}, + }, + { + name: "Slice", + content: `type Test struct { Slice []int }`, + typeNames: []string{"Test"}, + wantImports: []string{"tailscale.com/types/views"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", "package test\n\n"+tt.content, 0) + if err != nil { + fmt.Println("Error parsing:", err) + return + } + + info := &types.Info{ + Types: make(map[ast.Expr]types.TypeAndValue), + } + + conf := types.Config{} + pkg, err := conf.Check("", fset, []*ast.File{f}, info) + if err != nil { + t.Fatal(err) + } + + var output bytes.Buffer + tracker := codegen.NewImportTracker(pkg) + for i := range tt.typeNames { + typeName, ok := pkg.Scope().Lookup(tt.typeNames[i]).(*types.TypeName) + if !ok { + t.Fatalf("type %q does not exist", tt.typeNames[i]) + } + namedType, ok := typeName.Type().(*types.Named) + if !ok { + t.Fatalf("%q is not a named type", tt.typeNames[i]) + } + genView(&output, tracker, namedType, pkg) + } + + for _, pkgName := range tt.wantImports { + if !tracker.Has(pkgName) { + t.Errorf("missing import %q", pkgName) + } + } + }) + } +} diff --git a/cmd/vnet/vnet-main.go b/cmd/vnet/vnet-main.go index 1eb4f65ef2070..9dd4d8cfafe94 100644 --- a/cmd/vnet/vnet-main.go +++ b/cmd/vnet/vnet-main.go @@ -7,15 +7,21 @@ package main import ( "context" + "encoding/binary" "flag" + "fmt" + "io" "log" "net" "net/http" "net/http/httputil" "net/url" "os" + "path/filepath" + "slices" "time" + "github.com/coder/websocket" "tailscale.com/tstest/natlab/vnet" "tailscale.com/types/logger" "tailscale.com/util/must" @@ -31,10 +37,18 @@ var ( pcapFile = flag.String("pcap", "", "if non-empty, filename to write pcap") v4 = flag.Bool("v4", true, "enable IPv4") v6 = flag.Bool("v6", true, "enable IPv6") + + wsproxyListen = flag.String("wsproxy", "", "if non-empty, TCP address to run websocket server on. See https://github.com/copy/v86/blob/master/docs/networking.md#backend-url-schemes") ) func main() { flag.Parse() + if *wsproxyListen != "" { + if err := runWSProxy(); err != nil { + log.Fatalf("runWSProxy: %v", err) + } + return + } if _, err := os.Stat(*listen); err == nil { os.Remove(*listen) @@ -137,3 +151,168 @@ func main() { go s.ServeUnixConn(c.(*net.UnixConn), vnet.ProtocolQEMU) } } + +func runWSProxy() error { + ln, err := net.Listen("tcp", *wsproxyListen) + if err != nil { + return err + } + defer ln.Close() + + log.Printf("Running wsproxy mode on %v ...", *wsproxyListen) + + var hs http.Server + hs.Handler = http.HandlerFunc(handleWebSocket) + + return hs.Serve(ln) +} + +func handleWebSocket(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + InsecureSkipVerify: true, + }) + if err != nil { + log.Printf("Upgrade error: %v", err) + return + } + defer conn.Close(websocket.StatusInternalError, "closing") + log.Printf("WebSocket client connected: %s", r.RemoteAddr) + + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + + messageType, firstData, err := conn.Read(ctx) + if err != nil { + log.Printf("ReadMessage first: %v", err) + return + } + if messageType != websocket.MessageBinary { + log.Printf("Ignoring non-binary message") + return + } + if len(firstData) < 12 { + log.Printf("Ignoring short message") + return + } + clientMAC := vnet.MAC(firstData[6:12]) + + // Set up a qemu-protocol Unix socket pair. We'll fake the qemu protocol here + // to avoid changing the vnet package. + td, err := os.MkdirTemp("", "vnet") + if err != nil { + panic(fmt.Errorf("MkdirTemp: %v", err)) + } + defer os.RemoveAll(td) + + unixSrv := filepath.Join(td, "vnet.sock") + + srv, err := net.Listen("unix", unixSrv) + if err != nil { + panic(fmt.Errorf("Listen: %v", err)) + } + defer srv.Close() + + var c vnet.Config + c.SetBlendReality(true) + + var net1opt = []any{vnet.NAT("easy")} + net1opt = append(net1opt, "2.1.1.1", "192.168.1.1/24") + net1opt = append(net1opt, "2000:52::1/64") + + c.AddNode(c.AddNetwork(net1opt...), clientMAC) + + vs, err := vnet.New(&c) + if err != nil { + panic(fmt.Errorf("newServer: %v", err)) + } + if err := vs.PopulateDERPMapIPs(); err != nil { + log.Printf("warning: ignoring failure to populate DERP map: %v", err) + return + } + + errc := make(chan error, 1) + fail := func(err error) { + select { + case errc <- err: + log.Printf("failed: %v", err) + case <-ctx.Done(): + } + } + + go func() { + c, err := srv.Accept() + if err != nil { + fail(err) + return + } + vs.ServeUnixConn(c.(*net.UnixConn), vnet.ProtocolQEMU) + }() + + uc, err := net.Dial("unix", unixSrv) + if err != nil { + panic(fmt.Errorf("Dial: %v", err)) + } + defer uc.Close() + + var frameBuf []byte + writeDataToUnixConn := func(data []byte) error { + frameBuf = slices.Grow(frameBuf[:0], len(data)+4)[:len(data)+4] + binary.BigEndian.PutUint32(frameBuf[:4], uint32(len(data))) + copy(frameBuf[4:], data) + + _, err = uc.Write(frameBuf) + return err + } + if err := writeDataToUnixConn(firstData); err != nil { + fail(err) + return + } + + go func() { + for { + messageType, data, err := conn.Read(ctx) + if err != nil { + fail(fmt.Errorf("ReadMessage: %v", err)) + break + } + + if messageType != websocket.MessageBinary { + log.Printf("Ignoring non-binary message") + continue + } + + if err := writeDataToUnixConn(data); err != nil { + fail(err) + return + } + } + }() + + go func() { + const maxBuf = 4096 + frameBuf := make([]byte, maxBuf) + for { + _, err := io.ReadFull(uc, frameBuf[:4]) + if err != nil { + fail(err) + return + } + frameLen := binary.BigEndian.Uint32(frameBuf[:4]) + if frameLen > maxBuf { + fail(fmt.Errorf("frame too large: %d", frameLen)) + return + } + if _, err := io.ReadFull(uc, frameBuf[:frameLen]); err != nil { + fail(err) + return + } + + if err := conn.Write(ctx, websocket.MessageBinary, frameBuf[:frameLen]); err != nil { + fail(err) + return + } + } + }() + + <-ctx.Done() +} diff --git a/cmd/xdpderper/xdpderper.go b/cmd/xdpderper/xdpderper.go index 599034ae7259c..c127baf54e340 100644 --- a/cmd/xdpderper/xdpderper.go +++ b/cmd/xdpderper/xdpderper.go @@ -18,6 +18,9 @@ import ( "tailscale.com/derp/xdp" "tailscale.com/net/netutil" "tailscale.com/tsweb" + + // Support for prometheus varz in tsweb + _ "tailscale.com/tsweb/promvarz" ) var ( diff --git a/control/controlbase/conn_test.go b/control/controlbase/conn_test.go index 8a0f46967e342..ed4642d3b179c 100644 --- a/control/controlbase/conn_test.go +++ b/control/controlbase/conn_test.go @@ -280,7 +280,7 @@ func TestConnMemoryOverhead(t *testing.T) { growthTotal := int64(ms.HeapAlloc) - int64(ms0.HeapAlloc) growthEach := float64(growthTotal) / float64(num) t.Logf("Alloced %v bytes, %.2f B/each", growthTotal, growthEach) - const max = 2000 + const max = 2048 if growthEach > max { t.Errorf("allocated more than expected; want max %v bytes/each", max) } diff --git a/control/controlclient/auto.go b/control/controlclient/auto.go index edd0ae29c645d..e0168c19db6c0 100644 --- a/control/controlclient/auto.go +++ b/control/controlclient/auto.go @@ -21,6 +21,7 @@ import ( "tailscale.com/types/netmap" "tailscale.com/types/persist" "tailscale.com/types/structs" + "tailscale.com/util/clientmetric" "tailscale.com/util/execqueue" ) @@ -118,6 +119,7 @@ type Auto struct { updateCh chan struct{} // readable when we should inform the server of a change observer Observer // called to update Client status; always non-nil observerQueue execqueue.ExecQueue + shutdownFn func() // to be called prior to shutdown or nil unregisterHealthWatch func() @@ -131,6 +133,8 @@ type Auto struct { // the server. lastUpdateGen updateGen + lastStatus atomic.Pointer[Status] + paused bool // whether we should stop making HTTP requests unpauseWaiters []chan bool // chans that gets sent true (once) on wake, or false on Shutdown loggedIn bool // true if currently logged in @@ -186,6 +190,7 @@ func NewNoStart(opts Options) (_ *Auto, err error) { mapDone: make(chan struct{}), updateDone: make(chan struct{}), observer: opts.Observer, + shutdownFn: opts.Shutdown, } c.authCtx, c.authCancel = context.WithCancel(context.Background()) c.authCtx = sockstats.WithSockStats(c.authCtx, sockstats.LabelControlClientAuto, opts.Logf) @@ -596,21 +601,90 @@ func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkM // not logged in. nm = nil } - new := Status{ + newSt := &Status{ URL: url, Persist: p, NetMap: nm, Err: err, state: state, } + c.lastStatus.Store(newSt) // Launch a new goroutine to avoid blocking the caller while the observer // does its thing, which may result in a call back into the client. + metricQueued.Add(1) c.observerQueue.Add(func() { - c.observer.SetControlClientStatus(c, new) + if canSkipStatus(newSt, c.lastStatus.Load()) { + metricSkippable.Add(1) + if !c.direct.controlKnobs.DisableSkipStatusQueue.Load() { + metricSkipped.Add(1) + return + } + } + c.observer.SetControlClientStatus(c, *newSt) + + // Best effort stop retaining the memory now that we've sent it to the + // observer (LocalBackend). We CAS here because the caller goroutine is + // doing a Store which we want to win a race. This is only a memory + // optimization and is not for correctness. + // + // If the CAS fails, that means somebody else's Store replaced our + // pointer (so mission accomplished: our netmap is no longer retained in + // any case) and that Store caller will be responsible for removing + // their own netmap (or losing their race too, down the chain). + // Eventually the last caller will win this CAS and zero lastStatus. + c.lastStatus.CompareAndSwap(newSt, nil) }) } +var ( + metricQueued = clientmetric.NewCounter("controlclient_auto_status_queued") + metricSkippable = clientmetric.NewCounter("controlclient_auto_status_queue_skippable") + metricSkipped = clientmetric.NewCounter("controlclient_auto_status_queue_skipped") +) + +// canSkipStatus reports whether we can skip sending s1, knowing +// that s2 is enqueued sometime in the future after s1. +// +// s1 must be non-nil. s2 may be nil. +func canSkipStatus(s1, s2 *Status) bool { + if s2 == nil { + // Nothing in the future. + return false + } + if s1 == s2 { + // If the last item in the queue is the same as s1, + // we can't skip it. + return false + } + if s1.Err != nil || s1.URL != "" { + // If s1 has an error or a URL, we shouldn't skip it, lest the error go + // away in s2 or in-between. We want to make sure all the subsystems see + // it. Plus there aren't many of these, so not worth skipping. + return false + } + if !s1.Persist.Equals(s2.Persist) || s1.state != s2.state { + // If s1 has a different Persist or state than s2, + // don't skip it. We only care about skipping the typical + // entries where the only difference is the NetMap. + return false + } + // If nothing above precludes it, and both s1 and s2 have NetMaps, then + // we can skip it, because s2's NetMap is a newer version and we can + // jump straight from whatever state we had before to s2's state, + // without passing through s1's state first. A NetMap is regrettably a + // full snapshot of the state, not an incremental delta. We're slowly + // moving towards passing around only deltas around internally at all + // layers, but this is explicitly the case where we didn't have a delta + // path for the message we received over the wire and had to resort + // to the legacy full NetMap path. And then we can get behind processing + // these full NetMap snapshots in LocalBackend/wgengine/magicsock/netstack + // and this path (when it returns true) lets us skip over useless work + // and not get behind in the queue. This matters in particular for tailnets + // that are both very large + very churny. + return s1.NetMap != nil && s2.NetMap != nil +} + func (c *Auto) Login(flags LoginFlags) { c.logf("client.Login(%v)", flags) @@ -683,6 +757,7 @@ func (c *Auto) Shutdown() { return } c.logf("client.Shutdown ...") + shutdownFn := c.shutdownFn direct := c.direct c.closed = true @@ -695,6 +770,10 @@ func (c *Auto) Shutdown() { c.unpauseWaiters = nil c.mu.Unlock() + if shutdownFn != nil { + shutdownFn() + } + c.unregisterHealthWatch() <-c.authDone <-c.mapDone diff --git a/control/controlclient/controlclient_test.go b/control/controlclient/controlclient_test.go index b376234511c09..f8882a4e796ca 100644 --- a/control/controlclient/controlclient_test.go +++ b/control/controlclient/controlclient_test.go @@ -4,8 +4,15 @@ package controlclient import ( + "errors" + "fmt" + "io" "reflect" + "slices" "testing" + + "tailscale.com/types/netmap" + "tailscale.com/types/persist" ) func fieldsOf(t reflect.Type) (fields []string) { @@ -62,3 +69,122 @@ func TestStatusEqual(t *testing.T) { } } } + +// tests [canSkipStatus]. +func TestCanSkipStatus(t *testing.T) { + st := new(Status) + nm1 := &netmap.NetworkMap{} + nm2 := &netmap.NetworkMap{} + + tests := []struct { + name string + s1, s2 *Status + want bool + }{ + { + name: "nil-s2", + s1: st, + s2: nil, + want: false, + }, + { + name: "equal", + s1: st, + s2: st, + want: false, + }, + { + name: "s1-error", + s1: &Status{Err: io.EOF, NetMap: nm1}, + s2: &Status{NetMap: nm2}, + want: false, + }, + { + name: "s1-url", + s1: &Status{URL: "foo", NetMap: nm1}, + s2: &Status{NetMap: nm2}, + want: false, + }, + { + name: "s1-persist-diff", + s1: &Status{Persist: new(persist.Persist).View(), NetMap: nm1}, + s2: &Status{NetMap: nm2}, + want: false, + }, + { + name: "s1-state-diff", + s1: &Status{state: 123, NetMap: nm1}, + s2: &Status{NetMap: nm2}, + want: false, + }, + { + name: "s1-no-netmap1", + s1: &Status{NetMap: nil}, + s2: &Status{NetMap: nm2}, + want: false, + }, + { + name: "s1-no-netmap2", + s1: &Status{NetMap: nm1}, + s2: &Status{NetMap: nil}, + want: false, + }, + { + name: "skip", + s1: &Status{NetMap: nm1}, + s2: &Status{NetMap: nm2}, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := canSkipStatus(tt.s1, tt.s2); got != tt.want { + t.Errorf("canSkipStatus = %v, want %v", got, tt.want) + } + }) + } + + want := []string{"Err", "URL", "NetMap", "Persist", "state"} + if f := fieldsOf(reflect.TypeFor[Status]()); !slices.Equal(f, want) { + t.Errorf("Status fields = %q; this code was only written to handle fields %q", f, want) + } +} + +func TestRetryableErrors(t *testing.T) { + errorTests := []struct { + err error + want bool + }{ + {errNoNoiseClient, true}, + {errNoNodeKey, true}, + {fmt.Errorf("%w: %w", errNoNoiseClient, errors.New("no noise")), true}, + {fmt.Errorf("%w: %w", errHTTPPostFailure, errors.New("bad post")), true}, + {fmt.Errorf("%w: %w", errNoNodeKey, errors.New("not node key")), true}, + {errBadHTTPResponse(429, "too may requests"), true}, + {errBadHTTPResponse(500, "internal server eror"), true}, + {errBadHTTPResponse(502, "bad gateway"), true}, + {errBadHTTPResponse(503, "service unavailable"), true}, + {errBadHTTPResponse(504, "gateway timeout"), true}, + {errBadHTTPResponse(1234, "random error"), false}, + } + + for _, tt := range errorTests { + t.Run(tt.err.Error(), func(t *testing.T) { + if isRetryableErrorForTest(tt.err) != tt.want { + t.Fatalf("retriable: got %v, want %v", tt.err, tt.want) + } + }) + } +} + +type retryableForTest interface { + Retryable() bool +} + +func isRetryableErrorForTest(err error) bool { + var ae retryableForTest + if errors.As(err, &ae) { + return ae.Retryable() + } + return false +} diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 9cbd0e14ead52..ac799e2d916dc 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -15,7 +15,6 @@ import ( "log" "net" "net/http" - "net/http/httptest" "net/netip" "net/url" "os" @@ -38,10 +37,12 @@ import ( "tailscale.com/net/dnsfallback" "tailscale.com/net/netmon" "tailscale.com/net/netutil" + "tailscale.com/net/netx" "tailscale.com/net/tlsdial" "tailscale.com/net/tsdial" "tailscale.com/net/tshttpproxy" "tailscale.com/tailcfg" + "tailscale.com/tempfork/httprec" "tailscale.com/tka" "tailscale.com/tstime" "tailscale.com/types/key" @@ -94,15 +95,16 @@ type Direct struct { sfGroup singleflight.Group[struct{}, *NoiseClient] // protects noiseClient creation. noiseClient *NoiseClient - persist persist.PersistView - authKey string - tryingNewKey key.NodePrivate - expiry time.Time // or zero value if none/unknown - hostinfo *tailcfg.Hostinfo // always non-nil - netinfo *tailcfg.NetInfo - endpoints []tailcfg.Endpoint - tkaHead string - lastPingURL string // last PingRequest.URL received, for dup suppression + persist persist.PersistView + authKey string + tryingNewKey key.NodePrivate + expiry time.Time // or zero value if none/unknown + hostinfo *tailcfg.Hostinfo // always non-nil + netinfo *tailcfg.NetInfo + endpoints []tailcfg.Endpoint + tkaHead string + lastPingURL string // last PingRequest.URL received, for dup suppression + connectionHandleForTest string // sent in MapRequest.ConnectionHandleForTest } // Observer is implemented by users of the control client (such as LocalBackend) @@ -156,6 +158,11 @@ type Options struct { // If we receive a new DialPlan from the server, this value will be // updated. DialPlan ControlDialPlanner + + // Shutdown is an optional function that will be called before client shutdown is + // attempted. It is used to allow the client to clean up any resources or complete any + // tasks that are dependent on a live client. + Shutdown func() } // ControlDialPlanner is the interface optionally supplied when creating a @@ -267,7 +274,7 @@ func NewDirect(opts Options) (*Direct, error) { tr.Proxy = tshttpproxy.ProxyFromEnvironment tshttpproxy.SetTransportGetProxyConnectHeader(tr) tr.TLSClientConfig = tlsdial.Config(serverURL.Hostname(), opts.HealthTracker, tr.TLSClientConfig) - var dialFunc dialFunc + var dialFunc netx.DialFunc dialFunc, interceptedDial = makeScreenTimeDetectingDialFunc(opts.Dialer.SystemDial) tr.DialContext = dnscache.Dialer(dialFunc, dnsCache) tr.DialTLSContext = dnscache.TLSDialer(dialFunc, dnsCache, tr.TLSClientConfig) @@ -397,6 +404,14 @@ func (c *Direct) SetTKAHead(tkaHead string) bool { return true } +// SetConnectionHandleForTest stores a new MapRequest.ConnectionHandleForTest +// value for the next update. +func (c *Direct) SetConnectionHandleForTest(handle string) { + c.mu.Lock() + defer c.mu.Unlock() + c.connectionHandleForTest = handle +} + func (c *Direct) GetPersist() persist.PersistView { c.mu.Lock() defer c.mu.Unlock() @@ -650,7 +665,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new c.logf("RegisterReq sign error: %v", err) } } - if debugRegister() { + if DevKnob.DumpRegister() { j, _ := json.MarshalIndent(request, "", "\t") c.logf("RegisterRequest: %s", j) } @@ -691,7 +706,7 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new c.logf("error decoding RegisterResponse with server key %s and machine key %s: %v", serverKey, machinePrivKey.Public(), err) return regen, opt.URL, nil, fmt.Errorf("register request: %v", err) } - if debugRegister() { + if DevKnob.DumpRegister() { j, _ := json.MarshalIndent(resp, "", "\t") c.logf("RegisterResponse: %s", j) } @@ -845,6 +860,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap serverNoiseKey := c.serverNoiseKey hi := c.hostInfoLocked() backendLogID := hi.BackendLogID + connectionHandleForTest := c.connectionHandleForTest var epStrs []string var eps []netip.AddrPort var epTypes []tailcfg.EndpointType @@ -877,7 +893,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap c.logf("[v1] PollNetMap: stream=%v ep=%v", isStreaming, epStrs) vlogf := logger.Discard - if DevKnob.DumpNetMaps() { + if DevKnob.DumpNetMapsVerbose() { // TODO(bradfitz): update this to use "[v2]" prefix perhaps? but we don't // want to upload it always. vlogf = c.logf @@ -885,17 +901,18 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap nodeKey := persist.PublicNodeKey() request := &tailcfg.MapRequest{ - Version: tailcfg.CurrentCapabilityVersion, - KeepAlive: true, - NodeKey: nodeKey, - DiscoKey: c.discoPubKey, - Endpoints: eps, - EndpointTypes: epTypes, - Stream: isStreaming, - Hostinfo: hi, - DebugFlags: c.debugFlags, - OmitPeers: nu == nil, - TKAHead: c.tkaHead, + Version: tailcfg.CurrentCapabilityVersion, + KeepAlive: true, + NodeKey: nodeKey, + DiscoKey: c.discoPubKey, + Endpoints: eps, + EndpointTypes: epTypes, + Stream: isStreaming, + Hostinfo: hi, + DebugFlags: c.debugFlags, + OmitPeers: nu == nil, + TKAHead: c.tkaHead, + ConnectionHandleForTest: connectionHandleForTest, } var extraDebugFlags []string if hi != nil && c.netMon != nil && !c.skipIPForwardingCheck && @@ -1003,7 +1020,9 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap if persist == c.persist { newPersist := persist.AsStruct() newPersist.NodeID = nm.SelfNode.StableID() - newPersist.UserProfile = nm.UserProfiles[nm.User()] + if up, ok := nm.UserProfiles[nm.User()]; ok { + newPersist.UserProfile = *up.AsStruct() + } c.persist = newPersist.View() persist = c.persist @@ -1079,7 +1098,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap } else { vlogf("netmap: got new map") } - if resp.ControlDialPlan != nil { + if resp.ControlDialPlan != nil && !ignoreDialPlan() { if c.dialPlan != nil { c.logf("netmap: got new dial plan from control") c.dialPlan.Store(resp.ControlDialPlan) @@ -1170,11 +1189,6 @@ func decode(res *http.Response, v any) error { return json.Unmarshal(msg, v) } -var ( - debugMap = envknob.RegisterBool("TS_DEBUG_MAP") - debugRegister = envknob.RegisterBool("TS_DEBUG_REGISTER") -) - var jsonEscapedZero = []byte(`\u0000`) // decodeMsg is responsible for uncompressing msg and unmarshaling into v. @@ -1183,7 +1197,7 @@ func (c *Direct) decodeMsg(compressedMsg []byte, v any) error { if err != nil { return err } - if debugMap() { + if DevKnob.DumpNetMaps() { var buf bytes.Buffer json.Indent(&buf, b, "", " ") log.Printf("MapResponse: %s", buf.Bytes()) @@ -1205,7 +1219,7 @@ func encode(v any) ([]byte, error) { if err != nil { return nil, err } - if debugMap() { + if DevKnob.DumpNetMaps() { if _, ok := v.(*tailcfg.MapRequest); ok { log.Printf("MapRequest: %s", b) } @@ -1253,18 +1267,25 @@ func loadServerPubKeys(ctx context.Context, httpc *http.Client, serverURL string var DevKnob = initDevKnob() type devKnobs struct { - DumpNetMaps func() bool - ForceProxyDNS func() bool - StripEndpoints func() bool // strip endpoints from control (only use disco messages) - StripCaps func() bool // strip all local node's control-provided capabilities + DumpRegister func() bool + DumpNetMaps func() bool + DumpNetMapsVerbose func() bool + ForceProxyDNS func() bool + StripEndpoints func() bool // strip endpoints from control (only use disco messages) + StripHomeDERP func() bool // strip Home DERP from control + StripCaps func() bool // strip all local node's control-provided capabilities } func initDevKnob() devKnobs { + nm := envknob.RegisterInt("TS_DEBUG_MAP") return devKnobs{ - DumpNetMaps: envknob.RegisterBool("TS_DEBUG_NETMAP"), - ForceProxyDNS: envknob.RegisterBool("TS_DEBUG_PROXY_DNS"), - StripEndpoints: envknob.RegisterBool("TS_DEBUG_STRIP_ENDPOINTS"), - StripCaps: envknob.RegisterBool("TS_DEBUG_STRIP_CAPS"), + DumpNetMaps: func() bool { return nm() > 0 }, + DumpNetMapsVerbose: func() bool { return nm() > 1 }, + DumpRegister: envknob.RegisterBool("TS_DEBUG_REGISTER"), + ForceProxyDNS: envknob.RegisterBool("TS_DEBUG_PROXY_DNS"), + StripEndpoints: envknob.RegisterBool("TS_DEBUG_STRIP_ENDPOINTS"), + StripHomeDERP: envknob.RegisterBool("TS_DEBUG_STRIP_HOME_DERP"), + StripCaps: envknob.RegisterBool("TS_DEBUG_STRIP_CAPS"), } } @@ -1384,7 +1405,7 @@ func answerC2NPing(logf logger.Logf, c2nHandler http.Handler, c *http.Client, pr handlerCtx, cancel := context.WithTimeout(context.Background(), handlerTimeout) defer cancel() hreq = hreq.WithContext(handlerCtx) - rec := httptest.NewRecorder() + rec := httprec.NewRecorder() c2nHandler.ServeHTTP(rec, hreq) cancel() @@ -1643,20 +1664,109 @@ func (c *Direct) ReportHealthChange(w *health.Warnable, us *health.UnhealthyStat res.Body.Close() } +// SetDeviceAttrs does a synchronous call to the control plane to update +// the node's attributes. +// +// See docs on [tailcfg.SetDeviceAttributesRequest] for background. +func (c *Auto) SetDeviceAttrs(ctx context.Context, attrs tailcfg.AttrUpdate) error { + return c.direct.SetDeviceAttrs(ctx, attrs) +} + +// SetDeviceAttrs does a synchronous call to the control plane to update +// the node's attributes. +// +// See docs on [tailcfg.SetDeviceAttributesRequest] for background. +func (c *Direct) SetDeviceAttrs(ctx context.Context, attrs tailcfg.AttrUpdate) error { + nc, err := c.getNoiseClient() + if err != nil { + return fmt.Errorf("%w: %w", errNoNoiseClient, err) + } + nodeKey, ok := c.GetPersist().PublicNodeKeyOK() + if !ok { + return errNoNodeKey + } + if c.panicOnUse { + panic("tainted client") + } + req := &tailcfg.SetDeviceAttributesRequest{ + NodeKey: nodeKey, + Version: tailcfg.CurrentCapabilityVersion, + Update: attrs, + } + + // TODO(bradfitz): unify the callers using doWithBody vs those using + // DoNoiseRequest. There seems to be a ~50/50 split and they're very close, + // but doWithBody sets the load balancing header and auto-JSON-encodes the + // body, but DoNoiseRequest is exported. Clean it up so they're consistent + // one way or another. + + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + res, err := nc.doWithBody(ctx, "PATCH", "/machine/set-device-attr", nodeKey, req) + if err != nil { + return err + } + defer res.Body.Close() + all, _ := io.ReadAll(res.Body) + if res.StatusCode != 200 { + return fmt.Errorf("HTTP error from control plane: %v: %s", res.Status, all) + } + return nil +} + +// SendAuditLog implements [auditlog.Transport] by sending an audit log synchronously to the control plane. +// +// See docs on [tailcfg.AuditLogRequest] and [auditlog.Logger] for background. +func (c *Auto) SendAuditLog(ctx context.Context, auditLog tailcfg.AuditLogRequest) (err error) { + return c.direct.sendAuditLog(ctx, auditLog) +} + +func (c *Direct) sendAuditLog(ctx context.Context, auditLog tailcfg.AuditLogRequest) (err error) { + nc, err := c.getNoiseClient() + if err != nil { + return fmt.Errorf("%w: %w", errNoNoiseClient, err) + } + + nodeKey, ok := c.GetPersist().PublicNodeKeyOK() + if !ok { + return errNoNodeKey + } + + req := &tailcfg.AuditLogRequest{ + Version: tailcfg.CurrentCapabilityVersion, + NodeKey: nodeKey, + Action: auditLog.Action, + Details: auditLog.Details, + } + + if c.panicOnUse { + panic("tainted client") + } + + res, err := nc.post(ctx, "/machine/audit-log", nodeKey, req) + if err != nil { + return fmt.Errorf("%w: %w", errHTTPPostFailure, err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + all, _ := io.ReadAll(res.Body) + return errBadHTTPResponse(res.StatusCode, string(all)) + } + return nil +} + func addLBHeader(req *http.Request, nodeKey key.NodePublic) { if !nodeKey.IsZero() { req.Header.Add(tailcfg.LBHeader, nodeKey.String()) } } -type dialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) - // makeScreenTimeDetectingDialFunc returns dialFunc, optionally wrapped (on // Apple systems) with a func that sets the returned atomic.Bool for whether // Screen Time seemed to intercept the connection. // // The returned *atomic.Bool is nil on non-Apple systems. -func makeScreenTimeDetectingDialFunc(dial dialFunc) (dialFunc, *atomic.Bool) { +func makeScreenTimeDetectingDialFunc(dial netx.DialFunc) (netx.DialFunc, *atomic.Bool) { switch runtime.GOOS { case "darwin", "ios": // Continue below. @@ -1674,6 +1784,13 @@ func makeScreenTimeDetectingDialFunc(dial dialFunc) (dialFunc, *atomic.Bool) { }, ab } +func ignoreDialPlan() bool { + // If we're running in v86 (a JavaScript-based emulation of a 32-bit x86) + // our networking is very limited. Let's ignore the dial plan since it's too + // complicated to race that many IPs anyway. + return hostinfo.IsInVM86() +} + func isTCPLoopback(a net.Addr) bool { if ta, ok := a.(*net.TCPAddr); ok { return ta.IP.IsLoopback() diff --git a/control/controlclient/errors.go b/control/controlclient/errors.go new file mode 100644 index 0000000000000..9b4dab84467b8 --- /dev/null +++ b/control/controlclient/errors.go @@ -0,0 +1,51 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package controlclient + +import ( + "errors" + "fmt" + "net/http" +) + +// apiResponseError is an error type that can be returned by controlclient +// api requests. +// +// It wraps an underlying error and a flag for clients to query if the +// error is retryable via the Retryable() method. +type apiResponseError struct { + err error + retryable bool +} + +// Error implements [error]. +func (e *apiResponseError) Error() string { + return e.err.Error() +} + +// Retryable reports whether the error is retryable. +func (e *apiResponseError) Retryable() bool { + return e.retryable +} + +func (e *apiResponseError) Unwrap() error { return e.err } + +var ( + errNoNodeKey = &apiResponseError{errors.New("no node key"), true} + errNoNoiseClient = &apiResponseError{errors.New("no noise client"), true} + errHTTPPostFailure = &apiResponseError{errors.New("http failure"), true} +) + +func errBadHTTPResponse(code int, msg string) error { + retryable := false + switch code { + case http.StatusTooManyRequests, + http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusServiceUnavailable, + http.StatusGatewayTimeout: + retryable = true + } + return &apiResponseError{fmt.Errorf("http error %d: %s", code, msg), retryable} +} diff --git a/control/controlclient/map.go b/control/controlclient/map.go index 7879122229e37..3173040fe31d8 100644 --- a/control/controlclient/map.go +++ b/control/controlclient/map.go @@ -7,20 +7,19 @@ import ( "cmp" "context" "encoding/json" - "fmt" "maps" "net" "reflect" "runtime" "runtime/debug" "slices" - "sort" "strconv" "sync" "time" "tailscale.com/control/controlknobs" "tailscale.com/envknob" + "tailscale.com/hostinfo" "tailscale.com/tailcfg" "tailscale.com/tstime" "tailscale.com/types/key" @@ -31,6 +30,7 @@ import ( "tailscale.com/util/clientmetric" "tailscale.com/util/mak" "tailscale.com/util/set" + "tailscale.com/util/slicesx" "tailscale.com/wgengine/filter" ) @@ -75,11 +75,10 @@ type mapSession struct { lastPrintMap time.Time lastNode tailcfg.NodeView lastCapSet set.Set[tailcfg.NodeCapability] - peers map[tailcfg.NodeID]*tailcfg.NodeView // pointer to view (oddly). same pointers as sortedPeers. - sortedPeers []*tailcfg.NodeView // same pointers as peers, but sorted by Node.ID + peers map[tailcfg.NodeID]tailcfg.NodeView lastDNSConfig *tailcfg.DNSConfig lastDERPMap *tailcfg.DERPMap - lastUserProfile map[tailcfg.UserID]tailcfg.UserProfile + lastUserProfile map[tailcfg.UserID]tailcfg.UserProfileView lastPacketFilterRules views.Slice[tailcfg.FilterRule] // concatenation of all namedPacketFilters namedPacketFilters map[string]views.Slice[tailcfg.FilterRule] lastParsedPacketFilter []filter.Match @@ -91,7 +90,6 @@ type mapSession struct { lastPopBrowserURL string lastTKAInfo *tailcfg.TKAInfo lastNetmapSummary string // from NetworkMap.VeryConcise - lastMaxExpiry time.Duration } // newMapSession returns a mostly unconfigured new mapSession. @@ -106,7 +104,7 @@ func newMapSession(privateNodeKey key.NodePrivate, nu NetmapUpdater, controlKnob privateNodeKey: privateNodeKey, publicNodeKey: privateNodeKey.Public(), lastDNSConfig: new(tailcfg.DNSConfig), - lastUserProfile: map[tailcfg.UserID]tailcfg.UserProfile{}, + lastUserProfile: map[tailcfg.UserID]tailcfg.UserProfileView{}, // Non-nil no-op defaults, to be optionally overridden by the caller. logf: logger.Discard, @@ -167,6 +165,7 @@ func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *t // For responses that mutate the self node, check for updated nodeAttrs. if resp.Node != nil { + upgradeNode(resp.Node) if DevKnob.StripCaps() { resp.Node.Capabilities = nil resp.Node.CapMap = nil @@ -182,6 +181,13 @@ func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *t ms.controlKnobs.UpdateFromNodeAttributes(resp.Node.CapMap) } + for _, p := range resp.Peers { + upgradeNode(p) + } + for _, p := range resp.PeersChanged { + upgradeNode(p) + } + // Call Node.InitDisplayNames on any changed nodes. initDisplayNames(cmp.Or(resp.Node.View(), ms.lastNode), resp) @@ -217,6 +223,33 @@ func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *t return nil } +// upgradeNode upgrades Node fields from the server into the modern forms +// not using deprecated fields. +func upgradeNode(n *tailcfg.Node) { + if n == nil { + return + } + if n.LegacyDERPString != "" { + if n.HomeDERP == 0 { + ip, portStr, err := net.SplitHostPort(n.LegacyDERPString) + if ip == tailcfg.DerpMagicIP && err == nil { + port, err := strconv.Atoi(portStr) + if err == nil { + n.HomeDERP = port + } + } + } + n.LegacyDERPString = "" + } + if DevKnob.StripHomeDERP() { + n.HomeDERP = 0 + } + + if n.AllowedIPs == nil { + n.AllowedIPs = slices.Clone(n.Addresses) + } +} + func (ms *mapSession) tryHandleIncrementally(res *tailcfg.MapResponse) bool { if ms.controlKnobs != nil && ms.controlKnobs.DisableDeltaUpdates.Load() { return false @@ -260,13 +293,47 @@ func (ms *mapSession) updateStateFromResponse(resp *tailcfg.MapResponse) { } for _, up := range resp.UserProfiles { - ms.lastUserProfile[up.ID] = up + ms.lastUserProfile[up.ID] = up.View() } // TODO(bradfitz): clean up old user profiles? maybe not worth it. if dm := resp.DERPMap; dm != nil { ms.vlogf("netmap: new map contains DERP map") + // Guard against the control server accidentally sending + // a nil region definition, which at least Headscale was + // observed to send. + for rid, r := range dm.Regions { + if r == nil { + delete(dm.Regions, rid) + } + } + + // In the copy/v86 wasm environment with limited networking, if the + // control plane didn't pick our DERP home for us, do it ourselves and + // mark all but the lowest region as NoMeasureNoHome. For prod, this + // will be Region 1, NYC, a compromise between the US and Europe. But + // really the control plane should pick this. This is only a fallback. + if hostinfo.IsInVM86() { + numCanMeasure := 0 + lowest := 0 + for rid, r := range dm.Regions { + if !r.NoMeasureNoHome { + numCanMeasure++ + if lowest == 0 || rid < lowest { + lowest = rid + } + } + } + if numCanMeasure > 1 { + for rid, r := range dm.Regions { + if rid != lowest { + r.NoMeasureNoHome = true + } + } + } + } + // Zero-valued fields in a DERPMap mean that we're not changing // anything and are using the previous value(s). if ldm := ms.lastDERPMap; ldm != nil { @@ -345,9 +412,6 @@ func (ms *mapSession) updateStateFromResponse(resp *tailcfg.MapResponse) { if resp.TKAInfo != nil { ms.lastTKAInfo = resp.TKAInfo } - if resp.MaxKeyDuration > 0 { - ms.lastMaxExpiry = resp.MaxKeyDuration - } } var ( @@ -366,16 +430,11 @@ var ( patchifiedPeerEqual = clientmetric.NewCounter("controlclient_patchified_peer_equal") ) -// updatePeersStateFromResponseres updates ms.peers and ms.sortedPeers from res. It takes ownership of res. +// updatePeersStateFromResponseres updates ms.peers from resp. +// It takes ownership of resp. func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (stats updateStats) { - defer func() { - if stats.removed > 0 || stats.added > 0 { - ms.rebuildSorted() - } - }() - if ms.peers == nil { - ms.peers = make(map[tailcfg.NodeID]*tailcfg.NodeView) + ms.peers = make(map[tailcfg.NodeID]tailcfg.NodeView) } if len(resp.Peers) > 0 { @@ -384,12 +443,12 @@ func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (s keep := make(map[tailcfg.NodeID]bool, len(resp.Peers)) for _, n := range resp.Peers { keep[n.ID] = true - if vp, ok := ms.peers[n.ID]; ok { + lenBefore := len(ms.peers) + ms.peers[n.ID] = n.View() + if len(ms.peers) == lenBefore { stats.changed++ - *vp = n.View() } else { stats.added++ - ms.peers[n.ID] = ptr.To(n.View()) } } for id := range ms.peers { @@ -410,12 +469,12 @@ func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (s } for _, n := range resp.PeersChanged { - if vp, ok := ms.peers[n.ID]; ok { + lenBefore := len(ms.peers) + ms.peers[n.ID] = n.View() + if len(ms.peers) == lenBefore { stats.changed++ - *vp = n.View() } else { stats.added++ - ms.peers[n.ID] = ptr.To(n.View()) } } @@ -427,7 +486,7 @@ func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (s } else { mut.LastSeen = nil } - *vp = mut.View() + ms.peers[nodeID] = mut.View() stats.changed++ } } @@ -436,7 +495,7 @@ func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (s if vp, ok := ms.peers[nodeID]; ok { mut := vp.AsStruct() mut.Online = ptr.To(online) - *vp = mut.View() + ms.peers[nodeID] = mut.View() stats.changed++ } } @@ -449,7 +508,7 @@ func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (s stats.changed++ mut := vp.AsStruct() if pc.DERPRegion != 0 { - mut.DERP = fmt.Sprintf("%s:%v", tailcfg.DerpMagicIP, pc.DERPRegion) + mut.HomeDERP = pc.DERPRegion patchDERPRegion.Add(1) } if pc.Cap != 0 { @@ -488,31 +547,12 @@ func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (s mut.CapMap = v patchCapMap.Add(1) } - *vp = mut.View() + ms.peers[pc.NodeID] = mut.View() } return } -// rebuildSorted rebuilds ms.sortedPeers from ms.peers. It should be called -// after any additions or removals from peers. -func (ms *mapSession) rebuildSorted() { - if ms.sortedPeers == nil { - ms.sortedPeers = make([]*tailcfg.NodeView, 0, len(ms.peers)) - } else { - if len(ms.sortedPeers) > len(ms.peers) { - clear(ms.sortedPeers[len(ms.peers):]) - } - ms.sortedPeers = ms.sortedPeers[:0] - } - for _, p := range ms.peers { - ms.sortedPeers = append(ms.sortedPeers, p) - } - sort.Slice(ms.sortedPeers, func(i, j int) bool { - return ms.sortedPeers[i].ID() < ms.sortedPeers[j].ID() - }) -} - func (ms *mapSession) addUserProfile(nm *netmap.NetworkMap, userID tailcfg.UserID) { if userID == 0 { return @@ -576,7 +616,7 @@ func (ms *mapSession) patchifyPeer(n *tailcfg.Node) (_ *tailcfg.PeerChange, ok b if !ok { return nil, false } - return peerChangeDiff(*was, n) + return peerChangeDiff(was, n) } // peerChangeDiff returns the difference from 'was' to 'n', if possible. @@ -656,17 +696,13 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang if !views.SliceEqual(was.Endpoints(), views.SliceOf(n.Endpoints)) { pc().Endpoints = slices.Clone(n.Endpoints) } - case "DERP": - if was.DERP() != n.DERP { - ip, portStr, err := net.SplitHostPort(n.DERP) - if err != nil || ip != "127.3.3.40" { - return nil, false - } - port, err := strconv.Atoi(portStr) - if err != nil || port < 1 || port > 65535 { - return nil, false - } - pc().DERPRegion = port + case "LegacyDERPString": + if was.LegacyDERPString() != "" || n.LegacyDERPString != "" { + panic("unexpected; caller should've already called upgradeNode") + } + case "HomeDERP": + if was.HomeDERP() != n.HomeDERP { + pc().DERPRegion = n.HomeDERP } case "Hostinfo": if !was.Hostinfo().Valid() && !n.Hostinfo.Valid() { @@ -688,21 +724,23 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang } case "CapMap": if len(n.CapMap) != was.CapMap().Len() { + // If they have different lengths, they're different. if n.CapMap == nil { pc().CapMap = make(tailcfg.NodeCapMap) } else { pc().CapMap = maps.Clone(n.CapMap) } - break - } - was.CapMap().Range(func(k tailcfg.NodeCapability, v views.Slice[tailcfg.RawMessage]) bool { - nv, ok := n.CapMap[k] - if !ok || !views.SliceEqual(v, views.SliceOf(nv)) { - pc().CapMap = maps.Clone(n.CapMap) - return false + } else { + // If they have the same length, check that all their keys + // have the same values. + for k, v := range was.CapMap().All() { + nv, ok := n.CapMap[k] + if !ok || !views.SliceEqual(v, views.SliceOf(nv)) { + pc().CapMap = maps.Clone(n.CapMap) + break + } } - return true - }) + } case "Tags": if !views.SliceEqual(was.Tags(), views.SliceOf(n.Tags)) { return nil, false @@ -712,13 +750,11 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang return nil, false } case "Online": - wasOnline := was.Online() - if n.Online != nil && wasOnline != nil && *n.Online != *wasOnline { + if wasOnline, ok := was.Online().GetOk(); ok && n.Online != nil && *n.Online != wasOnline { pc().Online = ptr.To(*n.Online) } case "LastSeen": - wasSeen := was.LastSeen() - if n.LastSeen != nil && wasSeen != nil && !wasSeen.Equal(*n.LastSeen) { + if wasSeen, ok := was.LastSeen().GetOk(); ok && n.LastSeen != nil && !wasSeen.Equal(*n.LastSeen) { pc().LastSeen = ptr.To(*n.LastSeen) } case "MachineAuthorized": @@ -743,18 +779,18 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang } case "SelfNodeV4MasqAddrForThisPeer": va, vb := was.SelfNodeV4MasqAddrForThisPeer(), n.SelfNodeV4MasqAddrForThisPeer - if va == nil && vb == nil { + if !va.Valid() && vb == nil { continue } - if va == nil || vb == nil || *va != *vb { + if va, ok := va.GetOk(); !ok || vb == nil || va != *vb { return nil, false } case "SelfNodeV6MasqAddrForThisPeer": va, vb := was.SelfNodeV6MasqAddrForThisPeer(), n.SelfNodeV6MasqAddrForThisPeer - if va == nil && vb == nil { + if !va.Valid() && vb == nil { continue } - if va == nil || vb == nil || *va != *vb { + if va, ok := va.GetOk(); !ok || vb == nil || va != *vb { return nil, false } case "ExitNodeDNSResolvers": @@ -778,21 +814,26 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang return ret, true } +func (ms *mapSession) sortedPeers() []tailcfg.NodeView { + ret := slicesx.MapValues(ms.peers) + slices.SortFunc(ret, func(a, b tailcfg.NodeView) int { + return cmp.Compare(a.ID(), b.ID()) + }) + return ret +} + // netmap returns a fully populated NetworkMap from the last state seen from // a call to updateStateFromResponse, filling in omitted // information from prior MapResponse values. func (ms *mapSession) netmap() *netmap.NetworkMap { - peerViews := make([]tailcfg.NodeView, len(ms.sortedPeers)) - for i, vp := range ms.sortedPeers { - peerViews[i] = *vp - } + peerViews := ms.sortedPeers() nm := &netmap.NetworkMap{ NodeKey: ms.publicNodeKey, PrivateKey: ms.privateNodeKey, MachineKey: ms.machinePubKey, Peers: peerViews, - UserProfiles: make(map[tailcfg.UserID]tailcfg.UserProfile), + UserProfiles: make(map[tailcfg.UserID]tailcfg.UserProfileView), Domain: ms.lastDomain, DomainAuditLogID: ms.lastDomainAuditLogID, DNS: *ms.lastDNSConfig, @@ -803,7 +844,6 @@ func (ms *mapSession) netmap() *netmap.NetworkMap { DERPMap: ms.lastDERPMap, ControlHealth: ms.lastHealth, TKAEnabled: ms.lastTKAInfo != nil && !ms.lastTKAInfo.Disabled, - MaxKeyDuration: ms.lastMaxExpiry, } if ms.lastTKAInfo != nil && ms.lastTKAInfo.Head != "" { diff --git a/control/controlclient/map_test.go b/control/controlclient/map_test.go index 897036a942f49..ccc57ae2b86a8 100644 --- a/control/controlclient/map_test.go +++ b/control/controlclient/map_test.go @@ -17,6 +17,7 @@ import ( "github.com/google/go-cmp/cmp" "go4.org/mem" "tailscale.com/control/controlknobs" + "tailscale.com/health" "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/tstime" @@ -50,9 +51,9 @@ func TestUpdatePeersStateFromResponse(t *testing.T) { n.LastSeen = &t } } - withDERP := func(d string) func(*tailcfg.Node) { + withDERP := func(regionID int) func(*tailcfg.Node) { return func(n *tailcfg.Node) { - n.DERP = d + n.HomeDERP = regionID } } withEP := func(ep string) func(*tailcfg.Node) { @@ -189,14 +190,14 @@ func TestUpdatePeersStateFromResponse(t *testing.T) { }, { name: "ep_change_derp", - prev: peers(n(1, "foo", withDERP("127.3.3.40:3"))), + prev: peers(n(1, "foo", withDERP(3))), mapRes: &tailcfg.MapResponse{ PeersChangedPatch: []*tailcfg.PeerChange{{ NodeID: 1, DERPRegion: 4, }}, }, - want: peers(n(1, "foo", withDERP("127.3.3.40:4"))), + want: peers(n(1, "foo", withDERP(4))), wantStats: updateStats{changed: 1}, }, { @@ -213,19 +214,19 @@ func TestUpdatePeersStateFromResponse(t *testing.T) { }, { name: "ep_change_udp_2", - prev: peers(n(1, "foo", withDERP("127.3.3.40:3"), withEP("1.2.3.4:111"))), + prev: peers(n(1, "foo", withDERP(3), withEP("1.2.3.4:111"))), mapRes: &tailcfg.MapResponse{ PeersChangedPatch: []*tailcfg.PeerChange{{ NodeID: 1, Endpoints: eps("1.2.3.4:56"), }}, }, - want: peers(n(1, "foo", withDERP("127.3.3.40:3"), withEP("1.2.3.4:56"))), + want: peers(n(1, "foo", withDERP(3), withEP("1.2.3.4:56"))), wantStats: updateStats{changed: 1}, }, { name: "ep_change_both", - prev: peers(n(1, "foo", withDERP("127.3.3.40:3"), withEP("1.2.3.4:111"))), + prev: peers(n(1, "foo", withDERP(3), withEP("1.2.3.4:111"))), mapRes: &tailcfg.MapResponse{ PeersChangedPatch: []*tailcfg.PeerChange{{ NodeID: 1, @@ -233,7 +234,7 @@ func TestUpdatePeersStateFromResponse(t *testing.T) { Endpoints: eps("1.2.3.4:56"), }}, }, - want: peers(n(1, "foo", withDERP("127.3.3.40:2"), withEP("1.2.3.4:56"))), + want: peers(n(1, "foo", withDERP(2), withEP("1.2.3.4:56"))), wantStats: updateStats{changed: 1}, }, { @@ -340,19 +341,18 @@ func TestUpdatePeersStateFromResponse(t *testing.T) { } ms := newTestMapSession(t, nil) for _, n := range tt.prev { - mak.Set(&ms.peers, n.ID, ptr.To(n.View())) + mak.Set(&ms.peers, n.ID, n.View()) } - ms.rebuildSorted() gotStats := ms.updatePeersStateFromResponse(tt.mapRes) - - got := make([]*tailcfg.Node, len(ms.sortedPeers)) - for i, vp := range ms.sortedPeers { - got[i] = vp.AsStruct() - } if gotStats != tt.wantStats { t.Errorf("got stats = %+v; want %+v", gotStats, tt.wantStats) } + + var got []*tailcfg.Node + for _, vp := range ms.sortedPeers() { + got = append(got, vp.AsStruct()) + } if !reflect.DeepEqual(got, tt.want) { t.Errorf("wrong results\n got: %s\nwant: %s", formatNodes(got), formatNodes(tt.want)) } @@ -745,8 +745,8 @@ func TestPeerChangeDiff(t *testing.T) { }, { name: "patch-derp", - a: &tailcfg.Node{ID: 1, DERP: "127.3.3.40:1"}, - b: &tailcfg.Node{ID: 1, DERP: "127.3.3.40:2"}, + a: &tailcfg.Node{ID: 1, HomeDERP: 1}, + b: &tailcfg.Node{ID: 1, HomeDERP: 2}, want: &tailcfg.PeerChange{NodeID: 1, DERPRegion: 2}, }, { @@ -930,23 +930,23 @@ func TestPatchifyPeersChanged(t *testing.T) { mr0: &tailcfg.MapResponse{ Node: &tailcfg.Node{Name: "foo.bar.ts.net."}, Peers: []*tailcfg.Node{ - {ID: 1, DERP: "127.3.3.40:1", Hostinfo: hi}, - {ID: 2, DERP: "127.3.3.40:2", Hostinfo: hi}, - {ID: 3, DERP: "127.3.3.40:3", Hostinfo: hi}, + {ID: 1, HomeDERP: 1, Hostinfo: hi}, + {ID: 2, HomeDERP: 2, Hostinfo: hi}, + {ID: 3, HomeDERP: 3, Hostinfo: hi}, }, }, mr1: &tailcfg.MapResponse{ PeersChanged: []*tailcfg.Node{ - {ID: 1, DERP: "127.3.3.40:11", Hostinfo: hi}, + {ID: 1, HomeDERP: 11, Hostinfo: hi}, {ID: 2, StableID: "other-change", Hostinfo: hi}, - {ID: 3, DERP: "127.3.3.40:33", Hostinfo: hi}, - {ID: 4, DERP: "127.3.3.40:4", Hostinfo: hi}, + {ID: 3, HomeDERP: 33, Hostinfo: hi}, + {ID: 4, HomeDERP: 4, Hostinfo: hi}, }, }, want: &tailcfg.MapResponse{ PeersChanged: []*tailcfg.Node{ {ID: 2, StableID: "other-change", Hostinfo: hi}, - {ID: 4, DERP: "127.3.3.40:4", Hostinfo: hi}, + {ID: 4, HomeDERP: 4, Hostinfo: hi}, }, PeersChangedPatch: []*tailcfg.PeerChange{ {NodeID: 1, DERPRegion: 11}, @@ -1007,6 +1007,85 @@ func TestPatchifyPeersChanged(t *testing.T) { } } +func TestUpgradeNode(t *testing.T) { + a1 := netip.MustParsePrefix("0.0.0.1/32") + a2 := netip.MustParsePrefix("0.0.0.2/32") + a3 := netip.MustParsePrefix("0.0.0.3/32") + a4 := netip.MustParsePrefix("0.0.0.4/32") + + tests := []struct { + name string + in *tailcfg.Node + want *tailcfg.Node + also func(t *testing.T, got *tailcfg.Node) // optional + }{ + { + name: "nil", + in: nil, + want: nil, + }, + { + name: "empty", + in: new(tailcfg.Node), + want: new(tailcfg.Node), + }, + { + name: "derp-both", + in: &tailcfg.Node{HomeDERP: 1, LegacyDERPString: tailcfg.DerpMagicIP + ":2"}, + want: &tailcfg.Node{HomeDERP: 1}, + }, + { + name: "derp-str-only", + in: &tailcfg.Node{LegacyDERPString: tailcfg.DerpMagicIP + ":2"}, + want: &tailcfg.Node{HomeDERP: 2}, + }, + { + name: "derp-int-only", + in: &tailcfg.Node{HomeDERP: 2}, + want: &tailcfg.Node{HomeDERP: 2}, + }, + { + name: "implicit-allowed-ips-all-set", + in: &tailcfg.Node{Addresses: []netip.Prefix{a1, a2}, AllowedIPs: []netip.Prefix{a3, a4}}, + want: &tailcfg.Node{Addresses: []netip.Prefix{a1, a2}, AllowedIPs: []netip.Prefix{a3, a4}}, + }, + { + name: "implicit-allowed-ips-only-address-set", + in: &tailcfg.Node{Addresses: []netip.Prefix{a1, a2}}, + want: &tailcfg.Node{Addresses: []netip.Prefix{a1, a2}, AllowedIPs: []netip.Prefix{a1, a2}}, + also: func(t *testing.T, got *tailcfg.Node) { + if t.Failed() { + return + } + if &got.Addresses[0] == &got.AllowedIPs[0] { + t.Error("Addresses and AllowIPs alias the same memory") + } + }, + }, + { + name: "implicit-allowed-ips-set-empty-slice", + in: &tailcfg.Node{Addresses: []netip.Prefix{a1, a2}, AllowedIPs: []netip.Prefix{}}, + want: &tailcfg.Node{Addresses: []netip.Prefix{a1, a2}, AllowedIPs: []netip.Prefix{}}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got *tailcfg.Node + if tt.in != nil { + got = ptr.To(*tt.in) // shallow clone + } + upgradeNode(got) + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("wrong result (-want +got):\n%s", diff) + } + if tt.also != nil { + tt.also(t, got) + } + }) + } + +} + func BenchmarkMapSessionDelta(b *testing.B) { for _, size := range []int{10, 100, 1_000, 10_000} { b.Run(fmt.Sprintf("size_%d", size), func(b *testing.B) { @@ -1023,7 +1102,7 @@ func BenchmarkMapSessionDelta(b *testing.B) { res.Peers = append(res.Peers, &tailcfg.Node{ ID: tailcfg.NodeID(i + 2), Name: fmt.Sprintf("peer%d.bar.ts.net.", i), - DERP: "127.3.3.40:10", + HomeDERP: 10, Addresses: []netip.Prefix{netip.MustParsePrefix("100.100.2.3/32"), netip.MustParsePrefix("fd7a:115c:a1e0::123/128")}, AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.100.2.3/32"), netip.MustParsePrefix("fd7a:115c:a1e0::123/128")}, Endpoints: eps("192.168.1.2:345", "192.168.1.3:678"), @@ -1058,3 +1137,34 @@ func BenchmarkMapSessionDelta(b *testing.B) { }) } } + +// TestNetmapHealthIntegration checks that we get the expected health warnings +// from processing a map response and passing the NetworkMap to a health tracker +func TestNetmapHealthIntegration(t *testing.T) { + ms := newTestMapSession(t, nil) + ht := health.Tracker{} + + ht.SetIPNState("NeedsLogin", true) + ht.GotStreamedMapResponse() + + nm := ms.netmapForResponse(&tailcfg.MapResponse{ + Health: []string{"Test message"}, + }) + ht.SetControlHealth(nm.ControlHealth) + + state := ht.CurrentState() + warning, ok := state.Warnings["control-health"] + + if !ok { + t.Fatal("no warning found in current state with code 'control-health'") + } + if got, want := warning.Title, "Coordination server reports an issue"; got != want { + t.Errorf("warning.Title = %q, want %q", got, want) + } + if got, want := warning.Severity, health.SeverityMedium; got != want { + t.Errorf("warning.Severity = %s, want %s", got, want) + } + if got, want := warning.Text, "The coordination server is reporting an health issue: Test message"; got != want { + t.Errorf("warning.Text = %q, want %q", got, want) + } +} diff --git a/control/controlclient/noise.go b/control/controlclient/noise.go index 3994af056fc3b..4bd8cfc25ee96 100644 --- a/control/controlclient/noise.go +++ b/control/controlclient/noise.go @@ -11,13 +11,13 @@ import ( "errors" "math" "net/http" + "net/netip" "net/url" "sync" "time" "golang.org/x/net/http2" "tailscale.com/control/controlhttp" - "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/internal/noiseconn" "tailscale.com/net/dnscache" @@ -30,7 +30,6 @@ import ( "tailscale.com/util/mak" "tailscale.com/util/multierr" "tailscale.com/util/singleflight" - "tailscale.com/util/testenv" ) // NoiseClient provides a http.Client to connect to tailcontrol over @@ -107,35 +106,45 @@ type NoiseOpts struct { DialPlan func() *tailcfg.ControlDialPlan } -// controlIsPlaintext is whether we should assume that the controlplane is only accessible -// over plaintext HTTP (as the first hop, before the ts2021 encryption begins). -// This is used by some tests which don't have a real TLS certificate. -var controlIsPlaintext = envknob.RegisterBool("TS_CONTROL_IS_PLAINTEXT_HTTP") - // NewNoiseClient returns a new noiseClient for the provided server and machine key. // serverURL is of the form https://: (no trailing slash). // // netMon may be nil, if non-nil it's used to do faster interface lookups. // dialPlan may be nil func NewNoiseClient(opts NoiseOpts) (*NoiseClient, error) { + logf := opts.Logf u, err := url.Parse(opts.ServerURL) if err != nil { return nil, err } + + if u.Scheme != "http" && u.Scheme != "https" { + return nil, errors.New("invalid ServerURL scheme, must be http or https") + } + var httpPort string var httpsPort string + addr, _ := netip.ParseAddr(u.Hostname()) + isPrivateHost := addr.IsPrivate() || addr.IsLoopback() || u.Hostname() == "localhost" if port := u.Port(); port != "" { - // If there is an explicit port specified, trust the scheme and hope for the best - if u.Scheme == "http" { + // If there is an explicit port specified, entirely rely on the scheme, + // unless it's http with a private host in which case we never try using HTTPS. + if u.Scheme == "https" { + httpPort = "" + httpsPort = port + } else if u.Scheme == "http" { httpPort = port httpsPort = "443" - if (testenv.InTest() || controlIsPlaintext()) && (u.Hostname() == "127.0.0.1" || u.Hostname() == "localhost") { + if isPrivateHost { + logf("setting empty HTTPS port with http scheme and private host %s", u.Hostname()) httpsPort = "" } - } else { - httpPort = "80" - httpsPort = port } + } else if u.Scheme == "http" && isPrivateHost { + // Whenever the scheme is http and the hostname is an IP address, do not set the HTTPS port, + // as there cannot be a TLS certificate issued for an IP, unless it's a public IP. + httpPort = "80" + httpsPort = "" } else { // Otherwise, use the standard ports httpPort = "80" @@ -387,17 +396,20 @@ func (nc *NoiseClient) dial(ctx context.Context) (*noiseconn.Conn, error) { // post does a POST to the control server at the given path, JSON-encoding body. // The provided nodeKey is an optional load balancing hint. func (nc *NoiseClient) post(ctx context.Context, path string, nodeKey key.NodePublic, body any) (*http.Response, error) { + return nc.doWithBody(ctx, "POST", path, nodeKey, body) +} + +func (nc *NoiseClient) doWithBody(ctx context.Context, method, path string, nodeKey key.NodePublic, body any) (*http.Response, error) { jbody, err := json.Marshal(body) if err != nil { return nil, err } - req, err := http.NewRequestWithContext(ctx, "POST", "https://"+nc.host+path, bytes.NewReader(jbody)) + req, err := http.NewRequestWithContext(ctx, method, "https://"+nc.host+path, bytes.NewReader(jbody)) if err != nil { return nil, err } addLBHeader(req, nodeKey) req.Header.Set("Content-Type", "application/json") - conn, err := nc.getConn(ctx) if err != nil { return nil, err diff --git a/control/controlclient/noise_test.go b/control/controlclient/noise_test.go index f2627bd0a50fa..4904016f2f082 100644 --- a/control/controlclient/noise_test.go +++ b/control/controlclient/noise_test.go @@ -10,16 +10,16 @@ import ( "io" "math" "net/http" - "net/http/httptest" "testing" "time" "golang.org/x/net/http2" - "tailscale.com/control/controlhttp" + "tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/internal/noiseconn" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" "tailscale.com/tailcfg" + "tailscale.com/tstest/nettest" "tailscale.com/types/key" "tailscale.com/types/logger" ) @@ -54,6 +54,123 @@ func TestNoiseClientHTTP2Upgrade_earlyPayload(t *testing.T) { }.run(t) } +func makeClientWithURL(t *testing.T, url string) *NoiseClient { + nc, err := NewNoiseClient(NoiseOpts{ + Logf: t.Logf, + ServerURL: url, + }) + if err != nil { + t.Fatal(err) + } + return nc +} + +func TestNoiseClientPortsAreSet(t *testing.T) { + tests := []struct { + name string + url string + wantHTTPS string + wantHTTP string + }{ + { + name: "https-url", + url: "https://example.com", + wantHTTPS: "443", + wantHTTP: "80", + }, + { + name: "http-url", + url: "http://example.com", + wantHTTPS: "443", // TODO(bradfitz): questionable; change? + wantHTTP: "80", + }, + { + name: "https-url-custom-port", + url: "https://example.com:123", + wantHTTPS: "123", + wantHTTP: "", + }, + { + name: "http-url-custom-port", + url: "http://example.com:123", + wantHTTPS: "443", // TODO(bradfitz): questionable; change? + wantHTTP: "123", + }, + { + name: "http-loopback-no-port", + url: "http://127.0.0.1", + wantHTTPS: "", + wantHTTP: "80", + }, + { + name: "http-loopback-custom-port", + url: "http://127.0.0.1:8080", + wantHTTPS: "", + wantHTTP: "8080", + }, + { + name: "http-localhost-no-port", + url: "http://localhost", + wantHTTPS: "", + wantHTTP: "80", + }, + { + name: "http-localhost-custom-port", + url: "http://localhost:8080", + wantHTTPS: "", + wantHTTP: "8080", + }, + { + name: "http-private-ip-no-port", + url: "http://192.168.2.3", + wantHTTPS: "", + wantHTTP: "80", + }, + { + name: "http-private-ip-custom-port", + url: "http://192.168.2.3:8080", + wantHTTPS: "", + wantHTTP: "8080", + }, + { + name: "http-public-ip", + url: "http://1.2.3.4", + wantHTTPS: "443", // TODO(bradfitz): questionable; change? + wantHTTP: "80", + }, + { + name: "http-public-ip-custom-port", + url: "http://1.2.3.4:8080", + wantHTTPS: "443", // TODO(bradfitz): questionable; change? + wantHTTP: "8080", + }, + { + name: "https-public-ip", + url: "https://1.2.3.4", + wantHTTPS: "443", + wantHTTP: "80", + }, + { + name: "https-public-ip-custom-port", + url: "https://1.2.3.4:8080", + wantHTTPS: "8080", + wantHTTP: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nc := makeClientWithURL(t, tt.url) + if nc.httpsPort != tt.wantHTTPS { + t.Errorf("nc.httpsPort = %q; want %q", nc.httpsPort, tt.wantHTTPS) + } + if nc.httpPort != tt.wantHTTP { + t.Errorf("nc.httpPort = %q; want %q", nc.httpPort, tt.wantHTTP) + } + }) + } +} + func (tt noiseClientTest) run(t *testing.T) { serverPrivate := key.NewMachine() clientPrivate := key.NewMachine() @@ -61,7 +178,8 @@ func (tt noiseClientTest) run(t *testing.T) { const msg = "Hello, client" h2 := &http2.Server{} - hs := httptest.NewServer(&Upgrader{ + nw := nettest.GetNetwork(t) + hs := nettest.NewHTTPServer(nw, &Upgrader{ h2srv: h2, noiseKeyPriv: serverPrivate, sendEarlyPayload: tt.sendEarlyPayload, @@ -76,11 +194,16 @@ func (tt noiseClientTest) run(t *testing.T) { defer hs.Close() dialer := tsdial.NewDialer(netmon.NewStatic()) + if nettest.PreferMemNetwork() { + dialer.SetSystemDialerForTest(nw.Dial) + } + nc, err := NewNoiseClient(NoiseOpts{ PrivKey: clientPrivate, ServerPubKey: serverPrivate.Public(), ServerURL: hs.URL, Dialer: dialer, + Logf: t.Logf, }) if err != nil { t.Fatal(err) @@ -201,7 +324,7 @@ func (up *Upgrader) ServeHTTP(w http.ResponseWriter, r *http.Request) { return nil } - cbConn, err := controlhttp.AcceptHTTP(r.Context(), w, r, up.noiseKeyPriv, earlyWriteFn) + cbConn, err := controlhttpserver.AcceptHTTP(r.Context(), w, r, up.noiseKeyPriv, earlyWriteFn) if err != nil { up.logf("controlhttp: Accept: %v", err) return diff --git a/control/controlclient/sign_supported.go b/control/controlclient/sign_supported.go index 0e3dd038e4ed7..a5d42ad7df4a2 100644 --- a/control/controlclient/sign_supported.go +++ b/control/controlclient/sign_supported.go @@ -13,7 +13,6 @@ import ( "crypto/x509" "errors" "fmt" - "sync" "time" "github.com/tailscale/certstore" @@ -22,11 +21,6 @@ import ( "tailscale.com/util/syspolicy" ) -var getMachineCertificateSubjectOnce struct { - sync.Once - v string // Subject of machine certificate to search for -} - // getMachineCertificateSubject returns the exact name of a Subject that needs // to be present in an identity's certificate chain to sign a RegisterRequest, // formatted as per pkix.Name.String(). The Subject may be that of the identity @@ -37,11 +31,8 @@ var getMachineCertificateSubjectOnce struct { // // Example: "CN=Tailscale Inc Test Root CA,OU=Tailscale Inc Test Certificate Authority,O=Tailscale Inc,ST=ON,C=CA" func getMachineCertificateSubject() string { - getMachineCertificateSubjectOnce.Do(func() { - getMachineCertificateSubjectOnce.v, _ = syspolicy.GetString(syspolicy.MachineCertificateSubject, "") - }) - - return getMachineCertificateSubjectOnce.v + machineCertSubject, _ := syspolicy.GetString(syspolicy.MachineCertificateSubject, "") + return machineCertSubject } var ( diff --git a/control/controlhttp/client.go b/control/controlhttp/client.go index 7e5263e3317fe..869bcb599c9f3 100644 --- a/control/controlhttp/client.go +++ b/control/controlhttp/client.go @@ -38,11 +38,13 @@ import ( "time" "tailscale.com/control/controlbase" + "tailscale.com/control/controlhttp/controlhttpcommon" "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/net/dnscache" "tailscale.com/net/dnsfallback" "tailscale.com/net/netutil" + "tailscale.com/net/netx" "tailscale.com/net/sockstats" "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" @@ -95,6 +97,9 @@ func (a *Dialer) httpsFallbackDelay() time.Duration { var _ = envknob.RegisterBool("TS_USE_CONTROL_DIAL_PLAN") // to record at init time whether it's in use func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) { + + a.logPort80Failure.Store(true) + // If we don't have a dial plan, just fall back to dialing the single // host we know about. useDialPlan := envknob.BoolDefaultTrue("TS_USE_CONTROL_DIAL_PLAN") @@ -245,6 +250,11 @@ func (a *Dialer) dial(ctx context.Context) (*ClientConn, error) { results[i].conn = nil // so we don't close it in the defer return conn, nil } + if ctx.Err() != nil { + a.logf("controlhttp: context aborted dialing") + return nil, ctx.Err() + } + merr := multierr.New(errs...) // If we get here, then we didn't get anywhere with our dial plan; fall back to just using DNS. @@ -266,6 +276,15 @@ var forceNoise443 = envknob.RegisterBool("TS_FORCE_NOISE_443") // use HTTPS connections as its underlay connection (double crypto). This can // be necessary when networks or middle boxes are messing with port 80. func (d *Dialer) forceNoise443() bool { + if runtime.GOOS == "plan9" { + // For running demos of Plan 9 in a browser with network relays, + // we want to minimize the number of connections we're making. + // The main reason to use port 80 is to avoid double crypto + // costs server-side but the costs are tiny and number of Plan 9 + // users doesn't make it worth it. Just disable this and always use + // HTTPS for Plan 9. That also reduces some log spam. + return true + } if forceNoise443() { return true } @@ -277,7 +296,9 @@ func (d *Dialer) forceNoise443() bool { // This heuristic works around networks where port 80 is MITMed and // appears to work for a bit post-Upgrade but then gets closed, // such as seen in https://github.com/tailscale/tailscale/issues/13597. - d.logf("controlhttp: forcing port 443 dial due to recent noise dial") + if d.logPort80Failure.CompareAndSwap(true, false) { + d.logf("controlhttp: forcing port 443 dial due to recent noise dial") + } return true } @@ -474,7 +495,7 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad dns = a.resolver() } - var dialer dnscache.DialContextFunc + var dialer netx.DialFunc if a.Dialer != nil { dialer = a.Dialer } else { @@ -571,9 +592,9 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad Method: "POST", URL: u, Header: http.Header{ - "Upgrade": []string{upgradeHeaderValue}, - "Connection": []string{"upgrade"}, - handshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)}, + "Upgrade": []string{controlhttpcommon.UpgradeHeaderValue}, + "Connection": []string{"upgrade"}, + controlhttpcommon.HandshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)}, }, } req = req.WithContext(ctx) @@ -597,7 +618,7 @@ func (a *Dialer) tryURLUpgrade(ctx context.Context, u *url.URL, optAddr netip.Ad return nil, fmt.Errorf("httptrace didn't provide a connection") } - if next := resp.Header.Get("Upgrade"); next != upgradeHeaderValue { + if next := resp.Header.Get("Upgrade"); next != controlhttpcommon.UpgradeHeaderValue { resp.Body.Close() return nil, fmt.Errorf("server switched to unexpected protocol %q", next) } diff --git a/control/controlhttp/client_js.go b/control/controlhttp/client_js.go index 4b7126b52cf38..cc05b5b192766 100644 --- a/control/controlhttp/client_js.go +++ b/control/controlhttp/client_js.go @@ -12,6 +12,7 @@ import ( "github.com/coder/websocket" "tailscale.com/control/controlbase" + "tailscale.com/control/controlhttp/controlhttpcommon" "tailscale.com/net/wsconn" ) @@ -42,11 +43,11 @@ func (d *Dialer) Dial(ctx context.Context) (*ClientConn, error) { // Can't set HTTP headers on the websocket request, so we have to to send // the handshake via an HTTP header. RawQuery: url.Values{ - handshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)}, + controlhttpcommon.HandshakeHeaderName: []string{base64.StdEncoding.EncodeToString(init)}, }.Encode(), } wsConn, _, err := websocket.Dial(ctx, wsURL.String(), &websocket.DialOptions{ - Subprotocols: []string{upgradeHeaderValue}, + Subprotocols: []string{controlhttpcommon.UpgradeHeaderValue}, }) if err != nil { return nil, err diff --git a/control/controlhttp/constants.go b/control/controlhttp/constants.go index ea1725e76d438..12038fae45b1c 100644 --- a/control/controlhttp/constants.go +++ b/control/controlhttp/constants.go @@ -6,11 +6,13 @@ package controlhttp import ( "net/http" "net/url" + "sync/atomic" "time" "tailscale.com/health" "tailscale.com/net/dnscache" "tailscale.com/net/netmon" + "tailscale.com/net/netx" "tailscale.com/tailcfg" "tailscale.com/tstime" "tailscale.com/types/key" @@ -18,15 +20,6 @@ import ( ) const ( - // upgradeHeader is the value of the Upgrade HTTP header used to - // indicate the Tailscale control protocol. - upgradeHeaderValue = "tailscale-control-protocol" - - // handshakeHeaderName is the HTTP request header that can - // optionally contain base64-encoded initial handshake - // payload, to save an RTT. - handshakeHeaderName = "X-Tailscale-Handshake" - // serverUpgradePath is where the server-side HTTP handler to // to do the protocol switch is located. serverUpgradePath = "/ts2021" @@ -74,7 +67,7 @@ type Dialer struct { // Dialer is the dialer used to make outbound connections. // // If not specified, this defaults to net.Dialer.DialContext. - Dialer dnscache.DialContextFunc + Dialer netx.DialFunc // DNSCache is the caching Resolver used by this Dialer. // @@ -85,6 +78,8 @@ type Dialer struct { // dropped. Logf logger.Logf + // NetMon is the [netmon.Monitor] to use for this Dialer. It must be + // non-nil. NetMon *netmon.Monitor // HealthTracker, if non-nil, is the health tracker to use. @@ -97,6 +92,11 @@ type Dialer struct { proxyFunc func(*http.Request) (*url.URL, error) // or nil + // logPort80Failure is whether we should log about port 80 interceptions + // and forcing a port 443 dial. We do this only once per "dial" method + // which can result in many concurrent racing dialHost calls. + logPort80Failure atomic.Bool + // For tests only drainFinished chan struct{} omitCertErrorLogging bool diff --git a/control/controlhttp/controlhttpcommon/controlhttpcommon.go b/control/controlhttp/controlhttpcommon/controlhttpcommon.go new file mode 100644 index 0000000000000..a86b7ca04a7f4 --- /dev/null +++ b/control/controlhttp/controlhttpcommon/controlhttpcommon.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package controlhttpcommon contains common constants for used +// by the controlhttp client and controlhttpserver packages. +package controlhttpcommon + +// UpgradeHeader is the value of the Upgrade HTTP header used to +// indicate the Tailscale control protocol. +const UpgradeHeaderValue = "tailscale-control-protocol" + +// handshakeHeaderName is the HTTP request header that can +// optionally contain base64-encoded initial handshake +// payload, to save an RTT. +const HandshakeHeaderName = "X-Tailscale-Handshake" diff --git a/control/controlhttp/server.go b/control/controlhttp/controlhttpserver/controlhttpserver.go similarity index 92% rename from control/controlhttp/server.go rename to control/controlhttp/controlhttpserver/controlhttpserver.go index 7c3dd5618c4a3..af320781069d1 100644 --- a/control/controlhttp/server.go +++ b/control/controlhttp/controlhttpserver/controlhttpserver.go @@ -3,7 +3,8 @@ //go:build !ios -package controlhttp +// Package controlhttpserver contains the HTTP server side of the ts2021 control protocol. +package controlhttpserver import ( "context" @@ -18,6 +19,7 @@ import ( "github.com/coder/websocket" "tailscale.com/control/controlbase" + "tailscale.com/control/controlhttp/controlhttpcommon" "tailscale.com/net/netutil" "tailscale.com/net/wsconn" "tailscale.com/types/key" @@ -45,12 +47,12 @@ func acceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri if next == "websocket" { return acceptWebsocket(ctx, w, r, private) } - if next != upgradeHeaderValue { + if next != controlhttpcommon.UpgradeHeaderValue { http.Error(w, "unknown next protocol", http.StatusBadRequest) return nil, fmt.Errorf("client requested unhandled next protocol %q", next) } - initB64 := r.Header.Get(handshakeHeaderName) + initB64 := r.Header.Get(controlhttpcommon.HandshakeHeaderName) if initB64 == "" { http.Error(w, "missing Tailscale handshake header", http.StatusBadRequest) return nil, errors.New("no tailscale handshake header in HTTP request") @@ -67,7 +69,7 @@ func acceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri return nil, errors.New("can't hijack client connection") } - w.Header().Set("Upgrade", upgradeHeaderValue) + w.Header().Set("Upgrade", controlhttpcommon.UpgradeHeaderValue) w.Header().Set("Connection", "upgrade") w.WriteHeader(http.StatusSwitchingProtocols) @@ -117,7 +119,7 @@ func acceptHTTP(ctx context.Context, w http.ResponseWriter, r *http.Request, pri // speak HTTP) to a Tailscale control protocol base transport connection. func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request, private key.MachinePrivate) (*controlbase.Conn, error) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{upgradeHeaderValue}, + Subprotocols: []string{controlhttpcommon.UpgradeHeaderValue}, OriginPatterns: []string{"*"}, // Disable compression because we transmit Noise messages that are not // compressible. @@ -129,7 +131,7 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request if err != nil { return nil, fmt.Errorf("Could not accept WebSocket connection %v", err) } - if c.Subprotocol() != upgradeHeaderValue { + if c.Subprotocol() != controlhttpcommon.UpgradeHeaderValue { c.Close(websocket.StatusPolicyViolation, "client must speak the control subprotocol") return nil, fmt.Errorf("Unexpected subprotocol %q", c.Subprotocol()) } @@ -137,7 +139,7 @@ func acceptWebsocket(ctx context.Context, w http.ResponseWriter, r *http.Request c.Close(websocket.StatusPolicyViolation, "Could not parse parameters") return nil, fmt.Errorf("parse query parameters: %v", err) } - initB64 := r.Form.Get(handshakeHeaderName) + initB64 := r.Form.Get(controlhttpcommon.HandshakeHeaderName) if initB64 == "" { c.Close(websocket.StatusPolicyViolation, "missing Tailscale handshake parameter") return nil, errors.New("no tailscale handshake parameter in HTTP request") diff --git a/control/controlhttp/http_test.go b/control/controlhttp/http_test.go index 8c8ed7f5701b0..daf262023da97 100644 --- a/control/controlhttp/http_test.go +++ b/control/controlhttp/http_test.go @@ -23,8 +23,11 @@ import ( "time" "tailscale.com/control/controlbase" - "tailscale.com/net/dnscache" + "tailscale.com/control/controlhttp/controlhttpcommon" + "tailscale.com/control/controlhttp/controlhttpserver" + "tailscale.com/health" "tailscale.com/net/netmon" + "tailscale.com/net/netx" "tailscale.com/net/socks5" "tailscale.com/net/tsdial" "tailscale.com/tailcfg" @@ -158,7 +161,7 @@ func testControlHTTP(t *testing.T, param httpTestParam) { return err } } - conn, err := AcceptHTTP(context.Background(), w, r, server, earlyWriteFn) + conn, err := controlhttpserver.AcceptHTTP(context.Background(), w, r, server, earlyWriteFn) if err != nil { log.Print(err) } @@ -225,6 +228,7 @@ func testControlHTTP(t *testing.T, param httpTestParam) { omitCertErrorLogging: true, testFallbackDelay: fallbackDelay, Clock: clock, + HealthTracker: new(health.Tracker), } if param.httpInDial { @@ -529,7 +533,7 @@ EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA== func brokenMITMHandler(clock tstime.Clock) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Upgrade", upgradeHeaderValue) + w.Header().Set("Upgrade", controlhttpcommon.UpgradeHeaderValue) w.Header().Set("Connection", "upgrade") w.WriteHeader(http.StatusSwitchingProtocols) w.(http.Flusher).Flush() @@ -574,7 +578,7 @@ func TestDialPlan(t *testing.T) { close(done) }) var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := AcceptHTTP(context.Background(), w, r, server, nil) + conn, err := controlhttpserver.AcceptHTTP(context.Background(), w, r, server, nil) if err != nil { log.Print(err) } else { @@ -726,6 +730,7 @@ func TestDialPlan(t *testing.T) { omitCertErrorLogging: true, testFallbackDelay: 50 * time.Millisecond, Clock: clock, + HealthTracker: new(health.Tracker), } conn, err := a.dial(ctx) @@ -754,7 +759,7 @@ func TestDialPlan(t *testing.T) { type closeTrackDialer struct { t testing.TB - inner dnscache.DialContextFunc + inner netx.DialFunc mu sync.Mutex conns map[*closeTrackConn]bool } diff --git a/control/controlknobs/controlknobs.go b/control/controlknobs/controlknobs.go index dd76a3abdba5b..a86f0af53a829 100644 --- a/control/controlknobs/controlknobs.go +++ b/control/controlknobs/controlknobs.go @@ -6,6 +6,8 @@ package controlknobs import ( + "fmt" + "reflect" "sync/atomic" "tailscale.com/syncs" @@ -103,6 +105,11 @@ type Knobs struct { // DisableCaptivePortalDetection is whether the node should not perform captive portal detection // automatically when the network state changes. DisableCaptivePortalDetection atomic.Bool + + // DisableSkipStatusQueue is whether the node should disable skipping + // of queued netmap.NetworkMap between the controlclient and LocalBackend. + // See tailscale/tailscale#14768. + DisableSkipStatusQueue atomic.Bool } // UpdateFromNodeAttributes updates k (if non-nil) based on the provided self @@ -132,6 +139,7 @@ func (k *Knobs) UpdateFromNodeAttributes(capMap tailcfg.NodeCapMap) { disableLocalDNSOverrideViaNRPT = has(tailcfg.NodeAttrDisableLocalDNSOverrideViaNRPT) disableCryptorouting = has(tailcfg.NodeAttrDisableMagicSockCryptoRouting) disableCaptivePortalDetection = has(tailcfg.NodeAttrDisableCaptivePortalDetection) + disableSkipStatusQueue = has(tailcfg.NodeAttrDisableSkipStatusQueue) ) if has(tailcfg.NodeAttrOneCGNATEnable) { @@ -159,6 +167,7 @@ func (k *Knobs) UpdateFromNodeAttributes(capMap tailcfg.NodeCapMap) { k.DisableLocalDNSOverrideViaNRPT.Store(disableLocalDNSOverrideViaNRPT) k.DisableCryptorouting.Store(disableCryptorouting) k.DisableCaptivePortalDetection.Store(disableCaptivePortalDetection) + k.DisableSkipStatusQueue.Store(disableSkipStatusQueue) } // AsDebugJSON returns k as something that can be marshalled with json.Marshal @@ -167,25 +176,19 @@ func (k *Knobs) AsDebugJSON() map[string]any { if k == nil { return nil } - return map[string]any{ - "DisableUPnP": k.DisableUPnP.Load(), - "KeepFullWGConfig": k.KeepFullWGConfig.Load(), - "RandomizeClientPort": k.RandomizeClientPort.Load(), - "OneCGNAT": k.OneCGNAT.Load(), - "ForceBackgroundSTUN": k.ForceBackgroundSTUN.Load(), - "DisableDeltaUpdates": k.DisableDeltaUpdates.Load(), - "PeerMTUEnable": k.PeerMTUEnable.Load(), - "DisableDNSForwarderTCPRetries": k.DisableDNSForwarderTCPRetries.Load(), - "SilentDisco": k.SilentDisco.Load(), - "LinuxForceIPTables": k.LinuxForceIPTables.Load(), - "LinuxForceNfTables": k.LinuxForceNfTables.Load(), - "SeamlessKeyRenewal": k.SeamlessKeyRenewal.Load(), - "ProbeUDPLifetime": k.ProbeUDPLifetime.Load(), - "AppCStoreRoutes": k.AppCStoreRoutes.Load(), - "UserDialUseRoutes": k.UserDialUseRoutes.Load(), - "DisableSplitDNSWhenNoCustomResolvers": k.DisableSplitDNSWhenNoCustomResolvers.Load(), - "DisableLocalDNSOverrideViaNRPT": k.DisableLocalDNSOverrideViaNRPT.Load(), - "DisableCryptorouting": k.DisableCryptorouting.Load(), - "DisableCaptivePortalDetection": k.DisableCaptivePortalDetection.Load(), + ret := map[string]any{} + rt := reflect.TypeFor[Knobs]() + rv := reflect.ValueOf(k).Elem() // of *k + for i := 0; i < rt.NumField(); i++ { + name := rt.Field(i).Name + switch v := rv.Field(i).Addr().Interface().(type) { + case *atomic.Bool: + ret[name] = v.Load() + case *syncs.AtomicValue[opt.Bool]: + ret[name] = v.Load() + default: + panic(fmt.Sprintf("unknown field type %T for %v", v, name)) + } } + return ret } diff --git a/control/controlknobs/controlknobs_test.go b/control/controlknobs/controlknobs_test.go index a78a486f3aaae..7618b7121c500 100644 --- a/control/controlknobs/controlknobs_test.go +++ b/control/controlknobs/controlknobs_test.go @@ -6,6 +6,8 @@ package controlknobs import ( "reflect" "testing" + + "tailscale.com/types/logger" ) func TestAsDebugJSON(t *testing.T) { @@ -18,4 +20,5 @@ func TestAsDebugJSON(t *testing.T) { if want := reflect.TypeFor[Knobs]().NumField(); len(got) != want { t.Errorf("AsDebugJSON map has %d fields; want %v", len(got), want) } + t.Logf("Got: %v", logger.AsJSON(got)) } diff --git a/derp/derp.go b/derp/derp.go index f9b0706477358..65acd43210234 100644 --- a/derp/derp.go +++ b/derp/derp.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "io" + "net" "time" ) @@ -79,8 +80,7 @@ const ( // framePeerGone to B so B can forget that a reverse path // exists on that connection to get back to A. It is also sent // if A tries to send a CallMeMaybe to B and the server has no - // record of B (which currently would only happen if there was - // a bug). + // record of B framePeerGone = frameType(0x08) // 32B pub key of peer that's gone + 1 byte reason // framePeerPresent is like framePeerGone, but for other members of the DERP @@ -131,8 +131,8 @@ const ( type PeerGoneReasonType byte const ( - PeerGoneReasonDisconnected = PeerGoneReasonType(0x00) // peer disconnected from this server - PeerGoneReasonNotHere = PeerGoneReasonType(0x01) // server doesn't know about this peer, unexpected + PeerGoneReasonDisconnected = PeerGoneReasonType(0x00) // is only sent when a peer disconnects from this server + PeerGoneReasonNotHere = PeerGoneReasonType(0x01) // server doesn't know about this peer PeerGoneReasonMeshConnBroke = PeerGoneReasonType(0xf0) // invented by Client.RunWatchConnectionLoop on disconnect; not sent on the wire ) @@ -147,6 +147,7 @@ const ( PeerPresentIsRegular = 1 << 0 PeerPresentIsMeshPeer = 1 << 1 PeerPresentIsProber = 1 << 2 + PeerPresentNotIdeal = 1 << 3 // client said derp server is not its Region.Nodes[0] ideal node ) var bin = binary.BigEndian @@ -254,3 +255,14 @@ func writeFrame(bw *bufio.Writer, t frameType, b []byte) error { } return bw.Flush() } + +// Conn is the subset of the underlying net.Conn the DERP Server needs. +// It is a defined type so that non-net connections can be used. +type Conn interface { + io.WriteCloser + LocalAddr() net.Addr + // The *Deadline methods follow the semantics of net.Conn. + SetDeadline(time.Time) error + SetReadDeadline(time.Time) error + SetWriteDeadline(time.Time) error +} diff --git a/derp/derp_server.go b/derp/derp_server.go index 8c5d6e890567b..abda9da73a6fc 100644 --- a/derp/derp_server.go +++ b/derp/derp_server.go @@ -11,6 +11,7 @@ import ( "context" "crypto/ed25519" crand "crypto/rand" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/binary" @@ -23,9 +24,9 @@ import ( "math" "math/big" "math/rand/v2" - "net" "net/http" "net/netip" + "os" "os/exec" "runtime" "strconv" @@ -36,7 +37,9 @@ import ( "go4.org/mem" "golang.org/x/sync/errgroup" + "tailscale.com/client/local" "tailscale.com/client/tailscale" + "tailscale.com/derp/derpconst" "tailscale.com/disco" "tailscale.com/envknob" "tailscale.com/metrics" @@ -46,6 +49,7 @@ import ( "tailscale.com/tstime/rate" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/util/ctxkey" "tailscale.com/util/mak" "tailscale.com/util/set" "tailscale.com/util/slicesx" @@ -56,6 +60,16 @@ import ( // verbosely log whenever DERP drops a packet. var verboseDropKeys = map[key.NodePublic]bool{} +// IdealNodeHeader is the HTTP request header sent on DERP HTTP client requests +// to indicate that they're connecting to their ideal (Region.Nodes[0]) node. +// The HTTP header value is the name of the node they wish they were connected +// to. This is an optional header. +const IdealNodeHeader = "Ideal-Node" + +// IdealNodeContextKey is the context key used to pass the IdealNodeHeader value +// from the HTTP handler to the DERP server's Accept method. +var IdealNodeContextKey = ctxkey.New[string]("ideal-node", "") + func init() { keys := envknob.String("TS_DEBUG_VERBOSE_DROPS") if keys == "" { @@ -72,10 +86,19 @@ func init() { } const ( - perClientSendQueueDepth = 32 // packets buffered for sending - writeTimeout = 2 * time.Second + defaultPerClientSendQueueDepth = 32 // default packets buffered for sending + DefaultTCPWiteTimeout = 2 * time.Second + privilegedWriteTimeout = 30 * time.Second // for clients with the mesh key ) +func getPerClientSendQueueDepth() int { + if v, ok := envknob.LookupInt("TS_DEBUG_DERP_PER_CLIENT_SEND_QUEUE_DEPTH"); ok { + return v + } + + return defaultPerClientSendQueueDepth +} + // dupPolicy is a temporary (2021-08-30) mechanism to change the policy // of how duplicate connection for the same key are handled. type dupPolicy int8 @@ -91,6 +114,14 @@ const ( disableFighters ) +// packetKind is the kind of packet being sent through DERP +type packetKind string + +const ( + packetKindDisco packetKind = "disco" + packetKindOther packetKind = "other" +) + type align64 [0]atomic.Int64 // for side effect of its 64-bit alignment // Server is a DERP server. @@ -108,44 +139,40 @@ type Server struct { metaCert []byte // the encoded x509 cert to send after LetsEncrypt cert+intermediate dupPolicy dupPolicy debug bool + localClient local.Client // Counters: - packetsSent, bytesSent expvar.Int - packetsRecv, bytesRecv expvar.Int - packetsRecvByKind metrics.LabelMap - packetsRecvDisco *expvar.Int - packetsRecvOther *expvar.Int - _ align64 - packetsDropped expvar.Int - packetsDroppedReason metrics.LabelMap - packetsDroppedReasonCounters []*expvar.Int // indexed by dropReason - packetsDroppedType metrics.LabelMap - packetsDroppedTypeDisco *expvar.Int - packetsDroppedTypeOther *expvar.Int - _ align64 - packetsForwardedOut expvar.Int - packetsForwardedIn expvar.Int - peerGoneDisconnectedFrames expvar.Int // number of peer disconnected frames sent - peerGoneNotHereFrames expvar.Int // number of peer not here frames sent - gotPing expvar.Int // number of ping frames from client - sentPong expvar.Int // number of pong frames enqueued to client - accepts expvar.Int - curClients expvar.Int - curHomeClients expvar.Int // ones with preferred - dupClientKeys expvar.Int // current number of public keys we have 2+ connections for - dupClientConns expvar.Int // current number of connections sharing a public key - dupClientConnTotal expvar.Int // total number of accepted connections when a dup key existed - unknownFrames expvar.Int - homeMovesIn expvar.Int // established clients announce home server moves in - homeMovesOut expvar.Int // established clients announce home server moves out - multiForwarderCreated expvar.Int - multiForwarderDeleted expvar.Int - removePktForwardOther expvar.Int - avgQueueDuration *uint64 // In milliseconds; accessed atomically - tcpRtt metrics.LabelMap // histogram - meshUpdateBatchSize *metrics.Histogram - meshUpdateLoopCount *metrics.Histogram - bufferedWriteFrames *metrics.Histogram // how many sendLoop frames (or groups of related frames) get written per flush + packetsSent, bytesSent expvar.Int + packetsRecv, bytesRecv expvar.Int + packetsRecvByKind metrics.LabelMap + packetsRecvDisco *expvar.Int + packetsRecvOther *expvar.Int + _ align64 + packetsForwardedOut expvar.Int + packetsForwardedIn expvar.Int + peerGoneDisconnectedFrames expvar.Int // number of peer disconnected frames sent + peerGoneNotHereFrames expvar.Int // number of peer not here frames sent + gotPing expvar.Int // number of ping frames from client + sentPong expvar.Int // number of pong frames enqueued to client + accepts expvar.Int + curClients expvar.Int + curClientsNotIdeal expvar.Int + curHomeClients expvar.Int // ones with preferred + dupClientKeys expvar.Int // current number of public keys we have 2+ connections for + dupClientConns expvar.Int // current number of connections sharing a public key + dupClientConnTotal expvar.Int // total number of accepted connections when a dup key existed + unknownFrames expvar.Int + homeMovesIn expvar.Int // established clients announce home server moves in + homeMovesOut expvar.Int // established clients announce home server moves out + multiForwarderCreated expvar.Int + multiForwarderDeleted expvar.Int + removePktForwardOther expvar.Int + sclientWriteTimeouts expvar.Int + avgQueueDuration *uint64 // In milliseconds; accessed atomically + tcpRtt metrics.LabelMap // histogram + meshUpdateBatchSize *metrics.Histogram + meshUpdateLoopCount *metrics.Histogram + bufferedWriteFrames *metrics.Histogram // how many sendLoop frames (or groups of related frames) get written per flush // verifyClientsLocalTailscaled only accepts client connections to the DERP // server if the clientKey is a known peer in the network, as specified by a @@ -175,6 +202,11 @@ type Server struct { // maps from netip.AddrPort to a client's public key keyOfAddr map[netip.AddrPort]key.NodePublic + // Sets the client send queue depth for the server. + perClientSendQueueDepth int + + tcpWriteTimeout time.Duration + clock tstime.Clock } @@ -314,16 +346,16 @@ type PacketForwarder interface { String() string } -// Conn is the subset of the underlying net.Conn the DERP Server needs. -// It is a defined type so that non-net connections can be used. -type Conn interface { - io.WriteCloser - LocalAddr() net.Addr - // The *Deadline methods follow the semantics of net.Conn. - SetDeadline(time.Time) error - SetReadDeadline(time.Time) error - SetWriteDeadline(time.Time) error -} +var packetsDropped = metrics.NewMultiLabelMap[dropReasonKindLabels]( + "derp_packets_dropped", + "counter", + "DERP packets dropped by reason and by kind") + +var bytesDropped = metrics.NewMultiLabelMap[dropReasonKindLabels]( + "derp_bytes_dropped", + "counter", + "DERP bytes dropped by reason and by kind", +) // NewServer returns a new DERP server. It doesn't listen on its own. // Connections are given to it via Server.Accept. @@ -332,59 +364,100 @@ func NewServer(privateKey key.NodePrivate, logf logger.Logf) *Server { runtime.ReadMemStats(&ms) s := &Server{ - debug: envknob.Bool("DERP_DEBUG_LOGS"), - privateKey: privateKey, - publicKey: privateKey.Public(), - logf: logf, - limitedLogf: logger.RateLimitedFn(logf, 30*time.Second, 5, 100), - packetsRecvByKind: metrics.LabelMap{Label: "kind"}, - packetsDroppedReason: metrics.LabelMap{Label: "reason"}, - packetsDroppedType: metrics.LabelMap{Label: "type"}, - clients: map[key.NodePublic]*clientSet{}, - clientsMesh: map[key.NodePublic]PacketForwarder{}, - netConns: map[Conn]chan struct{}{}, - memSys0: ms.Sys, - watchers: set.Set[*sclient]{}, - peerGoneWatchers: map[key.NodePublic]set.HandleSet[func(key.NodePublic)]{}, - avgQueueDuration: new(uint64), - tcpRtt: metrics.LabelMap{Label: "le"}, - meshUpdateBatchSize: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000}), - meshUpdateLoopCount: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100}), - bufferedWriteFrames: metrics.NewHistogram([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25, 50, 100}), - keyOfAddr: map[netip.AddrPort]key.NodePublic{}, - clock: tstime.StdClock{}, + debug: envknob.Bool("DERP_DEBUG_LOGS"), + privateKey: privateKey, + publicKey: privateKey.Public(), + logf: logf, + limitedLogf: logger.RateLimitedFn(logf, 30*time.Second, 5, 100), + packetsRecvByKind: metrics.LabelMap{Label: "kind"}, + clients: map[key.NodePublic]*clientSet{}, + clientsMesh: map[key.NodePublic]PacketForwarder{}, + netConns: map[Conn]chan struct{}{}, + memSys0: ms.Sys, + watchers: set.Set[*sclient]{}, + peerGoneWatchers: map[key.NodePublic]set.HandleSet[func(key.NodePublic)]{}, + avgQueueDuration: new(uint64), + tcpRtt: metrics.LabelMap{Label: "le"}, + meshUpdateBatchSize: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000}), + meshUpdateLoopCount: metrics.NewHistogram([]float64{0, 1, 2, 5, 10, 20, 50, 100}), + bufferedWriteFrames: metrics.NewHistogram([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25, 50, 100}), + keyOfAddr: map[netip.AddrPort]key.NodePublic{}, + clock: tstime.StdClock{}, + tcpWriteTimeout: DefaultTCPWiteTimeout, } s.initMetacert() - s.packetsRecvDisco = s.packetsRecvByKind.Get("disco") - s.packetsRecvOther = s.packetsRecvByKind.Get("other") + s.packetsRecvDisco = s.packetsRecvByKind.Get(string(packetKindDisco)) + s.packetsRecvOther = s.packetsRecvByKind.Get(string(packetKindOther)) - s.packetsDroppedReasonCounters = s.genPacketsDroppedReasonCounters() + genDroppedCounters() - s.packetsDroppedTypeDisco = s.packetsDroppedType.Get("disco") - s.packetsDroppedTypeOther = s.packetsDroppedType.Get("other") + s.perClientSendQueueDepth = getPerClientSendQueueDepth() return s } -func (s *Server) genPacketsDroppedReasonCounters() []*expvar.Int { - getMetric := s.packetsDroppedReason.Get - ret := []*expvar.Int{ - dropReasonUnknownDest: getMetric("unknown_dest"), - dropReasonUnknownDestOnFwd: getMetric("unknown_dest_on_fwd"), - dropReasonGoneDisconnected: getMetric("gone_disconnected"), - dropReasonQueueHead: getMetric("queue_head"), - dropReasonQueueTail: getMetric("queue_tail"), - dropReasonWriteError: getMetric("write_error"), - dropReasonDupClient: getMetric("dup_client"), +func genDroppedCounters() { + initMetrics := func(reason dropReason) { + packetsDropped.Add(dropReasonKindLabels{ + Kind: string(packetKindDisco), + Reason: string(reason), + }, 0) + packetsDropped.Add(dropReasonKindLabels{ + Kind: string(packetKindOther), + Reason: string(reason), + }, 0) + bytesDropped.Add(dropReasonKindLabels{ + Kind: string(packetKindDisco), + Reason: string(reason), + }, 0) + bytesDropped.Add(dropReasonKindLabels{ + Kind: string(packetKindOther), + Reason: string(reason), + }, 0) + } + getMetrics := func(reason dropReason) []expvar.Var { + return []expvar.Var{ + packetsDropped.Get(dropReasonKindLabels{ + Kind: string(packetKindDisco), + Reason: string(reason), + }), + packetsDropped.Get(dropReasonKindLabels{ + Kind: string(packetKindOther), + Reason: string(reason), + }), + bytesDropped.Get(dropReasonKindLabels{ + Kind: string(packetKindDisco), + Reason: string(reason), + }), + bytesDropped.Get(dropReasonKindLabels{ + Kind: string(packetKindOther), + Reason: string(reason), + }), + } } - if len(ret) != int(numDropReasons) { - panic("dropReason metrics out of sync") + + dropReasons := []dropReason{ + dropReasonUnknownDest, + dropReasonUnknownDestOnFwd, + dropReasonGoneDisconnected, + dropReasonQueueHead, + dropReasonQueueTail, + dropReasonWriteError, + dropReasonDupClient, } - for i := range numDropReasons { - if ret[i] == nil { + + for _, dr := range dropReasons { + initMetrics(dr) + m := getMetrics(dr) + if len(m) != 4 { panic("dropReason metrics out of sync") } + + for _, v := range m { + if v == nil { + panic("dropReason metrics out of sync") + } + } } - return ret } // SetMesh sets the pre-shared key that regional DERP servers used to mesh @@ -415,6 +488,23 @@ func (s *Server) SetVerifyClientURLFailOpen(v bool) { s.verifyClientsURLFailOpen = v } +// SetTailscaledSocketPath sets the unix socket path to use to talk to +// tailscaled if client verification is enabled. +// +// If unset or set to the empty string, the default path for the operating +// system is used. +func (s *Server) SetTailscaledSocketPath(path string) { + s.localClient.Socket = path + s.localClient.UseSocketOnly = path != "" +} + +// SetTCPWriteTimeout sets the timeout for writing to connected clients. +// This timeout does not apply to mesh connections. +// Defaults to 2 seconds. +func (s *Server) SetTCPWriteTimeout(d time.Duration) { + s.tcpWriteTimeout = d +} + // HasMeshKey reports whether the server is configured with a mesh key. func (s *Server) HasMeshKey() bool { return s.meshKey != "" } @@ -528,7 +618,7 @@ func (s *Server) initMetacert() { tmpl := &x509.Certificate{ SerialNumber: big.NewInt(ProtocolVersion), Subject: pkix.Name{ - CommonName: fmt.Sprintf("derpkey%s", s.publicKey.UntypedHexString()), + CommonName: derpconst.MetaCertCommonNamePrefix + s.publicKey.UntypedHexString(), }, // Windows requires NotAfter and NotBefore set: NotAfter: s.clock.Now().Add(30 * 24 * time.Hour), @@ -548,6 +638,25 @@ func (s *Server) initMetacert() { // TLS server to let the client skip a round trip during start-up. func (s *Server) MetaCert() []byte { return s.metaCert } +// ModifyTLSConfigToAddMetaCert modifies c.GetCertificate to make +// it append s.MetaCert to the returned certificates. +// +// It panics if c or c.GetCertificate is nil. +func (s *Server) ModifyTLSConfigToAddMetaCert(c *tls.Config) { + getCert := c.GetCertificate + if getCert == nil { + panic("c.GetCertificate is nil") + } + c.GetCertificate = func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := getCert(hi) + if err != nil { + return nil, err + } + cert.Certificate = append(cert.Certificate, s.MetaCert()) + return cert, nil + } +} + // registerClient notes that client c is now authenticated and ready for packets. // // If c.key is connected more than once, the earlier connection(s) are @@ -600,6 +709,9 @@ func (s *Server) registerClient(c *sclient) { } s.keyOfAddr[c.remoteIPPort] = c.key s.curClients.Add(1) + if c.isNotIdealConn { + s.curClientsNotIdeal.Add(1) + } s.broadcastPeerStateChangeLocked(c.key, c.remoteIPPort, c.presentFlags(), true) } @@ -690,6 +802,9 @@ func (s *Server) unregisterClient(c *sclient) { if c.preferred { s.curHomeClients.Add(-1) } + if c.isNotIdealConn { + s.curClientsNotIdeal.Add(-1) + } } // addPeerGoneFromRegionWatcher adds a function to be called when peer is gone @@ -806,8 +921,8 @@ func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, rem return fmt.Errorf("receive client key: %v", err) } - clientAP, _ := netip.ParseAddrPort(remoteAddr) - if err := s.verifyClient(ctx, clientKey, clientInfo, clientAP.Addr()); err != nil { + remoteIPPort, _ := netip.ParseAddrPort(remoteAddr) + if err := s.verifyClient(ctx, clientKey, clientInfo, remoteIPPort.Addr()); err != nil { return fmt.Errorf("client %v rejected: %v", clientKey, err) } @@ -817,8 +932,6 @@ func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, rem ctx, cancel := context.WithCancel(ctx) defer cancel() - remoteIPPort, _ := netip.ParseAddrPort(remoteAddr) - c := &sclient{ connNum: connNum, s: s, @@ -830,11 +943,12 @@ func (s *Server) accept(ctx context.Context, nc Conn, brw *bufio.ReadWriter, rem done: ctx.Done(), remoteIPPort: remoteIPPort, connectedAt: s.clock.Now(), - sendQueue: make(chan pkt, perClientSendQueueDepth), - discoSendQueue: make(chan pkt, perClientSendQueueDepth), + sendQueue: make(chan pkt, s.perClientSendQueueDepth), + discoSendQueue: make(chan pkt, s.perClientSendQueueDepth), sendPongCh: make(chan [8]byte, 1), peerGone: make(chan peerGoneMsg), canMesh: s.isMeshPeer(clientInfo), + isNotIdealConn: IdealNodeContextKey.Value(ctx) != "", peerGoneLim: rate.NewLimiter(rate.Every(time.Second), 3), } @@ -881,6 +995,9 @@ func (c *sclient) run(ctx context.Context) error { if errors.Is(err, context.Canceled) { c.debugLogf("sender canceled by reader exiting") } else { + if errors.Is(err, os.ErrDeadlineExceeded) { + c.s.sclientWriteTimeouts.Add(1) + } c.logf("sender failed: %v", err) } } @@ -1116,31 +1233,37 @@ func (c *sclient) debugLogf(format string, v ...any) { } } -// dropReason is why we dropped a DERP frame. -type dropReason int +type dropReasonKindLabels struct { + Reason string // metric label corresponding to a given dropReason + Kind string // either `disco` or `other` +} -//go:generate go run tailscale.com/cmd/addlicense -file dropreason_string.go go run golang.org/x/tools/cmd/stringer -type=dropReason -trimprefix=dropReason +// dropReason is why we dropped a DERP frame. +type dropReason string const ( - dropReasonUnknownDest dropReason = iota // unknown destination pubkey - dropReasonUnknownDestOnFwd // unknown destination pubkey on a derp-forwarded packet - dropReasonGoneDisconnected // destination tailscaled disconnected before we could send - dropReasonQueueHead // destination queue is full, dropped packet at queue head - dropReasonQueueTail // destination queue is full, dropped packet at queue tail - dropReasonWriteError // OS write() failed - dropReasonDupClient // the public key is connected 2+ times (active/active, fighting) - numDropReasons // unused; keep last + dropReasonUnknownDest dropReason = "unknown_dest" // unknown destination pubkey + dropReasonUnknownDestOnFwd dropReason = "unknown_dest_on_fwd" // unknown destination pubkey on a derp-forwarded packet + dropReasonGoneDisconnected dropReason = "gone_disconnected" // destination tailscaled disconnected before we could send + dropReasonQueueHead dropReason = "queue_head" // destination queue is full, dropped packet at queue head + dropReasonQueueTail dropReason = "queue_tail" // destination queue is full, dropped packet at queue tail + dropReasonWriteError dropReason = "write_error" // OS write() failed + dropReasonDupClient dropReason = "dup_client" // the public key is connected 2+ times (active/active, fighting) ) func (s *Server) recordDrop(packetBytes []byte, srcKey, dstKey key.NodePublic, reason dropReason) { - s.packetsDropped.Add(1) - s.packetsDroppedReasonCounters[reason].Add(1) + labels := dropReasonKindLabels{ + Reason: string(reason), + } looksDisco := disco.LooksLikeDiscoWrapper(packetBytes) if looksDisco { - s.packetsDroppedTypeDisco.Add(1) + labels.Kind = string(packetKindDisco) } else { - s.packetsDroppedTypeOther.Add(1) + labels.Kind = string(packetKindOther) } + packetsDropped.Add(labels, 1) + bytesDropped.Add(labels, int64(len(packetBytes))) + if verboseDropKeys[dstKey] { // Preformat the log string prior to calling limitedLogf. The // limiter acts based on the format string, and we want to @@ -1229,8 +1352,6 @@ func (c *sclient) requestMeshUpdate() { } } -var localClient tailscale.LocalClient - // isMeshPeer reports whether the client is a trusted mesh peer // node in the DERP region. func (s *Server) isMeshPeer(info *clientInfo) bool { @@ -1249,7 +1370,7 @@ func (s *Server) verifyClient(ctx context.Context, clientKey key.NodePublic, inf // tailscaled-based verification: if s.verifyClientsLocalTailscaled { - _, err := localClient.WhoIsNodeKey(ctx, clientKey) + _, err := s.localClient.WhoIsNodeKey(ctx, clientKey) if err == tailscale.ErrPeerNotFound { return fmt.Errorf("peer %v not authorized (not found in local tailscaled)", clientKey) } @@ -1505,6 +1626,7 @@ type sclient struct { peerGone chan peerGoneMsg // write request that a peer is not at this server (not used by mesh peers) meshUpdate chan struct{} // write request to write peerStateChange canMesh bool // clientInfo had correct mesh token for inter-region routing + isNotIdealConn bool // client indicated it is not its ideal node in the region isDup atomic.Bool // whether more than 1 sclient for key is connected isDisabled atomic.Bool // whether sends to this peer are disabled due to active/active dups debug bool // turn on for verbose logging @@ -1540,6 +1662,9 @@ func (c *sclient) presentFlags() PeerPresentFlags { if c.canMesh { f |= PeerPresentIsMeshPeer } + if c.isNotIdealConn { + f |= PeerPresentNotIdeal + } if f == 0 { return PeerPresentIsRegular } @@ -1721,7 +1846,30 @@ func (c *sclient) sendLoop(ctx context.Context) error { } func (c *sclient) setWriteDeadline() { - c.nc.SetWriteDeadline(time.Now().Add(writeTimeout)) + d := c.s.tcpWriteTimeout + if c.canMesh { + // Trusted peers get more tolerance. + // + // The "canMesh" is a bit of a misnomer; mesh peers typically run over a + // different interface for a per-region private VPC and are not + // throttled. But monitoring software elsewhere over the internet also + // use the private mesh key to subscribe to connect/disconnect events + // and might hit throttling and need more time to get the initial dump + // of connected peers. + d = privilegedWriteTimeout + } + if d == 0 { + // A zero value should disable the write deadline per + // --tcp-write-timeout docs. The flag should only be applicable for + // non-mesh connections, again per its docs. If mesh happened to use a + // zero value constant above it would be a bug, so we don't bother + // with a condition on c.canMesh. + return + } + // Ignore the error from setting the write deadline. In practice, + // setting the deadline will only fail if the connection is closed + // or closing, so the subsequent Write() will fail anyway. + _ = c.nc.SetWriteDeadline(time.Now().Add(d)) } // sendKeepAlive sends a keep-alive frame, without flushing. @@ -2033,6 +2181,7 @@ func (s *Server) ExpVar() expvar.Var { m.Set("gauge_current_file_descriptors", expvar.Func(func() any { return metrics.CurrentFDs() })) m.Set("gauge_current_connections", &s.curClients) m.Set("gauge_current_home_connections", &s.curHomeClients) + m.Set("gauge_current_notideal_connections", &s.curClientsNotIdeal) m.Set("gauge_clients_total", expvar.Func(func() any { return len(s.clientsMesh) })) m.Set("gauge_clients_local", expvar.Func(func() any { return len(s.clients) })) m.Set("gauge_clients_remote", expvar.Func(func() any { return len(s.clientsMesh) - len(s.clients) })) @@ -2042,9 +2191,6 @@ func (s *Server) ExpVar() expvar.Var { m.Set("accepts", &s.accepts) m.Set("bytes_received", &s.bytesRecv) m.Set("bytes_sent", &s.bytesSent) - m.Set("packets_dropped", &s.packetsDropped) - m.Set("counter_packets_dropped_reason", &s.packetsDroppedReason) - m.Set("counter_packets_dropped_type", &s.packetsDroppedType) m.Set("counter_packets_received_kind", &s.packetsRecvByKind) m.Set("packets_sent", &s.packetsSent) m.Set("packets_received", &s.packetsRecv) @@ -2060,6 +2206,7 @@ func (s *Server) ExpVar() expvar.Var { m.Set("multiforwarder_created", &s.multiForwarderCreated) m.Set("multiforwarder_deleted", &s.multiForwarderDeleted) m.Set("packet_forwarder_delete_other_value", &s.removePktForwardOther) + m.Set("sclient_write_timeouts", &s.sclientWriteTimeouts) m.Set("average_queue_duration_ms", expvar.Func(func() any { return math.Float64frombits(atomic.LoadUint64(s.avgQueueDuration)) })) @@ -2123,7 +2270,7 @@ func (s *Server) ConsistencyCheck() error { func (s *Server) checkVerifyClientsLocalTailscaled() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - status, err := localClient.StatusWithoutPeers(ctx) + status, err := s.localClient.StatusWithoutPeers(ctx) if err != nil { return fmt.Errorf("localClient.Status: %w", err) } diff --git a/derp/derp_server_default.go b/derp/derp_server_default.go index 3e0b5b5e96763..014cfffd642c2 100644 --- a/derp/derp_server_default.go +++ b/derp/derp_server_default.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !linux +//go:build !linux || android package derp diff --git a/derp/derp_server_linux.go b/derp/derp_server_linux.go index bfc2aade6588c..5a40e114eecd2 100644 --- a/derp/derp_server_linux.go +++ b/derp/derp_server_linux.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build linux && !android + package derp import ( diff --git a/derp/derp_test.go b/derp/derp_test.go index 9185194dd79cf..c5a92bafae1dd 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -6,6 +6,7 @@ package derp import ( "bufio" "bytes" + "cmp" "context" "crypto/x509" "encoding/asn1" @@ -23,8 +24,10 @@ import ( "testing" "time" + qt "github.com/frankban/quicktest" "go4.org/mem" "golang.org/x/time/rate" + "tailscale.com/derp/derpconst" "tailscale.com/disco" "tailscale.com/net/memnet" "tailscale.com/tstest" @@ -928,7 +931,7 @@ func TestMetaCert(t *testing.T) { if fmt.Sprint(cert.SerialNumber) != fmt.Sprint(ProtocolVersion) { t.Errorf("serial = %v; want %v", cert.SerialNumber, ProtocolVersion) } - if g, w := cert.Subject.CommonName, fmt.Sprintf("derpkey%s", pub.UntypedHexString()); g != w { + if g, w := cert.Subject.CommonName, derpconst.MetaCertCommonNamePrefix+pub.UntypedHexString(); g != w { t.Errorf("CommonName = %q; want %q", g, w) } if n := len(cert.Extensions); n != 1 { @@ -1598,3 +1601,29 @@ func TestServerRepliesToPing(t *testing.T) { } } } + +func TestGetPerClientSendQueueDepth(t *testing.T) { + c := qt.New(t) + envKey := "TS_DEBUG_DERP_PER_CLIENT_SEND_QUEUE_DEPTH" + + testCases := []struct { + envVal string + want int + }{ + // Empty case, envknob treats empty as missing also. + { + "", defaultPerClientSendQueueDepth, + }, + { + "64", 64, + }, + } + + for _, tc := range testCases { + t.Run(cmp.Or(tc.envVal, "empty"), func(t *testing.T) { + t.Setenv(envKey, tc.envVal) + val := getPerClientSendQueueDepth() + c.Assert(val, qt.Equals, tc.want) + }) + } +} diff --git a/derp/derpconst/derpconst.go b/derp/derpconst/derpconst.go new file mode 100644 index 0000000000000..74ca09ccb734b --- /dev/null +++ b/derp/derpconst/derpconst.go @@ -0,0 +1,11 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package derpconst contains constants used by the DERP client and server. +package derpconst + +// MetaCertCommonNamePrefix is the prefix that the DERP server +// puts on for the common name of its "metacert". The suffix of +// the common name after "derpkey" is the hex key.NodePublic +// of the DERP server. +const MetaCertCommonNamePrefix = "derpkey" diff --git a/derp/derphttp/derphttp_client.go b/derp/derphttp/derphttp_client.go index b8cce8cdcb4fa..faa218ca25f0a 100644 --- a/derp/derphttp/derphttp_client.go +++ b/derp/derphttp/derphttp_client.go @@ -30,11 +30,13 @@ import ( "go4.org/mem" "tailscale.com/derp" + "tailscale.com/derp/derpconst" "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/net/dnscache" "tailscale.com/net/netmon" "tailscale.com/net/netns" + "tailscale.com/net/netx" "tailscale.com/net/sockstats" "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" @@ -313,6 +315,9 @@ func (c *Client) preferIPv6() bool { var dialWebsocketFunc func(ctx context.Context, urlStr string) (net.Conn, error) func useWebsockets() bool { + if !canWebsockets { + return false + } if runtime.GOOS == "js" { return true } @@ -383,7 +388,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien var node *tailcfg.DERPNode // nil when using c.url to dial var idealNodeInRegion bool switch { - case useWebsockets(): + case canWebsockets && useWebsockets(): var urlStr string if c.url != nil { urlStr = c.url.String() @@ -498,7 +503,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien req.Header.Set("Connection", "Upgrade") if !idealNodeInRegion && reg != nil { // This is purely informative for now (2024-07-06) for stats: - req.Header.Set("Ideal-Node", reg.Nodes[0].Name) + req.Header.Set(derp.IdealNodeHeader, reg.Nodes[0].Name) // TODO(bradfitz,raggi): start a time.AfterFunc for 30m-1h or so to // dialNode(reg.Nodes[0]) and see if we can even TCP connect to it. If // so, TLS handshake it as well (which is mixed up in this massive @@ -584,7 +589,7 @@ func (c *Client) connect(ctx context.Context, caller string) (client *derp.Clien // // The primary use for this is the derper mesh mode to connect to each // other over a VPC network. -func (c *Client) SetURLDialer(dialer func(ctx context.Context, network, addr string) (net.Conn, error)) { +func (c *Client) SetURLDialer(dialer netx.DialFunc) { c.dialer = dialer } @@ -649,7 +654,11 @@ func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn { tlsConf.VerifyConnection = nil } if node.CertName != "" { - tlsdial.SetConfigExpectedCert(tlsConf, node.CertName) + if suf, ok := strings.CutPrefix(node.CertName, "sha256-raw:"); ok { + tlsdial.SetConfigExpectedCertHash(tlsConf, suf) + } else { + tlsdial.SetConfigExpectedCert(tlsConf, node.CertName) + } } } return tls.Client(nc, tlsConf) @@ -663,7 +672,7 @@ func (c *Client) tlsClient(nc net.Conn, node *tailcfg.DERPNode) *tls.Conn { func (c *Client) DialRegionTLS(ctx context.Context, reg *tailcfg.DERPRegion) (tlsConn *tls.Conn, connClose io.Closer, node *tailcfg.DERPNode, err error) { tcpConn, node, err := c.dialRegion(ctx, reg) if err != nil { - return nil, nil, nil, err + return nil, nil, nil, fmt.Errorf("dialRegion(%d): %w", reg.RegionID, err) } done := make(chan bool) // unbuffered defer close(done) @@ -738,6 +747,17 @@ func (c *Client) dialNode(ctx context.Context, n *tailcfg.DERPNode) (net.Conn, e nwait := 0 startDial := func(dstPrimary, proto string) { + dst := cmp.Or(dstPrimary, n.HostName) + + // If dialing an IP address directly, check its address family + // and bail out before incrementing nwait. + if ip, err := netip.ParseAddr(dst); err == nil { + if proto == "tcp4" && ip.Is6() || + proto == "tcp6" && ip.Is4() { + return + } + } + nwait++ go func() { if proto == "tcp4" && c.preferIPv6() { @@ -752,8 +772,10 @@ func (c *Client) dialNode(ctx context.Context, n *tailcfg.DERPNode) (net.Conn, e // Start v4 dial } } - dst := cmp.Or(dstPrimary, n.HostName) port := "443" + if !c.useHTTPS() { + port = "3340" + } if n.DERPPort != 0 { port = fmt.Sprint(n.DERPPort) } @@ -1131,7 +1153,7 @@ var ErrClientClosed = errors.New("derphttp.Client closed") func parseMetaCert(certs []*x509.Certificate) (serverPub key.NodePublic, serverProtoVersion int) { for _, cert := range certs { // Look for derpkey prefix added by initMetacert() on the server side. - if pubHex, ok := strings.CutPrefix(cert.Subject.CommonName, "derpkey"); ok { + if pubHex, ok := strings.CutPrefix(cert.Subject.CommonName, derpconst.MetaCertCommonNamePrefix); ok { var err error serverPub, err = key.ParseNodePublicUntyped(mem.S(pubHex)) if err == nil && cert.SerialNumber.BitLen() <= 8 { // supports up to version 255 diff --git a/derp/derphttp/derphttp_server.go b/derp/derphttp/derphttp_server.go index 41ce86764f66a..50aba774a9f1c 100644 --- a/derp/derphttp/derphttp_server.go +++ b/derp/derphttp/derphttp_server.go @@ -21,6 +21,8 @@ const fastStartHeader = "Derp-Fast-Start" // Handler returns an http.Handler to be mounted at /derp, serving s. func Handler(s *derp.Server) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + // These are installed both here and in cmd/derper. The check here // catches both cmd/derper run with DERP disabled (STUN only mode) as // well as DERP being run in tests with derphttp.Handler directly, @@ -66,7 +68,11 @@ func Handler(s *derp.Server) http.Handler { pubKey.UntypedHexString()) } - s.Accept(r.Context(), netConn, conn, netConn.RemoteAddr().String()) + if v := r.Header.Get(derp.IdealNodeHeader); v != "" { + ctx = derp.IdealNodeContextKey.WithValue(ctx, v) + } + + s.Accept(ctx, netConn, conn, netConn.RemoteAddr().String()) }) } @@ -92,6 +98,7 @@ func ServeNoContent(w http.ResponseWriter, r *http.Request) { w.Header().Set(NoContentResponseHeader, "response "+challenge) } } + w.Header().Set("Cache-Control", "no-cache, no-store, must-revalidate, no-transform, max-age=0") w.WriteHeader(http.StatusNoContent) } @@ -99,7 +106,7 @@ func isChallengeChar(c rune) bool { // Semi-randomly chosen as a limited set of valid characters return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || ('0' <= c && c <= '9') || - c == '.' || c == '-' || c == '_' + c == '.' || c == '-' || c == '_' || c == ':' } const ( diff --git a/derp/derphttp/websocket.go b/derp/derphttp/websocket.go index 6ef47473a2532..9dd640ee37083 100644 --- a/derp/derphttp/websocket.go +++ b/derp/derphttp/websocket.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || js +//go:build js || ((linux || darwin) && ts_debug_websockets) package derphttp @@ -14,6 +14,8 @@ import ( "tailscale.com/net/wsconn" ) +const canWebsockets = true + func init() { dialWebsocketFunc = dialWebsocket } diff --git a/derp/derphttp/websocket_stub.go b/derp/derphttp/websocket_stub.go new file mode 100644 index 0000000000000..d84bfba571f80 --- /dev/null +++ b/derp/derphttp/websocket_stub.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !(js || ((linux || darwin) && ts_debug_websockets)) + +package derphttp + +const canWebsockets = false diff --git a/derp/dropreason_string.go b/derp/dropreason_string.go deleted file mode 100644 index 3ad0728194b73..0000000000000 --- a/derp/dropreason_string.go +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Code generated by "stringer -type=dropReason -trimprefix=dropReason"; DO NOT EDIT. - -package derp - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[dropReasonUnknownDest-0] - _ = x[dropReasonUnknownDestOnFwd-1] - _ = x[dropReasonGoneDisconnected-2] - _ = x[dropReasonQueueHead-3] - _ = x[dropReasonQueueTail-4] - _ = x[dropReasonWriteError-5] - _ = x[dropReasonDupClient-6] - _ = x[numDropReasons-7] -} - -const _dropReason_name = "UnknownDestUnknownDestOnFwdGoneDisconnectedQueueHeadQueueTailWriteErrorDupClientnumDropReasons" - -var _dropReason_index = [...]uint8{0, 11, 27, 43, 52, 61, 71, 80, 94} - -func (i dropReason) String() string { - if i < 0 || i >= dropReason(len(_dropReason_index)-1) { - return "dropReason(" + strconv.FormatInt(int64(i), 10) + ")" - } - return _dropReason_name[_dropReason_index[i]:_dropReason_index[i+1]] -} diff --git a/disco/disco.go b/disco/disco.go index b9a90029d9ca8..0854eb4c0af5a 100644 --- a/disco/disco.go +++ b/disco/disco.go @@ -25,6 +25,7 @@ import ( "fmt" "net" "net/netip" + "time" "go4.org/mem" "tailscale.com/types/key" @@ -41,9 +42,13 @@ const NonceLen = 24 type MessageType byte const ( - TypePing = MessageType(0x01) - TypePong = MessageType(0x02) - TypeCallMeMaybe = MessageType(0x03) + TypePing = MessageType(0x01) + TypePong = MessageType(0x02) + TypeCallMeMaybe = MessageType(0x03) + TypeBindUDPRelayEndpoint = MessageType(0x04) + TypeBindUDPRelayEndpointChallenge = MessageType(0x05) + TypeBindUDPRelayEndpointAnswer = MessageType(0x06) + TypeCallMeMaybeVia = MessageType(0x07) ) const v0 = byte(0) @@ -77,12 +82,21 @@ func Parse(p []byte) (Message, error) { } t, ver, p := MessageType(p[0]), p[1], p[2:] switch t { + // TODO(jwhited): consider using a signature matching encoding.BinaryUnmarshaler case TypePing: return parsePing(ver, p) case TypePong: return parsePong(ver, p) case TypeCallMeMaybe: return parseCallMeMaybe(ver, p) + case TypeBindUDPRelayEndpoint: + return parseBindUDPRelayEndpoint(ver, p) + case TypeBindUDPRelayEndpointChallenge: + return parseBindUDPRelayEndpointChallenge(ver, p) + case TypeBindUDPRelayEndpointAnswer: + return parseBindUDPRelayEndpointAnswer(ver, p) + case TypeCallMeMaybeVia: + return parseCallMeMaybeVia(ver, p) default: return nil, fmt.Errorf("unknown message type 0x%02x", byte(t)) } @@ -91,6 +105,7 @@ func Parse(p []byte) (Message, error) { // Message a discovery message. type Message interface { // AppendMarshal appends the message's marshaled representation. + // TODO(jwhited): consider using a signature matching encoding.BinaryAppender AppendMarshal([]byte) []byte } @@ -266,7 +281,209 @@ func MessageSummary(m Message) string { return fmt.Sprintf("pong tx=%x", m.TxID[:6]) case *CallMeMaybe: return "call-me-maybe" + case *BindUDPRelayEndpoint: + return "bind-udp-relay-endpoint" + case *BindUDPRelayEndpointChallenge: + return "bind-udp-relay-endpoint-challenge" + case *BindUDPRelayEndpointAnswer: + return "bind-udp-relay-endpoint-answer" default: return fmt.Sprintf("%#v", m) } } + +// BindUDPRelayHandshakeState represents the state of the 3-way bind handshake +// between UDP relay client and UDP relay server. Its potential values include +// those for both participants, UDP relay client and UDP relay server. A UDP +// relay server implementation can be found in net/udprelay. This is currently +// considered experimental. +type BindUDPRelayHandshakeState int + +const ( + // BindUDPRelayHandshakeStateInit represents the initial state prior to any + // message being transmitted. + BindUDPRelayHandshakeStateInit BindUDPRelayHandshakeState = iota + // BindUDPRelayHandshakeStateBindSent is the first client state after + // transmitting a BindUDPRelayEndpoint message to a UDP relay server. + BindUDPRelayHandshakeStateBindSent + // BindUDPRelayHandshakeStateChallengeSent is the first server state after + // receiving a BindUDPRelayEndpoint message from a UDP relay client and + // replying with a BindUDPRelayEndpointChallenge. + BindUDPRelayHandshakeStateChallengeSent + // BindUDPRelayHandshakeStateAnswerSent is a client state that is entered + // after transmitting a BindUDPRelayEndpointAnswer message towards a UDP + // relay server in response to a BindUDPRelayEndpointChallenge message. + BindUDPRelayHandshakeStateAnswerSent + // BindUDPRelayHandshakeStateAnswerReceived is a server state that is + // entered after it has received a correct BindUDPRelayEndpointAnswer + // message from a UDP relay client in response to a + // BindUDPRelayEndpointChallenge message. + BindUDPRelayHandshakeStateAnswerReceived +) + +// bindUDPRelayEndpointLen is the length of a marshalled BindUDPRelayEndpoint +// message, without the message header. +const bindUDPRelayEndpointLen = BindUDPRelayEndpointChallengeLen + +// BindUDPRelayEndpoint is the first messaged transmitted from UDP relay client +// towards UDP relay server as part of the 3-way bind handshake. It is padded to +// match the length of BindUDPRelayEndpointChallenge. This message type is +// currently considered experimental and is not yet tied to a +// tailcfg.CapabilityVersion. +type BindUDPRelayEndpoint struct { +} + +func (m *BindUDPRelayEndpoint) AppendMarshal(b []byte) []byte { + ret, _ := appendMsgHeader(b, TypeBindUDPRelayEndpoint, v0, bindUDPRelayEndpointLen) + return ret +} + +func parseBindUDPRelayEndpoint(ver uint8, p []byte) (m *BindUDPRelayEndpoint, err error) { + m = new(BindUDPRelayEndpoint) + return m, nil +} + +// BindUDPRelayEndpointChallengeLen is the length of a marshalled +// BindUDPRelayEndpointChallenge message, without the message header. +const BindUDPRelayEndpointChallengeLen = 32 + +// BindUDPRelayEndpointChallenge is transmitted from UDP relay server towards +// UDP relay client in response to a BindUDPRelayEndpoint message as part of the +// 3-way bind handshake. This message type is currently considered experimental +// and is not yet tied to a tailcfg.CapabilityVersion. +type BindUDPRelayEndpointChallenge struct { + Challenge [BindUDPRelayEndpointChallengeLen]byte +} + +func (m *BindUDPRelayEndpointChallenge) AppendMarshal(b []byte) []byte { + ret, d := appendMsgHeader(b, TypeBindUDPRelayEndpointChallenge, v0, BindUDPRelayEndpointChallengeLen) + copy(d, m.Challenge[:]) + return ret +} + +func parseBindUDPRelayEndpointChallenge(ver uint8, p []byte) (m *BindUDPRelayEndpointChallenge, err error) { + if len(p) < BindUDPRelayEndpointChallengeLen { + return nil, errShort + } + m = new(BindUDPRelayEndpointChallenge) + copy(m.Challenge[:], p[:]) + return m, nil +} + +// bindUDPRelayEndpointAnswerLen is the length of a marshalled +// BindUDPRelayEndpointAnswer message, without the message header. +const bindUDPRelayEndpointAnswerLen = BindUDPRelayEndpointChallengeLen + +// BindUDPRelayEndpointAnswer is transmitted from UDP relay client to UDP relay +// server in response to a BindUDPRelayEndpointChallenge message. This message +// type is currently considered experimental and is not yet tied to a +// tailcfg.CapabilityVersion. +type BindUDPRelayEndpointAnswer struct { + Answer [bindUDPRelayEndpointAnswerLen]byte +} + +func (m *BindUDPRelayEndpointAnswer) AppendMarshal(b []byte) []byte { + ret, d := appendMsgHeader(b, TypeBindUDPRelayEndpointAnswer, v0, bindUDPRelayEndpointAnswerLen) + copy(d, m.Answer[:]) + return ret +} + +func parseBindUDPRelayEndpointAnswer(ver uint8, p []byte) (m *BindUDPRelayEndpointAnswer, err error) { + if len(p) < bindUDPRelayEndpointAnswerLen { + return nil, errShort + } + m = new(BindUDPRelayEndpointAnswer) + copy(m.Answer[:], p[:]) + return m, nil +} + +// CallMeMaybeVia is a message sent only over DERP to request that the recipient +// try to open up a magicsock path back to the sender. The 'Via' in +// CallMeMaybeVia highlights that candidate paths are served through an +// intermediate relay, likely a [tailscale.com/net/udprelay.Server]. +// +// Usage of the candidate paths in magicsock requires a 3-way handshake +// involving [BindUDPRelayEndpoint], [BindUDPRelayEndpointChallenge], and +// [BindUDPRelayEndpointAnswer]. +// +// CallMeMaybeVia mirrors [tailscale.com/net/udprelay/endpoint.ServerEndpoint], +// which contains field documentation. +// +// The recipient may choose to not open a path back if it's already happy with +// its path. Direct connections, e.g. [CallMeMaybe]-signaled, take priority over +// CallMeMaybeVia paths. +// +// This message type is currently considered experimental and is not yet tied to +// a [tailscale.com/tailcfg.CapabilityVersion]. +type CallMeMaybeVia struct { + // ServerDisco is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.ServerDisco] + ServerDisco key.DiscoPublic + // LamportID is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.LamportID] + LamportID uint64 + // VNI is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.VNI] + VNI uint32 + // BindLifetime is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.BindLifetime] + BindLifetime time.Duration + // SteadyStateLifetime is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.SteadyStateLifetime] + SteadyStateLifetime time.Duration + // AddrPorts is [tailscale.com/net/udprelay/endpoint.ServerEndpoint.AddrPorts] + AddrPorts []netip.AddrPort +} + +const cmmvDataLenMinusEndpoints = key.DiscoPublicRawLen + // ServerDisco + 8 + // LamportID + 4 + // VNI + 8 + // BindLifetime + 8 // SteadyStateLifetime + +func (m *CallMeMaybeVia) AppendMarshal(b []byte) []byte { + endpointsLen := epLength * len(m.AddrPorts) + ret, p := appendMsgHeader(b, TypeCallMeMaybeVia, v0, cmmvDataLenMinusEndpoints+endpointsLen) + disco := m.ServerDisco.AppendTo(nil) + copy(p, disco) + p = p[key.DiscoPublicRawLen:] + binary.BigEndian.PutUint64(p[:8], m.LamportID) + p = p[8:] + binary.BigEndian.PutUint32(p[:4], m.VNI) + p = p[4:] + binary.BigEndian.PutUint64(p[:8], uint64(m.BindLifetime)) + p = p[8:] + binary.BigEndian.PutUint64(p[:8], uint64(m.SteadyStateLifetime)) + p = p[8:] + for _, ipp := range m.AddrPorts { + a := ipp.Addr().As16() + copy(p, a[:]) + binary.BigEndian.PutUint16(p[16:18], ipp.Port()) + p = p[epLength:] + } + return ret +} + +func parseCallMeMaybeVia(ver uint8, p []byte) (m *CallMeMaybeVia, err error) { + m = new(CallMeMaybeVia) + if len(p) < cmmvDataLenMinusEndpoints+epLength || + (len(p)-cmmvDataLenMinusEndpoints)%epLength != 0 || + ver != 0 { + return m, nil + } + m.ServerDisco = key.DiscoPublicFromRaw32(mem.B(p[:key.DiscoPublicRawLen])) + p = p[key.DiscoPublicRawLen:] + m.LamportID = binary.BigEndian.Uint64(p[:8]) + p = p[8:] + m.VNI = binary.BigEndian.Uint32(p[:4]) + p = p[4:] + m.BindLifetime = time.Duration(binary.BigEndian.Uint64(p[:8])) + p = p[8:] + m.SteadyStateLifetime = time.Duration(binary.BigEndian.Uint64(p[:8])) + p = p[8:] + m.AddrPorts = make([]netip.AddrPort, 0, len(p)-cmmvDataLenMinusEndpoints/epLength) + for len(p) > 0 { + var a [16]byte + copy(a[:], p) + m.AddrPorts = append(m.AddrPorts, netip.AddrPortFrom( + netip.AddrFrom16(a).Unmap(), + binary.BigEndian.Uint16(p[16:18]))) + p = p[epLength:] + } + return m, nil +} diff --git a/disco/disco_test.go b/disco/disco_test.go index 1a56324a5a423..f2a29a744992f 100644 --- a/disco/disco_test.go +++ b/disco/disco_test.go @@ -9,6 +9,7 @@ import ( "reflect" "strings" "testing" + "time" "go4.org/mem" "tailscale.com/types/key" @@ -83,6 +84,44 @@ func TestMarshalAndParse(t *testing.T) { }, want: "03 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15", }, + { + name: "bind_udp_relay_endpoint", + m: &BindUDPRelayEndpoint{}, + want: "04 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00", + }, + { + name: "bind_udp_relay_endpoint_challenge", + m: &BindUDPRelayEndpointChallenge{ + Challenge: [BindUDPRelayEndpointChallengeLen]byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + }, + }, + want: "05 00 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f", + }, + { + name: "bind_udp_relay_endpoint_answer", + m: &BindUDPRelayEndpointAnswer{ + Answer: [bindUDPRelayEndpointAnswerLen]byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + }, + }, + want: "06 00 00 01 02 03 04 05 06 07 08 09 0a 0b 0c 0d 0e 0f 10 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f", + }, + { + name: "call_me_maybe_via", + m: &CallMeMaybeVia{ + ServerDisco: key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 30: 30, 31: 31})), + LamportID: 123, + VNI: 456, + BindLifetime: time.Second, + SteadyStateLifetime: time.Minute, + AddrPorts: []netip.AddrPort{ + netip.MustParseAddrPort("1.2.3.4:567"), + netip.MustParseAddrPort("[2001::3456]:789"), + }, + }, + want: "07 00 00 01 02 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f 00 00 00 00 00 00 00 7b 00 00 01 c8 00 00 00 00 3b 9a ca 00 00 00 00 0d f8 47 58 00 00 00 00 00 00 00 00 00 00 00 ff ff 01 02 03 04 02 37 20 01 00 00 00 00 00 00 00 00 00 00 00 00 34 56 03 15", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { diff --git a/docs/commit-messages.md b/docs/commit-messages.md new file mode 100644 index 0000000000000..b3881eaeb9fbb --- /dev/null +++ b/docs/commit-messages.md @@ -0,0 +1,194 @@ +# Commit messages + +There are different styles of commit messages followed by different projects. +This is Tailscale's style guide for writing git commit messages. +As with all style guides, many things here are subjective and exist primarily to +codify existing conventions and promote uniformity and thus ease of reading by +others. Others have stronger reasons, such as interop with tooling or making +future git archaeology easier. + +Our commit message style is largely based on the Go language's style, which +shares much in common with the Linux kernel's git commit message style (for +which git was invented): + +* Go's high-level example: https://go.dev/doc/contribute#commit_messages +* Go's details: https://golang.org/wiki/CommitMessage +* Linux's style: https://www.kernel.org/doc/html/v4.10/process/submitting-patches.html#describe-your-changes + +(We do *not* use the [Conventional +Commits](https://www.conventionalcommits.org/en/v1.0.0/) style or [Semantic +Commits](https://gist.github.com/joshbuchea/6f47e86d2510bce28f8e7f42ae84c716) +styles. They're reasonable, but we have already been using the Go and Linux +style of commit messages and there is little justification for switching styles. +Consistency is valuable.) + +In a nutshell, our commit messages should look like: + +``` +net/http: handle foo when bar + +[longer description here in the body] + +Fixes #nnnn +``` + +Notably, for the subject (the first line of description): + +- the primary director(ies) from the root affected by the change goes before the colon, e.g. “derp/derphttp:” (if a lot of packages are involved, you can abbreviate to top-level names e.g. ”derp,magicsock:”, and/or remove less relevant packages) +- the part after the colon is a verb, ideally an imperative verb (Linux style, telling the code what to do) or alternatively an infinitive verb that completes the blank in, *"this change modifies Tailscale to ___________"*. e.g. say *“fix the foobar feature”*, not *“fixing”*, *“fixed”*, or *“fixes”*. Or, as Linux guidelines say: + > Describe your changes in imperative mood, e.g. “make xyzzy do frotz” instead of “[This patch] makes xyzzy do frotz” or “[I] changed xyzzy to do frotz”, as if you are giving orders to the codebase to change its behaviour." +- the verb after the colon is lowercase +- there is no trailing period +- it should be kept as short as possible (many git viewing tools prefer under ~76 characters, though we aren’t super strict about this) + + Examples: + + | Good Example | notes | + | ------- | --- | + | `foo/bar: fix memory leak` | | + | `foo/bar: bump deps` | | + | `foo/bar: temporarily restrict access` | adverbs are okay | + | `foo/bar: implement new UI design` | | + | `control/{foo,bar}: optimize bar` | feel free to use {foo,bar} for common subpackages| + + | Bad Example | notes | + | ------- | --- | + | `fixed memory leak` | BAD: missing package prefix | + | `foo/bar: fixed memory leak` | BAD: past tense | + | `foo/bar: fixing memory leak` | BAD: present continuous tense; no `-ing` verbs | + | `foo/bar: bumping deps` | BAD: present continuous tense; no `-ing` verbs | + | `foo/bar: new UI design` | BAD: that's a noun phrase; no verb | + | `foo/bar: made things larger` | BAD: that's past tense | + | `foo/bar: faster algorithm` | BAD: that's an adjective and a noun, not a verb | + | `foo/bar: Fix memory leak` | BAD: capitalized verb | + | `foo/bar: fix memory leak.` | BAD: trailing period | + | `foo/bar:fix memory leak` | BAD: no space after colon | + | `foo/bar : fix memory leak` | BAD: space before colon | + | `foo/bar: fix memory leak Fixes #123` | BAD: the "Fixes" shouldn't be part of the title | + | `!fixup reviewer feedback` | BAD: we don't check in fixup commits; the history should always bissect to a clean, working tree | + + +For the body (the rest of the description): + +- blank line after the subject (first) line +- the text should be wrapped to ~76 characters (to appease git viewing tools, mainly), unless you really need longer lines (e.g. for ASCII art, tables, or long links) +- there must be a `Fixes` or `Updates` line for all non-cleanup commits linking to a tracking bug. This goes after the body with a blank newline separating the two. [Cleanup commits](#is-it-a-cleanup) can use `Updates #cleanup` instead of an issue. +- `Change-Id` lines should ideally be included in commits in the `corp` repo and are more optional in `tailscale/tailscale`. You can configure Git to do this for you by running `./tool/go run misc/install-git-hooks.go` from the root of the corp repo. This was originally a Gerrit thing and we don't use Gerrit, but it lets us tooling track commits as they're cherry-picked between branches. Also, tools like [git-cleanup](https://github.com/bradfitz/gitutil) use it to clean up your old local branches once they're merged upstream. +- we don't use Markdown in commit messages. (Accidental Markdown like bulleted lists or even headings is fine, but not links) +- we require `Signed-off-by` lines in public repos (such as `tailscale/tailscale`). Add them using `git commit --signoff` or `git commit -s` for short. You can use them in private repos but do not have to. +- when moving code between repos, include the repository name, and git hash that it was moved from/to, so it is easier to trace history/blame. + +Please don't use [alternate GitHub-supported +aliases](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue) +like `Close` or `Resolves`. Tailscale only uses the verbs `Fixes` and `Updates`. + +To link a commit to an issue without marking it fixed—for example, if the commit +is working toward a fix but not yet a complete fix—GitHub requires only that the +issue is mentioned by number in the commit message. By convention, our commits +mention this at the bottom of the message using `Updates`, where `Fixes` might +be expected, even if the number is also mentioned in the body of the commit +message. + +For example: + +``` +some/dir: refactor func Foo + +This will make the handling of +shorter and easier to test. + +Updates #nnnn +``` + +Please say `Updates` and not other common Github-recognized conventions (that is, don't use `For #nnnn`) + +## Public release notes + +For changes in `tailscale/tailscale` that fix a significant bug or add a new feature that should be included in the release notes for the next release, +add `RELNOTE: ` toward the end of the commit message. +This will aid the release engineer in writing the release notes for the next release. + +## Is it a #cleanup? + +Our issuebot permits writing `Updates #cleanup` instead of an actual GitHub issue number. + +But only do that if it’s actually a cleanup. Don’t use that as an excuse to avoid filing an issue. + +Shortcuts[^1] to file issues: +- [go/bugc](http://go/bugc) (corp, safe choice) +- [go/bugo](http://go/bugo) (open source, if you want it public to the world). + +[^1]: These shortcuts point to our Tailscale’s internal URL shortener service, which you too [can run in your own Tailnet](https://tailscale.com/blog/golink). + +The following guide can help you decide whether a tracking issue is warranted. + +| | | +| --- | --- | +| Was there a crash/panic? | Not a cleanup. Put the panic in a bug. Talk about when it was introduced, why, why a test didn’t catch it, note what followup work might need to be done. | +| Did a customer report it? | Not a cleanup. Make a corp bug with links to the customer ticket. | +| Is it from an incident, get paged? | Not a cleanup. Let’s track why we got paged. | +| Does it change behavior? | Not a cleanup. File a bug to track why. | +| Adding a test for a recently fixed bug? | Not a cleanup. Use the recently fixed bug’s bug number. | +| Does it tweak a constant/parameter? | Not a cleanup. File a bug to track the debugging/tuning effort and record past results and goals for the future state. | +| Fixing a regression from an earlier change? | Not a cleanup. At minimum, reference the PR that caused the regression, but if users noticed, it might warrant its own bug. | +| Is it part of an overall effort that’ll take a hundred small steps? | Not a cleanup. The overall effort should have a tracking bug to collect all the minor efforts. | +| Is it a security fix? Is it a security hardening? | Not a cleanup. There should be a bug about security incidents or security hardening efforts and backporting to previous releases, etc. | +| Is it a feature flag being removed? | Not a cleanup. File a task to coordinate with other teams and to track the work. | + +### Actual cleanup examples + +- Fixing typos in internal comments that users would’ve never seen +- Simple, mechanical replacement of a deprecated API to its equivalently behaving replacement + - [`errors.Wrapf`](https://pkg.go.dev/github.com/pkg/errors#Wrapf) → [`fmt.Errorf("%w")`](https://pkg.go.dev/fmt#Errorf) + - [math/rand](https://pkg.go.dev/math/rand) → [math/rand/v2](https://pkg.go.dev/math/rand/v2) +- Code movement +- Removing dead code that doesn’t change behavior (API changes, feature flags, etc) +- Refactoring in prep for another change (but maybe mention the upcoming change’s bug as motivation) +- Adding a test that you just noticed was missing, not as a result of any bug or report or new feature coming +- Formatting (gofmt / prettifier) that was missed earlier + +### What’s the point of an issue? + +- Let us capture information that is inappropriate for a commit message +- Let us have conversations on a change after the fact +- Let us track metadata on issues and decide what to backport +- Let us associate related changes to each other, including after the fact +- Lets you write the backstory once on an overall bug/effort and re-use that issue number for N future commits, without having to repeat yourself on each commit message +- Provides archaeological breadcrumbs to future debuggers, providing context on why things were changed + +# Reverts + +When you use `git revert` to revert a commit, the default commit message will identify the commit SHA and message that was reverted. You must expand this message to explain **why** it is being reverted, including a link to the associated issue. + +Don't revert reverts. That gets ugly. Send the change anew but reference +the original & earlier revert. + +# Other repos + +To reference an issue in one repo from a commit in another (for example, fixing an issue in corp with a commit in `tailscale/tailscale`), you need to fully-qualify the issue number with the GitHub org/repo syntax: + +``` +cipher/rot13: add new super secure cipher + +Fixes tailscale/corp#1234 +``` + +Referencing a full URL to the issue is also acceptable, but try to prefer the shorter way. + +It's okay to reference the `corp` repo in open source repo commit messages. + +# GitHub Pull Requests + +In the future we plan to make a bot rewrite all PR bodies programmatically from +the commit messages. But for now (2023-07-25).... + +By convention, GitHub Pull Requests follow similar rules to commits, especially +the title of the PR (which should be the first line of the commit). It is less +important to follow these conventions in the PR itself, as it’s the commits that +become a permanent part of the commit history. + +It's okay (but rare) for a PR to contain multiple commits. When a PR does +contain multiple commits, call that out in the PR body for reviewers so they can +review each separately. + +You don't need to include the `Change-Id` in the description of your PR. diff --git a/docs/k8s/operator-architecture.md b/docs/k8s/operator-architecture.md new file mode 100644 index 0000000000000..29672f6a39bd9 --- /dev/null +++ b/docs/k8s/operator-architecture.md @@ -0,0 +1,602 @@ +# Operator architecture diagrams + +The Tailscale [Kubernetes operator][kb-operator] has a collection of use-cases +that can be mixed and matched as required. The following diagrams illustrate +how the operator implements each use-case. + +In each diagram, the "tailscale" namespace is entirely managed by the operator +once the operator itself has been deployed. + +Tailscale devices are highlighted as black nodes. The salient devices for each +use-case are marked as "src" or "dst" to denote which node is a source or a +destination in the context of ACL rules that will apply to network traffic. + +Note, in some cases, the config and the state Secret may be the same Kubernetes +Secret. + +## API server proxy + +[Documentation][kb-operator-proxy] + +The operator runs the API server proxy in-process. If the proxy is running in +"noauth" mode, it forwards HTTP requests unmodified. If the proxy is running in +"auth" mode, it deletes any existing auth headers and adds +[impersonation headers][k8s-impersonation] to the request before forwarding to +the API server. A request with impersonation headers will look something like: + +``` +GET /api/v1/namespaces/default/pods HTTP/1.1 +Host: k8s-api.example.com +Authorization: Bearer +Impersonate-Group: tailnet-readers +Accept: application/json +``` + +```mermaid +%%{ init: { 'theme':'neutral' } }%% +flowchart LR + classDef tsnode color:#fff,fill:#000; + classDef pod fill:#fff; + + subgraph Key + ts[Tailscale device]:::tsnode + pod((Pod)):::pod + blank[" "]-->|WireGuard traffic| blank2[" "] + blank3[" "]-->|Other network traffic| blank4[" "] + end + + subgraph k8s[Kubernetes cluster] + subgraph tailscale-ns[namespace=tailscale] + operator(("operator (dst)")):::tsnode + end + + subgraph controlplane["Control plane"] + api[kube-apiserver] + end + end + + client["client (src)"]:::tsnode --> operator + operator -->|"proxy (maybe with impersonation headers)"| api + + linkStyle 0 stroke:red; + linkStyle 2 stroke:red; + + linkStyle 1 stroke:blue; + linkStyle 3 stroke:blue; + +``` + +## L3 ingress + +[Documentation][kb-operator-l3-ingress] + +The user deploys an app to the default namespace, and creates a normal Service +that selects the app's Pods. Either add the annotation +`tailscale.com/expose: "true"` or specify `.spec.type` as `Loadbalancer` and +`.spec.loadBalancerClass` as `tailscale`. The operator will create an ingress +proxy that allows devices anywhere on the tailnet to access the Service. + +The proxy Pod uses `iptables` or `nftables` rules to DNAT traffic bound for the +proxy's tailnet IP to the Service's internal Cluster IP instead. + +```mermaid +%%{ init: { 'theme':'neutral' } }%% +flowchart TD + classDef tsnode color:#fff,fill:#000; + classDef pod fill:#fff; + + subgraph Key + ts[Tailscale device]:::tsnode + pod((Pod)):::pod + blank[" "]-->|WireGuard traffic| blank2[" "] + blank3[" "]-->|Other network traffic| blank4[" "] + end + + subgraph k8s[Kubernetes cluster] + subgraph tailscale-ns[namespace=tailscale] + operator((operator)):::tsnode + ingress-sts["StatefulSet"] + ingress(("ingress proxy (dst)")):::tsnode + config-secret["config Secret"] + state-secret["state Secret"] + end + + subgraph defaultns[namespace=default] + svc[annotated Service] + svc --> pod1((pod1)) + svc --> pod2((pod2)) + end + end + + client["client (src)"]:::tsnode --> ingress + ingress -->|forwards traffic| svc + operator -.->|creates| ingress-sts + ingress-sts -.->|manages| ingress + operator -.->|reads| svc + operator -.->|creates| config-secret + config-secret -.->|mounted| ingress + ingress -.->|stores state| state-secret + + linkStyle 0 stroke:red; + linkStyle 4 stroke:red; + + linkStyle 1 stroke:blue; + linkStyle 2 stroke:blue; + linkStyle 3 stroke:blue; + linkStyle 5 stroke:blue; + +``` + +## L7 ingress + +[Documentation][kb-operator-l7-ingress] + +The L7 ingress architecture diagram is relatively similar to L3 ingress. It is +configured via an `Ingress` object instead of a `Service`, and uses +`tailscale serve` to accept traffic instead of configuring `iptables` or +`nftables` rules. Note that we use tailscaled's local API (`SetServeConfig`) to +set serve config, not the `tailscale serve` command. + +```mermaid +%%{ init: { 'theme':'neutral' } }%% +flowchart TD + classDef tsnode color:#fff,fill:#000; + classDef pod fill:#fff; + + subgraph Key + ts[Tailscale device]:::tsnode + pod((Pod)):::pod + blank[" "]-->|WireGuard traffic| blank2[" "] + blank3[" "]-->|Other network traffic| blank4[" "] + end + + subgraph k8s[Kubernetes cluster] + subgraph tailscale-ns[namespace=tailscale] + operator((operator)):::tsnode + ingress-sts["StatefulSet"] + ingress-pod(("ingress proxy (dst)")):::tsnode + config-secret["config Secret"] + state-secret["state Secret"] + end + + subgraph cluster-scope[Cluster scoped resources] + ingress-class[Tailscale IngressClass] + end + + subgraph defaultns[namespace=default] + ingress[tailscale Ingress] + svc["Service"] + svc --> pod1((pod1)) + svc --> pod2((pod2)) + end + end + + client["client (src)"]:::tsnode --> ingress-pod + ingress-pod -->|forwards /api prefix traffic| svc + operator -.->|creates| ingress-sts + ingress-sts -.->|manages| ingress-pod + operator -.->|reads| ingress + operator -.->|creates| config-secret + config-secret -.->|mounted| ingress-pod + ingress-pod -.->|stores state| state-secret + ingress -.->|/api prefix| svc + + linkStyle 0 stroke:red; + linkStyle 4 stroke:red; + + linkStyle 1 stroke:blue; + linkStyle 2 stroke:blue; + linkStyle 3 stroke:blue; + linkStyle 5 stroke:blue; + +``` + +## L3 egress + +[Documentation][kb-operator-l3-egress] + +1. The user deploys a Service with `type: ExternalName` and an annotation + `tailscale.com/tailnet-fqdn: db.tails-scales.ts.net`. +1. The operator creates a proxy Pod managed by a single replica StatefulSet, and a headless Service pointing at the proxy Pod. +1. The operator updates the `ExternalName` Service's `spec.externalName` field to point + at the headless Service it created in the previous step. + +(Optional) If the user also adds the `tailscale.com/proxy-group: egress-proxies` +annotation to their `ExternalName` Service, the operator will skip creating a +proxy Pod and instead point the headless Service at the existing ProxyGroup's +pods. In this case, ports are also required in the `ExternalName` Service spec. +See below for a more representative diagram. + +```mermaid +%%{ init: { 'theme':'neutral' } }%% + +flowchart TD + classDef tsnode color:#fff,fill:#000; + classDef pod fill:#fff; + + subgraph Key + ts[Tailscale device]:::tsnode + pod((Pod)):::pod + blank[" "]-->|WireGuard traffic| blank2[" "] + blank3[" "]-->|Other network traffic| blank4[" "] + end + + subgraph k8s[Kubernetes cluster] + subgraph tailscale-ns[namespace=tailscale] + operator((operator)):::tsnode + egress(("egress proxy (src)")):::tsnode + egress-sts["StatefulSet"] + headless-svc[headless Service] + cfg-secret["config Secret"] + state-secret["state Secret"] + end + + subgraph defaultns[namespace=default] + svc[ExternalName Service] + pod1((pod1)) --> svc + pod2((pod2)) --> svc + end + end + + node["db.tails-scales.ts.net (dst)"]:::tsnode + + svc -->|DNS points to| headless-svc + headless-svc -->|selects egress Pod| egress + egress -->|forwards traffic| node + operator -.->|creates| egress-sts + egress-sts -.->|manages| egress + operator -.->|creates| headless-svc + operator -.->|creates| cfg-secret + operator -.->|watches & updates| svc + cfg-secret -.->|mounted| egress + egress -.->|stores state| state-secret + + linkStyle 0 stroke:red; + linkStyle 6 stroke:red; + + linkStyle 1 stroke:blue; + linkStyle 2 stroke:blue; + linkStyle 3 stroke:blue; + linkStyle 4 stroke:blue; + linkStyle 5 stroke:blue; + +``` + +## `ProxyGroup` + +### Egress + +[Documentation][kb-operator-l3-egress-proxygroup] + +The `ProxyGroup` custom resource manages a collection of proxy Pods that +can be configured to egress traffic out of the cluster via ExternalName +Services. A `ProxyGroup` is both a high availability (HA) version of L3 +egress, and a mechanism to serve multiple ExternalName Services on a single +set of Tailscale devices (coalescing). + +In this diagram, the `ProxyGroup` is named `pg`. The Secrets associated with +the `ProxyGroup` Pods are omitted for simplicity. They are similar to the L3 +egress case above, but there is a pair of config + state Secrets _per Pod_. + +Each ExternalName Service defines which ports should be mapped to their defined +egress target. The operator maps from these ports to randomly chosen ephemeral +ports via the ClusterIP Service and its EndpointSlice. The operator then +generates the egress ConfigMap that tells the `ProxyGroup` Pods which incoming +ports map to which egress targets. + +```mermaid +%%{ init: { 'theme':'neutral' } }%% + +flowchart LR + classDef tsnode color:#fff,fill:#000; + classDef pod fill:#fff; + + subgraph Key + ts[Tailscale device]:::tsnode + pod((Pod)):::pod + blank[" "]-->|WireGuard traffic| blank2[" "] + blank3[" "]-->|Other network traffic| blank4[" "] + end + + subgraph k8s[Kubernetes cluster] + subgraph tailscale-ns[namespace=tailscale] + operator((operator)):::tsnode + pg-sts[StatefulSet] + pg-0(("pg-0 (src)")):::tsnode + pg-1(("pg-1 (src)")):::tsnode + db-cluster-ip[db ClusterIP Service] + api-cluster-ip[api ClusterIP Service] + egress-cm["egress ConfigMap"] + end + + subgraph cluster-scope["Cluster scoped resources"] + pg["ProxyGroup 'pg'"] + end + + subgraph defaultns[namespace=default] + db-svc[db ExternalName Service] + api-svc[api ExternalName Service] + pod1((pod1)) --> db-svc + pod2((pod2)) --> db-svc + pod1((pod1)) --> api-svc + pod2((pod2)) --> api-svc + end + end + + db["db.tails-scales.ts.net (dst)"]:::tsnode + api["api.tails-scales.ts.net (dst)"]:::tsnode + + db-svc -->|DNS points to| db-cluster-ip + api-svc -->|DNS points to| api-cluster-ip + db-cluster-ip -->|maps to ephemeral db ports| pg-0 + db-cluster-ip -->|maps to ephemeral db ports| pg-1 + api-cluster-ip -->|maps to ephemeral api ports| pg-0 + api-cluster-ip -->|maps to ephemeral api ports| pg-1 + pg-0 -->|forwards db port traffic| db + pg-0 -->|forwards api port traffic| api + pg-1 -->|forwards db port traffic| db + pg-1 -->|forwards api port traffic| api + operator -.->|creates & populates endpointslice| db-cluster-ip + operator -.->|creates & populates endpointslice| api-cluster-ip + operator -.->|stores port mapping| egress-cm + egress-cm -.->|mounted| pg-0 + egress-cm -.->|mounted| pg-1 + operator -.->|watches| pg + operator -.->|creates| pg-sts + pg-sts -.->|manages| pg-0 + pg-sts -.->|manages| pg-1 + operator -.->|watches| db-svc + operator -.->|watches| api-svc + + linkStyle 0 stroke:red; + linkStyle 12 stroke:red; + linkStyle 13 stroke:red; + linkStyle 14 stroke:red; + linkStyle 15 stroke:red; + + linkStyle 1 stroke:blue; + linkStyle 2 stroke:blue; + linkStyle 3 stroke:blue; + linkStyle 4 stroke:blue; + linkStyle 5 stroke:blue; + linkStyle 6 stroke:blue; + linkStyle 7 stroke:blue; + linkStyle 8 stroke:blue; + linkStyle 9 stroke:blue; + linkStyle 10 stroke:blue; + linkStyle 11 stroke:blue; + +``` + +### Ingress + +A ProxyGroup can also serve as a highly available set of proxies for an +Ingress resource. The `-0` Pod is always the replica that will issue a certificate +from Let's Encrypt. + +If the same Ingress config is applied in multiple clusters, ProxyGroup proxies +from each cluster will be valid targets for the ts.net DNS name, and the proxy +each client is routed to will depend on the same rules as for [high availability][kb-ha] +subnet routers, and is encoded in the client's netmap. + +```mermaid +%%{ init: { 'theme':'neutral' } }%% +flowchart LR + classDef tsnode color:#fff,fill:#000; + classDef pod fill:#fff; + + subgraph Key + ts[Tailscale device]:::tsnode + pod((Pod)):::pod + blank[" "]-->|WireGuard traffic| blank2[" "] + blank3[" "]-->|Other network traffic| blank4[" "] + end + + subgraph k8s[Kubernetes cluster] + subgraph tailscale-ns[namespace=tailscale] + operator((operator)):::tsnode + ingress-sts["StatefulSet"] + serve-cm[serve config ConfigMap] + ingress-0(("pg-0 (dst)")):::tsnode + ingress-1(("pg-1 (dst)")):::tsnode + tls-secret[myapp.tails.ts.net Secret] + end + + subgraph defaultns[namespace=default] + ingress[myapp.tails.ts.net Ingress] + svc["myapp Service"] + svc --> pod1((pod1)) + svc --> pod2((pod2)) + end + + subgraph cluster[Cluster scoped resources] + ingress-class[Tailscale IngressClass] + pg[ProxyGroup 'pg'] + end + end + + control["Tailscale control plane"] + ts-svc["myapp Tailscale Service"] + + client["client (src)"]:::tsnode -->|dials https\://myapp.tails.ts.net/api| ingress-1 + ingress-0 -->|forwards traffic| svc + ingress-1 -->|forwards traffic| svc + control -.->|creates| ts-svc + operator -.->|creates myapp Tailscale Service| control + control -.->|netmap points myapp Tailscale Service to pg-1| client + operator -.->|creates| ingress-sts + ingress-sts -.->|manages| ingress-0 + ingress-sts -.->|manages| ingress-1 + ingress-0 -.->|issues myapp.tails.ts.net cert| le[Let's Encrypt] + ingress-0 -.->|stores cert| tls-secret + ingress-1 -.->|reads cert| tls-secret + operator -.->|watches| ingress + operator -.->|watches| pg + operator -.->|creates| serve-cm + serve-cm -.->|mounted| ingress-0 + serve-cm -.->|mounted| ingress-1 + ingress -.->|/api prefix| svc + + linkStyle 0 stroke:red; + linkStyle 4 stroke:red; + + linkStyle 1 stroke:blue; + linkStyle 2 stroke:blue; + linkStyle 3 stroke:blue; + linkStyle 5 stroke:blue; + linkStyle 6 stroke:blue; + +``` + +## Connector + +[Subnet router and exit node documentation][kb-operator-connector] + +[App connector documentation][kb-operator-app-connector] + +The Connector Custom Resource can deploy either a subnet router, an exit node, +or an app connector. The following diagram shows all 3, but only one workflow +can be configured per Connector resource. + +```mermaid +%%{ init: { 'theme':'neutral' } }%% + +flowchart TD + classDef tsnode color:#fff,fill:#000; + classDef pod fill:#fff; + classDef hidden display:none; + + subgraph Key + ts[Tailscale device]:::tsnode + pod((Pod)):::pod + blank[" "]-->|WireGuard traffic| blank2[" "] + blank3[" "]-->|Other network traffic| blank4[" "] + end + + subgraph grouping[" "] + subgraph k8s[Kubernetes cluster] + subgraph tailscale-ns[namespace=tailscale] + operator((operator)):::tsnode + cn-sts[StatefulSet] + cn-pod(("tailscale (dst)")):::tsnode + cfg-secret["config Secret"] + state-secret["state Secret"] + end + + subgraph cluster-scope["Cluster scoped resources"] + cn["Connector"] + end + + subgraph defaultns["namespace=default"] + pod1 + end + end + + client["client (src)"]:::tsnode + Internet + end + + client --> cn-pod + cn-pod -->|app connector or exit node routes| Internet + cn-pod -->|subnet route| pod1 + operator -.->|watches| cn + operator -.->|creates| cn-sts + cn-sts -.->|manages| cn-pod + operator -.->|creates| cfg-secret + cfg-secret -.->|mounted| cn-pod + cn-pod -.->|stores state| state-secret + + class grouping hidden + + linkStyle 0 stroke:red; + linkStyle 2 stroke:red; + + linkStyle 1 stroke:blue; + linkStyle 3 stroke:blue; + linkStyle 4 stroke:blue; + +``` + +## Recorder nodes + +[Documentation][kb-operator-recorder] + +The `Recorder` custom resource makes it easier to deploy `tsrecorder` to a cluster. +It currently only supports a single replica. + +```mermaid +%%{ init: { 'theme':'neutral' } }%% + +flowchart TD + classDef tsnode color:#fff,fill:#000; + classDef pod fill:#fff; + classDef hidden display:none; + + subgraph Key + ts[Tailscale device]:::tsnode + pod((Pod)):::pod + blank[" "]-->|WireGuard traffic| blank2[" "] + blank3[" "]-->|Other network traffic| blank4[" "] + end + + subgraph grouping[" "] + subgraph k8s[Kubernetes cluster] + api["kube-apiserver"] + + subgraph tailscale-ns[namespace=tailscale] + operator(("operator (dst)")):::tsnode + rec-sts[StatefulSet] + rec-0(("tsrecorder")):::tsnode + cfg-secret-0["config Secret"] + state-secret-0["state Secret"] + end + + subgraph cluster-scope["Cluster scoped resources"] + rec["Recorder"] + end + end + + client["client (src)"]:::tsnode + kubectl-exec["kubectl exec (src)"]:::tsnode + server["server (dst)"]:::tsnode + s3["S3-compatible storage"] + end + + kubectl-exec -->|exec session| operator + operator -->|exec session recording| rec-0 + operator -->|exec session| api + client -->|ssh session| server + server -->|ssh session recording| rec-0 + rec-0 -->|session recordings| s3 + operator -.->|watches| rec + operator -.->|creates| rec-sts + rec-sts -.->|manages| rec-0 + operator -.->|creates| cfg-secret-0 + cfg-secret-0 -.->|mounted| rec-0 + rec-0 -.->|stores state| state-secret-0 + + class grouping hidden + + linkStyle 0 stroke:red; + linkStyle 2 stroke:red; + linkStyle 3 stroke:red; + linkStyle 5 stroke:red; + linkStyle 6 stroke:red; + + linkStyle 1 stroke:blue; + linkStyle 4 stroke:blue; + linkStyle 7 stroke:blue; + +``` + +[kb-operator]: https://tailscale.com/kb/1236/kubernetes-operator +[kb-operator-proxy]: https://tailscale.com/kb/1437/kubernetes-operator-api-server-proxy +[kb-operator-l3-ingress]: https://tailscale.com/kb/1439/kubernetes-operator-cluster-ingress#exposing-a-cluster-workload-using-a-kubernetes-service +[kb-operator-l7-ingress]: https://tailscale.com/kb/1439/kubernetes-operator-cluster-ingress#exposing-cluster-workloads-using-a-kubernetes-ingress +[kb-operator-l3-egress]: https://tailscale.com/kb/1438/kubernetes-operator-cluster-egress +[kb-operator-l3-egress-proxygroup]: https://tailscale.com/kb/1438/kubernetes-operator-cluster-egress#configure-an-egress-service-using-proxygroup +[kb-operator-connector]: https://tailscale.com/kb/1441/kubernetes-operator-connector +[kb-operator-app-connector]: https://tailscale.com/kb/1517/kubernetes-operator-app-connector +[kb-operator-recorder]: https://tailscale.com/kb/1484/kubernetes-operator-deploying-tsrecorder +[kb-ha]: https://tailscale.com/kb/1115/high-availability +[k8s-impersonation]: https://kubernetes.io/docs/reference/access-authn-authz/authentication/#user-impersonation diff --git a/docs/k8s/proxy.yaml b/docs/k8s/proxy.yaml index 2ab7ed334395d..048fd7a5bddf9 100644 --- a/docs/k8s/proxy.yaml +++ b/docs/k8s/proxy.yaml @@ -44,7 +44,13 @@ spec: value: "{{TS_DEST_IP}}" - name: TS_AUTH_ONCE value: "true" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid securityContext: - capabilities: - add: - - NET_ADMIN + privileged: true diff --git a/docs/k8s/role.yaml b/docs/k8s/role.yaml index 6d6a8117d1bbd..d7d0846ab29a6 100644 --- a/docs/k8s/role.yaml +++ b/docs/k8s/role.yaml @@ -13,3 +13,6 @@ rules: resourceNames: ["{{TS_KUBE_SECRET}}"] resources: ["secrets"] verbs: ["get", "update", "patch"] +- apiGroups: [""] # "" indicates the core API group + resources: ["events"] + verbs: ["get", "create", "patch"] diff --git a/docs/k8s/sidecar.yaml b/docs/k8s/sidecar.yaml index 7efd32a38d0ac..520e4379ad9ee 100644 --- a/docs/k8s/sidecar.yaml +++ b/docs/k8s/sidecar.yaml @@ -26,7 +26,13 @@ spec: name: tailscale-auth key: TS_AUTHKEY optional: true + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid securityContext: - capabilities: - add: - - NET_ADMIN + privileged: true diff --git a/docs/k8s/subnet.yaml b/docs/k8s/subnet.yaml index 4b7066fb3460a..ef4e4748c0ceb 100644 --- a/docs/k8s/subnet.yaml +++ b/docs/k8s/subnet.yaml @@ -28,7 +28,13 @@ spec: optional: true - name: TS_ROUTES value: "{{TS_ROUTES}}" + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid securityContext: - capabilities: - add: - - NET_ADMIN + privileged: true diff --git a/docs/k8s/userspace-sidecar.yaml b/docs/k8s/userspace-sidecar.yaml index fc4ed63502dbc..ee19b10a5e5dd 100644 --- a/docs/k8s/userspace-sidecar.yaml +++ b/docs/k8s/userspace-sidecar.yaml @@ -27,3 +27,11 @@ spec: name: tailscale-auth key: TS_AUTHKEY optional: true + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid diff --git a/docs/windows/policy/en-US/tailscale.adml b/docs/windows/policy/en-US/tailscale.adml index 7a658422cd7f6..62ff94da7096d 100644 --- a/docs/windows/policy/en-US/tailscale.adml +++ b/docs/windows/policy/en-US/tailscale.adml @@ -15,34 +15,42 @@ Tailscale version 1.58.0 and later Tailscale version 1.62.0 and later Tailscale version 1.74.0 and later + Tailscale version 1.78.0 and later + Tailscale version 1.80.0 and later + Tailscale version 1.82.0 and later + Tailscale version 1.84.0 and later Tailscale UI customization Settings + Allowed (with audit) + Not Allowed Require using a specific Tailscale coordination server +If you disable or do not configure this policy, the Tailscale SaaS coordination server will be used by default, but a non-standard Tailscale coordination server can be configured using the CLI. + +See https://tailscale.com/kb/1315/mdm-keys#set-a-custom-control-server-url for more details.]]> Require using a specific Tailscale log server Specify which Tailnet should be used for Login +See https://tailscale.com/kb/1315/mdm-keys#set-a-suggested-or-required-tailnet for more details.]]> Specify the auth key to authenticate devices without user interaction Require using a specific Exit Node +If you do not configure this policy, no exit node will be used by default but an exit node (if one is available and permitted by ACLs) can be chosen by the user if desired. + +See https://tailscale.com/kb/1315/mdm-keys#force-an-exit-node-to-always-be-used and https://tailscale.com/kb/1103/exit-nodes for more details.]]> + Limit automated Exit Node suggestions to specific nodes + Allow incoming connections +If you do not configure this policy, then Allow Incoming Connections depends on what is selected in the Preferences submenu. + +See https://tailscale.com/kb/1315/mdm-keys#set-whether-to-allow-incoming-connections and https://tailscale.com/kb/1072/client-preferences#allow-incoming-connections for more details.]]> Run Tailscale in Unattended Mode +If you do not configure this policy, then Run Unattended depends on what is selected in the Preferences submenu. + +See https://tailscale.com/kb/1315/mdm-keys#set-unattended-mode and https://tailscale.com/kb/1088/run-unattended for more details.]]> + Restrict users from disconnecting Tailscale (always-on mode) + + Configure automatic reconnect delay + Allow Local Network Access when an Exit Node is in use +If you do not configure this policy, then Allow Local Network Access depends on what is selected in the Exit Node submenu. + +See https://tailscale.com/kb/1315/mdm-keys#toggle-local-network-access-when-an-exit-node-is-in-use and https://tailscale.com/kb/1103/exit-nodes#step-4-use-the-exit-node for more details.]]> Use Tailscale DNS Settings +If you do not configure this policy, then Use Tailscale DNS depends on what is selected in the Preferences submenu. + +See https://tailscale.com/kb/1315/mdm-keys#set-whether-the-device-uses-tailscale-dns-settings for more details.]]> Use Tailscale Subnets +If you do not configure this policy, then Use Tailscale Subnets depends on what is selected in the Preferences submenu. + +See https://tailscale.com/kb/1315/mdm-keys#set-whether-the-device-accepts-tailscale-subnets or https://tailscale.com/kb/1019/subnets for more details.]]> + Always register + Use adapter properties + Register Tailscale IP addresses in DNS + Automatically install updates +If you do not configure this policy, then Automatically Install Updates depends on what is selected in the Preferences submenu. + +See https://tailscale.com/kb/1067/update#auto-updates for more details.]]> Run Tailscale as an Exit Node - Show the "Admin Panel" menu item - + Show the "Admin Console" menu item + Show the "Debug" submenu +If you disable this policy, the Debug submenu will be hidden from the Tailscale menu. + +See https://tailscale.com/kb/1315/mdm-keys#hide-the-debug-menu for more details.]]> Show the "Update Available" menu item +If you disable this policy, the Update Available item will be hidden from the Tailscale menu. + +See https://tailscale.com/kb/1315/mdm-keys#hide-the-update-menu for more details.]]> Show the "Run Exit Node" menu item +If you disable this policy, the Run Exit Node item will be hidden from the Exit Node submenu. + +See https://tailscale.com/kb/1315/mdm-keys#hide-the-run-as-exit-node-menu-item for more details.]]> Show the "Preferences" submenu +If you disable this policy, the Preferences submenu will be hidden from the Tailscale menu. + +See https://tailscale.com/kb/1315/mdm-keys#hide-the-preferences-menu for more details.]]> Show the "Exit Node" submenu +If you disable this policy, the Exit Node submenu will be hidden from the Tailscale menu. + +See https://tailscale.com/kb/1315/mdm-keys#hide-the-exit-node-picker for more details.]]> Specify a custom key expiration notification time +If you disable or don't configure this policy, the default time period will be used (as of Tailscale 1.56, this is 24 hours). + +See https://tailscale.com/kb/1315/mdm-keys#set-the-key-expiration-notice-period for more details.]]> Log extra details about service events Collect data for posture checking +If you do not configure this policy, then data collection depends on if it has been enabled from the CLI (as of Tailscale 1.56), it may be present in the GUI in later versions. + +See https://tailscale.com/kb/1315/mdm-keys#enable-gathering-device-posture-data and https://tailscale.com/kb/1326/device-identity for more details.]]> Show the "Managed By {Organization}" menu item + Show the onboarding flow + @@ -239,11 +300,27 @@ See https://tailscale.com/kb/1315/mdm-keys#set-your-organization-name for more d + + The options below allow configuring exceptions where disconnecting Tailscale is permitted. + Disconnects with reason: + + + The delay must be a valid Go duration string, such as 30s, 5m, or 1h30m, all without spaces or any other symbols. + + + + + + Registration mode: + + + Target IDs: + diff --git a/docs/windows/policy/tailscale.admx b/docs/windows/policy/tailscale.admx index e70f124ed1a36..d97b24c36b5df 100644 --- a/docs/windows/policy/tailscale.admx +++ b/docs/windows/policy/tailscale.admx @@ -50,6 +50,22 @@ displayName="$(string.SINCE_V1_74)"> + + + + + + + + + + + + @@ -97,6 +113,13 @@ + + + + + + + @@ -117,6 +140,37 @@ never + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -147,6 +201,24 @@ never + + + + + + + + always + + + + + user-decides + + + + + @@ -197,7 +269,7 @@ - + @@ -207,7 +279,7 @@ hide - + @@ -217,7 +289,7 @@ hide - + @@ -227,7 +299,7 @@ hide - + @@ -237,7 +309,7 @@ hide - + @@ -247,7 +319,7 @@ hide - + @@ -257,7 +329,7 @@ hide - + @@ -267,7 +339,17 @@ hide - + + + + + show + + + hide + + + @@ -276,7 +358,7 @@ - + diff --git a/doctor/ethtool/ethtool_linux.go b/doctor/ethtool/ethtool_linux.go index b8cc0800240de..f6eaac1df0542 100644 --- a/doctor/ethtool/ethtool_linux.go +++ b/doctor/ethtool/ethtool_linux.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build linux && !android + package ethtool import ( diff --git a/doctor/ethtool/ethtool_other.go b/doctor/ethtool/ethtool_other.go index 9aaa9dda8ba5f..7af74eec8f872 100644 --- a/doctor/ethtool/ethtool_other.go +++ b/doctor/ethtool/ethtool_other.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !linux +//go:build !linux || android package ethtool diff --git a/drive/drive_view.go b/drive/drive_view.go index a6adfbc705378..0f6686f24da68 100644 --- a/drive/drive_view.go +++ b/drive/drive_view.go @@ -14,7 +14,7 @@ import ( //go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type=Share -// View returns a readonly view of Share. +// View returns a read-only view of Share. func (p *Share) View() ShareView { return ShareView{Đļ: p} } @@ -30,7 +30,7 @@ type ShareView struct { Đļ *Share } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v ShareView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with diff --git a/drive/driveimpl/dirfs/dirfs.go b/drive/driveimpl/dirfs/dirfs.go index c1f28bb9dee4f..50a3330a9d751 100644 --- a/drive/driveimpl/dirfs/dirfs.go +++ b/drive/driveimpl/dirfs/dirfs.go @@ -44,7 +44,7 @@ func (c *Child) isAvailable() bool { // Any attempts to perform operations on paths inside of children will result // in a panic, as these are not expected to be performed on this FS. // -// An FS an optionally have a StaticRoot, which will insert a folder with that +// An FS can optionally have a StaticRoot, which will insert a folder with that // StaticRoot into the tree, like this: // // -- diff --git a/drive/driveimpl/drive_test.go b/drive/driveimpl/drive_test.go index 20b179511ea7c..e7dd832918cec 100644 --- a/drive/driveimpl/drive_test.go +++ b/drive/driveimpl/drive_test.go @@ -133,6 +133,71 @@ func TestPermissions(t *testing.T) { } } +// TestMissingPaths verifies that the fileserver running at localhost +// correctly handles paths with missing required components. +// +// Expected path format: +// http://localhost:[PORT]//[/] +func TestMissingPaths(t *testing.T) { + s := newSystem(t) + + fileserverAddr := s.addRemote(remote1) + s.addShare(remote1, share11, drive.PermissionReadWrite) + + client := &http.Client{ + Transport: &http.Transport{DisableKeepAlives: true}, + } + addr := strings.Split(fileserverAddr, "|")[1] + secretToken := strings.Split(fileserverAddr, "|")[0] + + testCases := []struct { + name string + path string + wantStatus int + }{ + { + name: "empty path", + path: "", + wantStatus: http.StatusForbidden, + }, + { + name: "single slash", + path: "/", + wantStatus: http.StatusForbidden, + }, + { + name: "only token", + path: "/" + secretToken, + wantStatus: http.StatusBadRequest, + }, + { + name: "token with trailing slash", + path: "/" + secretToken + "/", + wantStatus: http.StatusBadRequest, + }, + { + name: "token and invalid share", + path: "/" + secretToken + "/nonexistentshare", + wantStatus: http.StatusNotFound, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + u := fmt.Sprintf("http://%s%s", addr, tc.path) + resp, err := client.Get(u) + if err != nil { + t.Fatalf("unexpected error making request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != tc.wantStatus { + t.Errorf("got status code %d, want %d", resp.StatusCode, tc.wantStatus) + } + }) + } +} + // TestSecretTokenAuth verifies that the fileserver running at localhost cannot // be accessed directly without the correct secret token. This matters because // if a victim can be induced to visit the localhost URL and access a malicious @@ -704,8 +769,8 @@ func (a *noopAuthenticator) Close() error { return nil } -const lockBody = ` - - - +const lockBody = ` + + + ` diff --git a/drive/driveimpl/fileserver.go b/drive/driveimpl/fileserver.go index 0067c1cc7db63..113cb3b440218 100644 --- a/drive/driveimpl/fileserver.go +++ b/drive/driveimpl/fileserver.go @@ -61,7 +61,7 @@ func NewFileServer() (*FileServer, error) { }, nil } -// generateSecretToken generates a hex-encoded 256 bit secet. +// generateSecretToken generates a hex-encoded 256 bit secret. func generateSecretToken() (string, error) { tokenBytes := make([]byte, 32) _, err := rand.Read(tokenBytes) @@ -142,6 +142,10 @@ func (s *FileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if len(parts) < 2 { + w.WriteHeader(http.StatusBadRequest) + return + } r.URL.Path = shared.Join(parts[2:]...) share := parts[1] s.sharesMu.RLock() diff --git a/drive/driveimpl/shared/pathutil.go b/drive/driveimpl/shared/pathutil.go index efa9f5f320c74..fcadcdd5aa0e0 100644 --- a/drive/driveimpl/shared/pathutil.go +++ b/drive/driveimpl/shared/pathutil.go @@ -22,6 +22,9 @@ const ( // CleanAndSplit cleans the provided path p and splits it into its constituent // parts. This is different from path.Split which just splits a path into prefix // and suffix. +// +// If p is empty or contains only path separators, CleanAndSplit returns a slice +// of length 1 whose only element is "". func CleanAndSplit(p string) []string { return strings.Split(strings.Trim(path.Clean(p), sepStringAndDot), sepString) } @@ -38,6 +41,8 @@ func Parent(p string) string { } // Join behaves like path.Join() but also includes a leading slash. +// +// When parts are missing, the result is "/". func Join(parts ...string) string { fullParts := make([]string, 0, len(parts)) fullParts = append(fullParts, sepString) diff --git a/drive/driveimpl/shared/pathutil_test.go b/drive/driveimpl/shared/pathutil_test.go index 662adbd8b0b48..daee695632ff4 100644 --- a/drive/driveimpl/shared/pathutil_test.go +++ b/drive/driveimpl/shared/pathutil_test.go @@ -40,6 +40,7 @@ func TestJoin(t *testing.T) { parts []string want string }{ + {[]string{}, "/"}, {[]string{""}, "/"}, {[]string{"a"}, "/a"}, {[]string{"/a"}, "/a"}, diff --git a/drive/remote_permissions.go b/drive/remote_permissions.go index d3d41c6ec3311..420eff9a0e743 100644 --- a/drive/remote_permissions.go +++ b/drive/remote_permissions.go @@ -32,7 +32,7 @@ type grant struct { Access string } -// ParsePermissions builds a Permissions map from a lis of raw grants. +// ParsePermissions builds a Permissions map from a list of raw grants. func ParsePermissions(rawGrants [][]byte) (Permissions, error) { permissions := make(Permissions) for _, rawGrant := range rawGrants { diff --git a/envknob/envknob.go b/envknob/envknob.go index 59a6d90af213b..e581eb27e11cb 100644 --- a/envknob/envknob.go +++ b/envknob/envknob.go @@ -411,12 +411,35 @@ func TKASkipSignatureCheck() bool { return Bool("TS_UNSAFE_SKIP_NKS_VERIFICATION // Kubernetes Operator components. func App() string { a := os.Getenv("TS_INTERNAL_APP") - if a == kubetypes.AppConnector || a == kubetypes.AppEgressProxy || a == kubetypes.AppIngressProxy || a == kubetypes.AppIngressResource { + if a == kubetypes.AppConnector || a == kubetypes.AppEgressProxy || a == kubetypes.AppIngressProxy || a == kubetypes.AppIngressResource || a == kubetypes.AppProxyGroupEgress || a == kubetypes.AppProxyGroupIngress { return a } return "" } +// IsCertShareReadOnlyMode returns true if this replica should never attempt to +// issue or renew TLS credentials for any of the HTTPS endpoints that it is +// serving. It should only return certs found in its cert store. Currently, +// this is used by the Kubernetes Operator's HA Ingress via VIPServices, where +// multiple Ingress proxy instances serve the same HTTPS endpoint with a shared +// TLS credentials. The TLS credentials should only be issued by one of the +// replicas. +// For HTTPS Ingress the operator and containerboot ensure +// that read-only replicas will not be serving the HTTPS endpoints before there +// is a shared cert available. +func IsCertShareReadOnlyMode() bool { + m := String("TS_CERT_SHARE_MODE") + return m == "ro" +} + +// IsCertShareReadWriteMode returns true if this instance is the replica +// responsible for issuing and renewing TLS certs in an HA setup with certs +// shared between multiple replicas. +func IsCertShareReadWriteMode() bool { + m := String("TS_CERT_SHARE_MODE") + return m == "rw" +} + // CrashOnUnexpected reports whether the Tailscale client should panic // on unexpected conditions. If TS_DEBUG_CRASH_ON_UNEXPECTED is set, that's // used. Otherwise the default value is true for unstable builds. diff --git a/envknob/featureknob/featureknob.go b/envknob/featureknob/featureknob.go new file mode 100644 index 0000000000000..e9b871f74a8c0 --- /dev/null +++ b/envknob/featureknob/featureknob.go @@ -0,0 +1,67 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package featureknob provides a facility to control whether features +// can run based on either an envknob or running OS / distro. +package featureknob + +import ( + "errors" + "runtime" + + "tailscale.com/envknob" + "tailscale.com/hostinfo" + "tailscale.com/version" + "tailscale.com/version/distro" +) + +// CanRunTailscaleSSH reports whether serving a Tailscale SSH server is +// supported for the current os/distro. +func CanRunTailscaleSSH() error { + switch runtime.GOOS { + case "linux": + if distro.Get() == distro.Synology && !envknob.UseWIPCode() { + return errors.New("The Tailscale SSH server does not run on Synology.") + } + if distro.Get() == distro.QNAP && !envknob.UseWIPCode() { + return errors.New("The Tailscale SSH server does not run on QNAP.") + } + + // Setting SSH on Home Assistant causes trouble on startup + // (since the flag is not being passed to `tailscale up`). + // Although Tailscale SSH does work here, + // it's not terribly useful since it's running in a separate container. + if hostinfo.GetEnvType() == hostinfo.HomeAssistantAddOn { + return errors.New("The Tailscale SSH server does not run on HomeAssistant.") + } + // otherwise okay + case "darwin": + // okay only in tailscaled mode for now. + if version.IsSandboxedMacOS() { + return errors.New("The Tailscale SSH server does not run in sandboxed Tailscale GUI builds.") + } + case "freebsd", "openbsd", "plan9": + default: + return errors.New("The Tailscale SSH server is not supported on " + runtime.GOOS) + } + if !envknob.CanSSHD() { + return errors.New("The Tailscale SSH server has been administratively disabled.") + } + return nil +} + +// CanUseExitNode reports whether using an exit node is supported for the +// current os/distro. +func CanUseExitNode() error { + switch dist := distro.Get(); dist { + case distro.Synology, // see https://github.com/tailscale/tailscale/issues/1995 + distro.QNAP: + return errors.New("Tailscale exit nodes cannot be used on " + string(dist)) + } + + if hostinfo.GetEnvType() == hostinfo.HomeAssistantAddOn { + return errors.New("Tailscale exit nodes cannot be used on HomeAssistant.") + } + + return nil +} diff --git a/envknob/features.go b/envknob/features.go deleted file mode 100644 index 9e5909de309f0..0000000000000 --- a/envknob/features.go +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package envknob - -import ( - "errors" - "runtime" - - "tailscale.com/version" - "tailscale.com/version/distro" -) - -// CanRunTailscaleSSH reports whether serving a Tailscale SSH server is -// supported for the current os/distro. -func CanRunTailscaleSSH() error { - switch runtime.GOOS { - case "linux": - if distro.Get() == distro.Synology && !UseWIPCode() { - return errors.New("The Tailscale SSH server does not run on Synology.") - } - if distro.Get() == distro.QNAP && !UseWIPCode() { - return errors.New("The Tailscale SSH server does not run on QNAP.") - } - // otherwise okay - case "darwin": - // okay only in tailscaled mode for now. - if version.IsSandboxedMacOS() { - return errors.New("The Tailscale SSH server does not run in sandboxed Tailscale GUI builds.") - } - case "freebsd", "openbsd": - default: - return errors.New("The Tailscale SSH server is not supported on " + runtime.GOOS) - } - if !CanSSHD() { - return errors.New("The Tailscale SSH server has been administratively disabled.") - } - return nil -} diff --git a/envknob/logknob/logknob.go b/envknob/logknob/logknob.go index 350384b8626e3..93302d0d2bd5c 100644 --- a/envknob/logknob/logknob.go +++ b/envknob/logknob/logknob.go @@ -11,7 +11,6 @@ import ( "tailscale.com/envknob" "tailscale.com/tailcfg" "tailscale.com/types/logger" - "tailscale.com/types/views" ) // TODO(andrew-d): should we have a package-global registry of logknobs? It @@ -59,7 +58,7 @@ func (lk *LogKnob) Set(v bool) { // about; we use this rather than a concrete type to avoid a circular // dependency. type NetMap interface { - SelfCapabilities() views.Slice[tailcfg.NodeCapability] + HasSelfCapability(tailcfg.NodeCapability) bool } // UpdateFromNetMap will enable logging if the SelfNode in the provided NetMap @@ -68,8 +67,7 @@ func (lk *LogKnob) UpdateFromNetMap(nm NetMap) { if lk.capName == "" { return } - - lk.cap.Store(views.SliceContains(nm.SelfCapabilities(), lk.capName)) + lk.cap.Store(nm.HasSelfCapability(lk.capName)) } // Do will call log with the provided format and arguments if any of the diff --git a/envknob/logknob/logknob_test.go b/envknob/logknob/logknob_test.go index b2a376a25b371..aa4fb44214e12 100644 --- a/envknob/logknob/logknob_test.go +++ b/envknob/logknob/logknob_test.go @@ -11,6 +11,7 @@ import ( "tailscale.com/envknob" "tailscale.com/tailcfg" "tailscale.com/types/netmap" + "tailscale.com/util/set" ) var testKnob = NewLogKnob( @@ -63,11 +64,7 @@ func TestLogKnob(t *testing.T) { } testKnob.UpdateFromNetMap(&netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - Capabilities: []tailcfg.NodeCapability{ - "https://tailscale.com/cap/testing", - }, - }).View(), + AllCaps: set.Of(tailcfg.NodeCapability("https://tailscale.com/cap/testing")), }) if !testKnob.shouldLog() { t.Errorf("expected shouldLog()=true") diff --git a/wgengine/capture/capture.go b/feature/capture/capture.go similarity index 79% rename from wgengine/capture/capture.go rename to feature/capture/capture.go index 6ea5a9549b4f1..e5e150de8e761 100644 --- a/wgengine/capture/capture.go +++ b/feature/capture/capture.go @@ -13,21 +13,44 @@ import ( "sync" "time" - _ "embed" - + "tailscale.com/feature" + "tailscale.com/ipn/localapi" "tailscale.com/net/packet" "tailscale.com/util/set" ) -//go:embed ts-dissector.lua -var DissectorLua string +func init() { + feature.Register("capture") + localapi.Register("debug-capture", serveLocalAPIDebugCapture) +} + +func serveLocalAPIDebugCapture(h *localapi.Handler, w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !h.PermitWrite { + http.Error(w, "debug access denied", http.StatusForbidden) + return + } + if r.Method != "POST" { + http.Error(w, "POST required", http.StatusMethodNotAllowed) + return + } + + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + + b := h.LocalBackend() + s := b.GetOrSetCaptureSink(newSink) -// Callback describes a function which is called to -// record packets when debugging packet-capture. -// Such callbacks must not take ownership of the -// provided data slice: it may only copy out of it -// within the lifetime of the function. -type Callback func(Path, time.Time, []byte, packet.CaptureMeta) + unregister := s.RegisterOutput(w) + + select { + case <-ctx.Done(): + case <-s.WaitCh(): + } + unregister() + + b.ClearCaptureSink() +} var bufferPool = sync.Pool{ New: func() any { @@ -57,29 +80,8 @@ func writePktHeader(w *bytes.Buffer, when time.Time, length int) { binary.Write(w, binary.LittleEndian, uint32(length)) // total length } -// Path describes where in the data path the packet was captured. -type Path uint8 - -// Valid Path values. -const ( - // FromLocal indicates the packet was logged as it traversed the FromLocal path: - // i.e.: A packet from the local system into the TUN. - FromLocal Path = 0 - // FromPeer indicates the packet was logged upon reception from a remote peer. - FromPeer Path = 1 - // SynthesizedToLocal indicates the packet was generated from within tailscaled, - // and is being routed to the local machine's network stack. - SynthesizedToLocal Path = 2 - // SynthesizedToPeer indicates the packet was generated from within tailscaled, - // and is being routed to a remote Wireguard peer. - SynthesizedToPeer Path = 3 - - // PathDisco indicates the packet is information about a disco frame. - PathDisco Path = 254 -) - -// New creates a new capture sink. -func New() *Sink { +// newSink creates a new capture sink. +func newSink() packet.CaptureSink { ctx, c := context.WithCancel(context.Background()) return &Sink{ ctx: ctx, @@ -126,6 +128,10 @@ func (s *Sink) RegisterOutput(w io.Writer) (unregister func()) { } } +func (s *Sink) CaptureCallback() packet.CaptureCallback { + return s.LogPacket +} + // NumOutputs returns the number of outputs registered with the sink. func (s *Sink) NumOutputs() int { s.mu.Lock() @@ -174,7 +180,7 @@ func customDataLen(meta packet.CaptureMeta) int { // LogPacket is called to insert a packet into the capture. // // This function does not take ownership of the provided data slice. -func (s *Sink) LogPacket(path Path, when time.Time, data []byte, meta packet.CaptureMeta) { +func (s *Sink) LogPacket(path packet.CapturePath, when time.Time, data []byte, meta packet.CaptureMeta) { select { case <-s.ctx.Done(): return diff --git a/feature/capture/dissector/dissector.go b/feature/capture/dissector/dissector.go new file mode 100644 index 0000000000000..ab2f6c2ec1607 --- /dev/null +++ b/feature/capture/dissector/dissector.go @@ -0,0 +1,12 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package dissector contains the Lua dissector for Tailscale packets. +package dissector + +import ( + _ "embed" +) + +//go:embed ts-dissector.lua +var Lua string diff --git a/wgengine/capture/ts-dissector.lua b/feature/capture/dissector/ts-dissector.lua similarity index 100% rename from wgengine/capture/ts-dissector.lua rename to feature/capture/dissector/ts-dissector.lua diff --git a/feature/condregister/condregister.go b/feature/condregister/condregister.go new file mode 100644 index 0000000000000..f9025095147f1 --- /dev/null +++ b/feature/condregister/condregister.go @@ -0,0 +1,7 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// The condregister package registers all conditional features guarded +// by build tags. It is one central package that callers can empty import +// to ensure all conditional features are registered. +package condregister diff --git a/feature/condregister/maybe_capture.go b/feature/condregister/maybe_capture.go new file mode 100644 index 0000000000000..0c68331f101cd --- /dev/null +++ b/feature/condregister/maybe_capture.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !ts_omit_capture + +package condregister + +import _ "tailscale.com/feature/capture" diff --git a/feature/condregister/maybe_relayserver.go b/feature/condregister/maybe_relayserver.go new file mode 100644 index 0000000000000..3360dd0627cc1 --- /dev/null +++ b/feature/condregister/maybe_relayserver.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !ts_omit_relayserver + +package condregister + +import _ "tailscale.com/feature/relayserver" diff --git a/feature/condregister/maybe_taildrop.go b/feature/condregister/maybe_taildrop.go new file mode 100644 index 0000000000000..5fd7b5f8c9a00 --- /dev/null +++ b/feature/condregister/maybe_taildrop.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_taildrop + +package condregister + +import _ "tailscale.com/feature/taildrop" diff --git a/feature/condregister/maybe_tap.go b/feature/condregister/maybe_tap.go new file mode 100644 index 0000000000000..eca4fc3ac84af --- /dev/null +++ b/feature/condregister/maybe_tap.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux && !ts_omit_tap + +package condregister + +import _ "tailscale.com/feature/tap" diff --git a/feature/condregister/maybe_tpm.go b/feature/condregister/maybe_tpm.go new file mode 100644 index 0000000000000..caa57fef11d73 --- /dev/null +++ b/feature/condregister/maybe_tpm.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !ts_omit_tpm + +package condregister + +import _ "tailscale.com/feature/tpm" diff --git a/feature/condregister/maybe_wakeonlan.go b/feature/condregister/maybe_wakeonlan.go new file mode 100644 index 0000000000000..14cae605d1468 --- /dev/null +++ b/feature/condregister/maybe_wakeonlan.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_wakeonlan + +package condregister + +import _ "tailscale.com/feature/wakeonlan" diff --git a/feature/feature.go b/feature/feature.go new file mode 100644 index 0000000000000..5976d7f5a5d0d --- /dev/null +++ b/feature/feature.go @@ -0,0 +1,78 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package feature tracks which features are linked into the binary. +package feature + +import "reflect" + +var in = map[string]bool{} + +// Register notes that the named feature is linked into the binary. +func Register(name string) { + if _, ok := in[name]; ok { + panic("duplicate feature registration for " + name) + } + in[name] = true +} + +// Hook is a func that can only be set once. +// +// It is not safe for concurrent use. +type Hook[Func any] struct { + f Func + ok bool +} + +// IsSet reports whether the hook has been set. +func (h *Hook[Func]) IsSet() bool { + return h.ok +} + +// Set sets the hook function, panicking if it's already been set +// or f is the zero value. +// +// It's meant to be called in init. +func (h *Hook[Func]) Set(f Func) { + if h.ok { + panic("Set on already-set feature hook") + } + if reflect.ValueOf(f).IsZero() { + panic("Set with zero value") + } + h.f = f + h.ok = true +} + +// Get returns the hook function, or panics if it hasn't been set. +// Use IsSet to check if it's been set. +func (h *Hook[Func]) Get() Func { + if !h.ok { + panic("Get on unset feature hook, without IsSet") + } + return h.f +} + +// GetOk returns the hook function and true if it has been set, +// otherwise its zero value and false. +func (h *Hook[Func]) GetOk() (f Func, ok bool) { + return h.f, h.ok +} + +// Hooks is a slice of funcs. +// +// As opposed to a single Hook, this is meant to be used when +// multiple parties are able to install the same hook. +type Hooks[Func any] []Func + +// Add adds a hook to the list of hooks. +// +// Add should only be called during early program +// startup before Tailscale has started. +// It is not safe for concurrent use. +func (h *Hooks[Func]) Add(f Func) { + if reflect.ValueOf(f).IsZero() { + panic("Add with zero value") + } + *h = append(*h, f) +} diff --git a/feature/relayserver/relayserver.go b/feature/relayserver/relayserver.go new file mode 100644 index 0000000000000..96d21138edfbc --- /dev/null +++ b/feature/relayserver/relayserver.go @@ -0,0 +1,195 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package relayserver registers the relay server feature and implements its +// associated ipnext.Extension. +package relayserver + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "net/netip" + "sync" + + "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnext" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/net/udprelay" + "tailscale.com/net/udprelay/endpoint" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/types/ptr" + "tailscale.com/util/httpm" +) + +// featureName is the name of the feature implemented by this package. +// It is also the [extension] name and the log prefix. +const featureName = "relayserver" + +func init() { + feature.Register(featureName) + ipnext.RegisterExtension(featureName, newExtension) + ipnlocal.RegisterPeerAPIHandler("/v0/relay/endpoint", handlePeerAPIRelayAllocateEndpoint) +} + +// newExtension is an [ipnext.NewExtensionFn] that creates a new relay server +// extension. It is registered with [ipnext.RegisterExtension] if the package is +// imported. +func newExtension(logf logger.Logf, _ ipnext.SafeBackend) (ipnext.Extension, error) { + return &extension{logf: logger.WithPrefix(logf, featureName+": ")}, nil +} + +// extension is an [ipnext.Extension] managing the relay server on platforms +// that import this package. +type extension struct { + logf logger.Logf + + mu sync.Mutex // guards the following fields + shutdown bool + port *int // ipn.Prefs.RelayServerPort, nil if disabled + hasNodeAttrRelayServer bool // tailcfg.NodeAttrRelayServer + server relayServer // lazily initialized +} + +// relayServer is the interface of [udprelay.Server]. +type relayServer interface { + AllocateEndpoint(discoA key.DiscoPublic, discoB key.DiscoPublic) (endpoint.ServerEndpoint, error) + Close() error +} + +// Name implements [ipnext.Extension]. +func (e *extension) Name() string { + return featureName +} + +// Init implements [ipnext.Extension] by registering callbacks and providers +// for the duration of the extension's lifetime. +func (e *extension) Init(host ipnext.Host) error { + profile, prefs := host.Profiles().CurrentProfileState() + e.profileStateChanged(profile, prefs, false) + host.Hooks().ProfileStateChange.Add(e.profileStateChanged) + host.Hooks().OnSelfChange.Add(e.selfNodeViewChanged) + return nil +} + +func (e *extension) selfNodeViewChanged(nodeView tailcfg.NodeView) { + e.mu.Lock() + defer e.mu.Unlock() + e.hasNodeAttrRelayServer = nodeView.HasCap(tailcfg.NodeAttrRelayServer) + if !e.hasNodeAttrRelayServer && e.server != nil { + e.server.Close() + e.server = nil + } +} + +func (e *extension) profileStateChanged(_ ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + e.mu.Lock() + defer e.mu.Unlock() + newPort, ok := prefs.RelayServerPort().GetOk() + enableOrDisableServer := ok != (e.port != nil) + portChanged := ok && e.port != nil && newPort != *e.port + if enableOrDisableServer || portChanged || !sameNode { + if e.server != nil { + e.server.Close() + e.server = nil + } + e.port = nil + if ok { + e.port = ptr.To(newPort) + } + } +} + +// Shutdown implements [ipnlocal.Extension]. +func (e *extension) Shutdown() error { + e.mu.Lock() + defer e.mu.Unlock() + e.shutdown = true + if e.server != nil { + e.server.Close() + e.server = nil + } + return nil +} + +func (e *extension) relayServerOrInit() (relayServer, error) { + e.mu.Lock() + defer e.mu.Unlock() + if e.shutdown { + return nil, errors.New("relay server is shutdown") + } + if e.server != nil { + return e.server, nil + } + if e.port == nil { + return nil, errors.New("relay server is not configured") + } + if !e.hasNodeAttrRelayServer { + return nil, errors.New("no relay:server node attribute") + } + if !envknob.UseWIPCode() { + return nil, errors.New("TAILSCALE_USE_WIP_CODE envvar is not set") + } + var err error + e.server, _, err = udprelay.NewServer(*e.port, []netip.Addr{netip.MustParseAddr("127.0.0.1")}) + if err != nil { + return nil, err + } + return e.server, nil +} + +func handlePeerAPIRelayAllocateEndpoint(h ipnlocal.PeerAPIHandler, w http.ResponseWriter, r *http.Request) { + e, ok := ipnlocal.GetExt[*extension](h.LocalBackend()) + if !ok { + http.Error(w, "relay failed to initialize", http.StatusServiceUnavailable) + return + } + + httpErrAndLog := func(message string, code int) { + http.Error(w, message, code) + h.Logf("relayserver: request from %v returned code %d: %s", h.RemoteAddr(), code, message) + } + + if !h.PeerCaps().HasCapability(tailcfg.PeerCapabilityRelay) { + httpErrAndLog("relay not permitted", http.StatusForbidden) + return + } + + if r.Method != httpm.POST { + httpErrAndLog("only POST method is allowed", http.StatusMethodNotAllowed) + return + } + + var allocateEndpointReq struct { + DiscoKeys []key.DiscoPublic + } + err := json.NewDecoder(io.LimitReader(r.Body, 512)).Decode(&allocateEndpointReq) + if err != nil { + httpErrAndLog(err.Error(), http.StatusBadRequest) + return + } + if len(allocateEndpointReq.DiscoKeys) != 2 { + httpErrAndLog("2 disco public keys must be supplied", http.StatusBadRequest) + return + } + + rs, err := e.relayServerOrInit() + if err != nil { + httpErrAndLog(err.Error(), http.StatusServiceUnavailable) + return + } + ep, err := rs.AllocateEndpoint(allocateEndpointReq.DiscoKeys[0], allocateEndpointReq.DiscoKeys[1]) + if err != nil { + httpErrAndLog(err.Error(), http.StatusInternalServerError) + return + } + err = json.NewEncoder(w).Encode(&ep) + if err != nil { + httpErrAndLog(err.Error(), http.StatusInternalServerError) + } +} diff --git a/feature/relayserver/relayserver_test.go b/feature/relayserver/relayserver_test.go new file mode 100644 index 0000000000000..cc7f05f67fbdd --- /dev/null +++ b/feature/relayserver/relayserver_test.go @@ -0,0 +1,126 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package relayserver + +import ( + "errors" + "testing" + + "tailscale.com/ipn" + "tailscale.com/net/udprelay/endpoint" + "tailscale.com/types/key" + "tailscale.com/types/ptr" +) + +type fakeRelayServer struct{} + +func (f *fakeRelayServer) Close() error { return nil } + +func (f *fakeRelayServer) AllocateEndpoint(_, _ key.DiscoPublic) (endpoint.ServerEndpoint, error) { + return endpoint.ServerEndpoint{}, errors.New("fake relay server") +} + +func Test_extension_profileStateChanged(t *testing.T) { + prefsWithPortOne := ipn.Prefs{RelayServerPort: ptr.To(1)} + prefsWithNilPort := ipn.Prefs{RelayServerPort: nil} + + type fields struct { + server relayServer + port *int + } + type args struct { + prefs ipn.PrefsView + sameNode bool + } + tests := []struct { + name string + fields fields + args args + wantPort *int + wantNilServer bool + }{ + { + name: "no changes non-nil server", + fields: fields{ + server: &fakeRelayServer{}, + port: ptr.To(1), + }, + args: args{ + prefs: prefsWithPortOne.View(), + sameNode: true, + }, + wantPort: ptr.To(1), + wantNilServer: false, + }, + { + name: "prefs port nil", + fields: fields{ + server: &fakeRelayServer{}, + port: ptr.To(1), + }, + args: args{ + prefs: prefsWithNilPort.View(), + sameNode: true, + }, + wantPort: nil, + wantNilServer: true, + }, + { + name: "prefs port changed", + fields: fields{ + server: &fakeRelayServer{}, + port: ptr.To(2), + }, + args: args{ + prefs: prefsWithPortOne.View(), + sameNode: true, + }, + wantPort: ptr.To(1), + wantNilServer: true, + }, + { + name: "sameNode false", + fields: fields{ + server: &fakeRelayServer{}, + port: ptr.To(1), + }, + args: args{ + prefs: prefsWithPortOne.View(), + sameNode: false, + }, + wantPort: ptr.To(1), + wantNilServer: true, + }, + { + name: "prefs port non-nil extension port nil", + fields: fields{ + server: nil, + port: nil, + }, + args: args{ + prefs: prefsWithPortOne.View(), + sameNode: false, + }, + wantPort: ptr.To(1), + wantNilServer: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &extension{ + port: tt.fields.port, + server: tt.fields.server, + } + e.profileStateChanged(ipn.LoginProfileView{}, tt.args.prefs, tt.args.sameNode) + if tt.wantNilServer != (e.server == nil) { + t.Errorf("wantNilServer: %v != (e.server == nil): %v", tt.wantNilServer, e.server == nil) + } + if (tt.wantPort == nil) != (e.port == nil) { + t.Errorf("(tt.wantPort == nil): %v != (e.port == nil): %v", tt.wantPort == nil, e.port == nil) + } else if tt.wantPort != nil && *tt.wantPort != *e.port { + t.Errorf("wantPort: %d != *e.port: %d", *tt.wantPort, *e.port) + } + }) + } +} diff --git a/taildrop/delete.go b/feature/taildrop/delete.go similarity index 97% rename from taildrop/delete.go rename to feature/taildrop/delete.go index aaef34df1a7e4..e9c8d7f1c90fa 100644 --- a/taildrop/delete.go +++ b/feature/taildrop/delete.go @@ -47,7 +47,7 @@ type deleteFile struct { inserted time.Time } -func (d *fileDeleter) Init(m *Manager, eventHook func(string)) { +func (d *fileDeleter) Init(m *manager, eventHook func(string)) { d.logf = m.opts.Logf d.clock = m.opts.Clock d.dir = m.opts.Dir @@ -81,7 +81,7 @@ func (d *fileDeleter) Init(m *Manager, eventHook func(string)) { // Only enqueue the file for deletion if there is no active put. nameID := strings.TrimSuffix(de.Name(), partialSuffix) if i := strings.LastIndexByte(nameID, '.'); i > 0 { - key := incomingFileKey{ClientID(nameID[i+len("."):]), nameID[:i]} + key := incomingFileKey{clientID(nameID[i+len("."):]), nameID[:i]} m.incomingFiles.LoadFunc(key, func(_ *incomingFile, loaded bool) { if !loaded { d.Insert(de.Name()) diff --git a/taildrop/delete_test.go b/feature/taildrop/delete_test.go similarity index 99% rename from taildrop/delete_test.go rename to feature/taildrop/delete_test.go index 5fa4b9c374fdf..7a58de55c2492 100644 --- a/taildrop/delete_test.go +++ b/feature/taildrop/delete_test.go @@ -69,7 +69,7 @@ func TestDeleter(t *testing.T) { } eventHook := func(event string) { eventsChan <- event } - var m Manager + var m manager var fd fileDeleter m.opts.Logf = t.Logf m.opts.Clock = tstime.DefaultClock{Clock: clock} @@ -142,7 +142,7 @@ func TestDeleter(t *testing.T) { // Test that the asynchronous full scan of the taildrop directory does not occur // on a cold start if taildrop has never received any files. func TestDeleterInitWithoutTaildrop(t *testing.T) { - var m Manager + var m manager var fd fileDeleter m.opts.Logf = t.Logf m.opts.Dir = t.TempDir() diff --git a/feature/taildrop/doc.go b/feature/taildrop/doc.go new file mode 100644 index 0000000000000..8980a217096c0 --- /dev/null +++ b/feature/taildrop/doc.go @@ -0,0 +1,5 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package taildrop registers the taildrop (file sending) feature. +package taildrop diff --git a/feature/taildrop/ext.go b/feature/taildrop/ext.go new file mode 100644 index 0000000000000..c11fe3af427a1 --- /dev/null +++ b/feature/taildrop/ext.go @@ -0,0 +1,454 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "cmp" + "context" + "errors" + "fmt" + "io" + "maps" + "os" + "path/filepath" + "runtime" + "slices" + "strings" + "sync" + "sync/atomic" + + "tailscale.com/client/tailscale/apitype" + "tailscale.com/cmd/tailscaled/tailscaledhooks" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnext" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" + "tailscale.com/tstime" + "tailscale.com/types/empty" + "tailscale.com/types/logger" + "tailscale.com/util/osshare" + "tailscale.com/util/set" +) + +func init() { + ipnext.RegisterExtension("taildrop", newExtension) + + if runtime.GOOS == "windows" { + tailscaledhooks.UninstallSystemDaemonWindows.Add(func() { + // Remove file sharing from Windows shell. + osshare.SetFileSharingEnabled(false, logger.Discard) + }) + } +} + +func newExtension(logf logger.Logf, b ipnext.SafeBackend) (ipnext.Extension, error) { + e := &Extension{ + sb: b, + stateStore: b.Sys().StateStore.Get(), + logf: logger.WithPrefix(logf, "taildrop: "), + } + e.setPlatformDefaultDirectFileRoot() + return e, nil +} + +// Extension implements Taildrop. +type Extension struct { + logf logger.Logf + sb ipnext.SafeBackend + stateStore ipn.StateStore + host ipnext.Host // from Init + + // directFileRoot, if non-empty, means to write received files + // directly to this directory, without staging them in an + // intermediate buffered directory for "pick-up" later. If + // empty, the files are received in a daemon-owned location + // and the localapi is used to enumerate, download, and delete + // them. This is used on macOS where the GUI lifetime is the + // same as the Network Extension lifetime and we can thus avoid + // double-copying files by writing them to the right location + // immediately. + // It's also used on several NAS platforms (Synology, TrueNAS, etc) + // but in that case DoFinalRename is also set true, which moves the + // *.partial file to its final name on completion. + directFileRoot string + + // FileOps abstracts platform-specific file operations needed for file transfers. + // This is currently being used for Android to use the Storage Access Framework. + FileOps FileOps + + nodeBackendForTest ipnext.NodeBackend // if non-nil, pretend we're this node state for tests + + mu sync.Mutex // Lock order: lb.mu > e.mu + backendState ipn.State + selfUID tailcfg.UserID + capFileSharing bool + fileWaiters set.HandleSet[context.CancelFunc] // of wake-up funcs + mgr atomic.Pointer[manager] // mutex held to write; safe to read without lock; + // outgoingFiles keeps track of Taildrop outgoing files keyed to their OutgoingFile.ID + outgoingFiles map[string]*ipn.OutgoingFile +} + +// safDirectoryPrefix is used to determine if the directory is managed via SAF. +const SafDirectoryPrefix = "content://" + +// PutMode controls how Manager.PutFile writes files to storage. +// +// PutModeDirect – write files directly to a filesystem path (default). +// PutModeAndroidSAF – use Android’s Storage Access Framework (SAF), where +// the OS manages the underlying directory permissions. +type PutMode int + +const ( + PutModeDirect PutMode = iota + PutModeAndroidSAF +) + +// FileOps defines platform-specific file operations. +type FileOps interface { + OpenFileWriter(filename string) (io.WriteCloser, string, error) + + // RenamePartialFile finalizes a partial file. + // It returns the new SAF URI as a string and an error. + RenamePartialFile(partialUri, targetDirUri, targetName string) (string, error) +} + +func (e *Extension) Name() string { + return "taildrop" +} + +func (e *Extension) Init(h ipnext.Host) error { + e.host = h + + osshare.SetFileSharingEnabled(false, e.logf) + + h.Hooks().ProfileStateChange.Add(e.onChangeProfile) + h.Hooks().OnSelfChange.Add(e.onSelfChange) + h.Hooks().MutateNotifyLocked.Add(e.setNotifyFilesWaiting) + h.Hooks().SetPeerStatus.Add(e.setPeerStatus) + h.Hooks().BackendStateChange.Add(e.onBackendStateChange) + + // TODO(nickkhyl): remove this after the profileManager refactoring. + // See tailscale/tailscale#15974. + profile, prefs := h.Profiles().CurrentProfileState() + e.onChangeProfile(profile, prefs, false) + return nil +} + +func (e *Extension) onBackendStateChange(st ipn.State) { + e.mu.Lock() + defer e.mu.Unlock() + e.backendState = st +} + +func (e *Extension) onSelfChange(self tailcfg.NodeView) { + e.mu.Lock() + defer e.mu.Unlock() + + e.selfUID = 0 + if self.Valid() { + e.selfUID = self.User() + } + e.capFileSharing = self.Valid() && self.CapMap().Contains(tailcfg.CapabilityFileSharing) + osshare.SetFileSharingEnabled(e.capFileSharing, e.logf) +} + +func (e *Extension) setMgrLocked(mgr *manager) { + if old := e.mgr.Swap(mgr); old != nil { + old.Shutdown() + } +} + +func (e *Extension) onChangeProfile(profile ipn.LoginProfileView, _ ipn.PrefsView, sameNode bool) { + e.mu.Lock() + defer e.mu.Unlock() + + uid := profile.UserProfile().ID + activeLogin := profile.UserProfile().LoginName + + if uid == 0 { + e.setMgrLocked(nil) + e.outgoingFiles = nil + return + } + + if sameNode && e.manager() != nil { + return + } + + // If we have a netmap, create a taildrop manager. + fileRoot, isDirectFileMode := e.fileRoot(uid, activeLogin) + if fileRoot == "" { + e.logf("no Taildrop directory configured") + } + mode := PutModeDirect + if e.directFileRoot != "" && strings.HasPrefix(e.directFileRoot, SafDirectoryPrefix) { + mode = PutModeAndroidSAF + } + e.setMgrLocked(managerOptions{ + Logf: e.logf, + Clock: tstime.DefaultClock{Clock: e.sb.Clock()}, + State: e.stateStore, + Dir: fileRoot, + DirectFileMode: isDirectFileMode, + FileOps: e.FileOps, + Mode: mode, + SendFileNotify: e.sendFileNotify, + }.New()) +} + +// fileRoot returns where to store Taildrop files for the given user and whether +// to write received files directly to this directory, without staging them in +// an intermediate buffered directory for "pick-up" later. +// +// It is safe to call this with b.mu held but it does not require it or acquire +// it itself. +func (e *Extension) fileRoot(uid tailcfg.UserID, activeLogin string) (root string, isDirect bool) { + if v := e.directFileRoot; v != "" { + return v, true + } + varRoot := e.sb.TailscaleVarRoot() + if varRoot == "" { + e.logf("Taildrop disabled; no state directory") + return "", false + } + + if activeLogin == "" { + e.logf("taildrop: no active login; can't select a target directory") + return "", false + } + + baseDir := fmt.Sprintf("%s-uid-%d", + strings.ReplaceAll(activeLogin, "@", "-"), + uid) + dir := filepath.Join(varRoot, "files", baseDir) + if err := os.MkdirAll(dir, 0700); err != nil { + e.logf("Taildrop disabled; error making directory: %v", err) + return "", false + } + return dir, false +} + +// hasCapFileSharing reports whether the current node has the file sharing +// capability. +func (e *Extension) hasCapFileSharing() bool { + e.mu.Lock() + defer e.mu.Unlock() + return e.capFileSharing +} + +// manager returns the active Manager, or nil. +// +// Methods on a nil Manager are safe to call. +func (e *Extension) manager() *manager { + return e.mgr.Load() +} + +func (e *Extension) Clock() tstime.Clock { + return e.sb.Clock() +} + +func (e *Extension) Shutdown() error { + e.manager().Shutdown() // no-op on nil receiver + return nil +} + +func (e *Extension) sendFileNotify() { + mgr := e.manager() + if mgr == nil { + return + } + + var n ipn.Notify + + e.mu.Lock() + for _, wakeWaiter := range e.fileWaiters { + wakeWaiter() + } + n.IncomingFiles = mgr.IncomingFiles() + e.mu.Unlock() + + e.host.SendNotifyAsync(n) +} + +func (e *Extension) setNotifyFilesWaiting(n *ipn.Notify) { + if e.manager().HasFilesWaiting() { + n.FilesWaiting = &empty.Message{} + } +} + +func (e *Extension) setPeerStatus(ps *ipnstate.PeerStatus, p tailcfg.NodeView, nb ipnext.NodeBackend) { + ps.TaildropTarget = e.taildropTargetStatus(p, nb) +} + +func (e *Extension) removeFileWaiter(handle set.Handle) { + e.mu.Lock() + defer e.mu.Unlock() + delete(e.fileWaiters, handle) +} + +func (e *Extension) addFileWaiter(wakeWaiter context.CancelFunc) set.Handle { + e.mu.Lock() + defer e.mu.Unlock() + return e.fileWaiters.Add(wakeWaiter) +} + +func (e *Extension) WaitingFiles() ([]apitype.WaitingFile, error) { + return e.manager().WaitingFiles() +} + +// AwaitWaitingFiles is like WaitingFiles but blocks while ctx is not done, +// waiting for any files to be available. +// +// On return, exactly one of the results will be non-empty or non-nil, +// respectively. +func (e *Extension) AwaitWaitingFiles(ctx context.Context) ([]apitype.WaitingFile, error) { + if ff, err := e.WaitingFiles(); err != nil || len(ff) > 0 { + return ff, err + } + if err := ctx.Err(); err != nil { + return nil, err + } + for { + gotFile, gotFileCancel := context.WithCancel(context.Background()) + defer gotFileCancel() + + handle := e.addFileWaiter(gotFileCancel) + defer e.removeFileWaiter(handle) + + // Now that we've registered ourselves, check again, in case + // of race. Otherwise there's a small window where we could + // miss a file arrival and wait forever. + if ff, err := e.WaitingFiles(); err != nil || len(ff) > 0 { + return ff, err + } + + select { + case <-gotFile.Done(): + if ff, err := e.WaitingFiles(); err != nil || len(ff) > 0 { + return ff, err + } + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +func (e *Extension) DeleteFile(name string) error { + return e.manager().DeleteFile(name) +} + +func (e *Extension) OpenFile(name string) (rc io.ReadCloser, size int64, err error) { + return e.manager().OpenFile(name) +} + +func (e *Extension) nodeBackend() ipnext.NodeBackend { + if e.nodeBackendForTest != nil { + return e.nodeBackendForTest + } + return e.host.NodeBackend() +} + +// FileTargets lists nodes that the current node can send files to. +func (e *Extension) FileTargets() ([]*apitype.FileTarget, error) { + var ret []*apitype.FileTarget + + e.mu.Lock() + st := e.backendState + self := e.selfUID + e.mu.Unlock() + + if st != ipn.Running { + return nil, errors.New("not connected to the tailnet") + } + if !e.hasCapFileSharing() { + return nil, errors.New("file sharing not enabled by Tailscale admin") + } + nb := e.nodeBackend() + peers := nb.AppendMatchingPeers(nil, func(p tailcfg.NodeView) bool { + if !p.Valid() || p.Hostinfo().OS() == "tvOS" { + return false + } + if self == p.User() { + return true + } + if nb.PeerHasCap(p, tailcfg.PeerCapabilityFileSharingTarget) { + // Explicitly noted in the netmap ACL caps as a target. + return true + } + return false + }) + for _, p := range peers { + peerAPI := nb.PeerAPIBase(p) + if peerAPI == "" { + continue + } + ret = append(ret, &apitype.FileTarget{ + Node: p.AsStruct(), + PeerAPIURL: peerAPI, + }) + } + slices.SortFunc(ret, func(a, b *apitype.FileTarget) int { + return cmp.Compare(a.Node.Name, b.Node.Name) + }) + return ret, nil +} + +func (e *Extension) taildropTargetStatus(p tailcfg.NodeView, nb ipnext.NodeBackend) ipnstate.TaildropTargetStatus { + e.mu.Lock() + st := e.backendState + selfUID := e.selfUID + capFileSharing := e.capFileSharing + e.mu.Unlock() + + if st != ipn.Running { + return ipnstate.TaildropTargetIpnStateNotRunning + } + + if !capFileSharing { + return ipnstate.TaildropTargetMissingCap + } + if !p.Valid() { + return ipnstate.TaildropTargetNoPeerInfo + } + if !p.Online().Get() { + return ipnstate.TaildropTargetOffline + } + if p.Hostinfo().OS() == "tvOS" { + return ipnstate.TaildropTargetUnsupportedOS + } + if selfUID != p.User() { + // Different user must have the explicit file sharing target capability + if !nb.PeerHasCap(p, tailcfg.PeerCapabilityFileSharingTarget) { + return ipnstate.TaildropTargetOwnedByOtherUser + } + } + if !nb.PeerHasPeerAPI(p) { + return ipnstate.TaildropTargetNoPeerAPI + } + return ipnstate.TaildropTargetAvailable +} + +// updateOutgoingFiles updates b.outgoingFiles to reflect the given updates and +// sends an ipn.Notify with the full list of outgoingFiles. +func (e *Extension) updateOutgoingFiles(updates map[string]*ipn.OutgoingFile) { + e.mu.Lock() + if e.outgoingFiles == nil { + e.outgoingFiles = make(map[string]*ipn.OutgoingFile, len(updates)) + } + maps.Copy(e.outgoingFiles, updates) + outgoingFiles := make([]*ipn.OutgoingFile, 0, len(e.outgoingFiles)) + for _, file := range e.outgoingFiles { + outgoingFiles = append(outgoingFiles, file) + } + e.mu.Unlock() + slices.SortFunc(outgoingFiles, func(a, b *ipn.OutgoingFile) int { + t := a.Started.Compare(b.Started) + if t != 0 { + return t + } + return strings.Compare(a.Name, b.Name) + }) + + e.host.SendNotifyAsync(ipn.Notify{OutgoingFiles: outgoingFiles}) +} diff --git a/feature/taildrop/integration_test.go b/feature/taildrop/integration_test.go new file mode 100644 index 0000000000000..75896a95b2b54 --- /dev/null +++ b/feature/taildrop/integration_test.go @@ -0,0 +1,196 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop_test + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "testing" + "time" + + "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/tstest/integration" + "tailscale.com/tstest/integration/testcontrol" +) + +// TODO(bradfitz): add test where control doesn't send tailcfg.CapabilityFileSharing +// and verify that we get the "file sharing not enabled by Tailscale admin" error. + +// TODO(bradfitz): add test between different users with the peercap to permit that? + +func TestTaildropIntegration(t *testing.T) { + testTaildropIntegration(t, false) +} + +func TestTaildropIntegration_Fresh(t *testing.T) { + testTaildropIntegration(t, true) +} + +// freshProfiles is whether to start the test right away +// with a fresh profile. If false, tailscaled is started, stopped, +// and restarted again to simulate a real-world scenario where +// the first profile already existed. +// +// This exercises an ipnext hook ordering issue we hit earlier. +func testTaildropIntegration(t *testing.T, freshProfiles bool) { + tstest.Parallel(t) + controlOpt := integration.ConfigureControl(func(s *testcontrol.Server) { + s.AllNodesSameUser = true // required for Taildrop + }) + env := integration.NewTestEnv(t, controlOpt) + + // Create two nodes: + n1 := integration.NewTestNode(t, env) + d1 := n1.StartDaemon() + + n2 := integration.NewTestNode(t, env) + d2 := n2.StartDaemon() + + awaitUp := func() { + t.Helper() + n1.AwaitListening() + t.Logf("n1 is listening") + n2.AwaitListening() + t.Logf("n2 is listening") + n1.MustUp() + t.Logf("n1 is up") + n2.MustUp() + t.Logf("n2 is up") + n1.AwaitRunning() + t.Logf("n1 is running") + n2.AwaitRunning() + t.Logf("n2 is running") + } + awaitUp() + + if !freshProfiles { + d1.MustCleanShutdown(t) + d2.MustCleanShutdown(t) + d1 = n1.StartDaemon() + d2 = n2.StartDaemon() + awaitUp() + } + + var peerStableID tailcfg.StableNodeID + + if err := tstest.WaitFor(5*time.Second, func() error { + st := n1.MustStatus() + if len(st.Peer) == 0 { + return errors.New("no peers") + } + if len(st.Peer) > 1 { + return fmt.Errorf("got %d peers; want 1", len(st.Peer)) + } + peer := st.Peer[st.Peers()[0]] + peerStableID = peer.ID + if peer.ID == st.Self.ID { + return errors.New("peer is self") + } + + if len(st.TailscaleIPs) == 0 { + return errors.New("no Tailscale IPs") + } + + return nil + }); err != nil { + t.Fatal(err) + } + + const timeout = 30 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + c1 := n1.LocalClient() + c2 := n2.LocalClient() + + wantNoWaitingFiles := func(c *local.Client) { + t.Helper() + files, err := c.WaitingFiles(ctx) + if err != nil { + t.Fatalf("WaitingFiles: %v", err) + } + if len(files) != 0 { + t.Fatalf("WaitingFiles: got %d files; want 0", len(files)) + } + } + + // Verify c2 has no files. + wantNoWaitingFiles(c2) + + gotFile := make(chan bool, 1) + go func() { + v, err := c2.AwaitWaitingFiles(t.Context(), timeout) + if err != nil { + return + } + if len(v) != 0 { + gotFile <- true + } + }() + + fileContents := []byte("hello world this is a file") + + n2ID := n2.MustStatus().Self.ID + t.Logf("n2 self.ID = %q; n1's peer[0].ID = %q", n2ID, peerStableID) + t.Logf("Doing PushFile ...") + err := c1.PushFile(ctx, n2.MustStatus().Self.ID, int64(len(fileContents)), "test.txt", bytes.NewReader(fileContents)) + if err != nil { + t.Fatalf("PushFile from n1->n2: %v", err) + } + t.Logf("PushFile done") + + select { + case <-gotFile: + t.Logf("n2 saw AwaitWaitingFiles wake up") + case <-ctx.Done(): + t.Fatalf("n2 timeout waiting for AwaitWaitingFiles") + } + + files, err := c2.WaitingFiles(ctx) + if err != nil { + t.Fatalf("c2.WaitingFiles: %v", err) + } + if len(files) != 1 { + t.Fatalf("c2.WaitingFiles: got %d files; want 1", len(files)) + } + got := files[0] + want := apitype.WaitingFile{ + Name: "test.txt", + Size: int64(len(fileContents)), + } + if got != want { + t.Fatalf("c2.WaitingFiles: got %+v; want %+v", got, want) + } + + // Download the file. + rc, size, err := c2.GetWaitingFile(ctx, got.Name) + if err != nil { + t.Fatalf("c2.GetWaitingFile: %v", err) + } + if size != int64(len(fileContents)) { + t.Fatalf("c2.GetWaitingFile: got size %d; want %d", size, len(fileContents)) + } + gotBytes, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("c2.GetWaitingFile: %v", err) + } + if !bytes.Equal(gotBytes, fileContents) { + t.Fatalf("c2.GetWaitingFile: got %q; want %q", gotBytes, fileContents) + } + + // Now delete it. + if err := c2.DeleteWaitingFile(ctx, got.Name); err != nil { + t.Fatalf("c2.DeleteWaitingFile: %v", err) + } + wantNoWaitingFiles(c2) + + d1.MustCleanShutdown(t) + d2.MustCleanShutdown(t) +} diff --git a/feature/taildrop/localapi.go b/feature/taildrop/localapi.go new file mode 100644 index 0000000000000..8a3904f9f0198 --- /dev/null +++ b/feature/taildrop/localapi.go @@ -0,0 +1,458 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "maps" + "mime" + "mime/multipart" + "net/http" + "net/http/httputil" + "net/url" + "strconv" + "strings" + "time" + + "tailscale.com/client/tailscale/apitype" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/localapi" + "tailscale.com/tailcfg" + "tailscale.com/util/clientmetric" + "tailscale.com/util/httphdr" + "tailscale.com/util/mak" + "tailscale.com/util/progresstracking" + "tailscale.com/util/rands" +) + +func init() { + localapi.Register("file-put/", serveFilePut) + localapi.Register("files/", serveFiles) + localapi.Register("file-targets", serveFileTargets) +} + +var ( + metricFilePutCalls = clientmetric.NewCounter("localapi_file_put") +) + +// serveFilePut sends a file to another node. +// +// It's sometimes possible for clients to do this themselves, without +// tailscaled, except in the case of tailscaled running in +// userspace-networking ("netstack") mode, in which case tailscaled +// needs to a do a netstack dial out. +// +// Instead, the CLI also goes through tailscaled so it doesn't need to be +// aware of the network mode in use. +// +// macOS/iOS have always used this localapi method to simplify the GUI +// clients. +// +// The Windows client currently (2021-11-30) uses the peerapi (/v0/put/) +// directly, as the Windows GUI always runs in tun mode anyway. +// +// In addition to single file PUTs, this endpoint accepts multipart file +// POSTS encoded as multipart/form-data.The first part should be an +// application/json file that contains a manifest consisting of a JSON array of +// OutgoingFiles which we can use for tracking progress even before reading the +// file parts. +// +// URL format: +// +// - PUT /localapi/v0/file-put/:stableID/:escaped-filename +// - POST /localapi/v0/file-put/:stableID +func serveFilePut(h *localapi.Handler, w http.ResponseWriter, r *http.Request) { + metricFilePutCalls.Add(1) + + if !h.PermitWrite { + http.Error(w, "file access denied", http.StatusForbidden) + return + } + + if r.Method != "PUT" && r.Method != "POST" { + http.Error(w, "want PUT to put file", http.StatusBadRequest) + return + } + + ext, ok := ipnlocal.GetExt[*Extension](h.LocalBackend()) + if !ok { + http.Error(w, "misconfigured taildrop extension", http.StatusInternalServerError) + return + } + + fts, err := ext.FileTargets() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + upath, ok := strings.CutPrefix(r.URL.EscapedPath(), "/localapi/v0/file-put/") + if !ok { + http.Error(w, "misconfigured", http.StatusInternalServerError) + return + } + var peerIDStr, filenameEscaped string + if r.Method == "PUT" { + ok := false + peerIDStr, filenameEscaped, ok = strings.Cut(upath, "/") + if !ok { + http.Error(w, "bogus URL", http.StatusBadRequest) + return + } + } else { + peerIDStr = upath + } + peerID := tailcfg.StableNodeID(peerIDStr) + + var ft *apitype.FileTarget + for _, x := range fts { + if x.Node.StableID == peerID { + ft = x + break + } + } + if ft == nil { + http.Error(w, "node not found", http.StatusNotFound) + return + } + dstURL, err := url.Parse(ft.PeerAPIURL) + if err != nil { + http.Error(w, "bogus peer URL", http.StatusInternalServerError) + return + } + + // Periodically report progress of outgoing files. + outgoingFiles := make(map[string]*ipn.OutgoingFile) + t := time.NewTicker(1 * time.Second) + progressUpdates := make(chan ipn.OutgoingFile) + defer close(progressUpdates) + + go func() { + defer t.Stop() + defer ext.updateOutgoingFiles(outgoingFiles) + for { + select { + case u, ok := <-progressUpdates: + if !ok { + return + } + outgoingFiles[u.ID] = &u + case <-t.C: + ext.updateOutgoingFiles(outgoingFiles) + } + } + }() + + switch r.Method { + case "PUT": + file := ipn.OutgoingFile{ + ID: rands.HexString(30), + PeerID: peerID, + Name: filenameEscaped, + DeclaredSize: r.ContentLength, + } + singleFilePut(h, r.Context(), progressUpdates, w, r.Body, dstURL, file) + case "POST": + multiFilePost(h, progressUpdates, w, r, peerID, dstURL) + default: + http.Error(w, "want PUT to put file", http.StatusBadRequest) + return + } +} + +func multiFilePost(h *localapi.Handler, progressUpdates chan (ipn.OutgoingFile), w http.ResponseWriter, r *http.Request, peerID tailcfg.StableNodeID, dstURL *url.URL) { + _, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + if err != nil { + http.Error(w, fmt.Sprintf("invalid Content-Type for multipart POST: %s", err), http.StatusBadRequest) + return + } + + ww := &multiFilePostResponseWriter{} + defer func() { + if err := ww.Flush(w); err != nil { + h.Logf("error: multiFilePostResponseWriter.Flush(): %s", err) + } + }() + + outgoingFilesByName := make(map[string]ipn.OutgoingFile) + first := true + mr := multipart.NewReader(r.Body, params["boundary"]) + for { + part, err := mr.NextPart() + if err == io.EOF { + // No more parts. + return + } else if err != nil { + http.Error(ww, fmt.Sprintf("failed to decode multipart/form-data: %s", err), http.StatusBadRequest) + return + } + + if first { + first = false + if part.Header.Get("Content-Type") != "application/json" { + http.Error(ww, "first MIME part must be a JSON map of filename -> size", http.StatusBadRequest) + return + } + + var manifest []ipn.OutgoingFile + err := json.NewDecoder(part).Decode(&manifest) + if err != nil { + http.Error(ww, fmt.Sprintf("invalid manifest: %s", err), http.StatusBadRequest) + return + } + + for _, file := range manifest { + outgoingFilesByName[file.Name] = file + progressUpdates <- file + } + + continue + } + + if !singleFilePut(h, r.Context(), progressUpdates, ww, part, dstURL, outgoingFilesByName[part.FileName()]) { + return + } + + if ww.statusCode >= 400 { + // put failed, stop immediately + h.Logf("error: singleFilePut: failed with status %d", ww.statusCode) + return + } + } +} + +// multiFilePostResponseWriter is a buffering http.ResponseWriter that can be +// reused across multiple singleFilePut calls and then flushed to the client +// when all files have been PUT. +type multiFilePostResponseWriter struct { + header http.Header + statusCode int + body *bytes.Buffer +} + +func (ww *multiFilePostResponseWriter) Header() http.Header { + if ww.header == nil { + ww.header = make(http.Header) + } + return ww.header +} + +func (ww *multiFilePostResponseWriter) WriteHeader(statusCode int) { + ww.statusCode = statusCode +} + +func (ww *multiFilePostResponseWriter) Write(p []byte) (int, error) { + if ww.body == nil { + ww.body = bytes.NewBuffer(nil) + } + return ww.body.Write(p) +} + +func (ww *multiFilePostResponseWriter) Flush(w http.ResponseWriter) error { + if ww.header != nil { + maps.Copy(w.Header(), ww.header) + } + if ww.statusCode > 0 { + w.WriteHeader(ww.statusCode) + } + if ww.body != nil { + _, err := io.Copy(w, ww.body) + return err + } + return nil +} + +func singleFilePut( + h *localapi.Handler, + ctx context.Context, + progressUpdates chan (ipn.OutgoingFile), + w http.ResponseWriter, + body io.Reader, + dstURL *url.URL, + outgoingFile ipn.OutgoingFile, +) bool { + outgoingFile.Started = time.Now() + body = progresstracking.NewReader(body, 1*time.Second, func(n int, err error) { + outgoingFile.Sent = int64(n) + progressUpdates <- outgoingFile + }) + + fail := func() { + outgoingFile.Finished = true + outgoingFile.Succeeded = false + progressUpdates <- outgoingFile + } + + // Before we PUT a file we check to see if there are any existing partial file and if so, + // we resume the upload from where we left off by sending the remaining file instead of + // the full file. + var offset int64 + var resumeDuration time.Duration + remainingBody := io.Reader(body) + client := &http.Client{ + Transport: h.LocalBackend().Dialer().PeerAPITransport(), + Timeout: 10 * time.Second, + } + req, err := http.NewRequestWithContext(ctx, "GET", dstURL.String()+"/v0/put/"+outgoingFile.Name, nil) + if err != nil { + http.Error(w, "bogus peer URL", http.StatusInternalServerError) + fail() + return false + } + resp, err := client.Do(req) + if resp != nil { + defer resp.Body.Close() + } + switch { + case err != nil: + h.Logf("could not fetch remote hashes: %v", err) + case resp.StatusCode == http.StatusMethodNotAllowed || resp.StatusCode == http.StatusNotFound: + // noop; implies older peerapi without resume support + case resp.StatusCode != http.StatusOK: + h.Logf("fetch remote hashes status code: %d", resp.StatusCode) + default: + resumeStart := time.Now() + dec := json.NewDecoder(resp.Body) + offset, remainingBody, err = resumeReader(body, func() (out blockChecksum, err error) { + err = dec.Decode(&out) + return out, err + }) + if err != nil { + h.Logf("reader could not be fully resumed: %v", err) + } + resumeDuration = time.Since(resumeStart).Round(time.Millisecond) + } + + outReq, err := http.NewRequestWithContext(ctx, "PUT", "http://peer/v0/put/"+outgoingFile.Name, remainingBody) + if err != nil { + http.Error(w, "bogus outreq", http.StatusInternalServerError) + fail() + return false + } + outReq.ContentLength = outgoingFile.DeclaredSize + if offset > 0 { + h.Logf("resuming put at offset %d after %v", offset, resumeDuration) + rangeHdr, _ := httphdr.FormatRange([]httphdr.Range{{Start: offset, Length: 0}}) + outReq.Header.Set("Range", rangeHdr) + if outReq.ContentLength >= 0 { + outReq.ContentLength -= offset + } + } + + rp := httputil.NewSingleHostReverseProxy(dstURL) + rp.Transport = h.LocalBackend().Dialer().PeerAPITransport() + rp.ServeHTTP(w, outReq) + + outgoingFile.Finished = true + outgoingFile.Succeeded = true + progressUpdates <- outgoingFile + + return true +} + +func serveFiles(h *localapi.Handler, w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "file access denied", http.StatusForbidden) + return + } + + ext, ok := ipnlocal.GetExt[*Extension](h.LocalBackend()) + if !ok { + http.Error(w, "misconfigured taildrop extension", http.StatusInternalServerError) + return + } + + suffix, ok := strings.CutPrefix(r.URL.EscapedPath(), "/localapi/v0/files/") + if !ok { + http.Error(w, "misconfigured", http.StatusInternalServerError) + return + } + if suffix == "" { + if r.Method != "GET" { + http.Error(w, "want GET to list files", http.StatusBadRequest) + return + } + ctx := r.Context() + var wfs []apitype.WaitingFile + if s := r.FormValue("waitsec"); s != "" && s != "0" { + d, err := strconv.Atoi(s) + if err != nil { + http.Error(w, "invalid waitsec", http.StatusBadRequest) + return + } + deadline := time.Now().Add(time.Duration(d) * time.Second) + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, deadline) + defer cancel() + wfs, err = ext.AwaitWaitingFiles(ctx) + if err != nil && ctx.Err() == nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } else { + var err error + wfs, err = ext.WaitingFiles() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(wfs) + return + } + name, err := url.PathUnescape(suffix) + if err != nil { + http.Error(w, "bad filename", http.StatusBadRequest) + return + } + if r.Method == "DELETE" { + if err := ext.DeleteFile(name); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusNoContent) + return + } + rc, size, err := ext.OpenFile(name) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer rc.Close() + w.Header().Set("Content-Length", fmt.Sprint(size)) + w.Header().Set("Content-Type", "application/octet-stream") + io.Copy(w, rc) +} + +func serveFileTargets(h *localapi.Handler, w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "access denied", http.StatusForbidden) + return + } + if r.Method != "GET" { + http.Error(w, "want GET to list targets", http.StatusBadRequest) + return + } + + ext, ok := ipnlocal.GetExt[*Extension](h.LocalBackend()) + if !ok { + http.Error(w, "misconfigured taildrop extension", http.StatusInternalServerError) + return + } + + fts, err := ext.FileTargets() + if err != nil { + localapi.WriteErrorJSON(w, err) + return + } + mak.NonNilSliceForJSON(&fts) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(fts) +} diff --git a/cmd/tailscaled/taildrop.go b/feature/taildrop/paths.go similarity index 85% rename from cmd/tailscaled/taildrop.go rename to feature/taildrop/paths.go index 39fe54373bdda..22d01160cff8e 100644 --- a/cmd/tailscaled/taildrop.go +++ b/feature/taildrop/paths.go @@ -1,35 +1,44 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build go1.19 - -package main +package taildrop import ( "fmt" "os" "path/filepath" - "tailscale.com/ipn/ipnlocal" - "tailscale.com/types/logger" "tailscale.com/version/distro" ) -func configureTaildrop(logf logger.Logf, lb *ipnlocal.LocalBackend) { +// SetDirectFileRoot sets the directory where received files are written. +// +// This must be called before Tailscale is started. +func (e *Extension) SetDirectFileRoot(root string) { + e.directFileRoot = root +} + +// SetFileOps sets the platform specific file operations. This is used +// to call Android's Storage Access Framework APIs. +func (e *Extension) SetFileOps(fileOps FileOps) { + e.FileOps = fileOps +} + +func (e *Extension) setPlatformDefaultDirectFileRoot() { dg := distro.Get() + switch dg { case distro.Synology, distro.TrueNAS, distro.QNAP, distro.Unraid: // See if they have a "Taildrop" share. // See https://github.com/tailscale/tailscale/issues/2179#issuecomment-982821319 path, err := findTaildropDir(dg) if err != nil { - logf("%s Taildrop support: %v", dg, err) + e.logf("%s Taildrop support: %v", dg, err) } else { - logf("%s Taildrop: using %v", dg, path) - lb.SetDirectFileRoot(path) + e.logf("%s Taildrop: using %v", dg, path) + e.directFileRoot = path } } - } func findTaildropDir(dg distro.Distro) (string, error) { diff --git a/feature/taildrop/peerapi.go b/feature/taildrop/peerapi.go new file mode 100644 index 0000000000000..b75ce33b864b4 --- /dev/null +++ b/feature/taildrop/peerapi.go @@ -0,0 +1,169 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "tailscale.com/ipn/ipnlocal" + "tailscale.com/tailcfg" + "tailscale.com/tstime" + "tailscale.com/util/clientmetric" + "tailscale.com/util/httphdr" +) + +func init() { + ipnlocal.RegisterPeerAPIHandler("/v0/put/", handlePeerPut) +} + +var ( + metricPutCalls = clientmetric.NewCounter("peerapi_put") +) + +// canPutFile reports whether h can put a file ("Taildrop") to this node. +func canPutFile(h ipnlocal.PeerAPIHandler) bool { + if h.Peer().UnsignedPeerAPIOnly() { + // Unsigned peers can't send files. + return false + } + return h.IsSelfUntagged() || h.PeerCaps().HasCapability(tailcfg.PeerCapabilityFileSharingSend) +} + +func handlePeerPut(h ipnlocal.PeerAPIHandler, w http.ResponseWriter, r *http.Request) { + ext, ok := ipnlocal.GetExt[*Extension](h.LocalBackend()) + if !ok { + http.Error(w, "miswired", http.StatusInternalServerError) + return + } + handlePeerPutWithBackend(h, ext, w, r) +} + +// extensionForPut is the subset of taildrop extension that taildrop +// file put needs. This is pulled out for testability. +type extensionForPut interface { + manager() *manager + hasCapFileSharing() bool + Clock() tstime.Clock +} + +func handlePeerPutWithBackend(h ipnlocal.PeerAPIHandler, ext extensionForPut, w http.ResponseWriter, r *http.Request) { + if r.Method == "PUT" { + metricPutCalls.Add(1) + } + + taildropMgr := ext.manager() + if taildropMgr == nil { + h.Logf("taildrop: no taildrop manager") + http.Error(w, "failed to get taildrop manager", http.StatusInternalServerError) + return + } + + if !canPutFile(h) { + http.Error(w, ErrNoTaildrop.Error(), http.StatusForbidden) + return + } + if !ext.hasCapFileSharing() { + http.Error(w, ErrNoTaildrop.Error(), http.StatusForbidden) + return + } + rawPath := r.URL.EscapedPath() + prefix, ok := strings.CutPrefix(rawPath, "/v0/put/") + if !ok { + http.Error(w, "misconfigured internals", http.StatusForbidden) + return + } + baseName, err := url.PathUnescape(prefix) + if err != nil { + http.Error(w, ErrInvalidFileName.Error(), http.StatusBadRequest) + return + } + enc := json.NewEncoder(w) + switch r.Method { + case "GET": + id := clientID(h.Peer().StableID()) + if prefix == "" { + // List all the partial files. + files, err := taildropMgr.PartialFiles(id) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := enc.Encode(files); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + h.Logf("json.Encoder.Encode error: %v", err) + return + } + } else { + // Stream all the block hashes for the specified file. + next, close, err := taildropMgr.HashPartialFile(id, baseName) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer close() + for { + switch cs, err := next(); { + case err == io.EOF: + return + case err != nil: + http.Error(w, err.Error(), http.StatusInternalServerError) + h.Logf("HashPartialFile.next error: %v", err) + return + default: + if err := enc.Encode(cs); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + h.Logf("json.Encoder.Encode error: %v", err) + return + } + } + } + } + case "PUT": + t0 := ext.Clock().Now() + id := clientID(h.Peer().StableID()) + + var offset int64 + if rangeHdr := r.Header.Get("Range"); rangeHdr != "" { + ranges, ok := httphdr.ParseRange(rangeHdr) + if !ok || len(ranges) != 1 || ranges[0].Length != 0 { + http.Error(w, "invalid Range header", http.StatusBadRequest) + return + } + offset = ranges[0].Start + } + n, err := taildropMgr.PutFile(clientID(fmt.Sprint(id)), baseName, r.Body, offset, r.ContentLength) + switch err { + case nil: + d := ext.Clock().Since(t0).Round(time.Second / 10) + h.Logf("got put of %s in %v from %v/%v", approxSize(n), d, h.RemoteAddr().Addr(), h.Peer().ComputedName) + io.WriteString(w, "{}\n") + case ErrNoTaildrop: + http.Error(w, err.Error(), http.StatusForbidden) + case ErrInvalidFileName: + http.Error(w, err.Error(), http.StatusBadRequest) + case ErrFileExists: + http.Error(w, err.Error(), http.StatusConflict) + default: + http.Error(w, err.Error(), http.StatusInternalServerError) + } + default: + http.Error(w, "expected method GET or PUT", http.StatusMethodNotAllowed) + } +} + +func approxSize(n int64) string { + if n <= 1<<10 { + return "<=1KB" + } + if n <= 1<<20 { + return "<=1MB" + } + return fmt.Sprintf("~%dMB", n>>20) +} diff --git a/feature/taildrop/peerapi_test.go b/feature/taildrop/peerapi_test.go new file mode 100644 index 0000000000000..1a003b6eddca7 --- /dev/null +++ b/feature/taildrop/peerapi_test.go @@ -0,0 +1,574 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "bytes" + "fmt" + "io" + "io/fs" + "math/rand" + "net/http" + "net/http/httptest" + "net/netip" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/tstime" + "tailscale.com/types/logger" +) + +// peerAPIHandler serves the PeerAPI for a source specific client. +type peerAPIHandler struct { + remoteAddr netip.AddrPort + isSelf bool // whether peerNode is owned by same user as this node + selfNode tailcfg.NodeView // this node; always non-nil + peerNode tailcfg.NodeView // peerNode is who's making the request +} + +func (h *peerAPIHandler) IsSelfUntagged() bool { + return !h.selfNode.IsTagged() && !h.peerNode.IsTagged() && h.isSelf +} +func (h *peerAPIHandler) Peer() tailcfg.NodeView { return h.peerNode } +func (h *peerAPIHandler) Self() tailcfg.NodeView { return h.selfNode } +func (h *peerAPIHandler) RemoteAddr() netip.AddrPort { return h.remoteAddr } +func (h *peerAPIHandler) LocalBackend() *ipnlocal.LocalBackend { panic("unexpected") } +func (h *peerAPIHandler) Logf(format string, a ...any) { + //h.logf(format, a...) +} + +func (h *peerAPIHandler) PeerCaps() tailcfg.PeerCapMap { + return nil +} + +type fakeExtension struct { + logf logger.Logf + capFileSharing bool + clock tstime.Clock + taildrop *manager +} + +func (lb *fakeExtension) manager() *manager { + return lb.taildrop +} +func (lb *fakeExtension) Clock() tstime.Clock { return lb.clock } +func (lb *fakeExtension) hasCapFileSharing() bool { + return lb.capFileSharing +} + +type peerAPITestEnv struct { + taildrop *manager + ph *peerAPIHandler + rr *httptest.ResponseRecorder + logBuf tstest.MemLogger +} + +type check func(*testing.T, *peerAPITestEnv) + +func checks(vv ...check) []check { return vv } + +func httpStatus(wantStatus int) check { + return func(t *testing.T, e *peerAPITestEnv) { + if res := e.rr.Result(); res.StatusCode != wantStatus { + t.Errorf("HTTP response code = %v; want %v", res.Status, wantStatus) + } + } +} + +func bodyContains(sub string) check { + return func(t *testing.T, e *peerAPITestEnv) { + if body := e.rr.Body.String(); !strings.Contains(body, sub) { + t.Errorf("HTTP response body does not contain %q; got: %s", sub, body) + } + } +} + +func fileHasSize(name string, size int) check { + return func(t *testing.T, e *peerAPITestEnv) { + root := e.taildrop.Dir() + if root == "" { + t.Errorf("no rootdir; can't check whether %q has size %v", name, size) + return + } + path := filepath.Join(root, name) + if fi, err := os.Stat(path); err != nil { + t.Errorf("fileHasSize(%q, %v): %v", name, size, err) + } else if fi.Size() != int64(size) { + t.Errorf("file %q has size %v; want %v", name, fi.Size(), size) + } + } +} + +func fileHasContents(name string, want string) check { + return func(t *testing.T, e *peerAPITestEnv) { + root := e.taildrop.Dir() + if root == "" { + t.Errorf("no rootdir; can't check contents of %q", name) + return + } + path := filepath.Join(root, name) + got, err := os.ReadFile(path) + if err != nil { + t.Errorf("fileHasContents: %v", err) + return + } + if string(got) != want { + t.Errorf("file contents = %q; want %q", got, want) + } + } +} + +func hexAll(v string) string { + var sb strings.Builder + for i := range len(v) { + fmt.Fprintf(&sb, "%%%02x", v[i]) + } + return sb.String() +} + +func TestHandlePeerAPI(t *testing.T) { + tests := []struct { + name string + isSelf bool // the peer sending the request is owned by us + capSharing bool // self node has file sharing capability + debugCap bool // self node has debug capability + omitRoot bool // don't configure + reqs []*http.Request + checks []check + }{ + { + name: "reject_non_owner_put", + isSelf: false, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, + checks: checks( + httpStatus(http.StatusForbidden), + bodyContains("Taildrop disabled"), + ), + }, + { + name: "owner_without_cap", + isSelf: true, + capSharing: false, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, + checks: checks( + httpStatus(http.StatusForbidden), + bodyContains("Taildrop disabled"), + ), + }, + { + name: "owner_with_cap_no_rootdir", + omitRoot: true, + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, + checks: checks( + httpStatus(http.StatusForbidden), + bodyContains("Taildrop disabled; no storage directory"), + ), + }, + { + name: "bad_method", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("POST", "/v0/put/foo", nil)}, + checks: checks( + httpStatus(405), + bodyContains("expected method GET or PUT"), + ), + }, + { + name: "put_zero_length", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, + checks: checks( + httpStatus(200), + bodyContains("{}"), + fileHasSize("foo", 0), + fileHasContents("foo", ""), + ), + }, + { + name: "put_non_zero_length_content_length", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("contents"))}, + checks: checks( + httpStatus(200), + bodyContains("{}"), + fileHasSize("foo", len("contents")), + fileHasContents("foo", "contents"), + ), + }, + { + name: "put_non_zero_length_chunked", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", struct{ io.Reader }{strings.NewReader("contents")})}, + checks: checks( + httpStatus(200), + bodyContains("{}"), + fileHasSize("foo", len("contents")), + fileHasContents("foo", "contents"), + ), + }, + { + name: "bad_filename_partial", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo.partial", nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_deleted", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo.deleted", nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_dot", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/.", nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_empty", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/", nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_slash", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo/bar", nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_encoded_dot", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("."), nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_encoded_slash", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("/"), nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_encoded_backslash", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("\\"), nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_encoded_dotdot", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll(".."), nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "bad_filename_encoded_dotdot_out", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("foo/../../../../../etc/passwd"), nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "put_spaces_and_caps", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("Foo Bar.dat"), strings.NewReader("baz"))}, + checks: checks( + httpStatus(200), + bodyContains("{}"), + fileHasContents("Foo Bar.dat", "baz"), + ), + }, + { + name: "put_unicode", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("ĐĸĐžĐŧĐ°Ņ и ĐĩĐŗĐž Đ´Ņ€ŅƒĐˇŅŒŅ.mp3"), strings.NewReader("ĐŗĐģавĐŊŅ‹Đš ĐžĐˇĐžŅ€ĐŊиĐē"))}, + checks: checks( + httpStatus(200), + bodyContains("{}"), + fileHasContents("ĐĸĐžĐŧĐ°Ņ и ĐĩĐŗĐž Đ´Ņ€ŅƒĐˇŅŒŅ.mp3", "ĐŗĐģавĐŊŅ‹Đš ĐžĐˇĐžŅ€ĐŊиĐē"), + ), + }, + { + name: "put_invalid_utf8", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+(hexAll("😜")[:3]), nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "put_invalid_null", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/%00", nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "put_invalid_non_printable", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/%01", nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "put_invalid_colon", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("nul:"), nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "put_invalid_surrounding_whitespace", + isSelf: true, + capSharing: true, + reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll(" foo "), nil)}, + checks: checks( + httpStatus(400), + bodyContains("invalid filename"), + ), + }, + { + name: "duplicate_zero_length", + isSelf: true, + capSharing: true, + reqs: []*http.Request{ + httptest.NewRequest("PUT", "/v0/put/foo", nil), + httptest.NewRequest("PUT", "/v0/put/foo", nil), + }, + checks: checks( + httpStatus(200), + func(t *testing.T, env *peerAPITestEnv) { + got, err := env.taildrop.WaitingFiles() + if err != nil { + t.Fatalf("WaitingFiles error: %v", err) + } + want := []apitype.WaitingFile{{Name: "foo", Size: 0}} + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("WaitingFile mismatch (-got +want):\n%s", diff) + } + }, + ), + }, + { + name: "duplicate_non_zero_length_content_length", + isSelf: true, + capSharing: true, + reqs: []*http.Request{ + httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("contents")), + httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("contents")), + }, + checks: checks( + httpStatus(200), + func(t *testing.T, env *peerAPITestEnv) { + got, err := env.taildrop.WaitingFiles() + if err != nil { + t.Fatalf("WaitingFiles error: %v", err) + } + want := []apitype.WaitingFile{{Name: "foo", Size: 8}} + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("WaitingFile mismatch (-got +want):\n%s", diff) + } + }, + ), + }, + { + name: "duplicate_different_files", + isSelf: true, + capSharing: true, + reqs: []*http.Request{ + httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("fizz")), + httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("buzz")), + }, + checks: checks( + httpStatus(200), + func(t *testing.T, env *peerAPITestEnv) { + got, err := env.taildrop.WaitingFiles() + if err != nil { + t.Fatalf("WaitingFiles error: %v", err) + } + want := []apitype.WaitingFile{{Name: "foo", Size: 4}, {Name: "foo (1)", Size: 4}} + if diff := cmp.Diff(got, want); diff != "" { + t.Fatalf("WaitingFile mismatch (-got +want):\n%s", diff) + } + }, + ), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + selfNode := &tailcfg.Node{ + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.100.100.101/32"), + }, + } + if tt.debugCap { + selfNode.CapMap = tailcfg.NodeCapMap{tailcfg.CapabilityDebug: nil} + } + var rootDir string + if !tt.omitRoot { + rootDir = t.TempDir() + } + + var e peerAPITestEnv + e.taildrop = managerOptions{ + Logf: e.logBuf.Logf, + Dir: rootDir, + }.New() + + ext := &fakeExtension{ + logf: e.logBuf.Logf, + capFileSharing: tt.capSharing, + clock: &tstest.Clock{}, + taildrop: e.taildrop, + } + e.ph = &peerAPIHandler{ + isSelf: tt.isSelf, + selfNode: selfNode.View(), + peerNode: (&tailcfg.Node{ + ComputedName: "some-peer-name", + }).View(), + } + for _, req := range tt.reqs { + e.rr = httptest.NewRecorder() + if req.Host == "example.com" { + req.Host = "100.100.100.101:12345" + } + handlePeerPutWithBackend(e.ph, ext, e.rr, req) + } + for _, f := range tt.checks { + f(t, &e) + } + if t.Failed() && rootDir != "" { + t.Logf("Contents of %s:", rootDir) + des, _ := fs.ReadDir(os.DirFS(rootDir), ".") + for _, de := range des { + fi, err := de.Info() + if err != nil { + t.Log(err) + } else { + t.Logf(" %v %5d %s", fi.Mode(), fi.Size(), de.Name()) + } + } + } + }) + } +} + +// Windows likes to hold on to file descriptors for some indeterminate +// amount of time after you close them and not let you delete them for +// a bit. So test that we work around that sufficiently. +func TestFileDeleteRace(t *testing.T) { + dir := t.TempDir() + taildropMgr := managerOptions{ + Logf: t.Logf, + Dir: dir, + }.New() + + ph := &peerAPIHandler{ + isSelf: true, + peerNode: (&tailcfg.Node{ + ComputedName: "some-peer-name", + }).View(), + selfNode: (&tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("100.100.100.101/32")}, + }).View(), + } + fakeLB := &fakeExtension{ + logf: t.Logf, + capFileSharing: true, + clock: &tstest.Clock{}, + taildrop: taildropMgr, + } + buf := make([]byte, 2<<20) + for range 30 { + rr := httptest.NewRecorder() + handlePeerPutWithBackend(ph, fakeLB, rr, httptest.NewRequest("PUT", "http://100.100.100.101:123/v0/put/foo.txt", bytes.NewReader(buf[:rand.Intn(len(buf))]))) + if res := rr.Result(); res.StatusCode != 200 { + t.Fatal(res.Status) + } + wfs, err := taildropMgr.WaitingFiles() + if err != nil { + t.Fatal(err) + } + if len(wfs) != 1 { + t.Fatalf("waiting files = %d; want 1", len(wfs)) + } + + if err := taildropMgr.DeleteFile("foo.txt"); err != nil { + t.Fatal(err) + } + wfs, err = taildropMgr.WaitingFiles() + if err != nil { + t.Fatal(err) + } + if len(wfs) != 0 { + t.Fatalf("waiting files = %d; want 0", len(wfs)) + } + } +} diff --git a/taildrop/resume.go b/feature/taildrop/resume.go similarity index 76% rename from taildrop/resume.go rename to feature/taildrop/resume.go index f7bee3d95c466..211a1ff6b68dd 100644 --- a/taildrop/resume.go +++ b/feature/taildrop/resume.go @@ -19,29 +19,29 @@ var ( hashAlgorithm = "sha256" ) -// BlockChecksum represents the checksum for a single block. -type BlockChecksum struct { - Checksum Checksum `json:"checksum"` +// blockChecksum represents the checksum for a single block. +type blockChecksum struct { + Checksum checksum `json:"checksum"` Algorithm string `json:"algo"` // always "sha256" for now Size int64 `json:"size"` // always (64<<10) for now } -// Checksum is an opaque checksum that is comparable. -type Checksum struct{ cs [sha256.Size]byte } +// checksum is an opaque checksum that is comparable. +type checksum struct{ cs [sha256.Size]byte } -func hash(b []byte) Checksum { - return Checksum{sha256.Sum256(b)} +func hash(b []byte) checksum { + return checksum{sha256.Sum256(b)} } -func (cs Checksum) String() string { +func (cs checksum) String() string { return hex.EncodeToString(cs.cs[:]) } -func (cs Checksum) AppendText(b []byte) ([]byte, error) { +func (cs checksum) AppendText(b []byte) ([]byte, error) { return hex.AppendEncode(b, cs.cs[:]), nil } -func (cs Checksum) MarshalText() ([]byte, error) { +func (cs checksum) MarshalText() ([]byte, error) { return hex.AppendEncode(nil, cs.cs[:]), nil } -func (cs *Checksum) UnmarshalText(b []byte) error { +func (cs *checksum) UnmarshalText(b []byte) error { if len(b) != 2*len(cs.cs) { return fmt.Errorf("invalid hex length: %d", len(b)) } @@ -51,7 +51,7 @@ func (cs *Checksum) UnmarshalText(b []byte) error { // PartialFiles returns a list of partial files in [Handler.Dir] // that were sent (or is actively being sent) by the provided id. -func (m *Manager) PartialFiles(id ClientID) (ret []string, err error) { +func (m *manager) PartialFiles(id clientID) (ret []string, err error) { if m == nil || m.opts.Dir == "" { return nil, ErrNoTaildrop } @@ -72,11 +72,11 @@ func (m *Manager) PartialFiles(id ClientID) (ret []string, err error) { // starting from the beginning of the file. // It returns (BlockChecksum{}, io.EOF) when the stream is complete. // It is the caller's responsibility to call close. -func (m *Manager) HashPartialFile(id ClientID, baseName string) (next func() (BlockChecksum, error), close func() error, err error) { +func (m *manager) HashPartialFile(id clientID, baseName string) (next func() (blockChecksum, error), close func() error, err error) { if m == nil || m.opts.Dir == "" { return nil, nil, ErrNoTaildrop } - noopNext := func() (BlockChecksum, error) { return BlockChecksum{}, io.EOF } + noopNext := func() (blockChecksum, error) { return blockChecksum{}, io.EOF } noopClose := func() error { return nil } dstFile, err := joinDir(m.opts.Dir, baseName) @@ -92,25 +92,25 @@ func (m *Manager) HashPartialFile(id ClientID, baseName string) (next func() (Bl } b := make([]byte, blockSize) // TODO: Pool this? - next = func() (BlockChecksum, error) { + next = func() (blockChecksum, error) { switch n, err := io.ReadFull(f, b); { case err != nil && err != io.EOF && err != io.ErrUnexpectedEOF: - return BlockChecksum{}, redactError(err) + return blockChecksum{}, redactError(err) case n == 0: - return BlockChecksum{}, io.EOF + return blockChecksum{}, io.EOF default: - return BlockChecksum{hash(b[:n]), hashAlgorithm, int64(n)}, nil + return blockChecksum{hash(b[:n]), hashAlgorithm, int64(n)}, nil } } close = f.Close return next, close, nil } -// ResumeReader reads and discards the leading content of r +// resumeReader reads and discards the leading content of r // that matches the content based on the checksums that exist. // It returns the number of bytes consumed, // and returns an [io.Reader] representing the remaining content. -func ResumeReader(r io.Reader, hashNext func() (BlockChecksum, error)) (int64, io.Reader, error) { +func resumeReader(r io.Reader, hashNext func() (blockChecksum, error)) (int64, io.Reader, error) { if hashNext == nil { return 0, r, nil } diff --git a/taildrop/resume_test.go b/feature/taildrop/resume_test.go similarity index 92% rename from taildrop/resume_test.go rename to feature/taildrop/resume_test.go index d366340eb6efa..dac3c657bfb58 100644 --- a/taildrop/resume_test.go +++ b/feature/taildrop/resume_test.go @@ -19,7 +19,7 @@ func TestResume(t *testing.T) { defer func() { blockSize = oldBlockSize }() blockSize = 256 - m := ManagerOptions{Logf: t.Logf, Dir: t.TempDir()}.New() + m := managerOptions{Logf: t.Logf, Dir: t.TempDir()}.New() defer m.Shutdown() rn := rand.New(rand.NewSource(0)) @@ -32,7 +32,7 @@ func TestResume(t *testing.T) { next, close, err := m.HashPartialFile("", "foo") must.Do(err) defer close() - offset, r, err := ResumeReader(r, next) + offset, r, err := resumeReader(r, next) must.Do(err) must.Do(close()) // Windows wants the file handle to be closed to rename it. @@ -51,7 +51,7 @@ func TestResume(t *testing.T) { next, close, err := m.HashPartialFile("", "bar") must.Do(err) defer close() - offset, r, err := ResumeReader(r, next) + offset, r, err := resumeReader(r, next) must.Do(err) must.Do(close()) // Windows wants the file handle to be closed to rename it. diff --git a/taildrop/retrieve.go b/feature/taildrop/retrieve.go similarity index 95% rename from taildrop/retrieve.go rename to feature/taildrop/retrieve.go index 3e37b492adc0a..6fb97519363bc 100644 --- a/taildrop/retrieve.go +++ b/feature/taildrop/retrieve.go @@ -20,7 +20,7 @@ import ( // HasFilesWaiting reports whether any files are buffered in [Handler.Dir]. // This always returns false when [Handler.DirectFileMode] is false. -func (m *Manager) HasFilesWaiting() (has bool) { +func (m *manager) HasFilesWaiting() (has bool) { if m == nil || m.opts.Dir == "" || m.opts.DirectFileMode { return false } @@ -61,7 +61,7 @@ func (m *Manager) HasFilesWaiting() (has bool) { // WaitingFiles returns the list of files that have been sent by a // peer that are waiting in [Handler.Dir]. // This always returns nil when [Handler.DirectFileMode] is false. -func (m *Manager) WaitingFiles() (ret []apitype.WaitingFile, err error) { +func (m *manager) WaitingFiles() (ret []apitype.WaitingFile, err error) { if m == nil || m.opts.Dir == "" { return nil, ErrNoTaildrop } @@ -94,7 +94,7 @@ func (m *Manager) WaitingFiles() (ret []apitype.WaitingFile, err error) { // DeleteFile deletes a file of the given baseName from [Handler.Dir]. // This method is only allowed when [Handler.DirectFileMode] is false. -func (m *Manager) DeleteFile(baseName string) error { +func (m *manager) DeleteFile(baseName string) error { if m == nil || m.opts.Dir == "" { return ErrNoTaildrop } @@ -151,7 +151,7 @@ func touchFile(path string) error { // OpenFile opens a file of the given baseName from [Handler.Dir]. // This method is only allowed when [Handler.DirectFileMode] is false. -func (m *Manager) OpenFile(baseName string) (rc io.ReadCloser, size int64, err error) { +func (m *manager) OpenFile(baseName string) (rc io.ReadCloser, size int64, err error) { if m == nil || m.opts.Dir == "" { return nil, 0, ErrNoTaildrop } diff --git a/feature/taildrop/send.go b/feature/taildrop/send.go new file mode 100644 index 0000000000000..59a1701da6f0d --- /dev/null +++ b/feature/taildrop/send.go @@ -0,0 +1,367 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "crypto/sha256" + "fmt" + "io" + "os" + "path/filepath" + "sync" + "time" + + "tailscale.com/envknob" + "tailscale.com/ipn" + "tailscale.com/tstime" + "tailscale.com/version/distro" +) + +type incomingFileKey struct { + id clientID + name string // e.g., "foo.jpeg" +} + +type incomingFile struct { + clock tstime.DefaultClock + + started time.Time + size int64 // or -1 if unknown; never 0 + w io.Writer // underlying writer + sendFileNotify func() // called when done + partialPath string // non-empty in direct mode + finalPath string // not used in direct mode + + mu sync.Mutex + copied int64 + done bool + lastNotify time.Time +} + +func (f *incomingFile) Write(p []byte) (n int, err error) { + n, err = f.w.Write(p) + + var needNotify bool + defer func() { + if needNotify { + f.sendFileNotify() + } + }() + if n > 0 { + f.mu.Lock() + defer f.mu.Unlock() + f.copied += int64(n) + now := f.clock.Now() + if f.lastNotify.IsZero() || now.Sub(f.lastNotify) > time.Second { + f.lastNotify = now + needNotify = true + } + } + return n, err +} + +// PutFile stores a file into [manager.Dir] from a given client id. +// The baseName must be a base filename without any slashes. +// The length is the expected length of content to read from r, +// it may be negative to indicate that it is unknown. +// It returns the length of the entire file. +// +// If there is a failure reading from r, then the partial file is not deleted +// for some period of time. The [manager.PartialFiles] and [manager.HashPartialFile] +// methods may be used to list all partial files and to compute the hash for a +// specific partial file. This allows the client to determine whether to resume +// a partial file. While resuming, PutFile may be called again with a non-zero +// offset to specify where to resume receiving data at. +func (m *manager) PutFile(id clientID, baseName string, r io.Reader, offset, length int64) (int64, error) { + switch { + case m == nil || m.opts.Dir == "": + return 0, ErrNoTaildrop + case !envknob.CanTaildrop(): + return 0, ErrNoTaildrop + case distro.Get() == distro.Unraid && !m.opts.DirectFileMode: + return 0, ErrNotAccessible + } + + //Compute dstPath & avoid mid‑upload deletion + var dstPath string + if m.opts.Mode == PutModeDirect { + var err error + dstPath, err = joinDir(m.opts.Dir, baseName) + if err != nil { + return 0, err + } + } else { + // In SAF mode, we simply use the baseName as the destination "path" + // (the actual directory is managed by SAF). + dstPath = baseName + } + m.deleter.Remove(filepath.Base(dstPath)) // avoid deleting the partial file while receiving + + // Check whether there is an in-progress transfer for the file. + partialFileKey := incomingFileKey{id, baseName} + inFile, loaded := m.incomingFiles.LoadOrInit(partialFileKey, func() *incomingFile { + return &incomingFile{ + clock: m.opts.Clock, + started: m.opts.Clock.Now(), + size: length, + sendFileNotify: m.opts.SendFileNotify, + } + }) + if loaded { + return 0, ErrFileExists + } + defer m.incomingFiles.Delete(partialFileKey) + + // Open writer & populate inFile paths + wc, partialPath, err := m.openWriterAndPaths(id, m.opts.Mode, inFile, baseName, dstPath, offset) + if err != nil { + return 0, m.redactAndLogError("Create", err) + } + defer func() { + wc.Close() + if err != nil { + m.deleter.Insert(filepath.Base(partialPath)) // mark partial file for eventual deletion + } + }() + + // Record that we have started to receive at least one file. + // This is used by the deleter upon a cold-start to scan the directory + // for any files that need to be deleted. + if st := m.opts.State; st != nil { + if b, _ := st.ReadState(ipn.TaildropReceivedKey); len(b) == 0 { + if werr := st.WriteState(ipn.TaildropReceivedKey, []byte{1}); werr != nil { + m.opts.Logf("WriteState error: %v", werr) // non-fatal error + } + } + } + + // Copy the contents of the file to the writer. + copyLength, err := io.Copy(wc, r) + if err != nil { + return 0, m.redactAndLogError("Copy", err) + } + if length >= 0 && copyLength != length { + return 0, m.redactAndLogError("Copy", fmt.Errorf("copied %d bytes; expected %d", copyLength, length)) + } + if err := wc.Close(); err != nil { + return 0, m.redactAndLogError("Close", err) + } + + fileLength := offset + copyLength + + inFile.mu.Lock() + inFile.done = true + inFile.mu.Unlock() + + // Finalize rename + switch m.opts.Mode { + case PutModeDirect: + var finalDst string + finalDst, err = m.finalizeDirect(inFile, partialPath, dstPath, fileLength) + if err != nil { + return 0, m.redactAndLogError("Rename", err) + } + inFile.finalPath = finalDst + + case PutModeAndroidSAF: + if err = m.finalizeSAF(partialPath, baseName); err != nil { + return 0, m.redactAndLogError("Rename", err) + } + } + + m.totalReceived.Add(1) + m.opts.SendFileNotify() + return fileLength, nil +} + +// openWriterAndPaths opens the correct writer, seeks/truncates if needed, +// and sets inFile.partialPath & inFile.finalPath for later cleanup/rename. +// The caller is responsible for closing the file on completion. +func (m *manager) openWriterAndPaths( + id clientID, + mode PutMode, + inFile *incomingFile, + baseName string, + dstPath string, + offset int64, +) (wc io.WriteCloser, partialPath string, err error) { + switch mode { + + case PutModeDirect: + partialPath = dstPath + id.partialSuffix() + f, err := os.OpenFile(partialPath, os.O_CREATE|os.O_RDWR, 0o666) + if err != nil { + return nil, "", m.redactAndLogError("Create", err) + } + if offset != 0 { + curr, err := f.Seek(0, io.SeekEnd) + if err != nil { + f.Close() + return nil, "", m.redactAndLogError("Seek", err) + } + if offset < 0 || offset > curr { + f.Close() + return nil, "", m.redactAndLogError("Seek", fmt.Errorf("offset %d out of range", offset)) + } + if _, err := f.Seek(offset, io.SeekStart); err != nil { + f.Close() + return nil, "", m.redactAndLogError("Seek", err) + } + if err := f.Truncate(offset); err != nil { + f.Close() + return nil, "", m.redactAndLogError("Truncate", err) + } + } + inFile.w = f + wc = f + inFile.partialPath = partialPath + inFile.finalPath = dstPath + return wc, partialPath, nil + + case PutModeAndroidSAF: + if m.opts.FileOps == nil { + return nil, "", m.redactAndLogError("Create (SAF)", fmt.Errorf("missing FileOps")) + } + writer, uri, err := m.opts.FileOps.OpenFileWriter(baseName) + if err != nil { + return nil, "", m.redactAndLogError("Create (SAF)", fmt.Errorf("failed to open file for writing via SAF")) + } + if writer == nil || uri == "" { + return nil, "", fmt.Errorf("invalid SAF writer or URI") + } + // SAF mode does not support resuming, so enforce offset == 0. + if offset != 0 { + writer.Close() + return nil, "", m.redactAndLogError("Seek", fmt.Errorf("resuming is not supported in SAF mode")) + } + inFile.w = writer + wc = writer + partialPath = uri + inFile.partialPath = uri + inFile.finalPath = baseName + return wc, partialPath, nil + + default: + return nil, "", fmt.Errorf("unsupported PutMode: %v", mode) + } +} + +// finalizeDirect atomically renames or dedups the partial file, retrying +// under new names up to 10 times. It returns the final path that succeeded. +func (m *manager) finalizeDirect( + inFile *incomingFile, + partialPath string, + initialDst string, + fileLength int64, +) (string, error) { + var ( + once sync.Once + cachedSum [sha256.Size]byte + cacheErr error + computeSum = func() ([sha256.Size]byte, error) { + once.Do(func() { cachedSum, cacheErr = sha256File(partialPath) }) + return cachedSum, cacheErr + } + ) + + dstPath := initialDst + const maxRetries = 10 + for i := 0; i < maxRetries; i++ { + // Atomically rename the partial file as the destination file if it doesn't exist. + // Otherwise, it returns the length of the current destination file. + // The operation is atomic. + lengthOnDisk, err := func() (int64, error) { + m.renameMu.Lock() + defer m.renameMu.Unlock() + fi, statErr := os.Stat(dstPath) + if os.IsNotExist(statErr) { + // dst missing → rename partial into place + return -1, os.Rename(partialPath, dstPath) + } + if statErr != nil { + return -1, statErr + } + return fi.Size(), nil + }() + if err != nil { + return "", err + } + if lengthOnDisk < 0 { + // successfully moved + inFile.finalPath = dstPath + return dstPath, nil + } + + // Avoid the final rename if a destination file has the same contents. + // + // Note: this is best effort and copying files from iOS from the Media Library + // results in processing on the iOS side which means the size and shas of the + // same file can be different. + if lengthOnDisk == fileLength { + partSum, err := computeSum() + if err != nil { + return "", err + } + dstSum, err := sha256File(dstPath) + if err != nil { + return "", err + } + if partSum == dstSum { + // same content → drop the partial + if err := os.Remove(partialPath); err != nil { + return "", err + } + inFile.finalPath = dstPath + return dstPath, nil + } + } + + // Choose a new destination filename and try again. + dstPath = nextFilename(dstPath) + } + + return "", fmt.Errorf("too many retries trying to rename a partial file %q", initialDst) +} + +// finalizeSAF retries RenamePartialFile up to 10 times, generating a new +// name on each failure until the SAF URI changes. +func (m *manager) finalizeSAF( + partialPath, finalName string, +) error { + if m.opts.FileOps == nil { + return fmt.Errorf("missing FileOps for SAF finalize") + } + const maxTries = 10 + name := finalName + for i := 0; i < maxTries; i++ { + newURI, err := m.opts.FileOps.RenamePartialFile(partialPath, m.opts.Dir, name) + if err != nil { + return err + } + if newURI != "" && newURI != name { + return nil + } + name = nextFilename(name) + } + return fmt.Errorf("failed to finalize SAF file after %d retries", maxTries) +} + +func (m *manager) redactAndLogError(stage string, err error) error { + err = redactError(err) + m.opts.Logf("put %s error: %v", stage, err) + return err +} + +func sha256File(file string) (out [sha256.Size]byte, err error) { + h := sha256.New() + f, err := os.Open(file) + if err != nil { + return out, err + } + defer f.Close() + if _, err := io.Copy(h, f); err != nil { + return out, err + } + return [sha256.Size]byte(h.Sum(nil)), nil +} diff --git a/feature/taildrop/send_test.go b/feature/taildrop/send_test.go new file mode 100644 index 0000000000000..8edb704172fc5 --- /dev/null +++ b/feature/taildrop/send_test.go @@ -0,0 +1,128 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "bytes" + "fmt" + "io" + "os" + "path/filepath" + "testing" + + "tailscale.com/tstime" +) + +// nopWriteCloser is a no-op io.WriteCloser wrapping a bytes.Buffer. +type nopWriteCloser struct{ *bytes.Buffer } + +func (nwc nopWriteCloser) Close() error { return nil } + +// mockFileOps implements just enough of the FileOps interface for SAF tests. +type mockFileOps struct { + writes *bytes.Buffer + renameOK bool +} + +func (m *mockFileOps) OpenFileWriter(name string) (io.WriteCloser, string, error) { + m.writes = new(bytes.Buffer) + return nopWriteCloser{m.writes}, "uri://" + name + ".partial", nil +} + +func (m *mockFileOps) RenamePartialFile(partialPath, dir, finalName string) (string, error) { + if !m.renameOK { + m.renameOK = true + return "uri://" + finalName, nil + } + return "", io.ErrUnexpectedEOF +} + +func TestPutFile(t *testing.T) { + const content = "hello, world" + + tests := []struct { + name string + mode PutMode + setup func(t *testing.T) (*manager, string, *mockFileOps) + wantFile string + }{ + { + name: "PutModeDirect", + mode: PutModeDirect, + setup: func(t *testing.T) (*manager, string, *mockFileOps) { + dir := t.TempDir() + opts := managerOptions{ + Logf: t.Logf, + Clock: tstime.DefaultClock{}, + State: nil, + Dir: dir, + Mode: PutModeDirect, + DirectFileMode: true, + SendFileNotify: func() {}, + } + mgr := opts.New() + return mgr, dir, nil + }, + wantFile: "file.txt", + }, + { + name: "PutModeAndroidSAF", + mode: PutModeAndroidSAF, + setup: func(t *testing.T) (*manager, string, *mockFileOps) { + // SAF still needs a non-empty Dir to pass the guard. + dir := t.TempDir() + mops := &mockFileOps{} + opts := managerOptions{ + Logf: t.Logf, + Clock: tstime.DefaultClock{}, + State: nil, + Dir: dir, + Mode: PutModeAndroidSAF, + FileOps: mops, + DirectFileMode: true, + SendFileNotify: func() {}, + } + mgr := opts.New() + return mgr, dir, mops + }, + wantFile: "file.txt", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mgr, dir, mops := tc.setup(t) + id := clientID(fmt.Sprint(0)) + reader := bytes.NewReader([]byte(content)) + + n, err := mgr.PutFile(id, "file.txt", reader, 0, int64(len(content))) + if err != nil { + t.Fatalf("PutFile(%s) error: %v", tc.name, err) + } + if n != int64(len(content)) { + t.Errorf("wrote %d bytes; want %d", n, len(content)) + } + + switch tc.mode { + case PutModeDirect: + path := filepath.Join(dir, tc.wantFile) + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("ReadFile error: %v", err) + } + if got := string(data); got != content { + t.Errorf("file contents = %q; want %q", got, content) + } + + case PutModeAndroidSAF: + if mops.writes == nil { + t.Fatal("SAF writer was never created") + } + if got := mops.writes.String(); got != content { + t.Errorf("SAF writes = %q; want %q", got, content) + } + } + }) + } +} diff --git a/taildrop/taildrop.go b/feature/taildrop/taildrop.go similarity index 92% rename from taildrop/taildrop.go rename to feature/taildrop/taildrop.go index 4d14787afbf54..2dfa415bbf0cc 100644 --- a/taildrop/taildrop.go +++ b/feature/taildrop/taildrop.go @@ -18,6 +18,7 @@ import ( "path" "path/filepath" "regexp" + "sort" "strconv" "strings" "sync" @@ -53,20 +54,20 @@ const ( deletedSuffix = ".deleted" ) -// ClientID is an opaque identifier for file resumption. +// clientID is an opaque identifier for file resumption. // A client can only list and resume partial files for its own ID. // It must contain any filesystem specific characters (e.g., slashes). -type ClientID string // e.g., "n12345CNTRL" +type clientID string // e.g., "n12345CNTRL" -func (id ClientID) partialSuffix() string { +func (id clientID) partialSuffix() string { if id == "" { return partialSuffix } return "." + string(id) + partialSuffix // e.g., ".n12345CNTRL.partial" } -// ManagerOptions are options to configure the [Manager]. -type ManagerOptions struct { +// managerOptions are options to configure the [manager]. +type managerOptions struct { Logf logger.Logf // may be nil Clock tstime.DefaultClock // may be nil State ipn.StateStore // may be nil @@ -90,6 +91,10 @@ type ManagerOptions struct { // copy them out, and then delete them. DirectFileMode bool + FileOps FileOps + + Mode PutMode + // SendFileNotify is called periodically while a file is actively // receiving the contents for the file. There is a final call // to the function when reception completes. @@ -97,9 +102,9 @@ type ManagerOptions struct { SendFileNotify func() } -// Manager manages the state for receiving and managing taildropped files. -type Manager struct { - opts ManagerOptions +// manager manages the state for receiving and managing taildropped files. +type manager struct { + opts managerOptions // incomingFiles is a map of files actively being received. incomingFiles syncs.Map[incomingFileKey, *incomingFile] @@ -119,27 +124,27 @@ type Manager struct { // New initializes a new taildrop manager. // It may spawn asynchronous goroutines to delete files, // so the Shutdown method must be called for resource cleanup. -func (opts ManagerOptions) New() *Manager { +func (opts managerOptions) New() *manager { if opts.Logf == nil { opts.Logf = logger.Discard } if opts.SendFileNotify == nil { opts.SendFileNotify = func() {} } - m := &Manager{opts: opts} + m := &manager{opts: opts} m.deleter.Init(m, func(string) {}) m.emptySince.Store(-1) // invalidate this cache return m } // Dir returns the directory. -func (m *Manager) Dir() string { +func (m *manager) Dir() string { return m.opts.Dir } // Shutdown shuts down the Manager. // It blocks until all spawned goroutines have stopped running. -func (m *Manager) Shutdown() { +func (m *manager) Shutdown() { if m != nil { m.deleter.shutdown() m.deleter.group.Wait() @@ -221,7 +226,7 @@ func rangeDir(dir string, fn func(fs.DirEntry) bool) error { } // IncomingFiles returns a list of active incoming files. -func (m *Manager) IncomingFiles() []ipn.PartialFile { +func (m *manager) IncomingFiles() []ipn.PartialFile { // Make sure we always set n.IncomingFiles non-nil so it gets encoded // in JSON to clients. They distinguish between empty and non-nil // to know whether a Notify should be able about files. @@ -239,6 +244,11 @@ func (m *Manager) IncomingFiles() []ipn.PartialFile { }) f.mu.Unlock() } + + sort.Slice(files, func(i, j int) bool { + return files[i].Started.Before(files[j].Started) + }) + return files } @@ -312,12 +322,12 @@ var ( rxNumberSuffix = regexp.MustCompile(` \([0-9]+\)`) ) -// NextFilename returns the next filename in a sequence. +// nextFilename returns the next filename in a sequence. // It is used for construction a new filename if there is a conflict. // // For example, "Foo.jpg" becomes "Foo (1).jpg" and // "Foo (1).jpg" becomes "Foo (2).jpg". -func NextFilename(name string) string { +func nextFilename(name string) string { ext := rxExtensionSuffix.FindString(strings.TrimPrefix(name, ".")) name = strings.TrimSuffix(name, ext) var n uint64 diff --git a/taildrop/taildrop_test.go b/feature/taildrop/taildrop_test.go similarity index 94% rename from taildrop/taildrop_test.go rename to feature/taildrop/taildrop_test.go index df4783c303203..da0bd2f430579 100644 --- a/taildrop/taildrop_test.go +++ b/feature/taildrop/taildrop_test.go @@ -59,10 +59,10 @@ func TestNextFilename(t *testing.T) { } for _, tt := range tests { - if got := NextFilename(tt.in); got != tt.want { + if got := nextFilename(tt.in); got != tt.want { t.Errorf("NextFilename(%q) = %q, want %q", tt.in, got, tt.want) } - if got2 := NextFilename(tt.want); got2 != tt.want2 { + if got2 := nextFilename(tt.want); got2 != tt.want2 { t.Errorf("NextFilename(%q) = %q, want %q", tt.want, got2, tt.want2) } } diff --git a/feature/taildrop/target_test.go b/feature/taildrop/target_test.go new file mode 100644 index 0000000000000..57c96a77a4802 --- /dev/null +++ b/feature/taildrop/target_test.go @@ -0,0 +1,73 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package taildrop + +import ( + "fmt" + "testing" + + "tailscale.com/ipn" + "tailscale.com/ipn/ipnext" + "tailscale.com/tailcfg" +) + +func TestFileTargets(t *testing.T) { + e := new(Extension) + + _, err := e.FileTargets() + if got, want := fmt.Sprint(err), "not connected to the tailnet"; got != want { + t.Errorf("before connect: got %q; want %q", got, want) + } + + e.nodeBackendForTest = testNodeBackend{peers: nil} + + _, err = e.FileTargets() + if got, want := fmt.Sprint(err), "not connected to the tailnet"; got != want { + t.Errorf("non-running netmap: got %q; want %q", got, want) + } + + e.backendState = ipn.Running + _, err = e.FileTargets() + if got, want := fmt.Sprint(err), "file sharing not enabled by Tailscale admin"; got != want { + t.Errorf("without cap: got %q; want %q", got, want) + } + + e.capFileSharing = true + got, err := e.FileTargets() + if err != nil { + t.Fatal(err) + } + if len(got) != 0 { + t.Fatalf("unexpected %d peers", len(got)) + } + + var nodeID tailcfg.NodeID = 1234 + peer := &tailcfg.Node{ + ID: nodeID, + Hostinfo: (&tailcfg.Hostinfo{OS: "tvOS"}).View(), + } + e.nodeBackendForTest = testNodeBackend{peers: []tailcfg.NodeView{peer.View()}} + + got, err = e.FileTargets() + if err != nil { + t.Fatal(err) + } + if len(got) != 0 { + t.Fatalf("unexpected %d peers", len(got)) + } +} + +type testNodeBackend struct { + ipnext.NodeBackend + peers []tailcfg.NodeView +} + +func (t testNodeBackend) AppendMatchingPeers(peers []tailcfg.NodeView, f func(tailcfg.NodeView) bool) []tailcfg.NodeView { + for _, p := range t.peers { + if f(p) { + peers = append(peers, p) + } + } + return peers +} diff --git a/net/tstun/tap_linux.go b/feature/tap/tap_linux.go similarity index 66% rename from net/tstun/tap_linux.go rename to feature/tap/tap_linux.go index c721e6e2734b5..58ac00593d3a8 100644 --- a/net/tstun/tap_linux.go +++ b/feature/tap/tap_linux.go @@ -1,11 +1,11 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !ts_omit_tap - -package tstun +// Package tap registers Tailscale's experimental (demo) Linux TAP (Layer 2) support. +package tap import ( + "bytes" "fmt" "net" "net/netip" @@ -20,10 +20,15 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/checksum" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "tailscale.com/net/netaddr" "tailscale.com/net/packet" + "tailscale.com/net/tsaddr" + "tailscale.com/net/tstun" + "tailscale.com/syncs" "tailscale.com/types/ipproto" + "tailscale.com/types/logger" "tailscale.com/util/multierr" ) @@ -33,15 +38,19 @@ import ( // For now just hard code it. var ourMAC = net.HardwareAddr{0x30, 0x2D, 0x66, 0xEC, 0x7A, 0x93} -func init() { createTAP = createTAPLinux } +const tapDebug = tstun.TAPDebug + +func init() { + tstun.CreateTAP.Set(createTAPLinux) +} -func createTAPLinux(tapName, bridgeName string) (tun.Device, error) { +func createTAPLinux(logf logger.Logf, tapName, bridgeName string) (tun.Device, error) { fd, err := unix.Open("/dev/net/tun", unix.O_RDWR, 0) if err != nil { return nil, err } - dev, err := openDevice(fd, tapName, bridgeName) + dev, err := openDevice(logf, fd, tapName, bridgeName) if err != nil { unix.Close(fd) return nil, err @@ -50,7 +59,7 @@ func createTAPLinux(tapName, bridgeName string) (tun.Device, error) { return dev, nil } -func openDevice(fd int, tapName, bridgeName string) (tun.Device, error) { +func openDevice(logf logger.Logf, fd int, tapName, bridgeName string) (tun.Device, error) { ifr, err := unix.NewIfreq(tapName) if err != nil { return nil, err @@ -71,7 +80,7 @@ func openDevice(fd int, tapName, bridgeName string) (tun.Device, error) { } } - return newTAPDevice(fd, tapName) + return newTAPDevice(logf, fd, tapName) } type etherType [2]byte @@ -82,7 +91,10 @@ var ( etherTypeIPv6 = etherType{0x86, 0xDD} ) -const ipv4HeaderLen = 20 +const ( + ipv4HeaderLen = 20 + ethernetFrameSize = 14 // 2 six byte MACs, 2 bytes ethertype +) const ( consumePacket = true @@ -91,7 +103,7 @@ const ( // handleTAPFrame handles receiving a raw TAP ethernet frame and reports whether // it's been handled (that is, whether it should NOT be passed to wireguard). -func (t *Wrapper) handleTAPFrame(ethBuf []byte) bool { +func (t *tapDevice) handleTAPFrame(ethBuf []byte) bool { if len(ethBuf) < ethernetFrameSize { // Corrupt. Ignore. @@ -154,7 +166,7 @@ func (t *Wrapper) handleTAPFrame(ethBuf []byte) bool { // If the client's asking about their own IP, tell them it's // their own MAC. TODO(bradfitz): remove String allocs. - if net.IP(req.ProtocolAddressTarget()).String() == theClientIP { + if net.IP(req.ProtocolAddressTarget()).String() == t.clientIPv4.Load() { copy(res.HardwareAddressSender(), ethSrcMAC) } else { copy(res.HardwareAddressSender(), ourMAC[:]) @@ -164,8 +176,7 @@ func (t *Wrapper) handleTAPFrame(ethBuf []byte) bool { copy(res.HardwareAddressTarget(), req.HardwareAddressSender()) copy(res.ProtocolAddressTarget(), req.ProtocolAddressSender()) - // TODO(raggi): reduce allocs! - n, err := t.tdev.Write([][]byte{buf}, 0) + n, err := t.WriteEthernet(buf) if tapDebug { t.logf("tap: wrote ARP reply %v, %v", n, err) } @@ -175,14 +186,22 @@ func (t *Wrapper) handleTAPFrame(ethBuf []byte) bool { } } -// TODO(bradfitz): remove these hard-coded values and move from a /24 to a /10 CGNAT as the range. -const theClientIP = "100.70.145.3" // TODO: make dynamic from netmap -const routerIP = "100.70.145.1" // must be in same netmask (currently hack at /24) as theClientIP +var ( + // routerIP is the IP address of the DHCP server. + routerIP = net.ParseIP(tsaddr.TailscaleServiceIPString) + // cgnatNetMask is the netmask of the 100.64.0.0/10 CGNAT range. + cgnatNetMask = net.IPMask(net.ParseIP("255.192.0.0").To4()) +) + +// parsedPacketPool holds a pool of Parsed structs for use in filtering. +// This is needed because escape analysis cannot see that parsed packets +// do not escape through {Pre,Post}Filter{In,Out}. +var parsedPacketPool = sync.Pool{New: func() any { return new(packet.Parsed) }} // handleDHCPRequest handles receiving a raw TAP ethernet frame and reports whether // it's been handled as a DHCP request. That is, it reports whether the frame should // be ignored by the caller and not passed on. -func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { +func (t *tapDevice) handleDHCPRequest(ethBuf []byte) bool { const udpHeader = 8 if len(ethBuf) < ethernetFrameSize+ipv4HeaderLen+udpHeader { if tapDebug { @@ -207,7 +226,7 @@ func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { if p.IPProto != ipproto.UDP || p.Src.Port() != 68 || p.Dst.Port() != 67 { // Not a DHCP request. if tapDebug { - t.logf("tap: DHCP wrong meta") + t.logf("tap: DHCP wrong meta: %+v", p) } return passOnPacket } @@ -225,17 +244,22 @@ func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { } switch dp.MessageType() { case dhcpv4.MessageTypeDiscover: + ips := t.clientIPv4.Load() + if ips == "" { + t.logf("tap: DHCP no client IP") + return consumePacket + } offer, err := dhcpv4.New( dhcpv4.WithReply(dp), dhcpv4.WithMessageType(dhcpv4.MessageTypeOffer), - dhcpv4.WithRouter(net.ParseIP(routerIP)), // the default route - dhcpv4.WithDNS(net.ParseIP("100.100.100.100")), - dhcpv4.WithServerIP(net.ParseIP("100.100.100.100")), // TODO: what is this? - dhcpv4.WithOption(dhcpv4.OptServerIdentifier(net.ParseIP("100.100.100.100"))), - dhcpv4.WithYourIP(net.ParseIP(theClientIP)), + dhcpv4.WithRouter(routerIP), // the default route + dhcpv4.WithDNS(routerIP), + dhcpv4.WithServerIP(routerIP), // TODO: what is this? + dhcpv4.WithOption(dhcpv4.OptServerIdentifier(routerIP)), + dhcpv4.WithYourIP(net.ParseIP(ips)), dhcpv4.WithLeaseTime(3600), // hour works //dhcpv4.WithHwAddr(ethSrcMAC), - dhcpv4.WithNetmask(net.IPMask(net.ParseIP("255.255.255.0").To4())), // TODO: wrong + dhcpv4.WithNetmask(cgnatNetMask), //dhcpv4.WithTransactionID(dp.TransactionID), ) if err != nil { @@ -250,22 +274,26 @@ func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { netip.AddrPortFrom(netaddr.IPv4(255, 255, 255, 255), 68), // dst ) - // TODO(raggi): reduce allocs! - n, err := t.tdev.Write([][]byte{pkt}, 0) + n, err := t.WriteEthernet(pkt) if tapDebug { t.logf("tap: wrote DHCP OFFER %v, %v", n, err) } case dhcpv4.MessageTypeRequest: + ips := t.clientIPv4.Load() + if ips == "" { + t.logf("tap: DHCP no client IP") + return consumePacket + } ack, err := dhcpv4.New( dhcpv4.WithReply(dp), dhcpv4.WithMessageType(dhcpv4.MessageTypeAck), - dhcpv4.WithDNS(net.ParseIP("100.100.100.100")), - dhcpv4.WithRouter(net.ParseIP(routerIP)), // the default route - dhcpv4.WithServerIP(net.ParseIP("100.100.100.100")), // TODO: what is this? - dhcpv4.WithOption(dhcpv4.OptServerIdentifier(net.ParseIP("100.100.100.100"))), - dhcpv4.WithYourIP(net.ParseIP(theClientIP)), // Hello world - dhcpv4.WithLeaseTime(3600), // hour works - dhcpv4.WithNetmask(net.IPMask(net.ParseIP("255.255.255.0").To4())), + dhcpv4.WithDNS(routerIP), + dhcpv4.WithRouter(routerIP), // the default route + dhcpv4.WithServerIP(routerIP), // TODO: what is this? + dhcpv4.WithOption(dhcpv4.OptServerIdentifier(routerIP)), + dhcpv4.WithYourIP(net.ParseIP(ips)), // Hello world + dhcpv4.WithLeaseTime(3600), // hour works + dhcpv4.WithNetmask(cgnatNetMask), ) if err != nil { t.logf("error building DHCP ack: %v", err) @@ -278,8 +306,7 @@ func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { netip.AddrPortFrom(netaddr.IPv4(100, 100, 100, 100), 67), // src netip.AddrPortFrom(netaddr.IPv4(255, 255, 255, 255), 68), // dst ) - // TODO(raggi): reduce allocs! - n, err := t.tdev.Write([][]byte{pkt}, 0) + n, err := t.WriteEthernet(pkt) if tapDebug { t.logf("tap: wrote DHCP ACK %v, %v", n, err) } @@ -291,6 +318,16 @@ func (t *Wrapper) handleDHCPRequest(ethBuf []byte) bool { return consumePacket } +func writeEthernetFrame(buf []byte, srcMAC, dstMAC net.HardwareAddr, proto tcpip.NetworkProtocolNumber) { + // Ethernet header + eth := header.Ethernet(buf) + eth.Encode(&header.EthernetFields{ + SrcAddr: tcpip.LinkAddress(srcMAC), + DstAddr: tcpip.LinkAddress(dstMAC), + Type: proto, + }) +} + func packLayer2UDP(payload []byte, srcMAC, dstMAC net.HardwareAddr, src, dst netip.AddrPort) []byte { buf := make([]byte, header.EthernetMinimumSize+header.UDPMinimumSize+header.IPv4MinimumSize+len(payload)) payloadStart := len(buf) - len(payload) @@ -300,12 +337,7 @@ func packLayer2UDP(payload []byte, srcMAC, dstMAC net.HardwareAddr, src, dst net dstB := dst.Addr().As4() dstIP := tcpip.AddrFromSlice(dstB[:]) // Ethernet header - eth := header.Ethernet(buf) - eth.Encode(&header.EthernetFields{ - SrcAddr: tcpip.LinkAddress(srcMAC), - DstAddr: tcpip.LinkAddress(dstMAC), - Type: ipv4.ProtocolNumber, - }) + writeEthernetFrame(buf, srcMAC, dstMAC, ipv4.ProtocolNumber) // IP header ipbuf := buf[header.EthernetMinimumSize:] ip := header.IPv4(ipbuf) @@ -342,17 +374,18 @@ func run(prog string, args ...string) error { return nil } -func (t *Wrapper) destMAC() [6]byte { +func (t *tapDevice) destMAC() [6]byte { return t.destMACAtomic.Load() } -func newTAPDevice(fd int, tapName string) (tun.Device, error) { +func newTAPDevice(logf logger.Logf, fd int, tapName string) (tun.Device, error) { err := unix.SetNonblock(fd, true) if err != nil { return nil, err } file := os.NewFile(uintptr(fd), "/dev/tap") d := &tapDevice{ + logf: logf, file: file, events: make(chan tun.Event), name: tapName, @@ -360,20 +393,22 @@ func newTAPDevice(fd int, tapName string) (tun.Device, error) { return d, nil } -var ( - _ setWrapperer = &tapDevice{} -) - type tapDevice struct { - file *os.File - events chan tun.Event - name string - wrapper *Wrapper - closeOnce sync.Once + file *os.File + logf func(format string, args ...any) + events chan tun.Event + name string + closeOnce sync.Once + clientIPv4 syncs.AtomicValue[string] + + destMACAtomic syncs.AtomicValue[[6]byte] } -func (t *tapDevice) setWrapper(wrapper *Wrapper) { - t.wrapper = wrapper +var _ tstun.SetIPer = (*tapDevice)(nil) + +func (t *tapDevice) SetIP(ipV4, ipV6TODO netip.Addr) error { + t.clientIPv4.Store(ipV4.String()) + return nil } func (t *tapDevice) File() *os.File { @@ -384,36 +419,63 @@ func (t *tapDevice) Name() (string, error) { return t.name, nil } +// Read reads an IP packet from the TAP device. It strips the ethernet frame header. func (t *tapDevice) Read(buffs [][]byte, sizes []int, offset int) (int, error) { + n, err := t.ReadEthernet(buffs, sizes, offset) + if err != nil || n == 0 { + return n, err + } + // Strip the ethernet frame header. + copy(buffs[0][offset:], buffs[0][offset+ethernetFrameSize:offset+sizes[0]]) + sizes[0] -= ethernetFrameSize + return 1, nil +} + +// ReadEthernet reads a raw ethernet frame from the TAP device. +func (t *tapDevice) ReadEthernet(buffs [][]byte, sizes []int, offset int) (int, error) { n, err := t.file.Read(buffs[0][offset:]) if err != nil { return 0, err } + if t.handleTAPFrame(buffs[0][offset : offset+n]) { + return 0, nil + } sizes[0] = n return 1, nil } +// WriteEthernet writes a raw ethernet frame to the TAP device. +func (t *tapDevice) WriteEthernet(buf []byte) (int, error) { + return t.file.Write(buf) +} + +// ethBufPool holds a pool of bytes.Buffers for use in [tapDevice.Write]. +var ethBufPool = syncs.Pool[*bytes.Buffer]{New: func() *bytes.Buffer { return new(bytes.Buffer) }} + +// Write writes a raw IP packet to the TAP device. It adds the ethernet frame header. func (t *tapDevice) Write(buffs [][]byte, offset int) (int, error) { errs := make([]error, 0) wrote := 0 + m := t.destMAC() + dstMac := net.HardwareAddr(m[:]) + buf := ethBufPool.Get() + defer ethBufPool.Put(buf) for _, buff := range buffs { - if offset < ethernetFrameSize { - errs = append(errs, fmt.Errorf("[unexpected] weird offset %d for TAP write", offset)) - return 0, multierr.New(errs...) - } - eth := buff[offset-ethernetFrameSize:] - dst := t.wrapper.destMAC() - copy(eth[:6], dst[:]) - copy(eth[6:12], ourMAC[:]) - et := etherTypeIPv4 - if buff[offset]>>4 == 6 { - et = etherTypeIPv6 + buf.Reset() + buf.Grow(header.EthernetMinimumSize + len(buff) - offset) + + var ebuf [14]byte + switch buff[offset] >> 4 { + case 4: + writeEthernetFrame(ebuf[:], ourMAC, dstMac, ipv4.ProtocolNumber) + case 6: + writeEthernetFrame(ebuf[:], ourMAC, dstMac, ipv6.ProtocolNumber) + default: + continue } - eth[12], eth[13] = et[0], et[1] - if tapDebug { - t.wrapper.logf("tap: tapWrite off=%v % x", offset, buff) - } - _, err := t.file.Write(buff[offset-ethernetFrameSize:]) + buf.Write(ebuf[:]) + buf.Write(buff[offset:]) + _, err := t.WriteEthernet(buf.Bytes()) if err != nil { errs = append(errs, err) } else { @@ -428,8 +490,7 @@ func (t *tapDevice) MTU() (int, error) { if err != nil { return 0, err } - err = unix.IoctlIfreq(int(t.file.Fd()), unix.SIOCGIFMTU, ifr) - if err != nil { + if err := unix.IoctlIfreq(int(t.file.Fd()), unix.SIOCGIFMTU, ifr); err != nil { return 0, err } return int(ifr.Uint32()), nil diff --git a/feature/tpm/tpm.go b/feature/tpm/tpm.go new file mode 100644 index 0000000000000..18e56ae891ee1 --- /dev/null +++ b/feature/tpm/tpm.go @@ -0,0 +1,83 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tpm implements support for TPM 2.0 devices. +package tpm + +import ( + "slices" + "sync" + + "github.com/google/go-tpm/tpm2" + "github.com/google/go-tpm/tpm2/transport" + "tailscale.com/feature" + "tailscale.com/hostinfo" + "tailscale.com/tailcfg" +) + +var infoOnce = sync.OnceValue(info) + +func init() { + feature.Register("tpm") + hostinfo.RegisterHostinfoNewHook(func(hi *tailcfg.Hostinfo) { + hi.TPM = infoOnce() + }) +} + +//lint:ignore U1000 used in Linux and Windows builds only +func infoFromCapabilities(tpm transport.TPM) *tailcfg.TPMInfo { + info := new(tailcfg.TPMInfo) + toStr := func(s *string) func(*tailcfg.TPMInfo, uint32) { + return func(info *tailcfg.TPMInfo, value uint32) { + *s += propToString(value) + } + } + for _, cap := range []struct { + prop tpm2.TPMPT + apply func(info *tailcfg.TPMInfo, value uint32) + }{ + {tpm2.TPMPTManufacturer, toStr(&info.Manufacturer)}, + {tpm2.TPMPTVendorString1, toStr(&info.Vendor)}, + {tpm2.TPMPTVendorString2, toStr(&info.Vendor)}, + {tpm2.TPMPTVendorString3, toStr(&info.Vendor)}, + {tpm2.TPMPTVendorString4, toStr(&info.Vendor)}, + {tpm2.TPMPTRevision, func(info *tailcfg.TPMInfo, value uint32) { info.SpecRevision = int(value) }}, + {tpm2.TPMPTVendorTPMType, func(info *tailcfg.TPMInfo, value uint32) { info.Model = int(value) }}, + {tpm2.TPMPTFirmwareVersion1, func(info *tailcfg.TPMInfo, value uint32) { info.FirmwareVersion += uint64(value) << 32 }}, + {tpm2.TPMPTFirmwareVersion2, func(info *tailcfg.TPMInfo, value uint32) { info.FirmwareVersion += uint64(value) }}, + } { + resp, err := tpm2.GetCapability{ + Capability: tpm2.TPMCapTPMProperties, + Property: uint32(cap.prop), + PropertyCount: 1, + }.Execute(tpm) + if err != nil { + continue + } + props, err := resp.CapabilityData.Data.TPMProperties() + if err != nil { + continue + } + if len(props.TPMProperty) == 0 { + continue + } + cap.apply(info, props.TPMProperty[0].Value) + } + return info +} + +// propToString converts TPM_PT property value, which is a uint32, into a +// string of up to 4 ASCII characters. This encoding applies only to some +// properties, see +// https://trustedcomputinggroup.org/resource/tpm-library-specification/ Part +// 2, section 6.13. +func propToString(v uint32) string { + chars := []byte{ + byte(v >> 24), + byte(v >> 16), + byte(v >> 8), + byte(v), + } + // Delete any non-printable ASCII characters. + return string(slices.DeleteFunc(chars, func(b byte) bool { return b < ' ' || b > '~' })) +} diff --git a/feature/tpm/tpm_linux.go b/feature/tpm/tpm_linux.go new file mode 100644 index 0000000000000..a90c0e153962f --- /dev/null +++ b/feature/tpm/tpm_linux.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tpm + +import ( + "github.com/google/go-tpm/tpm2/transport/linuxtpm" + "tailscale.com/tailcfg" +) + +func info() *tailcfg.TPMInfo { + t, err := linuxtpm.Open("/dev/tpm0") + if err != nil { + return nil + } + defer t.Close() + return infoFromCapabilities(t) +} diff --git a/feature/tpm/tpm_other.go b/feature/tpm/tpm_other.go new file mode 100644 index 0000000000000..ba7c67621eafb --- /dev/null +++ b/feature/tpm/tpm_other.go @@ -0,0 +1,12 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !windows + +package tpm + +import "tailscale.com/tailcfg" + +func info() *tailcfg.TPMInfo { + return nil +} diff --git a/feature/tpm/tpm_test.go b/feature/tpm/tpm_test.go new file mode 100644 index 0000000000000..fc0fc178c1a79 --- /dev/null +++ b/feature/tpm/tpm_test.go @@ -0,0 +1,19 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tpm + +import "testing" + +func TestPropToString(t *testing.T) { + for prop, want := range map[uint32]string{ + 0: "", + 0x4D534654: "MSFT", + 0x414D4400: "AMD", + 0x414D440D: "AMD", + } { + if got := propToString(prop); got != want { + t.Errorf("propToString(0x%x): got %q, want %q", prop, got, want) + } + } +} diff --git a/feature/tpm/tpm_windows.go b/feature/tpm/tpm_windows.go new file mode 100644 index 0000000000000..578d687af5739 --- /dev/null +++ b/feature/tpm/tpm_windows.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tpm + +import ( + "github.com/google/go-tpm/tpm2/transport/windowstpm" + "tailscale.com/tailcfg" +) + +func info() *tailcfg.TPMInfo { + t, err := windowstpm.Open() + if err != nil { + return nil + } + defer t.Close() + return infoFromCapabilities(t) +} diff --git a/feature/wakeonlan/wakeonlan.go b/feature/wakeonlan/wakeonlan.go new file mode 100644 index 0000000000000..96c424084dcc6 --- /dev/null +++ b/feature/wakeonlan/wakeonlan.go @@ -0,0 +1,243 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package wakeonlan registers the Wake-on-LAN feature. +package wakeonlan + +import ( + "encoding/json" + "log" + "net" + "net/http" + "runtime" + "sort" + "strings" + "unicode" + + "github.com/kortschak/wol" + "tailscale.com/envknob" + "tailscale.com/feature" + "tailscale.com/hostinfo" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/tailcfg" + "tailscale.com/util/clientmetric" +) + +func init() { + feature.Register("wakeonlan") + ipnlocal.RegisterC2N("POST /wol", handleC2NWoL) + ipnlocal.RegisterPeerAPIHandler("/v0/wol", handlePeerAPIWakeOnLAN) + hostinfo.RegisterHostinfoNewHook(func(h *tailcfg.Hostinfo) { + h.WoLMACs = getWoLMACs() + }) +} + +func handleC2NWoL(b *ipnlocal.LocalBackend, w http.ResponseWriter, r *http.Request) { + r.ParseForm() + var macs []net.HardwareAddr + for _, macStr := range r.Form["mac"] { + mac, err := net.ParseMAC(macStr) + if err != nil { + http.Error(w, "bad 'mac' param", http.StatusBadRequest) + return + } + macs = append(macs, mac) + } + var res struct { + SentTo []string + Errors []string + } + st := b.NetMon().InterfaceState() + if st == nil { + res.Errors = append(res.Errors, "no interface state") + writeJSON(w, &res) + return + } + var password []byte // TODO(bradfitz): support? does anything use WoL passwords? + for _, mac := range macs { + for ifName, ips := range st.InterfaceIPs { + for _, ip := range ips { + if ip.Addr().IsLoopback() || ip.Addr().Is6() { + continue + } + local := &net.UDPAddr{ + IP: ip.Addr().AsSlice(), + Port: 0, + } + remote := &net.UDPAddr{ + IP: net.IPv4bcast, + Port: 0, + } + if err := wol.Wake(mac, password, local, remote); err != nil { + res.Errors = append(res.Errors, err.Error()) + } else { + res.SentTo = append(res.SentTo, ifName) + } + break // one per interface is enough + } + } + } + sort.Strings(res.SentTo) + writeJSON(w, &res) +} + +func writeJSON(w http.ResponseWriter, v any) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(v) +} + +func canWakeOnLAN(h ipnlocal.PeerAPIHandler) bool { + if h.Peer().UnsignedPeerAPIOnly() { + return false + } + return h.IsSelfUntagged() || h.PeerCaps().HasCapability(tailcfg.PeerCapabilityWakeOnLAN) +} + +var metricWakeOnLANCalls = clientmetric.NewCounter("peerapi_wol") + +func handlePeerAPIWakeOnLAN(h ipnlocal.PeerAPIHandler, w http.ResponseWriter, r *http.Request) { + metricWakeOnLANCalls.Add(1) + if !canWakeOnLAN(h) { + http.Error(w, "no WoL access", http.StatusForbidden) + return + } + if r.Method != "POST" { + http.Error(w, "bad method", http.StatusMethodNotAllowed) + return + } + macStr := r.FormValue("mac") + if macStr == "" { + http.Error(w, "missing 'mac' param", http.StatusBadRequest) + return + } + mac, err := net.ParseMAC(macStr) + if err != nil { + http.Error(w, "bad 'mac' param", http.StatusBadRequest) + return + } + var password []byte // TODO(bradfitz): support? does anything use WoL passwords? + st := h.LocalBackend().NetMon().InterfaceState() + if st == nil { + http.Error(w, "failed to get interfaces state", http.StatusInternalServerError) + return + } + var res struct { + SentTo []string + Errors []string + } + for ifName, ips := range st.InterfaceIPs { + for _, ip := range ips { + if ip.Addr().IsLoopback() || ip.Addr().Is6() { + continue + } + local := &net.UDPAddr{ + IP: ip.Addr().AsSlice(), + Port: 0, + } + remote := &net.UDPAddr{ + IP: net.IPv4bcast, + Port: 0, + } + if err := wol.Wake(mac, password, local, remote); err != nil { + res.Errors = append(res.Errors, err.Error()) + } else { + res.SentTo = append(res.SentTo, ifName) + } + break // one per interface is enough + } + } + sort.Strings(res.SentTo) + writeJSON(w, res) +} + +// TODO(bradfitz): this is all too simplistic and static. It needs to run +// continuously in response to netmon events (USB ethernet adapters might get +// plugged in) and look for the media type/status/etc. Right now on macOS it +// still detects a half dozen "up" en0, en1, en2, en3 etc interfaces that don't +// have any media. We should only report the one that's actually connected. +// But it works for now (2023-10-05) for fleshing out the rest. + +var wakeMAC = envknob.RegisterString("TS_WAKE_MAC") // mac address, "false" or "auto". for https://github.com/tailscale/tailscale/issues/306 + +// getWoLMACs returns up to 10 MAC address of the local machine to send +// wake-on-LAN packets to in order to wake it up. The returned MACs are in +// lowercase hex colon-separated form ("xx:xx:xx:xx:xx:xx"). +// +// If TS_WAKE_MAC=auto, it tries to automatically find the MACs based on the OS +// type and interface properties. (TODO(bradfitz): incomplete) If TS_WAKE_MAC is +// set to a MAC address, that sole MAC address is returned. +func getWoLMACs() (macs []string) { + switch runtime.GOOS { + case "ios", "android": + return nil + } + if s := wakeMAC(); s != "" { + switch s { + case "auto": + ifs, _ := net.Interfaces() + for _, iface := range ifs { + if iface.Flags&net.FlagLoopback != 0 { + continue + } + if iface.Flags&net.FlagBroadcast == 0 || + iface.Flags&net.FlagRunning == 0 || + iface.Flags&net.FlagUp == 0 { + continue + } + if keepMAC(iface.Name, iface.HardwareAddr) { + macs = append(macs, iface.HardwareAddr.String()) + } + if len(macs) == 10 { + break + } + } + return macs + case "false", "off": // fast path before ParseMAC error + return nil + } + mac, err := net.ParseMAC(s) + if err != nil { + log.Printf("invalid MAC %q", s) + return nil + } + return []string{mac.String()} + } + return nil +} + +var ignoreWakeOUI = map[[3]byte]bool{ + {0x00, 0x15, 0x5d}: true, // Hyper-V + {0x00, 0x50, 0x56}: true, // VMware + {0x00, 0x1c, 0x14}: true, // VMware + {0x00, 0x05, 0x69}: true, // VMware + {0x00, 0x0c, 0x29}: true, // VMware + {0x00, 0x1c, 0x42}: true, // Parallels + {0x08, 0x00, 0x27}: true, // VirtualBox + {0x00, 0x21, 0xf6}: true, // VirtualBox + {0x00, 0x14, 0x4f}: true, // VirtualBox + {0x00, 0x0f, 0x4b}: true, // VirtualBox + {0x52, 0x54, 0x00}: true, // VirtualBox/Vagrant +} + +func keepMAC(ifName string, mac []byte) bool { + if len(mac) != 6 { + return false + } + base := strings.TrimRightFunc(ifName, unicode.IsNumber) + switch runtime.GOOS { + case "darwin": + switch base { + case "llw", "awdl", "utun", "bridge", "lo", "gif", "stf", "anpi", "ap": + return false + } + } + if mac[0] == 0x02 && mac[1] == 0x42 { + // Docker container. + return false + } + oui := [3]byte{mac[0], mac[1], mac[2]} + if ignoreWakeOUI[oui] { + return false + } + return true +} diff --git a/flake.lock b/flake.lock index 8c4aa7dfc2c73..05b0f303e6433 100644 --- a/flake.lock +++ b/flake.lock @@ -36,11 +36,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1724748588, - "narHash": "sha256-NlpGA4+AIf1dKNq76ps90rxowlFXUsV9x7vK/mN37JM=", + "lastModified": 1743938762, + "narHash": "sha256-UgFYn8sGv9B8PoFpUfCa43CjMZBl1x/ShQhRDHBFQdI=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "a6292e34000dc93d43bccf78338770c1c5ec8a99", + "rev": "74a40410369a1c35ee09b8a1abee6f4acbedc059", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 95d5c3035c7a9..2f920bfd40ba5 100644 --- a/flake.nix +++ b/flake.nix @@ -68,14 +68,14 @@ # you're an end user you should be prepared for this flake to not # build periodically. tailscale = pkgs: - pkgs.buildGo123Module rec { + pkgs.buildGo124Module rec { name = "tailscale"; src = ./.; vendorHash = pkgs.lib.fileContents ./go.mod.sri; nativeBuildInputs = pkgs.lib.optionals pkgs.stdenv.isLinux [pkgs.makeWrapper]; ldflags = ["-X tailscale.com/version.gitCommitStamp=${tailscaleRev}"]; - CGO_ENABLED = 0; + env.CGO_ENABLED = 0; subPackages = ["cmd/tailscale" "cmd/tailscaled"]; doCheck = false; @@ -118,7 +118,7 @@ gotools graphviz perl - go_1_23 + go_1_24 yarn # qemu and e2fsprogs are needed for natlab @@ -130,4 +130,4 @@ in flake-utils.lib.eachDefaultSystem (system: flakeForSystem nixpkgs system); } -# nix-direnv cache busting line: sha256-xO1DuLWi6/lpA9ubA2ZYVJM+CkVNA5IaVGZxX9my0j0= +# nix-direnv cache busting line: sha256-av4kr09rjNRmag94ziNjJuI/cg8b8lAD3Tk24t/ezH4= diff --git a/go.mod b/go.mod index 464db8313b5fd..f346b1e4095dd 100644 --- a/go.mod +++ b/go.mod @@ -1,28 +1,28 @@ module tailscale.com -go 1.23.1 +go 1.24.0 require ( filippo.io/mkcert v1.4.4 - fyne.io/systray v1.11.0 + fyne.io/systray v1.11.1-0.20250317195939-bcf6eed85e7a + github.com/Kodeworks/golang-image-ico v0.0.0-20141118225523-73f0f4cfade9 github.com/akutz/memconn v0.1.0 github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa github.com/andybalholm/brotli v1.1.0 github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be github.com/atotto/clipboard v0.1.4 - github.com/aws/aws-sdk-go-v2 v1.24.1 - github.com/aws/aws-sdk-go-v2/config v1.26.5 - github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.64 - github.com/aws/aws-sdk-go-v2/service/s3 v1.33.0 + github.com/aws/aws-sdk-go-v2 v1.36.0 + github.com/aws/aws-sdk-go-v2/config v1.29.5 + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.58 + github.com/aws/aws-sdk-go-v2/service/s3 v1.75.3 github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7 github.com/bramvdbogaerde/go-scp v1.4.0 github.com/cilium/ebpf v0.15.0 github.com/coder/websocket v1.8.12 github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf + github.com/creachadair/taskgroup v0.13.2 github.com/creack/pty v1.1.23 - github.com/dave/courtney v0.4.0 - github.com/dave/patsy v0.0.0-20210517141501-957256f50cba github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e github.com/distribution/reference v0.6.0 @@ -32,35 +32,38 @@ require ( github.com/evanw/esbuild v0.19.11 github.com/fogleman/gg v1.3.0 github.com/frankban/quicktest v1.14.6 - github.com/fxamacker/cbor/v2 v2.6.0 - github.com/gaissmai/bart v0.11.1 - github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 + github.com/fxamacker/cbor/v2 v2.7.0 + github.com/gaissmai/bart v0.18.0 + github.com/go-json-experiment/json v0.0.0-20250223041408-d3c622f1b874 github.com/go-logr/zapr v1.3.0 github.com/go-ole/go-ole v1.3.0 + github.com/go4org/plan9netshell v0.0.0-20250324183649-788daa080737 github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da github.com/golang/snappy v0.0.4 github.com/golangci/golangci-lint v1.57.1 github.com/google/go-cmp v0.6.0 - github.com/google/go-containerregistry v0.18.0 + github.com/google/go-containerregistry v0.20.2 + github.com/google/go-tpm v0.9.4 github.com/google/gopacket v1.1.19 github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806 github.com/google/uuid v1.6.0 github.com/goreleaser/nfpm/v2 v2.33.1 + github.com/hashicorp/go-hclog v1.6.2 + github.com/hashicorp/raft v1.7.2 github.com/hdevalence/ed25519consensus v0.2.0 - github.com/illarion/gonotify/v2 v2.0.3 - github.com/inetaf/tcpproxy v0.0.0-20240214030015-3ce58045626c + github.com/illarion/gonotify/v3 v3.0.2 + github.com/inetaf/tcpproxy v0.0.0-20250203165043-ded522cbd03f github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 github.com/jellydator/ttlcache/v3 v3.1.0 - github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 github.com/jsimonetti/rtnetlink v1.4.0 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 - github.com/klauspost/compress v1.17.4 + github.com/klauspost/compress v1.17.11 github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a github.com/mattn/go-colorable v0.1.13 github.com/mattn/go-isatty v0.0.20 github.com/mdlayher/genetlink v1.3.2 - github.com/mdlayher/netlink v1.7.2 + github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 github.com/mdlayher/sdnotify v1.0.0 github.com/miekg/dns v1.1.58 github.com/mitchellh/go-ps v1.0.0 @@ -69,88 +72,91 @@ require ( github.com/pkg/sftp v1.13.6 github.com/prometheus-community/pro-bing v0.4.0 github.com/prometheus/client_golang v1.19.1 - github.com/prometheus/common v0.48.0 + github.com/prometheus/common v0.55.0 github.com/prometheus/prometheus v0.49.2-0.20240125131847-c3b8ef1694ff github.com/safchain/ethtool v0.3.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/studio-b12/gowebdav v0.9.0 github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e - github.com/tailscale/depaware v0.0.0-20210622194025-720c4b409502 + github.com/tailscale/depaware v0.0.0-20250112153213-b748de04d81b github.com/tailscale/goexpect v0.0.0-20210902213824-6e8c725cea41 - github.com/tailscale/golang-x-crypto v0.0.0-20240604161659-3fde5e568aa4 + github.com/tailscale/golang-x-crypto v0.0.0-20250404221719-a5573b049869 github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05 github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a - github.com/tailscale/mkctr v0.0.0-20240628074852-17ca944da6ba + github.com/tailscale/mkctr v0.0.0-20250228050937-c75ea1476830 github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 - github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 - github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 + github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc + github.com/tailscale/setec v0.0.0-20250205144240-8898a29c3fbb + github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 - github.com/tailscale/wireguard-go v0.0.0-20240905161824-799c1978fafc + github.com/tailscale/wireguard-go v0.0.0-20250304000100-91a0587fb251 github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e github.com/tc-hib/winres v0.2.1 github.com/tcnksm/go-httpstat v0.2.0 github.com/toqueteos/webbrowser v1.2.0 - github.com/u-root/u-root v0.12.0 + github.com/u-root/u-root v0.14.0 github.com/vishvananda/netns v0.0.4 go.uber.org/zap v1.27.0 - go4.org/mem v0.0.0-20220726221520-4f986261bf13 + go4.org/mem v0.0.0-20240501181205-ae6ca9944745 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba - golang.org/x/crypto v0.25.0 - golang.org/x/exp v0.0.0-20240119083558-1b970713d09a - golang.org/x/mod v0.19.0 - golang.org/x/net v0.27.0 - golang.org/x/oauth2 v0.16.0 - golang.org/x/sync v0.7.0 - golang.org/x/sys v0.22.0 - golang.org/x/term v0.22.0 - golang.org/x/time v0.5.0 - golang.org/x/tools v0.23.0 + golang.org/x/crypto v0.37.0 + golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac + golang.org/x/mod v0.23.0 + golang.org/x/net v0.36.0 + golang.org/x/oauth2 v0.26.0 + golang.org/x/sync v0.13.0 + golang.org/x/sys v0.32.0 + golang.org/x/term v0.31.0 + golang.org/x/time v0.10.0 + golang.org/x/tools v0.30.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard/windows v0.5.3 gopkg.in/square/go-jose.v2 v2.6.0 - gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987 + gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 honnef.co/go/tools v0.5.1 - k8s.io/api v0.30.3 - k8s.io/apimachinery v0.30.3 - k8s.io/apiserver v0.30.3 - k8s.io/client-go v0.30.3 - sigs.k8s.io/controller-runtime v0.18.4 - sigs.k8s.io/controller-tools v0.15.1-0.20240618033008-7824932b0cab + k8s.io/api v0.32.0 + k8s.io/apimachinery v0.32.0 + k8s.io/apiserver v0.32.0 + k8s.io/client-go v0.32.0 + sigs.k8s.io/controller-runtime v0.19.4 + sigs.k8s.io/controller-tools v0.17.0 sigs.k8s.io/yaml v1.4.0 software.sslmate.com/src/go-pkcs12 v0.4.0 ) require ( + 9fans.net/go v0.0.8-0.20250307142834-96bdba94b63f // indirect github.com/4meepo/tagalign v1.3.3 // indirect github.com/Antonboom/testifylint v1.2.0 // indirect github.com/GaijinEntertainment/go-exhaustruct/v3 v3.2.0 // indirect github.com/Masterminds/sprig v2.22.0+incompatible // indirect - github.com/Microsoft/go-winio v0.6.1 // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect github.com/OpenPeeDeeP/depguard/v2 v2.2.0 // indirect github.com/alecthomas/go-check-sumtype v0.1.4 // indirect github.com/alexkohler/nakedret/v2 v2.0.4 // indirect - github.com/bits-and-blooms/bitset v1.13.0 // indirect + github.com/armon/go-metrics v0.4.1 // indirect github.com/bombsimon/wsl/v4 v4.2.1 // indirect github.com/butuzov/mirror v1.1.0 // indirect github.com/catenacyber/perfsprint v0.7.1 // indirect github.com/ccojocar/zxcvbn-go v1.0.2 // indirect github.com/ckaznocha/intrange v0.1.0 // indirect - github.com/cyphar/filepath-securejoin v0.2.4 // indirect - github.com/dave/astrid v0.0.0-20170323122508-8c2895878b14 // indirect - github.com/dave/brenda v1.1.0 // indirect - github.com/docker/go-connections v0.4.0 // indirect + github.com/cyphar/filepath-securejoin v0.3.6 // indirect + github.com/docker/go-connections v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/ghostiam/protogetter v0.3.5 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1 // indirect - github.com/gobuffalo/flect v1.0.2 // indirect + github.com/gobuffalo/flect v1.0.3 // indirect github.com/goccy/go-yaml v1.12.0 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golangci/plugin-module-register v0.1.1 // indirect github.com/google/gnostic-models v0.6.9-0.20230804172637-c7be7c783f49 // indirect - github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd // indirect github.com/gorilla/securecookie v1.1.2 // indirect + github.com/hashicorp/go-immutable-radix v1.3.1 // indirect + github.com/hashicorp/go-metrics v0.5.4 // indirect + github.com/hashicorp/go-msgpack/v2 v2.1.2 // indirect + github.com/hashicorp/golang-lru v0.6.0 // indirect github.com/jjti/go-spancheck v0.5.3 // indirect github.com/karamaru-alpha/copyloopvar v1.0.8 // indirect github.com/macabu/inamedparam v0.1.3 // indirect @@ -160,12 +166,14 @@ require ( github.com/ykadowak/zerologlint v0.1.5 // indirect go-simpler.org/musttag v0.9.0 // indirect go-simpler.org/sloglint v0.5.0 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.47.0 // indirect - go.opentelemetry.io/otel v1.22.0 // indirect - go.opentelemetry.io/otel/metric v1.22.0 // indirect - go.opentelemetry.io/otel/trace v1.22.0 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 // indirect + go.opentelemetry.io/otel v1.33.0 // indirect + go.opentelemetry.io/otel/metric v1.33.0 // indirect + go.opentelemetry.io/otel/trace v1.33.0 // indirect go.uber.org/automaxprocs v1.5.3 // indirect golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect + gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect ) require ( @@ -183,26 +191,26 @@ require ( github.com/Masterminds/semver v1.5.0 // indirect github.com/Masterminds/semver/v3 v3.2.1 // indirect github.com/Masterminds/sprig/v3 v3.2.3 // indirect - github.com/ProtonMail/go-crypto v1.0.0 // indirect + github.com/ProtonMail/go-crypto v1.1.3 // indirect github.com/alexkohler/prealloc v1.0.0 // indirect github.com/alingse/asasalint v0.0.11 // indirect github.com/ashanbrown/forbidigo v1.6.0 // indirect github.com/ashanbrown/makezero v1.1.1 // indirect - github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10 // indirect - github.com/aws/aws-sdk-go-v2/credentials v1.16.16 // indirect - github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 // indirect - github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 // indirect - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 // indirect - github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 // indirect - github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.25 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.28 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10 // indirect - github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.2 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.18.7 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7 // indirect - github.com/aws/aws-sdk-go-v2/service/sts v1.26.7 // indirect - github.com/aws/smithy-go v1.19.0 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.58 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.31 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.5 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.12 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.33.13 // indirect + github.com/aws/smithy-go v1.22.2 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/bkielbasa/cyclop v1.2.1 // indirect github.com/blakesmith/ar v0.0.0-20190502131153-809d4375e1fb // indirect @@ -211,37 +219,36 @@ require ( github.com/breml/errchkjson v0.3.6 // indirect github.com/butuzov/ireturn v0.3.0 // indirect github.com/cavaliergopher/cpio v1.0.1 // indirect - github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/charithe/durationcheck v0.0.10 // indirect github.com/chavacava/garif v0.1.0 // indirect github.com/cloudflare/circl v1.3.7 // indirect - github.com/containerd/stargz-snapshotter/estargz v0.15.1 // indirect + github.com/containerd/stargz-snapshotter/estargz v0.16.3 // indirect github.com/curioswitch/go-reassign v0.2.0 // indirect github.com/daixiang0/gci v0.12.3 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/denis-tingaikin/go-header v0.5.0 // indirect - github.com/docker/cli v25.0.0+incompatible // indirect + github.com/docker/cli v27.4.1+incompatible // indirect github.com/docker/distribution v2.8.3+incompatible // indirect - github.com/docker/docker v26.1.4+incompatible // indirect - github.com/docker/docker-credential-helpers v0.8.1 // indirect + github.com/docker/docker v27.4.1+incompatible // indirect + github.com/docker/docker-credential-helpers v0.8.2 // indirect github.com/emicklei/go-restful/v3 v3.11.2 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/ettle/strcase v0.2.0 // indirect - github.com/evanphx/json-patch v5.6.0+incompatible // indirect github.com/evanphx/json-patch/v5 v5.9.0 // indirect - github.com/fatih/color v1.17.0 // indirect + github.com/fatih/color v1.18.0 // indirect github.com/fatih/structtag v1.2.0 // indirect github.com/firefart/nonamedreturns v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 github.com/fzipp/gocyclo v0.6.0 // indirect github.com/go-critic/go-critic v0.11.2 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect - github.com/go-git/go-billy/v5 v5.5.0 // indirect - github.com/go-git/go-git/v5 v5.11.0 // indirect + github.com/go-git/go-billy/v5 v5.6.1 // indirect + github.com/go-git/go-git/v5 v5.13.1 // indirect github.com/go-logr/logr v1.4.2 // indirect - github.com/go-openapi/jsonpointer v0.20.2 // indirect + github.com/go-openapi/jsonpointer v0.21.0 // indirect github.com/go-openapi/jsonreference v0.20.4 // indirect - github.com/go-openapi/swag v0.22.7 // indirect + github.com/go-openapi/swag v0.23.0 // indirect github.com/go-toolsmith/astcast v1.1.0 // indirect github.com/go-toolsmith/astcopy v1.1.0 // indirect github.com/go-toolsmith/astequal v1.2.0 // indirect @@ -266,7 +273,7 @@ require ( github.com/gordonklaus/ineffassign v0.1.0 // indirect github.com/goreleaser/chglog v0.5.0 // indirect github.com/goreleaser/fileglob v1.3.0 // indirect - github.com/gorilla/csrf v1.7.2 + github.com/gorilla/csrf v1.7.3 github.com/gostaticanalysis/analysisutil v0.7.1 // indirect github.com/gostaticanalysis/comment v1.4.2 // indirect github.com/gostaticanalysis/forcetypeassert v0.1.0 // indirect @@ -322,34 +329,34 @@ require ( github.com/nunnatsa/ginkgolinter v0.16.1 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect - github.com/opencontainers/image-spec v1.1.0-rc6 // indirect + github.com/opencontainers/image-spec v1.1.0 // indirect github.com/pelletier/go-toml/v2 v2.2.0 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/polyfloyd/go-errorlint v1.4.8 // indirect - github.com/prometheus/client_model v0.5.0 - github.com/prometheus/procfs v0.12.0 // indirect + github.com/prometheus/client_model v0.6.1 + github.com/prometheus/procfs v0.15.1 // indirect github.com/quasilyte/go-ruleguard v0.4.2 // indirect github.com/quasilyte/gogrep v0.5.0 // indirect github.com/quasilyte/regex/syntax v0.0.0-20210819130434-b3f0c404a727 // indirect github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567 // indirect github.com/rivo/uniseg v0.4.4 // indirect - github.com/rogpeppe/go-internal v1.12.0 // indirect + github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/ryancurrah/gomodguard v1.3.1 // indirect github.com/ryanrolds/sqlclosecheck v0.5.1 // indirect github.com/sanposhiho/wastedassign/v2 v2.0.7 // indirect github.com/sashamelentyev/interfacebloat v1.1.0 // indirect github.com/sashamelentyev/usestdlibvars v1.25.0 // indirect github.com/securego/gosec/v2 v2.19.0 // indirect - github.com/sergi/go-diff v1.3.1 // indirect + github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/shazow/go-diff v0.0.0-20160112020656-b6b7b6733b8c // indirect github.com/shopspring/decimal v1.3.1 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/sivchari/containedctx v1.0.3 // indirect github.com/sivchari/tenv v1.7.1 // indirect - github.com/skeema/knownhosts v1.2.1 // indirect + github.com/skeema/knownhosts v1.3.0 // indirect github.com/sonatard/noctx v0.0.2 // indirect github.com/sourcegraph/go-diff v0.7.0 // indirect github.com/spf13/afero v1.11.0 // indirect @@ -361,7 +368,7 @@ require ( github.com/ssgreg/nlreturn/v2 v2.2.1 // indirect github.com/stbenjam/no-sprintf-host-port v0.1.1 // indirect github.com/stretchr/objx v0.5.2 // indirect - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.10.0 github.com/subosito/gotenv v1.4.2 // indirect github.com/t-yuki/gocover-cobertura v0.0.0-20180217150009-aaee18c8195c // indirect github.com/tailscale/go-winio v0.0.0-20231025203758-c4f33415bf55 @@ -371,12 +378,12 @@ require ( github.com/timonwong/loggercheck v0.9.4 // indirect github.com/tomarrell/wrapcheck/v2 v2.8.3 // indirect github.com/tommy-muehle/go-mnd/v2 v2.5.1 // indirect - github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e // indirect + github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 // indirect github.com/ulikunitz/xz v0.5.11 // indirect github.com/ultraware/funlen v0.1.0 // indirect github.com/ultraware/whitespace v0.1.0 // indirect github.com/uudashr/gocognit v1.1.2 // indirect - github.com/vbatts/tar-split v0.11.5 // indirect + github.com/vbatts/tar-split v0.11.6 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect github.com/yagipy/maintidx v1.0.0 // indirect @@ -385,23 +392,22 @@ require ( gitlab.com/digitalxero/go-conventional-commit v1.0.7 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f // indirect - golang.org/x/image v0.18.0 // indirect - golang.org/x/text v0.16.0 // indirect + golang.org/x/image v0.24.0 // indirect + golang.org/x/text v0.24.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect - google.golang.org/appengine v1.6.8 // indirect - google.golang.org/protobuf v1.33.0 // indirect + google.golang.org/protobuf v1.35.1 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 howett.net/plist v1.0.0 // indirect - k8s.io/apiextensions-apiserver v0.30.3 // indirect + k8s.io/apiextensions-apiserver v0.32.0 k8s.io/klog/v2 v2.130.1 // indirect - k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340 // indirect - k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 + k8s.io/kube-openapi v0.0.0-20241105132330-32ad38e42d3f // indirect + k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 mvdan.cc/gofumpt v0.6.0 // indirect mvdan.cc/unparam v0.0.0-20240104100049-c549a3470d14 // indirect - sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect - sigs.k8s.io/structured-merge-diff/v4 v4.4.1 // indirect + sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 // indirect + sigs.k8s.io/structured-merge-diff/v4 v4.4.2 // indirect ) diff --git a/go.mod.sri b/go.mod.sri index 4abb3c5165d2f..6c8357e0468ba 100644 --- a/go.mod.sri +++ b/go.mod.sri @@ -1 +1 @@ -sha256-xO1DuLWi6/lpA9ubA2ZYVJM+CkVNA5IaVGZxX9my0j0= +sha256-av4kr09rjNRmag94ziNjJuI/cg8b8lAD3Tk24t/ezH4= \ No newline at end of file diff --git a/go.sum b/go.sum index 549f559d001fd..bdbae11bb2e3a 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ 4d63.com/gocheckcompilerdirectives v1.2.1/go.mod h1:yjDJSxmDTtIHHCqX0ufRYZDL6vQtMG7tJdKVeWwsqvs= 4d63.com/gochecknoglobals v0.2.1 h1:1eiorGsgHOFOuoOiJDy2psSrQbRdIHrlge0IJIkUgDc= 4d63.com/gochecknoglobals v0.2.1/go.mod h1:KRE8wtJB3CXCsb1xy421JfTHIIbmT3U5ruxw2Qu8fSU= +9fans.net/go v0.0.8-0.20250307142834-96bdba94b63f h1:1C7nZuxUMNz7eiQALRfiqNOm04+m3edWlRff/BYHf0Q= +9fans.net/go v0.0.8-0.20250307142834-96bdba94b63f/go.mod h1:hHyrZRryGqVdqrknjq5OWDLGCTJ2NeEvtrpR96mjraM= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= @@ -41,8 +43,8 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= filippo.io/mkcert v1.4.4 h1:8eVbbwfVlaqUM7OwuftKc2nuYOoTDQWqsoXmzoXZdbc= filippo.io/mkcert v1.4.4/go.mod h1:VyvOchVuAye3BoUsPUOOofKygVwLV2KQMVFJNRq+1dA= -fyne.io/systray v1.11.0 h1:D9HISlxSkx+jHSniMBR6fCFOUjk1x/OOOJLa9lJYAKg= -fyne.io/systray v1.11.0/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= +fyne.io/systray v1.11.1-0.20250317195939-bcf6eed85e7a h1:I8mEKo5sawHu8CqYf3FSjIl9b3puXasFVn2D/hrCneY= +fyne.io/systray v1.11.1-0.20250317195939-bcf6eed85e7a/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= github.com/4meepo/tagalign v1.3.3 h1:ZsOxcwGD/jP4U/aw7qeWu58i7dwYemfy5Y+IF1ACoNw= github.com/4meepo/tagalign v1.3.3/go.mod h1:Q9c1rYMZJc9dPRkbQPpcBNCLEmY2njbAsXhQOZFE2dE= github.com/Abirdcfly/dupword v0.0.14 h1:3U4ulkc8EUo+CaT105/GJ1BQwtgyj6+VaBVbAX11Ba8= @@ -61,12 +63,15 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c h1:pxW6RcqyfI9/kWtOwnv/G+AzdKuy2ZrqINhenH4HyNs= github.com/BurntSushi/toml v1.4.1-0.20240526193622-a339e1f7089c/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= -github.com/DataDog/zstd v1.4.5 h1:EndNeuB0l9syBZhut0wns3gV1hL8zX8LIu6ZiVHWLIQ= -github.com/DataDog/zstd v1.4.5/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo= +github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/DataDog/zstd v1.5.2 h1:vUG4lAyuPCXO0TLbXvPv7EB7cNK1QV/luu55UHLrrn8= +github.com/DataDog/zstd v1.5.2/go.mod h1:g4AWEaM3yOg3HYfnJ3YIawPnVdXJh9QME85blwSAmyw= github.com/Djarvur/go-err113 v0.1.0 h1:uCRZZOdMQ0TZPHYTdYpoC0bLYJKPEHPUJ8MeAa51lNU= github.com/Djarvur/go-err113 v0.1.0/go.mod h1:4UJr5HIiMZrwgkSPdsjy2uOQExX/WEILpIrO9UPGuXs= github.com/GaijinEntertainment/go-exhaustruct/v3 v3.2.0 h1:sATXp1x6/axKxz2Gjxv8MALP0bXaNRfQinEwyfMcx8c= github.com/GaijinEntertainment/go-exhaustruct/v3 v3.2.0/go.mod h1:Nl76DrGNJTA1KJ0LePKBw/vznBX1EHbAZX8mwjR82nI= +github.com/Kodeworks/golang-image-ico v0.0.0-20141118225523-73f0f4cfade9 h1:1ltqoej5GtaWF8jaiA49HwsZD459jqm9YFz9ZtMFpQA= +github.com/Kodeworks/golang-image-ico v0.0.0-20141118225523-73f0f4cfade9/go.mod h1:7uhhqiBaR4CpN0k9rMjOtjpcfGd6DG2m04zQxKnWQ0I= github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= github.com/Masterminds/semver v1.5.0 h1:H65muMkzWKEuNDnfl9d70GUjFniHKHRbFPGBuZ3QEww= @@ -79,12 +84,12 @@ github.com/Masterminds/sprig v2.22.0+incompatible/go.mod h1:y6hNFY5UBTIWBxnzTeuN github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA= github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM= github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= -github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= -github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/OpenPeeDeeP/depguard/v2 v2.2.0 h1:vDfG60vDtIuf0MEOhmLlLLSzqaRM8EMcgJPdp74zmpA= github.com/OpenPeeDeeP/depguard/v2 v2.2.0/go.mod h1:CIzddKRvLBC4Au5aYP/i3nyaWQ+ClszLIuVocRiCYFQ= -github.com/ProtonMail/go-crypto v1.0.0 h1:LRuvITjQWX+WIfr930YHG2HNfjR1uOfyf5vE0kC2U78= -github.com/ProtonMail/go-crypto v1.0.0/go.mod h1:EjAoLdwvbIOoOQr3ihjnSoLZRtE8azugULFRteWMNc0= +github.com/ProtonMail/go-crypto v1.1.3 h1:nRBOetoydLeUb4nHajyO2bKqMLfWQ/ZPwkXqXxPxCFk= +github.com/ProtonMail/go-crypto v1.1.3/go.mod h1:rA3QumHc/FZ8pAHreoekgiAbzpNsfQAosU5td4SnOrE= github.com/ProtonMail/go-mime v0.0.0-20230322103455-7d82a3887f2f h1:tCbYj7/299ekTTXpdwKYF8eBlsYsDVoggDAuAjoK66k= github.com/ProtonMail/go-mime v0.0.0-20230322103455-7d82a3887f2f/go.mod h1:gcr0kNtGBqin9zDW9GOHcVntrwnjrK+qdJ06mWYBybw= github.com/ProtonMail/gopenpgp/v2 v2.7.1 h1:Awsg7MPc2gD3I7IFac2qE3Gdls0lZW8SzrFZ3k1oz0s= @@ -114,6 +119,8 @@ github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1 github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= +github.com/armon/go-metrics v0.4.1 h1:hR91U9KYmb6bLBYLQjyM+3j+rcd/UhE+G78SFnF8gJA= +github.com/armon/go-metrics v0.4.1/go.mod h1:E6amYzXo6aW1tqzoZGT755KkbgrJsSdpwZ+3JqfkOG4= github.com/armon/go-proxyproto v0.0.0-20210323213023-7e956b284f0a/go.mod h1:QmP9hvJ91BbJmGVGSbutW19IC0Q9phDCLGaomwTJbgU= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= @@ -123,65 +130,50 @@ github.com/ashanbrown/makezero v1.1.1 h1:iCQ87C0V0vSyO+M9E/FZYbu65auqH0lnsOkf5Fc github.com/ashanbrown/makezero v1.1.1/go.mod h1:i1bJLCRSCHOcOa9Y6MyF2FTfMZMFdHvxKHxgO5Z1axI= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= -github.com/aws/aws-sdk-go-v2 v1.18.0/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= -github.com/aws/aws-sdk-go-v2 v1.24.1 h1:xAojnj+ktS95YZlDf0zxWBkbFtymPeDP+rvUQIH3uAU= -github.com/aws/aws-sdk-go-v2 v1.24.1/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10 h1:dK82zF6kkPeCo8J1e+tGx4JdvDIQzj7ygIoLg8WMuGs= -github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10/go.mod h1:VeTZetY5KRJLuD/7fkQXMU6Mw7H5m/KP2J5Iy9osMno= -github.com/aws/aws-sdk-go-v2/config v1.18.22/go.mod h1:mN7Li1wxaPxSSy4Xkr6stFuinJGf3VZW3ZSNvO0q6sI= -github.com/aws/aws-sdk-go-v2/config v1.26.5 h1:lodGSevz7d+kkFJodfauThRxK9mdJbyutUxGq1NNhvw= -github.com/aws/aws-sdk-go-v2/config v1.26.5/go.mod h1:DxHrz6diQJOc9EwDslVRh84VjjrE17g+pVZXUeSxaDU= -github.com/aws/aws-sdk-go-v2/credentials v1.13.21/go.mod h1:90Dk1lJoMyspa/EDUrldTxsPns0wn6+KpRKpdAWc0uA= -github.com/aws/aws-sdk-go-v2/credentials v1.16.16 h1:8q6Rliyv0aUFAVtzaldUEcS+T5gbadPbWdV1WcAddK8= -github.com/aws/aws-sdk-go-v2/credentials v1.16.16/go.mod h1:UHVZrdUsv63hPXFo1H7c5fEneoVo9UXiz36QG1GEPi0= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.3/go.mod h1:4Q0UFP0YJf0NrsEuEYHpM9fTSEVnD16Z3uyEF7J9JGM= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 h1:c5I5iH+DZcH3xOIMlz3/tCKJDaHFwYEmxvlh2fAcFo8= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11/go.mod h1:cRrYDYAMUohBJUtUnOhydaMHtiK/1NZ0Otc9lIb6O0Y= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.64 h1:9QJQs36z61YB8nxGwRDfWXEDYbU6H7jdI6zFiAX1vag= -github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.64/go.mod h1:4Q7R9MFpXRdjO3YnAfUTdnuENs32WzBkASt6VxSYDYQ= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33/go.mod h1:7i0PF1ME/2eUPFcjkVIwq+DOygHEoK92t5cDqNgYbIw= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 h1:vF+Zgd9s+H4vOXd5BMaPWykta2a6Ih0AKLq/X6NYKn4= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10/go.mod h1:6BkRjejp/GR4411UGqkX8+wFMbFbqsUIimfK4XjOKR4= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.27/go.mod h1:UrHnn3QV/d0pBZ6QBAEQcqFLf8FAzLmoUfPVIueOvoM= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 h1:nYPe006ktcqUji8S2mqXf9c/7NdiKriOwMvWQHgYztw= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10/go.mod h1:6UV4SZkVvmODfXKql4LCbaZUpF7HO2BX38FgBf9ZOLw= -github.com/aws/aws-sdk-go-v2/internal/ini v1.3.34/go.mod h1:Etz2dj6UHYuw+Xw830KfzCfWGMzqvUTCjUj5b76GVDc= -github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 h1:GrSw8s0Gs/5zZ0SX+gX4zQjRnRsMJDJ2sLur1gRBhEM= -github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.25 h1:AzwRi5OKKwo4QNqPf7TjeO+tK8AyOK3GVSwmRPo7/Cs= -github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.25/go.mod h1:SUbB4wcbSEyCvqBxv/O/IBf93RbEze7U7OnoTlpPB+g= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11/go.mod h1:iV4q2hsqtNECrfmlXyord9u4zyuFEJX9eLgLpSPzWA8= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.28 h1:vGWm5vTpMr39tEZfQeDiDAMgk+5qsnvRny3FjLpnH5w= -github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.28/go.mod h1:spfrICMD6wCAhjhzHuy6DOZZ+LAIY10UxhUmLzpJTTs= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.27/go.mod h1:EOwBD4J4S5qYszS5/3DpkejfuK+Z5/1uzICfPaZLtqw= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10 h1:DBYTXwIGQSGs9w4jKm60F5dmCQ3EEruxdc0MFh+3EY4= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10/go.mod h1:wohMUQiFdzo0NtxbBg0mSRGZ4vL3n0dKjLTINdcIino= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.2 h1:NbWkRxEEIRSCqxhsHQuMiTH7yo+JZW1gp8v3elSVMTQ= -github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.14.2/go.mod h1:4tfW5l4IAB32VWCDEBxCRtR9T4BWy4I4kr1spr8NgZM= -github.com/aws/aws-sdk-go-v2/service/s3 v1.33.0 h1:L5h2fymEdVJYvn6hYO8Jx48YmC6xVmjmgHJV3oGKgmc= -github.com/aws/aws-sdk-go-v2/service/s3 v1.33.0/go.mod h1:J9kLNzEiHSeGMyN7238EjJmBpCniVzFda75Gxl/NqB8= +github.com/aws/aws-sdk-go-v2 v1.36.0 h1:b1wM5CcE65Ujwn565qcwgtOTT1aT4ADOHHgglKjG7fk= +github.com/aws/aws-sdk-go-v2 v1.36.0/go.mod h1:5PMILGVKiW32oDzjj6RU52yrNrDPUHcbZQYr1sM7qmM= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8 h1:zAxi9p3wsZMIaVCdoiQp2uZ9k1LsZvmAnoTBeZPXom0= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8/go.mod h1:3XkePX5dSaxveLAYY7nsbsZZrKxCyEuE5pM4ziFxyGg= +github.com/aws/aws-sdk-go-v2/config v1.29.5 h1:4lS2IB+wwkj5J43Tq/AwvnscBerBJtQQ6YS7puzCI1k= +github.com/aws/aws-sdk-go-v2/config v1.29.5/go.mod h1:SNzldMlDVbN6nWxM7XsUiNXPSa1LWlqiXtvh/1PrJGg= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58 h1:/d7FUpAPU8Lf2KUdjniQvfNdlMID0Sd9pS23FJ3SS9Y= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58/go.mod h1:aVYW33Ow10CyMQGFgC0ptMRIqJWvJ4nxZb0sUiuQT/A= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 h1:7lOW8NUwE9UZekS1DYoiPdVAqZ6A+LheHWb+mHbNOq8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27/go.mod h1:w1BASFIPOPUae7AgaH4SbjNbfdkxuggLyGfNFTn8ITY= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.58 h1:/BsEGAyMai+KdXS+CMHlLhB5miAO19wOqE6tj8azWPM= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.58/go.mod h1:KHM3lfl/sAJBCoLI1Lsg5w4SD2VDYWwQi7vxbKhw7TI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31 h1:lWm9ucLSRFiI4dQQafLrEOmEDGry3Swrz0BIRdiHJqQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31/go.mod h1:Huu6GG0YTfbPphQkDSo4dEGmQRTKb9k9G7RdtyQWxuI= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31 h1:ACxDklUKKXb48+eg5ROZXi1vDgfMyfIA/WyvqHcHI0o= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31/go.mod h1:yadnfsDwqXeVaohbGc/RaD287PuyRw2wugkh5ZL2J6k= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 h1:Pg9URiobXy85kgFev3og2CuOZ8JZUBENF+dcgWBaYNk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.31 h1:8IwBjuLdqIO1dGB+dZ9zJEl8wzY3bVYxcs0Xyu/Lsc0= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.31/go.mod h1:8tMBcuVjL4kP/ECEIWTCWtwV2kj6+ouEKl4cqR4iWLw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 h1:D4oz8/CzT9bAEYtVhSBmFj2dNOtaHOtMKc2vHBwYizA= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2/go.mod h1:Za3IHqTQ+yNcRHxu1OFucBh0ACZT4j4VQFF0BqpZcLY= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.5 h1:siiQ+jummya9OLPDEyHVb2dLW4aOMe22FGDd0sAfuSw= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.5/go.mod h1:iHVx2J9pWzITdP5MJY6qWfG34TfD9EA+Qi3eV6qQCXw= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12 h1:O+8vD2rGjfihBewr5bT+QUfYUHIxCVgG61LHoT59shM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12/go.mod h1:usVdWJaosa66NMvmCrr08NcWDBRv4E6+YFG2pUdw1Lk= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.12 h1:tkVNm99nkJnFo1H9IIQb5QkCiPcvCDn3Pos+IeTbGRA= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.12/go.mod h1:dIVlquSPUMqEJtx2/W17SM2SuESRaVEhEV9alcMqxjw= +github.com/aws/aws-sdk-go-v2/service/s3 v1.75.3 h1:JBod0SnNqcWQ0+uAyzeRFG1zCHotW8DukumYYyNy0zo= +github.com/aws/aws-sdk-go-v2/service/s3 v1.75.3/go.mod h1:FHSHmyEUkzRbaFFqqm6bkLAOQHgqhsLmfCahvCBMiyA= github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7 h1:a8HvP/+ew3tKwSXqL3BCSjiuicr+XTU2eFYeogV9GJE= github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7/go.mod h1:Q7XIWsMo0JcMpI/6TGD6XXcXcV1DbTj6e9BKNntIMIM= -github.com/aws/aws-sdk-go-v2/service/sso v1.12.9/go.mod h1:ouy2P4z6sJN70fR3ka3wD3Ro3KezSxU6eKGQI2+2fjI= -github.com/aws/aws-sdk-go-v2/service/sso v1.18.7 h1:eajuO3nykDPdYicLlP3AGgOyVN3MOlFmZv7WGTuJPow= -github.com/aws/aws-sdk-go-v2/service/sso v1.18.7/go.mod h1:+mJNDdF+qiUlNKNC3fxn74WWNN+sOiGOEImje+3ScPM= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.14.9/go.mod h1:AFvkxc8xfBe8XA+5St5XIHHrQQtkxqrRincx4hmMHOk= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7 h1:QPMJf+Jw8E1l7zqhZmMlFw6w1NmfkfiSK8mS4zOx3BA= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7/go.mod h1:ykf3COxYI0UJmxcfcxcVuz7b6uADi1FkiUz6Eb7AgM8= -github.com/aws/aws-sdk-go-v2/service/sts v1.18.10/go.mod h1:BgQOMsg8av8jset59jelyPW7NoZcZXLVpDsXunGDrk8= -github.com/aws/aws-sdk-go-v2/service/sts v1.26.7 h1:NzO4Vrau795RkUdSHKEwiR01FaGzGOH1EETJ+5QHnm0= -github.com/aws/aws-sdk-go-v2/service/sts v1.26.7/go.mod h1:6h2YuIoxaMSCFf5fi1EgZAwdfkGMgDY+DVfa61uLe4U= -github.com/aws/smithy-go v1.13.5/go.mod h1:Tg+OJXh4MB2R/uN61Ko2f6hTZwB/ZYGOtib8J3gBHzA= -github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= -github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 h1:c5WJ3iHz7rLIgArznb3JCSQT3uUMiz9DLZhIX+1G8ok= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14/go.mod h1:+JJQTxB6N4niArC14YNtxcQtwEqzS3o9Z32n7q33Rfs= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 h1:f1L/JtUkVODD+k1+IiSJUUv8A++2qVr+Xvb3xWXETMU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13/go.mod h1:tvqlFoja8/s0o+UruA1Nrezo/df0PzdunMDDurUfg6U= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.13 h1:3LXNnmtH3TURctC23hnC0p/39Q5gre3FI7BNOiDcVWc= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.13/go.mod h1:7Yn+p66q/jt38qMoVfNvjbm3D89mGBnkwDcijgtih8w= +github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= +github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= -github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/bkielbasa/cyclop v1.2.1 h1:AeF71HZDob1P2/pRm1so9cd1alZnrpyc4q2uP2l0gJY= github.com/bkielbasa/cyclop v1.2.1/go.mod h1:K/dT/M0FPAiYjBgQGau7tz+3TMh4FWAEqlMhzFWCrgM= github.com/blakesmith/ar v0.0.0-20190502131153-809d4375e1fb h1:m935MPodAbYS46DG4pJSv7WO+VECIWUQ7OJYSoTrMh4= @@ -200,7 +192,6 @@ github.com/butuzov/ireturn v0.3.0 h1:hTjMqWw3y5JC3kpnC5vXmFJAWI/m31jaCYQqzkS6PL0 github.com/butuzov/ireturn v0.3.0/go.mod h1:A09nIiwiqzN/IoVo9ogpa0Hzi9fex1kd9PSD6edP5ZA= github.com/butuzov/mirror v1.1.0 h1:ZqX54gBVMXu78QLoiqdwpl2mgmoOJTk7s4p4o+0avZI= github.com/butuzov/mirror v1.1.0/go.mod h1:8Q0BdQU6rC6WILDiBM60DBfvV78OLJmMmixe7GF45AE= -github.com/bwesterb/go-ristretto v1.2.3/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= github.com/caarlos0/go-rpmutils v0.2.1-0.20211112020245-2cd62ff89b11 h1:IRrDwVlWQr6kS1U8/EtyA1+EHcc4yl8pndcqXWrEamg= github.com/caarlos0/go-rpmutils v0.2.1-0.20211112020245-2cd62ff89b11/go.mod h1:je2KZ+LxaCNvCoKg32jtOIULcFogJKcL1ZWUaIBjKj0= github.com/caarlos0/testfs v0.4.4 h1:3PHvzHi5Lt+g332CiShwS8ogTgS3HjrmzZxCm6JCDr8= @@ -212,13 +203,13 @@ github.com/cavaliergopher/cpio v1.0.1/go.mod h1:pBdaqQjnvXxdS/6CvNDwIANIFSP0xRKI github.com/ccojocar/zxcvbn-go v1.0.2 h1:na/czXU8RrhXO4EZme6eQJLR4PzcGsahsBOAwU6I3Vg= github.com/ccojocar/zxcvbn-go v1.0.2/go.mod h1:g1qkXtUSvHP8lhHp5GrSmTz6uWALGRMQdw6Qnz/hi60= github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= -github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= -github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/charithe/durationcheck v0.0.10 h1:wgw73BiocdBDQPik+zcEoBG/ob8uyBHf2iyoHGPf5w4= github.com/charithe/durationcheck v0.0.10/go.mod h1:bCWXb7gYRysD1CU3C+u4ceO49LoGOY1C1L6uouGNreQ= github.com/chavacava/garif v0.1.0 h1:2JHa3hbYf5D9dsgseMKAmc/MZ109otzgNFk5s87H9Pc= @@ -228,10 +219,11 @@ github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5P github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk= github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso= +github.com/circonus-labs/circonus-gometrics v2.3.1+incompatible/go.mod h1:nmEj6Dob7S7YxXgwXpfOuvO54S+tGdZdw9fuRZt25Ag= +github.com/circonus-labs/circonusllhist v0.1.3/go.mod h1:kMXHVDlOchFAehlya5ePtbp5jckzBHf4XRpQvBOLI+I= github.com/ckaznocha/intrange v0.1.0 h1:ZiGBhvrdsKpoEfzh9CjBfDSZof6QB0ORY5tXasUtiew= github.com/ckaznocha/intrange v0.1.0/go.mod h1:Vwa9Ekex2BrEQMg6zlrWwbs/FtYw7eS5838Q7UjK7TQ= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cloudflare/circl v1.3.3/go.mod h1:5XYMA4rFBvNIrhs50XuiBJ15vF2pZn4nnUKZrLbUZFA= github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU= github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBSc8r4zxgA= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= @@ -239,30 +231,26 @@ github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NA github.com/coder/websocket v1.8.12/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= -github.com/containerd/stargz-snapshotter/estargz v0.15.1 h1:eXJjw9RbkLFgioVaTG+G/ZW/0kEe2oEKCdS/ZxIyoCU= -github.com/containerd/stargz-snapshotter/estargz v0.15.1/go.mod h1:gr2RNwukQ/S9Nv33Lt6UC7xEx58C+LHRdoqbEKjz1Kk= +github.com/containerd/stargz-snapshotter/estargz v0.16.3 h1:7evrXtoh1mSbGj/pfRccTampEyKpjpOnS3CyiV1Ebr8= +github.com/containerd/stargz-snapshotter/estargz v0.16.3/go.mod h1:uyr4BfYfOj3G9WBVE8cOlQmXAbPN9VEQpBBeJIuOipU= github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 h1:8h5+bWd7R6AYUslN6c6iuZWTKsKxUFDlpnmilO6R2n0= github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU= github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/creachadair/mds v0.17.1 h1:lXQbTGKmb3nE3aK6OEp29L1gCx6B5ynzlQ6c1KOBurc= +github.com/creachadair/mds v0.17.1/go.mod h1:4b//mUiL8YldH6TImXjmW45myzTLNS1LLjOmrk888eg= +github.com/creachadair/taskgroup v0.13.2 h1:3KyqakBuFsm3KkXi/9XIb0QcA8tEzLHLgaoidf0MdVc= +github.com/creachadair/taskgroup v0.13.2/go.mod h1:i3V1Zx7H8RjwljUEeUWYT30Lmb9poewSb2XI1yTwD0g= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.23 h1:4M6+isWdcStXEf15G/RbrMPOQj1dZ7HPZCGwE4kOeP0= github.com/creack/pty v1.1.23/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/curioswitch/go-reassign v0.2.0 h1:G9UZyOcpk/d7Gd6mqYgd8XYWFMw/znxwGDUstnC9DIo= github.com/curioswitch/go-reassign v0.2.0/go.mod h1:x6OpXuWvgfQaMGks2BZybTngWjT84hqJfKoO8Tt/Roc= -github.com/cyphar/filepath-securejoin v0.2.4 h1:Ugdm7cg7i6ZK6x3xDF1oEu1nfkyfH53EtKeQYTC3kyg= -github.com/cyphar/filepath-securejoin v0.2.4/go.mod h1:aPGpWjXOXUn2NCNjFvBE6aRxGGx79pTxQpKOJNYHHl4= +github.com/cyphar/filepath-securejoin v0.3.6 h1:4d9N5ykBnSp5Xn2JkhocYDkOpURL/18CYMpo6xB9uWM= +github.com/cyphar/filepath-securejoin v0.3.6/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI= github.com/daixiang0/gci v0.12.3 h1:yOZI7VAxAGPQmkb1eqt5g/11SUlwoat1fSblGLmdiQc= github.com/daixiang0/gci v0.12.3/go.mod h1:xtHP9N7AHdNvtRNfcx9gwTDfw7FRJx4bZUsiEfiNNAI= -github.com/dave/astrid v0.0.0-20170323122508-8c2895878b14 h1:YI1gOOdmMk3xodBao7fehcvoZsEeOyy/cfhlpCSPgM4= -github.com/dave/astrid v0.0.0-20170323122508-8c2895878b14/go.mod h1:Sth2QfxfATb/nW4EsrSi2KyJmbcniZ8TgTaji17D6ms= -github.com/dave/brenda v1.1.0 h1:Sl1LlwXnbw7xMhq3y2x11McFu43AjDcwkllxxgZ3EZw= -github.com/dave/brenda v1.1.0/go.mod h1:4wCUr6gSlu5/1Tk7akE5X7UorwiQ8Rij0SKH3/BGMOM= -github.com/dave/courtney v0.4.0 h1:Vb8hi+k3O0h5++BR96FIcX0x3NovRbnhGd/dRr8inBk= -github.com/dave/courtney v0.4.0/go.mod h1:3WSU3yaloZXYAxRuWt8oRyVb9SaRiMBt5Kz/2J227tM= -github.com/dave/patsy v0.0.0-20210517141501-957256f50cba h1:1o36L4EKbZzazMk8iGC4kXpVnZ6TPxR2mZ9qVKjNNAs= -github.com/dave/patsy v0.0.0-20210517141501-957256f50cba/go.mod h1:qfR88CgEGLoiqDaE+xxDCi5QA5v4vUoW0UCX2Nd5Tlc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -277,24 +265,24 @@ github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5Qvfr github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c= github.com/djherbis/times v1.6.0/go.mod h1:gOHeRAz2h+VJNZ5Gmc/o7iD9k4wW7NMVqieYCY99oc0= -github.com/docker/cli v25.0.0+incompatible h1:zaimaQdnX7fYWFqzN88exE9LDEvRslexpFowZBX6GoQ= -github.com/docker/cli v25.0.0+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= +github.com/docker/cli v27.4.1+incompatible h1:VzPiUlRJ/xh+otB75gva3r05isHMo5wXDfPRi5/b4hI= +github.com/docker/cli v27.4.1+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk= github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= -github.com/docker/docker v26.1.4+incompatible h1:vuTpXDuoga+Z38m1OZHzl7NKisKWaWlhjQk7IDPSLsU= -github.com/docker/docker v26.1.4+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= -github.com/docker/docker-credential-helpers v0.8.1 h1:j/eKUktUltBtMzKqmfLB0PAgqYyMHOp5vfsD1807oKo= -github.com/docker/docker-credential-helpers v0.8.1/go.mod h1:P3ci7E3lwkZg6XiHdRKft1KckHiO9a2rNtyFbZ/ry9M= -github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= -github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= +github.com/docker/docker v27.4.1+incompatible h1:ZJvcY7gfwHn1JF48PfbyXg7Jyt9ZCWDW+GGXOIxEwp4= +github.com/docker/docker v27.4.1+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/docker-credential-helpers v0.8.2 h1:bX3YxiGzFP5sOXWc3bTPEXdEaZSeVMrFgOr3T+zrFAo= +github.com/docker/docker-credential-helpers v0.8.2/go.mod h1:P3ci7E3lwkZg6XiHdRKft1KckHiO9a2rNtyFbZ/ry9M= +github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= +github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dsnet/try v0.0.3 h1:ptR59SsrcFUYbT/FhAbKTV6iLkeD6O18qfIWRml2fqI= github.com/dsnet/try v0.0.3/go.mod h1:WBM8tRpUmnXXhY1U6/S8dt6UWdHTQ7y8A5YSkRCkq40= github.com/elastic/crd-ref-docs v0.0.12 h1:F3seyncbzUz3rT3d+caeYWhumb5ojYQ6Bl0Z+zOp16M= github.com/elastic/crd-ref-docs v0.0.12/go.mod h1:X83mMBdJt05heJUYiS3T0yJ/JkCuliuhSUNav5Gjo/U= -github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a h1:mATvB/9r/3gvcejNsXKSkQ6lcIaNec2nyfOdlTBR2lU= -github.com/elazarl/goproxy v0.0.0-20230808193330-2592e75ae04a/go.mod h1:Ro8st/ElPeALwNFlcTpWmkr6IoMFfkjXAvTHpevnDsM= +github.com/elazarl/goproxy v1.2.3 h1:xwIyKHbaP5yfT6O9KIeYJR5549MXRQkoQMRXGztz8YQ= +github.com/elazarl/goproxy v1.2.3/go.mod h1:YfEbZtqP4AetfO6d40vWchF3znWX7C7Vd6ZMfdL8z64= github.com/emicklei/go-restful/v3 v3.11.2 h1:1onLa9DcsMYO9P+CXaL0dStDqQ2EHHXLiz+BtnqkLAU= github.com/emicklei/go-restful/v3 v3.11.2/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= @@ -311,8 +299,9 @@ github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0 github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ= github.com/evanw/esbuild v0.19.11 h1:mbPO1VJ/df//jjUd+p/nRLYCpizXxXb2w/zZMShxa2k= github.com/evanw/esbuild v0.19.11/go.mod h1:D2vIQZqV/vIf/VRHtViaUtViZmG7o+kKmlBfVQuRi48= -github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4= -github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI= +github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/fatih/structtag v1.2.0 h1:/OdNE99OxoI/PqaW/SuSK9uxxT3f/tcSZgon/ssNSx4= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= @@ -321,37 +310,39 @@ github.com/firefart/nonamedreturns v1.0.4 h1:abzI1p7mAEPYuR4A+VLKn4eNDOycjYo2phm github.com/firefart/nonamedreturns v1.0.4/go.mod h1:TDhe/tjI1BXo48CmYbUduTV7BdIga8MAO/xbKdcVsGI= github.com/fogleman/gg v1.3.0 h1:/7zJX8F6AaYQc57WQCyN9cAIz+4bCJGO9B+dyW29am8= github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= +github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= +github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= -github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA= -github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/fzipp/gocyclo v0.6.0 h1:lsblElZG7d3ALtGMx9fmxeTKZaLLpU8mET09yN4BBLo= github.com/fzipp/gocyclo v0.6.0/go.mod h1:rXPyn8fnlpa0R2csP/31uerbiVBugk5whMdlyaLkLoA= -github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc= -github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= +github.com/gaissmai/bart v0.18.0 h1:jQLBT/RduJu0pv/tLwXE+xKPgtWJejbxuXAR+wLJafo= +github.com/gaissmai/bart v0.18.0/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY= github.com/ghostiam/protogetter v0.3.5 h1:+f7UiF8XNd4w3a//4DnusQ2SZjPkUjxkMEfjbxOK4Ug= github.com/ghostiam/protogetter v0.3.5/go.mod h1:7lpeDnEJ1ZjL/YtyoN99ljO4z0pd3H0d18/t2dPBxHw= github.com/github/fakeca v0.1.0 h1:Km/MVOFvclqxPM9dZBC4+QE564nU4gz4iZ0D9pMw28I= github.com/github/fakeca v0.1.0/go.mod h1:+bormgoGMMuamOscx7N91aOuUST7wdaJ2rNjeohylyo= -github.com/gliderlabs/ssh v0.3.5 h1:OcaySEmAQJgyYcArR+gGGTHCyE7nvhEMTlYY+Dp8CpY= -github.com/gliderlabs/ssh v0.3.5/go.mod h1:8XB4KraRrX39qHhT6yxPsHedjA08I/uBVwj4xC+/+z4= +github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c= +github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU= github.com/go-critic/go-critic v0.11.2 h1:81xH/2muBphEgPtcwH1p6QD+KzXl2tMSi3hXjBSxDnM= github.com/go-critic/go-critic v0.11.2/go.mod h1:OePaicfjsf+KPy33yq4gzv6CO7TEQ9Rom6ns1KsJnl8= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= -github.com/go-git/go-billy/v5 v5.5.0 h1:yEY4yhzCDuMGSv83oGxiBotRzhwhNr8VZyphhiu+mTU= -github.com/go-git/go-billy/v5 v5.5.0/go.mod h1:hmexnoNsr2SJU1Ju67OaNz5ASJY3+sHgFRpCtpDCKow= +github.com/go-git/go-billy/v5 v5.6.1 h1:u+dcrgaguSSkbjzHwelEjc0Yj300NUevrrPphk/SoRA= +github.com/go-git/go-billy/v5 v5.6.1/go.mod h1:0AsLr1z2+Uksi4NlElmMblP5rPcDZNRCD8ujZCRR2BE= github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4= github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII= -github.com/go-git/go-git/v5 v5.11.0 h1:XIZc1p+8YzypNr34itUfSvYJcv+eYdTnTvOZ2vD3cA4= -github.com/go-git/go-git/v5 v5.11.0/go.mod h1:6GFcX2P3NM7FPBfpePbpLd21XxsgdAt+lKqXmCUiUCY= +github.com/go-git/go-git/v5 v5.13.1 h1:DAQ9APonnlvSWpvolXWIuV6Q6zXy2wHbN4cVlNR5Q+M= +github.com/go-git/go-git/v5 v5.13.1/go.mod h1:qryJB4cSBoq3FRoBRf5A77joojuBcmPJ0qu3XXXVixc= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= -github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f35nQbASLnvxEde4XOBL+Sn7rFuV+FOJqkljg= -github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA= +github.com/go-json-experiment/json v0.0.0-20250223041408-d3c622f1b874 h1:F8d1AJ6M9UQCavhwmO6ZsrYLfG8zVFWfEfMS2MXPkSY= +github.com/go-json-experiment/json v0.0.0-20250223041408-d3c622f1b874/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= @@ -367,12 +358,12 @@ github.com/go-logr/zapr v1.3.0 h1:XGdV8XW8zdwFiwOA2Dryh1gj2KRQyOOoNmBy4EplIcQ= github.com/go-logr/zapr v1.3.0/go.mod h1:YKepepNBd1u/oyhd/yQmtjVXmm9uML4IXUgMOwR8/Gg= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= -github.com/go-openapi/jsonpointer v0.20.2 h1:mQc3nmndL8ZBzStEo3JYF8wzmeWffDH4VbXz58sAx6Q= -github.com/go-openapi/jsonpointer v0.20.2/go.mod h1:bHen+N0u1KEO3YlmqOjTT9Adn1RfD91Ar825/PuiRVs= +github.com/go-openapi/jsonpointer v0.21.0 h1:YgdVicSA9vH5RiHs9TZW5oyafXZFc6+2Vc1rr/O9oNQ= +github.com/go-openapi/jsonpointer v0.21.0/go.mod h1:IUyH9l/+uyhIYQ/PXVA41Rexl+kOkAPDdXEYns6fzUY= github.com/go-openapi/jsonreference v0.20.4 h1:bKlDxQxQJgwpUSgOENiMPzCTBVuc7vTdXSSgNeAhojU= github.com/go-openapi/jsonreference v0.20.4/go.mod h1:5pZJyJP2MnYCpoeoMAql78cCHauHj0V9Lhc506VOpw4= -github.com/go-openapi/swag v0.22.7 h1:JWrc1uc/P9cSomxfnsFSVWoE1FW6bNbrVPmpQYpCcR8= -github.com/go-openapi/swag v0.22.7/go.mod h1:Gl91UqO+btAM0plGGxHqJcQZ1ZTy6jbmridBTsDy8A0= +github.com/go-openapi/swag v0.23.0 h1:vsEVJDUo2hPJ2tu0/Xc+4noaxyEffXNIs3cOULZ+GrE= +github.com/go-openapi/swag v0.23.0/go.mod h1:esZ8ITTYEsH1V2trKHjAN8Ai7xHb8RV+YSZ577vPjgQ= github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= @@ -383,7 +374,8 @@ github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7 github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= -github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= +github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= +github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/go-toolsmith/astcast v1.1.0 h1:+JN9xZV1A+Re+95pgnMgDboWNVnIMMQXwfBwLRPgSC8= github.com/go-toolsmith/astcast v1.1.0/go.mod h1:qdcuFWeGGS2xX5bLM/c3U9lewg7+Zu4mr+xPwZIB4ZU= github.com/go-toolsmith/astcopy v1.1.0 h1:YGwBN0WM+ekI/6SS6+52zLDEf8Yvp3n2seZITCUBt5s= @@ -407,8 +399,10 @@ github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1 h1:TQcrn6Wq+sKGkpyPvppOz99zsM github.com/go-viper/mapstructure/v2 v2.0.0-alpha.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/go-xmlfmt/xmlfmt v1.1.2 h1:Nea7b4icn8s57fTx1M5AI4qQT5HEM3rVUO8MuE6g80U= github.com/go-xmlfmt/xmlfmt v1.1.2/go.mod h1:aUCEOzzezBEjDBbFBoSiya/gduyIiWYRP6CnSFIV8AM= -github.com/gobuffalo/flect v1.0.2 h1:eqjPGSo2WmjgY2XlpGwo2NXgL3RucAKo4k4qQMNA5sA= -github.com/gobuffalo/flect v1.0.2/go.mod h1:A5msMlrHtLqh9umBSnvabjsMrCcCpAyzglnDvkbYKHs= +github.com/go4org/plan9netshell v0.0.0-20250324183649-788daa080737 h1:cf60tHxREO3g1nroKr2osU3JWZsJzkfi7rEg+oAB0Lo= +github.com/go4org/plan9netshell v0.0.0-20250324183649-788daa080737/go.mod h1:MIS0jDzbU/vuM9MC4YnBITCv+RYuTRq8dJzmCrFsK9g= +github.com/gobuffalo/flect v1.0.3 h1:xeWBM2nui+qnVvNM4S3foBhCAL2XgPU+a7FdpelbTq4= +github.com/gobuffalo/flect v1.0.3/go.mod h1:A5msMlrHtLqh9umBSnvabjsMrCcCpAyzglnDvkbYKHs= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= github.com/goccy/go-yaml v1.12.0 h1:/1WHjnMsI1dlIBQutrvSMGZRQufVO3asrHfTwfACoPM= @@ -490,8 +484,12 @@ github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-containerregistry v0.18.0 h1:ShE7erKNPqRh5ue6Z9DUOlk04WsnFWPO6YGr3OxnfoQ= -github.com/google/go-containerregistry v0.18.0/go.mod h1:u0qB2l7mvtWVR5kNcbFIhFY1hLbf8eeGapA+vbFDCtQ= +github.com/google/go-containerregistry v0.20.2 h1:B1wPJ1SN/S7pB+ZAimcciVD+r+yV/l/DSArMxlbwseo= +github.com/google/go-containerregistry v0.20.2/go.mod h1:z38EKdKh4h7IP2gSfUUqEvalZBqs6AoLeWfUy34nQC8= +github.com/google/go-tpm v0.9.4 h1:awZRf9FwOeTunQmHoDYSHJps3ie6f1UlhS1fOdPEt1I= +github.com/google/go-tpm v0.9.4/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= +github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba h1:qJEJcuLzH5KDR0gKc0zcktin6KSAwL7+jWKBYceddTc= +github.com/google/go-tpm-tools v0.3.13-0.20230620182252-4639ecce2aba/go.mod h1:EFYHy8/1y2KfgTAsx7Luu7NGhoxtuVHnNo8jE7FikKc= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -510,8 +508,8 @@ github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hf github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd h1:gbpYu9NMq8jhDVbvlGkMFWCjLFlqqEZjEmObmhUy6Vo= -github.com/google/pprof v0.0.0-20240409012703-83162a5b38cd/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw= +github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db h1:097atOisP2aRj7vFgYQBbFN4U4JNXUNYpxael3UzMyo= +github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/rpmpack v0.5.0 h1:L16KZ3QvkFGpYhmp23iQip+mx1X39foEsqszjMNBm8A= github.com/google/rpmpack v0.5.0/go.mod h1:uqVAUVQLq8UY2hCDfmJ/+rtO3aw7qyhc90rCVEabEfI= @@ -530,8 +528,8 @@ github.com/goreleaser/fileglob v1.3.0 h1:/X6J7U8lbDpQtBvGcwwPS6OpzkNVlVEsFUVRx9+ github.com/goreleaser/fileglob v1.3.0/go.mod h1:Jx6BoXv3mbYkEzwm9THo7xbr5egkAraxkGorbJb4RxU= github.com/goreleaser/nfpm/v2 v2.33.1 h1:EkdAzZyVhAI9JC1vjmjjbmnNzyH1J6Cu4JCsA7YcQuc= github.com/goreleaser/nfpm/v2 v2.33.1/go.mod h1:8wwWWvJWmn84xo/Sqiv0aMvEGTHlHZTXTEuVSgQpkIM= -github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI= -github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= +github.com/gorilla/csrf v1.7.3 h1:BHWt6FTLZAb2HtWT5KDBf6qgpZzvtbp9QWDRKZMXJC0= +github.com/gorilla/csrf v1.7.3/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gostaticanalysis/analysisutil v0.7.1 h1:ZMCjoue3DtDWQ5WyU16YbjbQEQ3VuzwxALrpYd+HeKk= @@ -547,15 +545,32 @@ github.com/gostaticanalysis/testutil v0.3.1-0.20210208050101-bfb5c8eec0e4/go.mod github.com/gostaticanalysis/testutil v0.4.0 h1:nhdCmubdmDF6VEatUNjgUZBJKWRqugoISdUv3PPQgHY= github.com/gostaticanalysis/testutil v0.4.0/go.mod h1:bLIoPefWXrRi/ssLFWX1dx7Repi5x3CuviD3dgAZaBU= github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 h1:YBftPWNWd4WwGqtY2yeZL2ef8rHAxPBD8KFhJpmcqms= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0/go.mod h1:YN5jB8ie0yfIUg6VvR9Kz84aCaG7AsGZnLjhHbUqwPg= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k= +github.com/hashicorp/go-cleanhttp v0.5.0/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= +github.com/hashicorp/go-hclog v1.6.2 h1:NOtoftovWkDheyUM/8JW3QMiXyxJK3uHRK7wV04nD2I= +github.com/hashicorp/go-hclog v1.6.2/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-immutable-radix v1.0.0/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-immutable-radix v1.3.1 h1:DKHmCUm2hRBK510BaiZlwvpD40f8bJFeZnpfm2KLowc= +github.com/hashicorp/go-immutable-radix v1.3.1/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-metrics v0.5.4 h1:8mmPiIJkTPPEbAiV97IxdAGNdRdaWwVap1BU6elejKY= +github.com/hashicorp/go-metrics v0.5.4/go.mod h1:CG5yz4NZ/AI/aQt9Ucm/vdBnbh7fvmv4lxZ350i+QQI= +github.com/hashicorp/go-msgpack/v2 v2.1.2 h1:4Ee8FTp834e+ewB71RDrQ0VKpyFdrKOjvYtnQ/ltVj0= +github.com/hashicorp/go-msgpack/v2 v2.1.2/go.mod h1:upybraOAblm4S7rx0+jeNy+CWWhzywQsSRV5033mMu4= +github.com/hashicorp/go-retryablehttp v0.5.3/go.mod h1:9B5zBasrRhHXnJnui7y6sL7es7NDiJgTc6Er0maI1Xs= +github.com/hashicorp/go-uuid v1.0.0 h1:RS8zrF7PhGwyNPOtxSClXXj9HA8feRnJzgnI1RJCSnM= +github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro= github.com/hashicorp/go-version v1.2.1/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek= github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= +github.com/hashicorp/golang-lru v0.6.0 h1:uL2shRDx7RTrOrTCUZEGP/wJUFiUI8QT6E7z5o8jga4= +github.com/hashicorp/golang-lru v0.6.0/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hashicorp/raft v1.7.2 h1:pyvxhfJ4R8VIAlHKvLoKQWElZspsCVT6YWuxVxsPAgc= +github.com/hashicorp/raft v1.7.2/go.mod h1:DfvCGFxpAUPE0L4Uc8JLlTPtc3GzSbdH0MTJCLgnmJQ= github.com/hdevalence/ed25519consensus v0.2.0 h1:37ICyZqdyj0lAZ8P4D1d1id3HqbbG1N3iBb1Tb4rdcU= github.com/hdevalence/ed25519consensus v0.2.0/go.mod h1:w3BHWjwJbFU29IRHL1Iqkw3sus+7FctEyM4RqDxYNzo= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= @@ -563,18 +578,18 @@ github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSo github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= -github.com/hugelgupf/vmtest v0.0.0-20240102225328-693afabdd27f h1:ov45/OzrJG8EKbGjn7jJZQJTN7Z1t73sFYNIRd64YlI= -github.com/hugelgupf/vmtest v0.0.0-20240102225328-693afabdd27f/go.mod h1:JoDrYMZpDPYo6uH9/f6Peqms3zNNWT2XiGgioMOIGuI= +github.com/hugelgupf/vmtest v0.0.0-20240216064925-0561770280a1 h1:jWoR2Yqg8tzM0v6LAiP7i1bikZJu3gxpgvu3g1Lw+a0= +github.com/hugelgupf/vmtest v0.0.0-20240216064925-0561770280a1/go.mod h1:B63hDJMhTupLWCHwopAyEo7wRFowx9kOc8m8j1sfOqE= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/illarion/gonotify/v2 v2.0.3 h1:B6+SKPo/0Sw8cRJh1aLzNEeNVFfzE3c6N+o+vyxM+9A= -github.com/illarion/gonotify/v2 v2.0.3/go.mod h1:38oIJTgFqupkEydkkClkbL6i5lXV/bxdH9do5TALPEE= +github.com/illarion/gonotify/v3 v3.0.2 h1:O7S6vcopHexutmpObkeWsnzMJt/r1hONIEogeVNmJMk= +github.com/illarion/gonotify/v3 v3.0.2/go.mod h1:HWGPdPe817GfvY3w7cx6zkbzNZfi3QjcBm/wgVvEL1U= github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= -github.com/inetaf/tcpproxy v0.0.0-20240214030015-3ce58045626c h1:gYfYE403/nlrGNYj6BEOs9ucLCAGB9gstlSk92DttTg= -github.com/inetaf/tcpproxy v0.0.0-20240214030015-3ce58045626c/go.mod h1:Di7LXRyUcnvAcLicFhtM9/MlZl/TNgRSDHORM2c6CMI= +github.com/inetaf/tcpproxy v0.0.0-20250203165043-ded522cbd03f h1:hPcDyz0u+Zo14n0fpJggxL9JMAmZIK97TVLcLJLPMDI= +github.com/inetaf/tcpproxy v0.0.0-20250203165043-ded522cbd03f/go.mod h1:Di7LXRyUcnvAcLicFhtM9/MlZl/TNgRSDHORM2c6CMI= github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 h1:9K06NfxkBh25x56yVhWWlKFE8YpicaSfHwoV8SFbueA= github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2/go.mod h1:3A9PQ1cunSDF/1rbTq99Ts4pVnycWg+vlPkfeD2NLFI= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= @@ -596,13 +611,11 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGw github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= -github.com/josharian/native v1.0.1-0.20221213033349-c1e37c09b531/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= -github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 h1:elKwZS1OcdQ0WwEDBeqxKwb7WB62QX8bvZ/FJnVXIfk= -github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86/go.mod h1:aFAMtuldEgx/4q7iSGazk22+IcgvtiC+HIimFO9XlS8= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/jsimonetti/rtnetlink v1.4.0 h1:Z1BF0fRgcETPEa0Kt0MRk3yV5+kF1FWTni6KUFKrq2I= github.com/jsimonetti/rtnetlink v1.4.0/go.mod h1:5W1jDvWdnthFJ7fxYX1GMK07BUpI4oskfOqvPteYS6E= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.11/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -627,8 +640,8 @@ github.com/kisielk/errcheck v1.7.0/go.mod h1:1kLL+jV4e+CFfueBmI1dSK2ADDyQnlrnrY/ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kkHAIKE/contextcheck v1.1.4 h1:B6zAaLhOEEcjvUgIYEqystmnFk1Oemn8bvJhbt0GMb8= github.com/kkHAIKE/contextcheck v1.1.4/go.mod h1:1+i/gWqokIa+dm31mqGLZhZJ7Uh44DJGZVmr6QRBNJg= -github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= -github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU= github.com/klauspost/pgzip v1.2.6/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -675,8 +688,12 @@ github.com/matoous/godox v0.0.0-20230222163458-006bad1f9d26 h1:gWg6ZQ4JhDfJPqlo2 github.com/matoous/godox v0.0.0-20230222163458-006bad1f9d26/go.mod h1:1BELzlh859Sh1c6+90blK8lbYy0kwQf1bYlBhBysy1s= github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE= github.com/matryer/is v1.4.0/go.mod h1:8I/i5uYgLzgsgEloJE1U6xx5HkBQpAZvepWuujKwMRU= +github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= @@ -686,8 +703,8 @@ github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= -github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= -github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= github.com/mdlayher/sdnotify v1.0.0 h1:Ma9XeLVN/l0qpyx1tNeMSeTjCPH6NtuD6/N9XdTlQ3c= github.com/mdlayher/sdnotify v1.0.0/go.mod h1:HQUmpM4XgYkhDLtd+Uad8ZFK1T9D5+pNxnXQjCeJlGE= github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI= @@ -743,14 +760,14 @@ github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= -github.com/onsi/ginkgo/v2 v2.17.1 h1:V++EzdbhI4ZV4ev0UTIj0PzhzOcReJFyJaLjtSF55M8= -github.com/onsi/ginkgo/v2 v2.17.1/go.mod h1:llBI3WDLL9Z6taip6f33H76YcWtJv+7R3HigUjbIBOs= -github.com/onsi/gomega v1.33.1 h1:dsYjIxxSR755MDmKVsaFQTE22ChNBcuuTWgkUDSubOk= -github.com/onsi/gomega v1.33.1/go.mod h1:U4R44UsT+9eLIaYRB2a5qajjtQYn0hauxvRm16AVYg0= +github.com/onsi/ginkgo/v2 v2.21.0 h1:7rg/4f3rB88pb5obDgNZrNHrQ4e6WpjonchcpuBRnZM= +github.com/onsi/ginkgo/v2 v2.21.0/go.mod h1:7Du3c42kxCUegi0IImZ1wUQzMBVecgIHjR1C+NkhLQo= +github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= +github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= -github.com/opencontainers/image-spec v1.1.0-rc6 h1:XDqvyKsJEbRtATzkgItUqBA7QHk58yxX1Ov9HERHNqU= -github.com/opencontainers/image-spec v1.1.0-rc6/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= +github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= +github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= github.com/otiai10/copy v1.2.0/go.mod h1:rrF5dJ5F0t/EWSYODDu4j9/vEeYHMkc8jt0zJChqQWw= github.com/otiai10/copy v1.14.0 h1:dCI/t1iTdYGtkvCuBG2BgR6KZa83PTclw4U5n2wAllU= github.com/otiai10/copy v1.14.0/go.mod h1:ECfuL02W+/FkTWZWgQqXPWZgW9oeKCSQ5qVfSc4qc4w= @@ -758,11 +775,12 @@ github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJ github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs= github.com/otiai10/mint v1.3.0/go.mod h1:F5AjcsTsWUqX+Na9fpHb52P8pcRX2CI6A3ctIT91xUo= github.com/otiai10/mint v1.3.1/go.mod h1:/yxELlJQ0ufhjUwhshSj+wFjZ78CnZ48/1wtmBH1OTc= +github.com/pascaldekloe/goe v0.1.0 h1:cBOtyMzM9HTpWjXfbbunk26uA6nG3a8n06Wieeh0MwY= +github.com/pascaldekloe/goe v0.1.0/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= github.com/pelletier/go-toml/v2 v2.2.0 h1:QLgLl2yMN7N+ruc31VynXs1vhMZa7CeHHejIeBAsoHo= github.com/pelletier/go-toml/v2 v2.2.0/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/peterbourgon/ff/v3 v3.4.0 h1:QBvM/rizZM1cB0p0lGMdmR7HxZeI/ZrBWB4DqLkMUBc= github.com/peterbourgon/ff/v3 v3.4.0/go.mod h1:zjJVUhx+twciwfDl0zBcFzl4dW8axCRyXE/eKY9RztQ= -github.com/pierrec/lz4/v4 v4.1.14/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pjbgf/sha1cd v0.3.0 h1:4D5XXmUUBUl/xQ6IjCkEAbqXskkq/4O7LmGn0AqMDs4= @@ -787,8 +805,10 @@ github.com/prometheus-community/pro-bing v0.4.0 h1:YMbv+i08gQz97OZZBwLyvmmQEEzyf github.com/prometheus-community/pro-bing v0.4.0/go.mod h1:b7wRYZtCcPmt4Sz319BykUU241rWLe1VFXyiyWK/dH4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= +github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU= github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M= github.com/prometheus/client_golang v1.11.0/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= +github.com/prometheus/client_golang v1.11.1/go.mod h1:Z6t4BnS23TR94PD6BsDNk8yVqroYurpAkEiz0P2BEV0= github.com/prometheus/client_golang v1.12.1/go.mod h1:3Z9XVyYiZYEO+YQWt3RD2R3jrbd179Rt297l4aS6nDY= github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= @@ -796,21 +816,23 @@ github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1: github.com/prometheus/client_model v0.0.0-20190129233127-fd36f4220a90/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= -github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/common v0.4.1/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4= +github.com/prometheus/common v0.9.1/go.mod h1:yhUN8i9wzaXS3w1O07YhxHEBxD+W35wd8bs7vj7HSQ4= github.com/prometheus/common v0.10.0/go.mod h1:Tlit/dnDKsSWFlCLTWaA1cyBgKHSMdTB80sz/V91rCo= github.com/prometheus/common v0.26.0/go.mod h1:M7rCNAaPfAosfx8veZJCuw84e35h3Cfd9VFqTh1DIvc= github.com/prometheus/common v0.32.1/go.mod h1:vu+V0TpY+O6vW9J44gczi3Ap/oXXR10b+M/gUGO4Hls= -github.com/prometheus/common v0.48.0 h1:QO8U2CdOzSn1BBsmXJXduaaW+dY/5QLjfB8svtSzKKE= -github.com/prometheus/common v0.48.0/go.mod h1:0/KsvlIEfPQCQ5I2iNSAWKPZziNCvRs5EC6ILDTlAPc= +github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= +github.com/prometheus/common v0.55.0/go.mod h1:2SECS4xJG1kd8XF9IcM1gMX6510RAEL65zxzNImwdc8= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.2/go.mod h1:TjEm7ze935MbeOT/UhFTIMYKhuLP4wbCsTZCD3I8kEA= +github.com/prometheus/procfs v0.0.8/go.mod h1:7Qr8sr6344vo1JqZ6HhLceV9o3AJ1Ff+GxbHq6oeK9A= github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/procfs v0.7.3/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= -github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= -github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/prometheus/prometheus v0.49.2-0.20240125131847-c3b8ef1694ff h1:X1Tly81aZ22DA1fxBdfvR3iw8+yFoUBUHMEd+AX/ZXI= github.com/prometheus/prometheus v0.49.2-0.20240125131847-c3b8ef1694ff/go.mod h1:FvE8dtQ1Ww63IlyKBn1V4s+zMwF9kHkVNkQBR1pM4CU= github.com/quasilyte/go-ruleguard v0.4.2 h1:htXcXDK6/rO12kiTHKfHuqR4kr3Y4M0J0rOL6CH/BYs= @@ -826,8 +848,8 @@ github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/ryancurrah/gomodguard v1.3.1 h1:fH+fUg+ngsQO0ruZXXHnA/2aNllWA1whly4a6UvyzGE= github.com/ryancurrah/gomodguard v1.3.1/go.mod h1:DGFHzEhi6iJ0oIDfMuo3TgrS+L9gZvrEfmjjuelnRU0= @@ -846,8 +868,8 @@ github.com/sashamelentyev/usestdlibvars v1.25.0/go.mod h1:9nl0jgOfHKWNFS43Ojw0i7 github.com/securego/gosec/v2 v2.19.0 h1:gl5xMkOI0/E6Hxx0XCY2XujA3V7SNSefA8sC+3f1gnk= github.com/securego/gosec/v2 v2.19.0/go.mod h1:hOkDcHz9J/XIgIlPDXalxjeVYsHxoWUc5zJSHxcB8YM= github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo= -github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= -github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= +github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN3Uc8sB6B/s6Z4t2xvBgU1htSHuq8= +github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/shazow/go-diff v0.0.0-20160112020656-b6b7b6733b8c h1:W65qqJCIOVP4jpqPQ0YvHYKwcMEMVWIzWC5iNQQfBTU= github.com/shazow/go-diff v0.0.0-20160112020656-b6b7b6733b8c/go.mod h1:/PevMnwAxekIXwN8qQyfc5gl2NlkB3CQlkizAbOkeBs= github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= @@ -865,8 +887,8 @@ github.com/sivchari/containedctx v1.0.3 h1:x+etemjbsh2fB5ewm5FeLNi5bUjK0V8n0RB+W github.com/sivchari/containedctx v1.0.3/go.mod h1:c1RDvCbnJLtH4lLcYD/GqwiBSSf4F5Qk0xld2rBqzJ4= github.com/sivchari/tenv v1.7.1 h1:PSpuD4bu6fSmtWMxSGWcvqUUgIn7k3yOJhOIzVWn8Ak= github.com/sivchari/tenv v1.7.1/go.mod h1:64yStXKSOxDfX47NlhVwND4dHwfZDdbp2Lyl018Icvg= -github.com/skeema/knownhosts v1.2.1 h1:SHWdIUa82uGZz+F+47k8SY4QhhI291cXCpopT1lK2AQ= -github.com/skeema/knownhosts v1.2.1/go.mod h1:xYbVRSPxqBZFrdmDyMmsOs+uX1UZC3nTN3ThzgDxUwo= +github.com/skeema/knownhosts v1.3.0 h1:AM+y0rI04VksttfwjkSTNQorvGqmwATnvnAHpSgc0LY= +github.com/skeema/knownhosts v1.3.0/go.mod h1:sPINvnADmT/qYH1kfv+ePMmOBTH6Tbl7b5LvTDjFK7M= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= github.com/smartystreets/assertions v1.13.1 h1:Ef7KhSmjZcK6AVf9YbJdvPYG9avaF0ZxudX+ThRdWfU= @@ -906,11 +928,13 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/studio-b12/gowebdav v0.9.0 h1:1j1sc9gQnNxbXXM4M/CebPOX4aXYtr7MojAVcN4dHjU= github.com/studio-b12/gowebdav v0.9.0/go.mod h1:bHA7t77X/QFExdeAnDzK6vKM34kEZAcE1OX4MfiwjkE= github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8= @@ -919,30 +943,32 @@ github.com/t-yuki/gocover-cobertura v0.0.0-20180217150009-aaee18c8195c h1:+aPplB github.com/t-yuki/gocover-cobertura v0.0.0-20180217150009-aaee18c8195c/go.mod h1:SbErYREK7xXdsRiigaQiQkI9McGRzYMvlKYaP3Nimdk= github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e h1:PtWT87weP5LWHEY//SWsYkSO3RWRZo4OSWagh3YD2vQ= github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e/go.mod h1:XrBNfAFN+pwoWuksbFS9Ccxnopa15zJGgXRFN90l3K4= -github.com/tailscale/depaware v0.0.0-20210622194025-720c4b409502 h1:34icjjmqJ2HPjrSuJYEkdZ+0ItmGQAQ75cRHIiftIyE= -github.com/tailscale/depaware v0.0.0-20210622194025-720c4b409502/go.mod h1:p9lPsd+cx33L3H9nNoecRRxPssFKUwwI50I3pZ0yT+8= +github.com/tailscale/depaware v0.0.0-20250112153213-b748de04d81b h1:ewWb4cA+YO9/3X+v5UhdV+eKFsNBOPcGRh39Glshx/4= +github.com/tailscale/depaware v0.0.0-20250112153213-b748de04d81b/go.mod h1:p9lPsd+cx33L3H9nNoecRRxPssFKUwwI50I3pZ0yT+8= github.com/tailscale/go-winio v0.0.0-20231025203758-c4f33415bf55 h1:Gzfnfk2TWrk8Jj4P4c1a3CtQyMaTVCznlkLZI++hok4= github.com/tailscale/go-winio v0.0.0-20231025203758-c4f33415bf55/go.mod h1:4k4QO+dQ3R5FofL+SanAUZe+/QfeK0+OIuwDIRu2vSg= github.com/tailscale/goexpect v0.0.0-20210902213824-6e8c725cea41 h1:/V2rCMMWcsjYaYO2MeovLw+ClP63OtXgCF2Y1eb8+Ns= github.com/tailscale/goexpect v0.0.0-20210902213824-6e8c725cea41/go.mod h1:/roCdA6gg6lQyw/Oz6gIIGu3ggJKYhF+WC/AQReE5XQ= -github.com/tailscale/golang-x-crypto v0.0.0-20240604161659-3fde5e568aa4 h1:rXZGgEa+k2vJM8xT0PoSKfVXwFGPQ3z3CJfmnHJkZZw= -github.com/tailscale/golang-x-crypto v0.0.0-20240604161659-3fde5e568aa4/go.mod h1:ikbF+YT089eInTp9f2vmvy4+ZVnW5hzX1q2WknxSprQ= +github.com/tailscale/golang-x-crypto v0.0.0-20250404221719-a5573b049869 h1:SRL6irQkKGQKKLzvQP/ke/2ZuB7Py5+XuqtOgSj+iMM= +github.com/tailscale/golang-x-crypto v0.0.0-20250404221719-a5573b049869/go.mod h1:ikbF+YT089eInTp9f2vmvy4+ZVnW5hzX1q2WknxSprQ= github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05 h1:4chzWmimtJPxRs2O36yuGRW3f9SYV+bMTTvMBI0EKio= github.com/tailscale/goupnp v1.0.1-0.20210804011211-c64d0f06ea05/go.mod h1:PdCqy9JzfWMJf1H5UJW2ip33/d4YkoKN0r67yKH1mG8= github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a h1:SJy1Pu0eH1C29XwJucQo73FrleVK6t4kYz4NVhp34Yw= github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a/go.mod h1:DFSS3NAGHthKo1gTlmEcSBiZrRJXi28rLNd/1udP1c8= -github.com/tailscale/mkctr v0.0.0-20240628074852-17ca944da6ba h1:uNo1VCm/xg4alMkIKo8RWTKNx5y1otfVOcKbp+irkL4= -github.com/tailscale/mkctr v0.0.0-20240628074852-17ca944da6ba/go.mod h1:DxnqIXBplij66U2ZkL688xy07q97qQ83P+TVueLiHq4= +github.com/tailscale/mkctr v0.0.0-20250228050937-c75ea1476830 h1:SwZ72kr1oRzzSPA5PYB4hzPh22UI0nm0dapn3bHaUPs= +github.com/tailscale/mkctr v0.0.0-20250228050937-c75ea1476830/go.mod h1:qTslktI+Qh9hXo7ZP8xLkl5V8AxUMfxG0xLtkCFLxnw= github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 h1:uFsXVBE9Qr4ZoF094vE6iYTLDl0qCiKzYXlL6UeWObU= github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= -github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 h1:Gz0rz40FvFVLTBk/K8UNAenb36EbDSnh+q7Z9ldcC8w= -github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4/go.mod h1:phI29ccmHQBc+wvroosENp1IF9195449VDnFDhJ4rJU= -github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 h1:tdUdyPqJ0C97SJfjB9tW6EylTtreyee9C44de+UBG0g= -github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc h1:24heQPtnFR+yfntqhI3oAu9i27nEojcQ4NuBQOo5ZFA= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc/go.mod h1:f93CXfllFsO9ZQVq+Zocb1Gp4G5Fz0b0rXHLOzt/Djc= +github.com/tailscale/setec v0.0.0-20250205144240-8898a29c3fbb h1:Rtklwm6HUlCtf/MR2MB9iY4FoA16acWWlC5pLrTVa90= +github.com/tailscale/setec v0.0.0-20250205144240-8898a29c3fbb/go.mod h1:R8iCVJnbOB05pGexHK/bKHneIRHpZ3jLl7wMQ0OM/jw= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 h1:UBPHPtv8+nEAy2PD8RyAhOYvau1ek0HDJqLS/Pysi14= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 h1:l10Gi6w9jxvinoiq15g8OToDdASBni4CyJOdHY1Hr8M= github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6/go.mod h1:ZXRML051h7o4OcI0d3AaILDIad/Xw0IkXaHM17dic1Y= -github.com/tailscale/wireguard-go v0.0.0-20240905161824-799c1978fafc h1:cezaQN9pvKVaw56Ma5qr/G646uKIYP0yQf+OyWN/okc= -github.com/tailscale/wireguard-go v0.0.0-20240905161824-799c1978fafc/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= +github.com/tailscale/wireguard-go v0.0.0-20250304000100-91a0587fb251 h1:h/41LFTrwMxB9Xvvug0kRdQCU5TlV1+pAMQw0ZtDE3U= +github.com/tailscale/wireguard-go v0.0.0-20250304000100-91a0587fb251/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= github.com/tc-hib/winres v0.2.1 h1:YDE0FiP0VmtRaDn7+aaChp1KiF4owBiJa5l964l5ujA= @@ -961,18 +987,21 @@ github.com/timakin/bodyclose v0.0.0-20230421092635-574207250966 h1:quvGphlmUVU+n github.com/timakin/bodyclose v0.0.0-20230421092635-574207250966/go.mod h1:27bSVNWSBOHm+qRp1T9qzaIpsWEP6TbUnei/43HK+PQ= github.com/timonwong/loggercheck v0.9.4 h1:HKKhqrjcVj8sxL7K77beXh0adEm6DLjV/QOGeMXEVi4= github.com/timonwong/loggercheck v0.9.4/go.mod h1:caz4zlPcgvpEkXgVnAJGowHAMW2NwHaNlpS8xDbVhTg= +github.com/tink-crypto/tink-go/v2 v2.1.0 h1:QXFBguwMwTIaU17EgZpEJWsUSc60b1BAGTzBIoMdmok= +github.com/tink-crypto/tink-go/v2 v2.1.0/go.mod h1:y1TnYFt1i2eZVfx4OGc+C+EMp4CoKWAw2VSEuoicHHI= github.com/tomarrell/wrapcheck/v2 v2.8.3 h1:5ov+Cbhlgi7s/a42BprYoxsr73CbdMUTzE3bRDFASUs= github.com/tomarrell/wrapcheck/v2 v2.8.3/go.mod h1:g9vNIyhb5/9TQgumxQyOEqDHsmGYcGsVMOx/xGkqdMo= github.com/tommy-muehle/go-mnd/v2 v2.5.1 h1:NowYhSdyE/1zwK9QCLeRb6USWdoif80Ie+v+yU8u1Zw= github.com/tommy-muehle/go-mnd/v2 v2.5.1/go.mod h1:WsUAkMJMYww6l/ufffCD3m+P7LEvr8TnZn9lwVDlgzw= github.com/toqueteos/webbrowser v1.2.0 h1:tVP/gpK69Fx+qMJKsLE7TD8LuGWPnEV71wBN9rrstGQ= github.com/toqueteos/webbrowser v1.2.0/go.mod h1:XWoZq4cyp9WeUeak7w7LXRUQf1F1ATJMir8RTqb4ayM= -github.com/u-root/gobusybox/src v0.0.0-20231228173702-b69f654846aa h1:unMPGGK/CRzfg923allsikmvk2l7beBeFPUNC4RVX/8= -github.com/u-root/gobusybox/src v0.0.0-20231228173702-b69f654846aa/go.mod h1:Zj4Tt22fJVn/nz/y6Ergm1SahR9dio1Zm/D2/S0TmXM= -github.com/u-root/u-root v0.12.0 h1:K0AuBFriwr0w/PGS3HawiAw89e3+MU7ks80GpghAsNs= -github.com/u-root/u-root v0.12.0/go.mod h1:FYjTOh4IkIZHhjsd17lb8nYW6udgXdJhG1c0r6u0arI= -github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e h1:BA9O3BmlTmpjbvajAwzWx4Wo2TRVdpPXZEeemGQcajw= -github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e/go.mod h1:eLL9Nub3yfAho7qB0MzZizFhTU2QkLeoVsWdHtDW264= +github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= +github.com/u-root/gobusybox/src v0.0.0-20240225013946-a274a8d5d83a h1:eg5FkNoQp76ZsswyGZ+TjYqA/rhKefxK8BW7XOlQsxo= +github.com/u-root/gobusybox/src v0.0.0-20240225013946-a274a8d5d83a/go.mod h1:e/8TmrdreH0sZOw2DFKBaUV7bvDWRq6SeM9PzkuVM68= +github.com/u-root/u-root v0.14.0 h1:Ka4T10EEML7dQ5XDvO9c3MBN8z4nuSnGjcd1jmU2ivg= +github.com/u-root/u-root v0.14.0/go.mod h1:hAyZorapJe4qzbLWlAkmSVCJGbfoU9Pu4jpJ1WMluqE= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8RaCKgVpHZnecvArXvPXcFkM= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA= github.com/ulikunitz/xz v0.5.11 h1:kpFauv27b6ynzBNT/Xy+1k+fK4WswhN/6PN5WhFAGw8= github.com/ulikunitz/xz v0.5.11/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/ultraware/funlen v0.1.0 h1:BuqclbkY6pO+cvxoq7OsktIXZpgBSkYTQtmwhAK81vI= @@ -981,8 +1010,8 @@ github.com/ultraware/whitespace v0.1.0 h1:O1HKYoh0kIeqE8sFqZf1o0qbORXUCOQFrlaQyZ github.com/ultraware/whitespace v0.1.0/go.mod h1:/se4r3beMFNmewJ4Xmz0nMQ941GJt+qmSHGP9emHYe0= github.com/uudashr/gocognit v1.1.2 h1:l6BAEKJqQH2UpKAPKdMfZf5kE4W/2xk8pfU1OVLvniI= github.com/uudashr/gocognit v1.1.2/go.mod h1:aAVdLURqcanke8h3vg35BC++eseDm66Z7KmchI5et4k= -github.com/vbatts/tar-split v0.11.5 h1:3bHCTIheBm1qFTcgh9oPu+nNBtX+XJIupG/vacinCts= -github.com/vbatts/tar-split v0.11.5/go.mod h1:yZbwRsSeGjusneWgA781EKej9HF8vme8okylkAeNKLk= +github.com/vbatts/tar-split v0.11.6 h1:4SjTW5+PU11n6fZenf2IPoV8/tz3AaYHMWjf23envGs= +github.com/vbatts/tar-split v0.11.6/go.mod h1:dqKNtesIOr2j2Qv3W/cHjnvk9I8+G7oAkFDFN6TCBEI= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= @@ -1022,22 +1051,24 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.47.0 h1:sv9kVfal0MK0wBMCOGr+HeJm9v803BkJxGrk2au7j08= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.47.0/go.mod h1:SK2UL73Zy1quvRPonmOmRDiWk1KBV3LyIeeIxcEApWw= -go.opentelemetry.io/otel v1.22.0 h1:xS7Ku+7yTFvDfDraDIJVpw7XPyuHlB9MCiqqX5mcJ6Y= -go.opentelemetry.io/otel v1.22.0/go.mod h1:eoV4iAi3Ea8LkAEI9+GFT44O6T/D0GWAVFyZVCC6pMI= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.22.0 h1:9M3+rhx7kZCIQQhQRYaZCdNu1V73tm4TvXs2ntl98C4= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.22.0/go.mod h1:noq80iT8rrHP1SfybmPiRGc9dc5M8RPmGvtwo7Oo7tc= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.22.0 h1:FyjCyI9jVEfqhUh2MoSkmolPjfh5fp2hnV0b0irxH4Q= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.22.0/go.mod h1:hYwym2nDEeZfG/motx0p7L7J1N1vyzIThemQsb4g2qY= -go.opentelemetry.io/otel/metric v1.22.0 h1:lypMQnGyJYeuYPhOM/bgjbFM6WE44W1/T45er4d8Hhg= -go.opentelemetry.io/otel/metric v1.22.0/go.mod h1:evJGjVpZv0mQ5QBRJoBF64yMuOf4xCWdXjK8pzFvliY= -go.opentelemetry.io/otel/sdk v1.22.0 h1:6coWHw9xw7EfClIC/+O31R8IY3/+EiRFHevmHafB2Gw= -go.opentelemetry.io/otel/sdk v1.22.0/go.mod h1:iu7luyVGYovrRpe2fmj3CVKouQNdTOkxtLzPvPz1DOc= -go.opentelemetry.io/otel/trace v1.22.0 h1:Hg6pPujv0XG9QaVbGOBVHunyuLcCC3jN7WEhPx83XD0= -go.opentelemetry.io/otel/trace v1.22.0/go.mod h1:RbbHXVqKES9QhzZq/fE5UnOSILqRt40a21sPw2He1xo= -go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= -go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 h1:yd02MEjBdJkG3uabWP9apV+OuWRIXGDuJEUJbOHmCFU= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0/go.mod h1:umTcuxiv1n/s/S6/c2AT/g2CQ7u5C59sHDNmfSwgz7Q= +go.opentelemetry.io/otel v1.33.0 h1:/FerN9bax5LoK51X/sI0SVYrjSE0/yUL7DpxW4K3FWw= +go.opentelemetry.io/otel v1.33.0/go.mod h1:SUUkR6csvUQl+yjReHu5uM3EtVV7MBm5FHKRlNx4I8I= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0 h1:3Q/xZUyC1BBkualc9ROb4G8qkH90LXEIICcs5zv1OYY= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0/go.mod h1:s75jGIWA9OfCMzF0xr+ZgfrB5FEbbV7UuYo32ahUiFI= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.28.0 h1:j9+03ymgYhPKmeXGk5Zu+cIZOlVzd9Zv7QIiyItjFBU= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.28.0/go.mod h1:Y5+XiUG4Emn1hTfciPzGPJaSI+RpDts6BnCIir0SLqk= +go.opentelemetry.io/otel/metric v1.33.0 h1:r+JOocAyeRVXD8lZpjdQjzMadVZp2M4WmQ+5WtEnklQ= +go.opentelemetry.io/otel/metric v1.33.0/go.mod h1:L9+Fyctbp6HFTddIxClbQkjtubW6O9QS3Ann/M82u6M= +go.opentelemetry.io/otel/sdk v1.28.0 h1:b9d7hIry8yZsgtbmM0DKyPWMMUMlK9NEKuIG4aBqWyE= +go.opentelemetry.io/otel/sdk v1.28.0/go.mod h1:oYj7ClPUA7Iw3m+r7GeEjz0qckQRJK2B8zjcZEfu7Pg= +go.opentelemetry.io/otel/trace v1.33.0 h1:cCJuF7LRjUFso9LPnEAHJDB2pqzp+hbO8eu1qqW2d/s= +go.opentelemetry.io/otel/trace v1.33.0/go.mod h1:uIcdVUZMpTAmz0tI1z04GoVSezK37CbGV4fr1f2nBck= +go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= +go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= go.uber.org/automaxprocs v1.5.3 h1:kWazyxZUrS3Gs4qUpbwo5kEIMGe/DAvi5Z4tl2NW4j8= go.uber.org/automaxprocs v1.5.3/go.mod h1:eRbA25aqJrxAbsLO0xy5jVwPt7FQnRgjW+efnwa1WM0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -1046,8 +1077,8 @@ go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= -go4.org/mem v0.0.0-20220726221520-4f986261bf13 h1:CbZeCBZ0aZj8EfVgnqQcYZgf0lpZ3H9rmp5nkDTAst8= -go4.org/mem v0.0.0-20220726221520-4f986261bf13/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745 h1:Tl++JLUCe4sxGu8cTpDzRLd3tN7US4hOxG5YpKCzkek= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -1060,10 +1091,8 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= -golang.org/x/crypto v0.3.1-0.20221117191849-2c476679df9a/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= -golang.org/x/crypto v0.7.0/go.mod h1:pYwdfH91IfpZVANVyUOhSIPZaFoJGxTFbZhFTx+dXZU= -golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= -golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.37.0 h1:kJNSjF/Xp7kU0iB2Z+9viTPMW4EqqsrywMXLJOOsXSE= +golang.org/x/crypto v0.37.0/go.mod h1:vg+k43peMZ0pUMhYmVAWysMK35e6ioLh3wB8ZCAfbVc= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -1074,16 +1103,16 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= -golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac h1:l5+whBCLH3iH2ZNHYLbAe58bo7yrN4mVcnkHDYz5vvs= +golang.org/x/exp v0.0.0-20250210185358-939b2ce775ac/go.mod h1:hH+7mtFmImwwcMvScyxUhjuVHR3HGaDPMn9rMSUUbxo= golang.org/x/exp/typeparams v0.0.0-20220428152302-39d4317da171/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/exp/typeparams v0.0.0-20230203172020-98cc5a0785f9/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8= golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.18.0 h1:jGzIakQa/ZXI1I0Fxvaa9W7yP25TqT6cHIHn+6CqvSQ= -golang.org/x/image v0.18.0/go.mod h1:4yyo5vMFQjVjUcVk4jEQcU9MGy/rulF5WvUILseCM2E= +golang.org/x/image v0.24.0 h1:AN7zRgVsbvmTfNyqIbbOraYL8mSwcKncEj8ofjgzcMQ= +golang.org/x/image v0.24.0/go.mod h1:4b/ITuLfqYq1hqZcjofwctIhi7sZh2WaCjvsBNjjya8= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= @@ -1111,8 +1140,8 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.19.0 h1:fEdghXQSo20giMthA7cd28ZC+jts4amQ3YMXiP5oMQ8= -golang.org/x/mod v0.19.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.23.0 h1:Zb7khfcRGKk+kqfxFaP5tZqCnDZMjC5VtUBs87Hr6QM= +golang.org/x/mod v0.23.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1152,17 +1181,16 @@ golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= -golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= -golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/net v0.36.0 h1:vWF2fRbw4qslQsQzgFqZff+BItCvGFQqKzKIzx1rmoA= +golang.org/x/net v0.36.0/go.mod h1:bFmbeoIPfrw4sMHNhb4J9f6+tPziuGjq7Jk/38fxi1I= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20191202225959-858c2ad4c8b6/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.16.0 h1:aDkGMBSYxElaoP81NpoUoz2oo2R2wHdZpGToUxfyQrQ= -golang.org/x/oauth2 v0.16.0/go.mod h1:hqZ+0LWXsiVoZpeld6jVt06P3adbS2Uu911W1SsJv2o= +golang.org/x/oauth2 v0.26.0 h1:afQXWNNaeC4nvZ0Ed9XvCCzXM6UHJG7iCg0W4fPqSBE= +golang.org/x/oauth2 v0.26.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1176,8 +1204,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.13.0 h1:AauUjRAJ9OSnvULf/ARrrVywoJDy0YS2AwQ98I37610= +golang.org/x/sync v0.13.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1196,6 +1224,7 @@ golang.org/x/sys v0.0.0-20191204072324-ce4227a45e2e/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191228213918-04cbcbbfeed8/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200113162924-86b910548bc1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200122134326-e047566fdf82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200212091648-12a6c2dcc1e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -1221,12 +1250,14 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210603081109-ebe580a85c40/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211105183446-c75c47738b0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220615213510-4f61da869c0c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220622161953-175b2fd9d664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220702020025-31831981b65f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -1234,22 +1265,19 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.32.0 h1:s77OFDvIQeibCmezSnk/q6iAfkdiQaJi4VzroCFrN20= +golang.org/x/sys v0.32.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= -golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= -golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= +golang.org/x/term v0.31.0 h1:erwDkOK1Msy6offm1mOgvspSkslFnIGsFnxOKoufg3o= +golang.org/x/term v0.31.0/go.mod h1:R4BeIy7D95HzImkxGkTW1UQTtP54tio2RyHz7PwK0aw= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1257,18 +1285,16 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.24.0 h1:dd5Bzh4yt5KYA8f9CJHCP4FB4D51c2c6JvN37xJJkJ0= +golang.org/x/text v0.24.0/go.mod h1:L8rBsPeo2pSS+xqN0d5u2ikmjtmoJbDBT1b7nHvFCdU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4= +golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= @@ -1333,8 +1359,8 @@ golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k= golang.org/x/tools v0.5.0/go.mod h1:N+Kgy78s5I24c24dU8OfWNEotWjutIs8SnJvn5IDq+k= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.23.0 h1:SGsXPZ+2l4JsgaCKkx+FQ9YZ5XEtA1GZYuoDjenLjvg= -golang.org/x/tools v0.23.0/go.mod h1:pnu6ufv6vQkll6szChhK3C3L/ruaIv5eBeztNG8wtsI= +golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY= +golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -1369,8 +1395,6 @@ google.golang.org/appengine v1.5.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7 google.golang.org/appengine v1.6.1/go.mod h1:i06prIuMbXzDqacNJfV5OdTW448YApPu5ww/cMBSeb0= google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.6/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= -google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= google.golang.org/genproto v0.0.0-20190418145605-e7d98fc518a7/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE= @@ -1400,11 +1424,11 @@ google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7Fc google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20240102182953-50ed04b92917 h1:nz5NESFLZbJGPFxDT/HCn+V1mZ8JGNoY4nUpmW/Y2eg= -google.golang.org/genproto/googleapis/api v0.0.0-20240116215550-a9fa1716bcac h1:OZkkudMUu9LVQMCoRUbI/1p5VCo9BOrlvkqMvWtqa6s= -google.golang.org/genproto/googleapis/api v0.0.0-20240116215550-a9fa1716bcac/go.mod h1:B5xPO//w8qmBDjGReYLpR6UJPnkldGkCSMoH/2vxJeg= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240116215550-a9fa1716bcac h1:nUQEQmH/csSvFECKYRv6HWEyypysidKl2I6Qpsglq/0= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240116215550-a9fa1716bcac/go.mod h1:daQN87bsDqDoe316QbbvX60nMoJQa4r6Ds0ZuoAe5yA= +google.golang.org/genproto v0.0.0-20240123012728-ef4313101c80 h1:KAeGQVN3M9nD0/bQXnr/ClcEMJ968gUXJQ9pwfSynuQ= +google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7 h1:YcyjlL1PRr2Q17/I0dPk2JmYS5CDXfcdb2Z3YRioEbw= +google.golang.org/genproto/googleapis/api v0.0.0-20240826202546-f6391c0de4c7/go.mod h1:OCdP9MfskevB/rbYvHTsXTtKC+3bHWajPdoKgjcYkfo= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7 h1:2035KHhUv+EpyB+hWgJnaWKJOdX1E95w2S8Rr4uWKTs= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240826202546-f6391c0de4c7/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -1417,8 +1441,8 @@ google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKa google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= -google.golang.org/grpc v1.61.0 h1:TOvOcuXn30kRao+gfcvsebNEa5iZIiLkisYEkf7R7o0= -google.golang.org/grpc v1.61.0/go.mod h1:VUbo7IFqmF1QtCAstipjG0GIoq49KvMe9+h1jFLBNJs= +google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc= +google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -1431,8 +1455,8 @@ google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGj google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= -google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA= +google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -1440,6 +1464,8 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/evanphx/json-patch.v4 v4.12.0 h1:n6jtcsulIzXPJaxegRbvFNNrZDjbij7ny3gmSPG+6V4= +gopkg.in/evanphx/json-patch.v4 v4.12.0/go.mod h1:p8EYWUEYMpynmqDbY58zCKCFZw8pRWMG4EsWvDvM72M= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= @@ -1464,8 +1490,8 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o= gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= -gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987 h1:TU8z2Lh3Bbq77w0t1eG8yRlLcNHzZu3x6mhoH2Mk0c8= -gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 h1:2gap+Kh/3F47cO6hAu3idFvsJ0ue6TRcEi2IUkv/F8k= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= @@ -1477,22 +1503,22 @@ honnef.co/go/tools v0.5.1 h1:4bH5o3b5ZULQ4UrBmP+63W9r7qIkqJClEA9ko5YKx+I= honnef.co/go/tools v0.5.1/go.mod h1:e9irvo83WDG9/irijV44wr3tbhcFeRnfpVlRqVwpzMs= howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM= howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g= -k8s.io/api v0.30.3 h1:ImHwK9DCsPA9uoU3rVh4QHAHHK5dTSv1nxJUapx8hoQ= -k8s.io/api v0.30.3/go.mod h1:GPc8jlzoe5JG3pb0KJCSLX5oAFIW3/qNJITlDj8BH04= -k8s.io/apiextensions-apiserver v0.30.3 h1:oChu5li2vsZHx2IvnGP3ah8Nj3KyqG3kRSaKmijhB9U= -k8s.io/apiextensions-apiserver v0.30.3/go.mod h1:uhXxYDkMAvl6CJw4lrDN4CPbONkF3+XL9cacCT44kV4= -k8s.io/apimachinery v0.30.3 h1:q1laaWCmrszyQuSQCfNB8cFgCuDAoPszKY4ucAjDwHc= -k8s.io/apimachinery v0.30.3/go.mod h1:iexa2somDaxdnj7bha06bhb43Zpa6eWH8N8dbqVjTUc= -k8s.io/apiserver v0.30.3 h1:QZJndA9k2MjFqpnyYv/PH+9PE0SHhx3hBho4X0vE65g= -k8s.io/apiserver v0.30.3/go.mod h1:6Oa88y1CZqnzetd2JdepO0UXzQX4ZnOekx2/PtEjrOg= -k8s.io/client-go v0.30.3 h1:bHrJu3xQZNXIi8/MoxYtZBBWQQXwy16zqJwloXXfD3k= -k8s.io/client-go v0.30.3/go.mod h1:8d4pf8vYu665/kUbsxWAQ/JDBNWqfFeZnvFiVdmx89U= +k8s.io/api v0.32.0 h1:OL9JpbvAU5ny9ga2fb24X8H6xQlVp+aJMFlgtQjR9CE= +k8s.io/api v0.32.0/go.mod h1:4LEwHZEf6Q/cG96F3dqR965sYOfmPM7rq81BLgsE0p0= +k8s.io/apiextensions-apiserver v0.32.0 h1:S0Xlqt51qzzqjKPxfgX1xh4HBZE+p8KKBq+k2SWNOE0= +k8s.io/apiextensions-apiserver v0.32.0/go.mod h1:86hblMvN5yxMvZrZFX2OhIHAuFIMJIZ19bTvzkP+Fmw= +k8s.io/apimachinery v0.32.0 h1:cFSE7N3rmEEtv4ei5X6DaJPHHX0C+upp+v5lVPiEwpg= +k8s.io/apimachinery v0.32.0/go.mod h1:GpHVgxoKlTxClKcteaeuF1Ul/lDVb74KpZcxcmLDElE= +k8s.io/apiserver v0.32.0 h1:VJ89ZvQZ8p1sLeiWdRJpRD6oLozNZD2+qVSLi+ft5Qs= +k8s.io/apiserver v0.32.0/go.mod h1:HFh+dM1/BE/Hm4bS4nTXHVfN6Z6tFIZPi649n83b4Ag= +k8s.io/client-go v0.32.0 h1:DimtMcnN/JIKZcrSrstiwvvZvLjG0aSxy8PxN8IChp8= +k8s.io/client-go v0.32.0/go.mod h1:boDWvdM1Drk4NJj/VddSLnx59X3OPgwrOo0vGbtq9+8= k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= -k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340 h1:BZqlfIlq5YbRMFko6/PM7FjZpUb45WallggurYhKGag= -k8s.io/kube-openapi v0.0.0-20240228011516-70dd3763d340/go.mod h1:yD4MZYeKMBwQKVht279WycxKyM84kkAx2DPrTXaeb98= -k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 h1:pUdcCO1Lk/tbT5ztQWOBi5HBgbBP1J8+AsQnQCKsi8A= -k8s.io/utils v0.0.0-20240711033017-18e509b52bc8/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= +k8s.io/kube-openapi v0.0.0-20241105132330-32ad38e42d3f h1:GA7//TjRY9yWGy1poLzYYJJ4JRdzg3+O6e8I+e+8T5Y= +k8s.io/kube-openapi v0.0.0-20241105132330-32ad38e42d3f/go.mod h1:R/HEjbvWI0qdfb8viZUeVZm0X6IZnxAydC7YU42CMw4= +k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738 h1:M3sRQVHv7vB20Xc2ybTt7ODCeFj6JSWYFzOFnYeS6Ro= +k8s.io/utils v0.0.0-20241104100929-3ea5e8cea738/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= mvdan.cc/gofumpt v0.6.0 h1:G3QvahNDmpD+Aek/bNOLrFR2XC6ZAdo62dZu65gmwGo= mvdan.cc/gofumpt v0.6.0/go.mod h1:4L0wf+kgIPZtcCWXynNS2e6bhmj73umwnuXSZarixzA= mvdan.cc/unparam v0.0.0-20240104100049-c549a3470d14 h1:zCr3iRRgdk5eIikZNDphGcM6KGVTx3Yu+/Uu9Es254w= @@ -1500,14 +1526,14 @@ mvdan.cc/unparam v0.0.0-20240104100049-c549a3470d14/go.mod h1:ZzZjEpJDOmx8TdVU6u rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= -sigs.k8s.io/controller-runtime v0.18.4 h1:87+guW1zhvuPLh1PHybKdYFLU0YJp4FhJRmiHvm5BZw= -sigs.k8s.io/controller-runtime v0.18.4/go.mod h1:TVoGrfdpbA9VRFaRnKgk9P5/atA0pMwq+f+msb9M8Sg= -sigs.k8s.io/controller-tools v0.15.1-0.20240618033008-7824932b0cab h1:Fq4VD28nejtsijBNTeRRy9Tt3FVwq+o6NB7fIxja8uY= -sigs.k8s.io/controller-tools v0.15.1-0.20240618033008-7824932b0cab/go.mod h1:egedX5jq2KrZ3A2zaOz3e2DSsh5BhFyyjvNcBRIQel8= -sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo= -sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0= -sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4= -sigs.k8s.io/structured-merge-diff/v4 v4.4.1/go.mod h1:N8hJocpFajUSSeSJ9bOZ77VzejKZaXsTtZo4/u7Io08= +sigs.k8s.io/controller-runtime v0.19.4 h1:SUmheabttt0nx8uJtoII4oIP27BVVvAKFvdvGFwV/Qo= +sigs.k8s.io/controller-runtime v0.19.4/go.mod h1:iRmWllt8IlaLjvTTDLhRBXIEtkCK6hwVBJJsYS9Ajf4= +sigs.k8s.io/controller-tools v0.17.0 h1:KaEQZbhrdY6J3zLBHplt+0aKUp8PeIttlhtF2UDo6bI= +sigs.k8s.io/controller-tools v0.17.0/go.mod h1:SKoWY8rwGWDzHtfnhmOwljn6fViG0JF7/xmnxpklgjo= +sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3 h1:/Rv+M11QRah1itp8VhT6HoVx1Ray9eB4DBr+K+/sCJ8= +sigs.k8s.io/json v0.0.0-20241010143419-9aa6b5e7a4b3/go.mod h1:18nIHnGi6636UCz6m8i4DhaJ65T6EruyzmoQqI2BVDo= +sigs.k8s.io/structured-merge-diff/v4 v4.4.2 h1:MdmvkGuXi/8io6ixD5wud3vOLwc1rj0aNqRlpuvjmwA= +sigs.k8s.io/structured-merge-diff/v4 v4.4.2/go.mod h1:N8f93tFZh9U6vpxwRArLiikrE5/2tiu1w1AGfACIGE4= sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= sigs.k8s.io/yaml v1.4.0/go.mod h1:Ejl7/uTz7PSA4eKMyQCUTnhZYNmLIl+5c2lQPGR2BPY= software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k= diff --git a/go.toolchain.branch b/go.toolchain.branch index 47469a20ad6e9..5e1cd0620554a 100644 --- a/go.toolchain.branch +++ b/go.toolchain.branch @@ -1 +1 @@ -tailscale.go1.23 +tailscale.go1.24 diff --git a/go.toolchain.rev b/go.toolchain.rev index 5d87594c25a31..e8ede337c11cb 100644 --- a/go.toolchain.rev +++ b/go.toolchain.rev @@ -1 +1 @@ -bf15628b759344c6fc7763795a405ba65b8be5d7 +982da8f24fa0504f2214f24b0d68b2febd5983f8 diff --git a/gokrazy/build.go b/gokrazy/build.go index 2392af0cb30e1..c1ee1cbeb1974 100644 --- a/gokrazy/build.go +++ b/gokrazy/build.go @@ -11,7 +11,6 @@ package main import ( "bytes" - "cmp" "encoding/json" "errors" "flag" @@ -30,7 +29,6 @@ import ( var ( app = flag.String("app", "tsapp", "appliance name; one of the subdirectories of gokrazy/") bucket = flag.String("bucket", "tskrazy-import", "S3 bucket to upload disk image to while making AMI") - goArch = flag.String("arch", cmp.Or(os.Getenv("GOARCH"), "amd64"), "GOARCH architecture to build for: arm64 or amd64") build = flag.Bool("build", false, "if true, just build locally and stop, without uploading") ) @@ -54,6 +52,26 @@ func findMkfsExt4() (string, error) { return "", errors.New("No mkfs.ext4 found on system") } +var conf gokrazyConfig + +// gokrazyConfig is the subset of gokrazy/internal/config.Struct +// that we care about. +type gokrazyConfig struct { + // Environment is os.Environment pairs to use when + // building userspace. + // See https://gokrazy.org/userguide/instance-config/#environment + Environment []string +} + +func (c *gokrazyConfig) GOARCH() string { + for _, e := range c.Environment { + if v, ok := strings.CutPrefix(e, "GOARCH="); ok { + return v + } + } + return "" +} + func main() { flag.Parse() @@ -61,6 +79,19 @@ func main() { log.Fatalf("--app must be non-empty name such as 'tsapp' or 'natlabapp'") } + confJSON, err := os.ReadFile(filepath.Join(*app, "config.json")) + if err != nil { + log.Fatalf("reading config.json: %v", err) + } + if err := json.Unmarshal(confJSON, &conf); err != nil { + log.Fatalf("unmarshaling config.json: %v", err) + } + switch conf.GOARCH() { + case "amd64", "arm64": + default: + log.Fatalf("config.json GOARCH %q must be amd64 or arm64", conf.GOARCH()) + } + if err := buildImage(); err != nil { log.Fatalf("build image: %v", err) } @@ -106,7 +137,6 @@ func buildImage() error { // Build the tsapp.img var buf bytes.Buffer cmd := exec.Command("go", "run", - "-exec=env GOOS=linux GOARCH="+*goArch+" ", "github.com/gokrazy/tools/cmd/gok", "--parent_dir="+dir, "--instance="+*app, @@ -253,13 +283,13 @@ func waitForImportSnapshot(importTaskID string) (snapID string, err error) { func makeAMI(name, ebsSnapID string) (ami string, err error) { var arch string - switch *goArch { + switch conf.GOARCH() { case "arm64": arch = "arm64" case "amd64": arch = "x86_64" default: - return "", fmt.Errorf("unknown arch %q", *goArch) + return "", fmt.Errorf("unknown arch %q", conf.GOARCH()) } out, err := exec.Command("aws", "ec2", "register-image", "--name", name, diff --git a/gokrazy/go.mod b/gokrazy/go.mod index a9ba5a07d1fb4..f7483f41d5d46 100644 --- a/gokrazy/go.mod +++ b/gokrazy/go.mod @@ -1,13 +1,13 @@ module tailscale.com/gokrazy -go 1.23.1 +go 1.23 -require github.com/gokrazy/tools v0.0.0-20240730192548-9f81add3a91e +require github.com/gokrazy/tools v0.0.0-20250128200151-63160424957c require ( github.com/breml/rootcerts v0.2.10 // indirect github.com/donovanhide/eventsource v0.0.0-20210830082556-c59027999da0 // indirect - github.com/gokrazy/internal v0.0.0-20240629150625-a0f1dee26ef5 // indirect + github.com/gokrazy/internal v0.0.0-20250126213949-423a5b587b57 // indirect github.com/gokrazy/updater v0.0.0-20230215172637-813ccc7f21e2 // indirect github.com/google/renameio/v2 v2.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -15,9 +15,5 @@ require ( github.com/spf13/pflag v1.0.5 // indirect golang.org/x/mod v0.11.0 // indirect golang.org/x/sync v0.1.0 // indirect - golang.org/x/sys v0.20.0 // indirect + golang.org/x/sys v0.28.0 // indirect ) - -replace github.com/gokrazy/gokrazy => github.com/tailscale/gokrazy v0.0.0-20240812224643-6b21ddf64678 - -replace github.com/gokrazy/tools => github.com/tailscale/gokrazy-tools v0.0.0-20240730192548-9f81add3a91e diff --git a/gokrazy/go.sum b/gokrazy/go.sum index dfac8ca37d101..170d15b3db19c 100644 --- a/gokrazy/go.sum +++ b/gokrazy/go.sum @@ -3,8 +3,10 @@ github.com/breml/rootcerts v0.2.10/go.mod h1:24FDtzYMpqIeYC7QzaE8VPRQaFZU5TIUDly github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/donovanhide/eventsource v0.0.0-20210830082556-c59027999da0 h1:C7t6eeMaEQVy6e8CarIhscYQlNmw5e3G36y7l7Y21Ao= github.com/donovanhide/eventsource v0.0.0-20210830082556-c59027999da0/go.mod h1:56wL82FO0bfMU5RvfXoIwSOP2ggqqxT+tAfNEIyxuHw= -github.com/gokrazy/internal v0.0.0-20240629150625-a0f1dee26ef5 h1:XDklMxV0pE5jWiNaoo5TzvWfqdoiRRScmr4ZtDzE4Uw= -github.com/gokrazy/internal v0.0.0-20240629150625-a0f1dee26ef5/go.mod h1:t3ZirVhcs9bH+fPAJuGh51rzT7sVCZ9yfXvszf0ZjF0= +github.com/gokrazy/internal v0.0.0-20250126213949-423a5b587b57 h1:f5bEvO4we3fbfiBkECrrUgWQ8OH6J3SdB2Dwxid/Yx4= +github.com/gokrazy/internal v0.0.0-20250126213949-423a5b587b57/go.mod h1:SJG1KwuJQXFEoBgryaNCkMbdISyovDgZd0xmXJRZmiw= +github.com/gokrazy/tools v0.0.0-20250128200151-63160424957c h1:iEbS8GrNOn671ze8J/AfrYFEVzf8qMx8aR5K0VxPK2w= +github.com/gokrazy/tools v0.0.0-20250128200151-63160424957c/go.mod h1:f2vZhnaPzy92+Bjpx1iuZHK7VuaJx6SNCWQWmu23HZA= github.com/gokrazy/updater v0.0.0-20230215172637-813ccc7f21e2 h1:kBY5R1tSf+EYZ+QaSrofLaVJtBqYsVNVBWkdMq3Smcg= github.com/gokrazy/updater v0.0.0-20230215172637-813ccc7f21e2/go.mod h1:PYOvzGOL4nlBmuxu7IyKQTFLaxr61+WPRNRzVtuYOHw= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= @@ -19,14 +21,12 @@ github.com/spf13/cobra v1.6.1 h1:o94oiPyS4KD1mPy2fmcYYHHfCxLqYjJOhGsCHFZtEzA= github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUqzrY= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/tailscale/gokrazy-tools v0.0.0-20240730192548-9f81add3a91e h1:3/xIc1QCvnKL7BCLng9od98HEvxCadjvqiI/bN+Twso= -github.com/tailscale/gokrazy-tools v0.0.0-20240730192548-9f81add3a91e/go.mod h1:eTZ0QsugEPFU5UAQ/87bKMkPxQuTNa7+iFAIahOFwRg= golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU= golang.org/x/mod v0.11.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= +golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/gokrazy/natlabapp.arm64/builddir/tailscale.com/go.sum b/gokrazy/natlabapp.arm64/builddir/tailscale.com/go.sum index 9123439ed88bf..ae814f31698f4 100644 --- a/gokrazy/natlabapp.arm64/builddir/tailscale.com/go.sum +++ b/gokrazy/natlabapp.arm64/builddir/tailscale.com/go.sum @@ -4,32 +4,58 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/aws/aws-sdk-go-v2 v1.24.1 h1:xAojnj+ktS95YZlDf0zxWBkbFtymPeDP+rvUQIH3uAU= github.com/aws/aws-sdk-go-v2 v1.24.1/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= +github.com/aws/aws-sdk-go-v2 v1.36.0 h1:b1wM5CcE65Ujwn565qcwgtOTT1aT4ADOHHgglKjG7fk= +github.com/aws/aws-sdk-go-v2 v1.36.0/go.mod h1:5PMILGVKiW32oDzjj6RU52yrNrDPUHcbZQYr1sM7qmM= github.com/aws/aws-sdk-go-v2/config v1.26.5 h1:lodGSevz7d+kkFJodfauThRxK9mdJbyutUxGq1NNhvw= github.com/aws/aws-sdk-go-v2/config v1.26.5/go.mod h1:DxHrz6diQJOc9EwDslVRh84VjjrE17g+pVZXUeSxaDU= +github.com/aws/aws-sdk-go-v2/config v1.29.5 h1:4lS2IB+wwkj5J43Tq/AwvnscBerBJtQQ6YS7puzCI1k= +github.com/aws/aws-sdk-go-v2/config v1.29.5/go.mod h1:SNzldMlDVbN6nWxM7XsUiNXPSa1LWlqiXtvh/1PrJGg= github.com/aws/aws-sdk-go-v2/credentials v1.16.16 h1:8q6Rliyv0aUFAVtzaldUEcS+T5gbadPbWdV1WcAddK8= github.com/aws/aws-sdk-go-v2/credentials v1.16.16/go.mod h1:UHVZrdUsv63hPXFo1H7c5fEneoVo9UXiz36QG1GEPi0= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58 h1:/d7FUpAPU8Lf2KUdjniQvfNdlMID0Sd9pS23FJ3SS9Y= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58/go.mod h1:aVYW33Ow10CyMQGFgC0ptMRIqJWvJ4nxZb0sUiuQT/A= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 h1:c5I5iH+DZcH3xOIMlz3/tCKJDaHFwYEmxvlh2fAcFo8= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11/go.mod h1:cRrYDYAMUohBJUtUnOhydaMHtiK/1NZ0Otc9lIb6O0Y= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 h1:7lOW8NUwE9UZekS1DYoiPdVAqZ6A+LheHWb+mHbNOq8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27/go.mod h1:w1BASFIPOPUae7AgaH4SbjNbfdkxuggLyGfNFTn8ITY= github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 h1:vF+Zgd9s+H4vOXd5BMaPWykta2a6Ih0AKLq/X6NYKn4= github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10/go.mod h1:6BkRjejp/GR4411UGqkX8+wFMbFbqsUIimfK4XjOKR4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31 h1:lWm9ucLSRFiI4dQQafLrEOmEDGry3Swrz0BIRdiHJqQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31/go.mod h1:Huu6GG0YTfbPphQkDSo4dEGmQRTKb9k9G7RdtyQWxuI= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 h1:nYPe006ktcqUji8S2mqXf9c/7NdiKriOwMvWQHgYztw= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10/go.mod h1:6UV4SZkVvmODfXKql4LCbaZUpF7HO2BX38FgBf9ZOLw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31 h1:ACxDklUKKXb48+eg5ROZXi1vDgfMyfIA/WyvqHcHI0o= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31/go.mod h1:yadnfsDwqXeVaohbGc/RaD287PuyRw2wugkh5ZL2J6k= github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 h1:GrSw8s0Gs/5zZ0SX+gX4zQjRnRsMJDJ2sLur1gRBhEM= github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 h1:Pg9URiobXy85kgFev3og2CuOZ8JZUBENF+dcgWBaYNk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 h1:D4oz8/CzT9bAEYtVhSBmFj2dNOtaHOtMKc2vHBwYizA= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2/go.mod h1:Za3IHqTQ+yNcRHxu1OFucBh0ACZT4j4VQFF0BqpZcLY= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10 h1:DBYTXwIGQSGs9w4jKm60F5dmCQ3EEruxdc0MFh+3EY4= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10/go.mod h1:wohMUQiFdzo0NtxbBg0mSRGZ4vL3n0dKjLTINdcIino= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12 h1:O+8vD2rGjfihBewr5bT+QUfYUHIxCVgG61LHoT59shM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12/go.mod h1:usVdWJaosa66NMvmCrr08NcWDBRv4E6+YFG2pUdw1Lk= github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7 h1:a8HvP/+ew3tKwSXqL3BCSjiuicr+XTU2eFYeogV9GJE= github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7/go.mod h1:Q7XIWsMo0JcMpI/6TGD6XXcXcV1DbTj6e9BKNntIMIM= github.com/aws/aws-sdk-go-v2/service/sso v1.18.7 h1:eajuO3nykDPdYicLlP3AGgOyVN3MOlFmZv7WGTuJPow= github.com/aws/aws-sdk-go-v2/service/sso v1.18.7/go.mod h1:+mJNDdF+qiUlNKNC3fxn74WWNN+sOiGOEImje+3ScPM= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 h1:c5WJ3iHz7rLIgArznb3JCSQT3uUMiz9DLZhIX+1G8ok= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14/go.mod h1:+JJQTxB6N4niArC14YNtxcQtwEqzS3o9Z32n7q33Rfs= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7 h1:QPMJf+Jw8E1l7zqhZmMlFw6w1NmfkfiSK8mS4zOx3BA= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7/go.mod h1:ykf3COxYI0UJmxcfcxcVuz7b6uADi1FkiUz6Eb7AgM8= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 h1:f1L/JtUkVODD+k1+IiSJUUv8A++2qVr+Xvb3xWXETMU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13/go.mod h1:tvqlFoja8/s0o+UruA1Nrezo/df0PzdunMDDurUfg6U= github.com/aws/aws-sdk-go-v2/service/sts v1.26.7 h1:NzO4Vrau795RkUdSHKEwiR01FaGzGOH1EETJ+5QHnm0= github.com/aws/aws-sdk-go-v2/service/sts v1.26.7/go.mod h1:6h2YuIoxaMSCFf5fi1EgZAwdfkGMgDY+DVfa61uLe4U= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.13 h1:3LXNnmtH3TURctC23hnC0p/39Q5gre3FI7BNOiDcVWc= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.13/go.mod h1:7Yn+p66q/jt38qMoVfNvjbm3D89mGBnkwDcijgtih8w= github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= +github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= +github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= @@ -46,10 +72,14 @@ github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c= github.com/djherbis/times v1.6.0/go.mod h1:gOHeRAz2h+VJNZ5Gmc/o7iD9k4wW7NMVqieYCY99oc0= github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA= github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc= github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f35nQbASLnvxEde4XOBL+Sn7rFuV+FOJqkljg= github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA= +github.com/go-json-experiment/json v0.0.0-20250103232110-6a9a0fde9288 h1:KbX3Z3CgiYlbaavUq3Cj9/MjpO+88S7/AGXzynVDv84= +github.com/go-json-experiment/json v0.0.0-20250103232110-6a9a0fde9288/go.mod h1:BWmvoE1Xia34f3l/ibJweyhrT+aROb/FQ6d+37F0e2s= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= @@ -62,6 +92,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI= github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= +github.com/gorilla/csrf v1.7.3-0.20250123201450-9dd6af1f6d30 h1:fiJdrgVBkjZ5B1HJ2WQwNOaXB+QyYcNXTA3t1XYLz0M= +github.com/gorilla/csrf v1.7.3-0.20250123201450-9dd6af1f6d30/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/hdevalence/ed25519consensus v0.2.0 h1:37ICyZqdyj0lAZ8P4D1d1id3HqbbG1N3iBb1Tb4rdcU= @@ -70,6 +102,8 @@ github.com/illarion/gonotify v1.0.1 h1:F1d+0Fgbq/sDWjj/r66ekjDG+IDeecQKUFH4wNwso github.com/illarion/gonotify v1.0.1/go.mod h1:zt5pmDofZpU1f8aqlK0+95eQhoEAn/d4G4B/FjVW4jE= github.com/illarion/gonotify/v2 v2.0.2 h1:oDH5yvxq9oiQGWUeut42uShcWzOy/hsT9E7pvO95+kQ= github.com/illarion/gonotify/v2 v2.0.2/go.mod h1:38oIJTgFqupkEydkkClkbL6i5lXV/bxdH9do5TALPEE= +github.com/illarion/gonotify/v2 v2.0.3 h1:B6+SKPo/0Sw8cRJh1aLzNEeNVFfzE3c6N+o+vyxM+9A= +github.com/illarion/gonotify/v2 v2.0.3/go.mod h1:38oIJTgFqupkEydkkClkbL6i5lXV/bxdH9do5TALPEE= github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 h1:9K06NfxkBh25x56yVhWWlKFE8YpicaSfHwoV8SFbueA= github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2/go.mod h1:3A9PQ1cunSDF/1rbTq99Ts4pVnycWg+vlPkfeD2NLFI= github.com/jellydator/ttlcache/v3 v3.1.0 h1:0gPFG0IHHP6xyUyXq+JaD8fwkDCqgqwohXNJBcYE71g= @@ -84,6 +118,8 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNU github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a h1:+RR6SqnTkDLWyICxS1xpjCi/3dhyV+TgZwA6Ww3KncQ= github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a/go.mod h1:YTtCCM3ryyfiu4F7t8HQ1mxvp1UBdWM2r6Xa+nGWvDk= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= @@ -96,6 +132,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5 github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= github.com/mdlayher/sdnotify v1.0.0 h1:Ma9XeLVN/l0qpyx1tNeMSeTjCPH6NtuD6/N9XdTlQ3c= github.com/mdlayher/sdnotify v1.0.0/go.mod h1:HQUmpM4XgYkhDLtd+Uad8ZFK1T9D5+pNxnXQjCeJlGE= github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI= @@ -126,12 +164,18 @@ github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 h1:uFsXVBE9Qr4 github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 h1:Gz0rz40FvFVLTBk/K8UNAenb36EbDSnh+q7Z9ldcC8w= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4/go.mod h1:phI29ccmHQBc+wvroosENp1IF9195449VDnFDhJ4rJU= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc h1:24heQPtnFR+yfntqhI3oAu9i27nEojcQ4NuBQOo5ZFA= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc/go.mod h1:f93CXfllFsO9ZQVq+Zocb1Gp4G5Fz0b0rXHLOzt/Djc= github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 h1:tdUdyPqJ0C97SJfjB9tW6EylTtreyee9C44de+UBG0g= github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 h1:UBPHPtv8+nEAy2PD8RyAhOYvau1ek0HDJqLS/Pysi14= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= github.com/tailscale/wireguard-go v0.0.0-20240705152531-2f5d148bcfe1 h1:ycpNCSYwzZ7x4G4ioPNtKQmIY0G/3o4pVf8wCZq6blY= github.com/tailscale/wireguard-go v0.0.0-20240705152531-2f5d148bcfe1/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98 h1:RNpJrXfI5u6e+uzyIzvmnXbhmhdRkVf//90sMBH3lso= github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= +github.com/tailscale/wireguard-go v0.0.0-20250107165329-0b8b35511f19 h1:BcEJP2ewTIK2ZCsqgl6YGpuO6+oKqqag5HHb7ehljKw= +github.com/tailscale/wireguard-go v0.0.0-20250107165329-0b8b35511f19/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/xnet v0.0.0-20240117122442-62b9a7c569f9 h1:81P7rjnikHKTJ75EkjppvbwUfKHDHYk6LJpO5PZy8pA= github.com/tailscale/xnet v0.0.0-20240117122442-62b9a7c569f9/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek= @@ -144,6 +188,8 @@ github.com/u-root/u-root v0.12.0 h1:K0AuBFriwr0w/PGS3HawiAw89e3+MU7ks80GpghAsNs= github.com/u-root/u-root v0.12.0/go.mod h1:FYjTOh4IkIZHhjsd17lb8nYW6udgXdJhG1c0r6u0arI= github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e h1:BA9O3BmlTmpjbvajAwzWx4Wo2TRVdpPXZEeemGQcajw= github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e/go.mod h1:eLL9Nub3yfAho7qB0MzZizFhTU2QkLeoVsWdHtDW264= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8RaCKgVpHZnecvArXvPXcFkM= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA= github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= @@ -152,42 +198,66 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= go4.org/mem v0.0.0-20220726221520-4f986261bf13 h1:CbZeCBZ0aZj8EfVgnqQcYZgf0lpZ3H9rmp5nkDTAst8= go4.org/mem v0.0.0-20220726221520-4f986261bf13/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745 h1:Tl++JLUCe4sxGu8cTpDzRLd3tN7US4hOxG5YpKCzkek= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.32.1-0.20250118192723-a8ea4be81f07 h1:Z+Zg+aXJYq6f4TK2E4H+vZkQ4dJAWnInXDR6hM9znxo= +golang.org/x/crypto v0.32.1-0.20250118192723-a8ea4be81f07/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA= +golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/oauth2 v0.16.0 h1:aDkGMBSYxElaoP81NpoUoz2oo2R2wHdZpGToUxfyQrQ= golang.org/x/oauth2 v0.16.0/go.mod h1:hqZ+0LWXsiVoZpeld6jVt06P3adbS2Uu911W1SsJv2o= +golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70= +golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.1-0.20250107080300-1c14dcadc3ab h1:BMkEEWYOjkvOX7+YKOGbp6jCyQ5pR2j0Ah47p1Vdsx4= +golang.org/x/sys v0.29.1-0.20250107080300-1c14dcadc3ab/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= +golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3 h1:/8/t5pz/mgdRXhYOIeqqYhFAQLE4DDGegc0Y4ZjyFJM= gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3/go.mod h1:NQHVAzMwvZ+Qe3ElSiHmq9RUm1MdNHpUZ52fiEqvn+0= gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987 h1:TU8z2Lh3Bbq77w0t1eG8yRlLcNHzZu3x6mhoH2Mk0c8= gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 h1:2gap+Kh/3F47cO6hAu3idFvsJ0ue6TRcEi2IUkv/F8k= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM= k8s.io/client-go v0.30.1 h1:uC/Ir6A3R46wdkgCV3vbLyNOYyCJ8oZnjtJGKfytl/Q= k8s.io/client-go v0.30.1/go.mod h1:wrAqLNs2trwiCH/wxxmT/x3hKVH9PuV0GGW0oDoHVqc= k8s.io/client-go v0.30.3 h1:bHrJu3xQZNXIi8/MoxYtZBBWQQXwy16zqJwloXXfD3k= k8s.io/client-go v0.30.3/go.mod h1:8d4pf8vYu665/kUbsxWAQ/JDBNWqfFeZnvFiVdmx89U= +k8s.io/client-go v0.32.0 h1:DimtMcnN/JIKZcrSrstiwvvZvLjG0aSxy8PxN8IChp8= +k8s.io/client-go v0.32.0/go.mod h1:boDWvdM1Drk4NJj/VddSLnx59X3OPgwrOo0vGbtq9+8= nhooyr.io/websocket v1.8.10 h1:mv4p+MnGrLDcPlBoWsvPP7XCzTYMXP9F9eIGoKbgx7Q= nhooyr.io/websocket v1.8.10/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= diff --git a/gokrazy/natlabapp.arm64/config.json b/gokrazy/natlabapp.arm64/config.json index 2577f61a56d31..2ba9a20f9510f 100644 --- a/gokrazy/natlabapp.arm64/config.json +++ b/gokrazy/natlabapp.arm64/config.json @@ -20,6 +20,10 @@ } } }, + "Environment": [ + "GOOS=linux", + "GOARCH=arm64" + ], "KernelPackage": "github.com/gokrazy/kernel.arm64", "FirmwarePackage": "github.com/gokrazy/kernel.arm64", "EEPROMPackage": "", diff --git a/gokrazy/natlabapp/builddir/tailscale.com/go.sum b/gokrazy/natlabapp/builddir/tailscale.com/go.sum index baa378c46708e..25f15059d3af6 100644 --- a/gokrazy/natlabapp/builddir/tailscale.com/go.sum +++ b/gokrazy/natlabapp/builddir/tailscale.com/go.sum @@ -4,32 +4,58 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/aws/aws-sdk-go-v2 v1.24.1 h1:xAojnj+ktS95YZlDf0zxWBkbFtymPeDP+rvUQIH3uAU= github.com/aws/aws-sdk-go-v2 v1.24.1/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= +github.com/aws/aws-sdk-go-v2 v1.36.0 h1:b1wM5CcE65Ujwn565qcwgtOTT1aT4ADOHHgglKjG7fk= +github.com/aws/aws-sdk-go-v2 v1.36.0/go.mod h1:5PMILGVKiW32oDzjj6RU52yrNrDPUHcbZQYr1sM7qmM= github.com/aws/aws-sdk-go-v2/config v1.26.5 h1:lodGSevz7d+kkFJodfauThRxK9mdJbyutUxGq1NNhvw= github.com/aws/aws-sdk-go-v2/config v1.26.5/go.mod h1:DxHrz6diQJOc9EwDslVRh84VjjrE17g+pVZXUeSxaDU= +github.com/aws/aws-sdk-go-v2/config v1.29.5 h1:4lS2IB+wwkj5J43Tq/AwvnscBerBJtQQ6YS7puzCI1k= +github.com/aws/aws-sdk-go-v2/config v1.29.5/go.mod h1:SNzldMlDVbN6nWxM7XsUiNXPSa1LWlqiXtvh/1PrJGg= github.com/aws/aws-sdk-go-v2/credentials v1.16.16 h1:8q6Rliyv0aUFAVtzaldUEcS+T5gbadPbWdV1WcAddK8= github.com/aws/aws-sdk-go-v2/credentials v1.16.16/go.mod h1:UHVZrdUsv63hPXFo1H7c5fEneoVo9UXiz36QG1GEPi0= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58 h1:/d7FUpAPU8Lf2KUdjniQvfNdlMID0Sd9pS23FJ3SS9Y= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58/go.mod h1:aVYW33Ow10CyMQGFgC0ptMRIqJWvJ4nxZb0sUiuQT/A= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 h1:c5I5iH+DZcH3xOIMlz3/tCKJDaHFwYEmxvlh2fAcFo8= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11/go.mod h1:cRrYDYAMUohBJUtUnOhydaMHtiK/1NZ0Otc9lIb6O0Y= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 h1:7lOW8NUwE9UZekS1DYoiPdVAqZ6A+LheHWb+mHbNOq8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27/go.mod h1:w1BASFIPOPUae7AgaH4SbjNbfdkxuggLyGfNFTn8ITY= github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 h1:vF+Zgd9s+H4vOXd5BMaPWykta2a6Ih0AKLq/X6NYKn4= github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10/go.mod h1:6BkRjejp/GR4411UGqkX8+wFMbFbqsUIimfK4XjOKR4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31 h1:lWm9ucLSRFiI4dQQafLrEOmEDGry3Swrz0BIRdiHJqQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31/go.mod h1:Huu6GG0YTfbPphQkDSo4dEGmQRTKb9k9G7RdtyQWxuI= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 h1:nYPe006ktcqUji8S2mqXf9c/7NdiKriOwMvWQHgYztw= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10/go.mod h1:6UV4SZkVvmODfXKql4LCbaZUpF7HO2BX38FgBf9ZOLw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31 h1:ACxDklUKKXb48+eg5ROZXi1vDgfMyfIA/WyvqHcHI0o= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31/go.mod h1:yadnfsDwqXeVaohbGc/RaD287PuyRw2wugkh5ZL2J6k= github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 h1:GrSw8s0Gs/5zZ0SX+gX4zQjRnRsMJDJ2sLur1gRBhEM= github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 h1:Pg9URiobXy85kgFev3og2CuOZ8JZUBENF+dcgWBaYNk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 h1:D4oz8/CzT9bAEYtVhSBmFj2dNOtaHOtMKc2vHBwYizA= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2/go.mod h1:Za3IHqTQ+yNcRHxu1OFucBh0ACZT4j4VQFF0BqpZcLY= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10 h1:DBYTXwIGQSGs9w4jKm60F5dmCQ3EEruxdc0MFh+3EY4= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10/go.mod h1:wohMUQiFdzo0NtxbBg0mSRGZ4vL3n0dKjLTINdcIino= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12 h1:O+8vD2rGjfihBewr5bT+QUfYUHIxCVgG61LHoT59shM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12/go.mod h1:usVdWJaosa66NMvmCrr08NcWDBRv4E6+YFG2pUdw1Lk= github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7 h1:a8HvP/+ew3tKwSXqL3BCSjiuicr+XTU2eFYeogV9GJE= github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7/go.mod h1:Q7XIWsMo0JcMpI/6TGD6XXcXcV1DbTj6e9BKNntIMIM= github.com/aws/aws-sdk-go-v2/service/sso v1.18.7 h1:eajuO3nykDPdYicLlP3AGgOyVN3MOlFmZv7WGTuJPow= github.com/aws/aws-sdk-go-v2/service/sso v1.18.7/go.mod h1:+mJNDdF+qiUlNKNC3fxn74WWNN+sOiGOEImje+3ScPM= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 h1:c5WJ3iHz7rLIgArznb3JCSQT3uUMiz9DLZhIX+1G8ok= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14/go.mod h1:+JJQTxB6N4niArC14YNtxcQtwEqzS3o9Z32n7q33Rfs= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7 h1:QPMJf+Jw8E1l7zqhZmMlFw6w1NmfkfiSK8mS4zOx3BA= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7/go.mod h1:ykf3COxYI0UJmxcfcxcVuz7b6uADi1FkiUz6Eb7AgM8= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 h1:f1L/JtUkVODD+k1+IiSJUUv8A++2qVr+Xvb3xWXETMU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13/go.mod h1:tvqlFoja8/s0o+UruA1Nrezo/df0PzdunMDDurUfg6U= github.com/aws/aws-sdk-go-v2/service/sts v1.26.7 h1:NzO4Vrau795RkUdSHKEwiR01FaGzGOH1EETJ+5QHnm0= github.com/aws/aws-sdk-go-v2/service/sts v1.26.7/go.mod h1:6h2YuIoxaMSCFf5fi1EgZAwdfkGMgDY+DVfa61uLe4U= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.13 h1:3LXNnmtH3TURctC23hnC0p/39Q5gre3FI7BNOiDcVWc= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.13/go.mod h1:7Yn+p66q/jt38qMoVfNvjbm3D89mGBnkwDcijgtih8w= github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= +github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= +github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/coder/websocket v1.8.12 h1:5bUXkEPPIbewrnkU8LTCLVaxi4N4J8ahufH2vlo4NAo= @@ -46,10 +72,14 @@ github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c= github.com/djherbis/times v1.6.0/go.mod h1:gOHeRAz2h+VJNZ5Gmc/o7iD9k4wW7NMVqieYCY99oc0= github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA= github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc= github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f35nQbASLnvxEde4XOBL+Sn7rFuV+FOJqkljg= github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA= +github.com/go-json-experiment/json v0.0.0-20250103232110-6a9a0fde9288 h1:KbX3Z3CgiYlbaavUq3Cj9/MjpO+88S7/AGXzynVDv84= +github.com/go-json-experiment/json v0.0.0-20250103232110-6a9a0fde9288/go.mod h1:BWmvoE1Xia34f3l/ibJweyhrT+aROb/FQ6d+37F0e2s= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= @@ -62,6 +92,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI= github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= +github.com/gorilla/csrf v1.7.3-0.20250123201450-9dd6af1f6d30 h1:fiJdrgVBkjZ5B1HJ2WQwNOaXB+QyYcNXTA3t1XYLz0M= +github.com/gorilla/csrf v1.7.3-0.20250123201450-9dd6af1f6d30/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/hdevalence/ed25519consensus v0.2.0 h1:37ICyZqdyj0lAZ8P4D1d1id3HqbbG1N3iBb1Tb4rdcU= @@ -86,6 +118,8 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNU github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a h1:+RR6SqnTkDLWyICxS1xpjCi/3dhyV+TgZwA6Ww3KncQ= github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a/go.mod h1:YTtCCM3ryyfiu4F7t8HQ1mxvp1UBdWM2r6Xa+nGWvDk= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= @@ -98,6 +132,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5 github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= github.com/mdlayher/sdnotify v1.0.0 h1:Ma9XeLVN/l0qpyx1tNeMSeTjCPH6NtuD6/N9XdTlQ3c= github.com/mdlayher/sdnotify v1.0.0/go.mod h1:HQUmpM4XgYkhDLtd+Uad8ZFK1T9D5+pNxnXQjCeJlGE= github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI= @@ -128,14 +164,20 @@ github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 h1:uFsXVBE9Qr4 github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 h1:Gz0rz40FvFVLTBk/K8UNAenb36EbDSnh+q7Z9ldcC8w= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4/go.mod h1:phI29ccmHQBc+wvroosENp1IF9195449VDnFDhJ4rJU= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc h1:24heQPtnFR+yfntqhI3oAu9i27nEojcQ4NuBQOo5ZFA= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc/go.mod h1:f93CXfllFsO9ZQVq+Zocb1Gp4G5Fz0b0rXHLOzt/Djc= github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 h1:tdUdyPqJ0C97SJfjB9tW6EylTtreyee9C44de+UBG0g= github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 h1:UBPHPtv8+nEAy2PD8RyAhOYvau1ek0HDJqLS/Pysi14= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= github.com/tailscale/wireguard-go v0.0.0-20240705152531-2f5d148bcfe1 h1:ycpNCSYwzZ7x4G4ioPNtKQmIY0G/3o4pVf8wCZq6blY= github.com/tailscale/wireguard-go v0.0.0-20240705152531-2f5d148bcfe1/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98 h1:RNpJrXfI5u6e+uzyIzvmnXbhmhdRkVf//90sMBH3lso= github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/wireguard-go v0.0.0-20240905161824-799c1978fafc h1:cezaQN9pvKVaw56Ma5qr/G646uKIYP0yQf+OyWN/okc= github.com/tailscale/wireguard-go v0.0.0-20240905161824-799c1978fafc/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= +github.com/tailscale/wireguard-go v0.0.0-20250107165329-0b8b35511f19 h1:BcEJP2ewTIK2ZCsqgl6YGpuO6+oKqqag5HHb7ehljKw= +github.com/tailscale/wireguard-go v0.0.0-20250107165329-0b8b35511f19/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/xnet v0.0.0-20240117122442-62b9a7c569f9 h1:81P7rjnikHKTJ75EkjppvbwUfKHDHYk6LJpO5PZy8pA= github.com/tailscale/xnet v0.0.0-20240117122442-62b9a7c569f9/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek= @@ -148,6 +190,8 @@ github.com/u-root/u-root v0.12.0 h1:K0AuBFriwr0w/PGS3HawiAw89e3+MU7ks80GpghAsNs= github.com/u-root/u-root v0.12.0/go.mod h1:FYjTOh4IkIZHhjsd17lb8nYW6udgXdJhG1c0r6u0arI= github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e h1:BA9O3BmlTmpjbvajAwzWx4Wo2TRVdpPXZEeemGQcajw= github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e/go.mod h1:eLL9Nub3yfAho7qB0MzZizFhTU2QkLeoVsWdHtDW264= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8RaCKgVpHZnecvArXvPXcFkM= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA= github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= @@ -156,42 +200,66 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= go4.org/mem v0.0.0-20220726221520-4f986261bf13 h1:CbZeCBZ0aZj8EfVgnqQcYZgf0lpZ3H9rmp5nkDTAst8= go4.org/mem v0.0.0-20220726221520-4f986261bf13/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745 h1:Tl++JLUCe4sxGu8cTpDzRLd3tN7US4hOxG5YpKCzkek= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.32.1-0.20250118192723-a8ea4be81f07 h1:Z+Zg+aXJYq6f4TK2E4H+vZkQ4dJAWnInXDR6hM9znxo= +golang.org/x/crypto v0.32.1-0.20250118192723-a8ea4be81f07/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA= +golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/oauth2 v0.16.0 h1:aDkGMBSYxElaoP81NpoUoz2oo2R2wHdZpGToUxfyQrQ= golang.org/x/oauth2 v0.16.0/go.mod h1:hqZ+0LWXsiVoZpeld6jVt06P3adbS2Uu911W1SsJv2o= +golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70= +golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.1-0.20250107080300-1c14dcadc3ab h1:BMkEEWYOjkvOX7+YKOGbp6jCyQ5pR2j0Ah47p1Vdsx4= +golang.org/x/sys v0.29.1-0.20250107080300-1c14dcadc3ab/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= +golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3 h1:/8/t5pz/mgdRXhYOIeqqYhFAQLE4DDGegc0Y4ZjyFJM= gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3/go.mod h1:NQHVAzMwvZ+Qe3ElSiHmq9RUm1MdNHpUZ52fiEqvn+0= gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987 h1:TU8z2Lh3Bbq77w0t1eG8yRlLcNHzZu3x6mhoH2Mk0c8= gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 h1:2gap+Kh/3F47cO6hAu3idFvsJ0ue6TRcEi2IUkv/F8k= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM= k8s.io/client-go v0.30.1 h1:uC/Ir6A3R46wdkgCV3vbLyNOYyCJ8oZnjtJGKfytl/Q= k8s.io/client-go v0.30.1/go.mod h1:wrAqLNs2trwiCH/wxxmT/x3hKVH9PuV0GGW0oDoHVqc= k8s.io/client-go v0.30.3 h1:bHrJu3xQZNXIi8/MoxYtZBBWQQXwy16zqJwloXXfD3k= k8s.io/client-go v0.30.3/go.mod h1:8d4pf8vYu665/kUbsxWAQ/JDBNWqfFeZnvFiVdmx89U= +k8s.io/client-go v0.32.0 h1:DimtMcnN/JIKZcrSrstiwvvZvLjG0aSxy8PxN8IChp8= +k8s.io/client-go v0.32.0/go.mod h1:boDWvdM1Drk4NJj/VddSLnx59X3OPgwrOo0vGbtq9+8= nhooyr.io/websocket v1.8.10 h1:mv4p+MnGrLDcPlBoWsvPP7XCzTYMXP9F9eIGoKbgx7Q= nhooyr.io/websocket v1.8.10/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= diff --git a/gokrazy/natlabapp/config.json b/gokrazy/natlabapp/config.json index 902f14acdb940..1968b2aac79f8 100644 --- a/gokrazy/natlabapp/config.json +++ b/gokrazy/natlabapp/config.json @@ -20,6 +20,10 @@ } } }, + "Environment": [ + "GOOS=linux", + "GOARCH=amd64" + ], "KernelPackage": "github.com/tailscale/gokrazy-kernel", "FirmwarePackage": "", "EEPROMPackage": "", diff --git a/gokrazy/tsapp/builddir/tailscale.com/go.sum b/gokrazy/tsapp/builddir/tailscale.com/go.sum index b3b73e2d0e764..2ffef7bf7ba22 100644 --- a/gokrazy/tsapp/builddir/tailscale.com/go.sum +++ b/gokrazy/tsapp/builddir/tailscale.com/go.sum @@ -4,48 +4,80 @@ github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFI github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= github.com/aws/aws-sdk-go-v2 v1.24.1 h1:xAojnj+ktS95YZlDf0zxWBkbFtymPeDP+rvUQIH3uAU= github.com/aws/aws-sdk-go-v2 v1.24.1/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= +github.com/aws/aws-sdk-go-v2 v1.36.0 h1:b1wM5CcE65Ujwn565qcwgtOTT1aT4ADOHHgglKjG7fk= +github.com/aws/aws-sdk-go-v2 v1.36.0/go.mod h1:5PMILGVKiW32oDzjj6RU52yrNrDPUHcbZQYr1sM7qmM= github.com/aws/aws-sdk-go-v2/config v1.26.5 h1:lodGSevz7d+kkFJodfauThRxK9mdJbyutUxGq1NNhvw= github.com/aws/aws-sdk-go-v2/config v1.26.5/go.mod h1:DxHrz6diQJOc9EwDslVRh84VjjrE17g+pVZXUeSxaDU= +github.com/aws/aws-sdk-go-v2/config v1.29.5 h1:4lS2IB+wwkj5J43Tq/AwvnscBerBJtQQ6YS7puzCI1k= +github.com/aws/aws-sdk-go-v2/config v1.29.5/go.mod h1:SNzldMlDVbN6nWxM7XsUiNXPSa1LWlqiXtvh/1PrJGg= github.com/aws/aws-sdk-go-v2/credentials v1.16.16 h1:8q6Rliyv0aUFAVtzaldUEcS+T5gbadPbWdV1WcAddK8= github.com/aws/aws-sdk-go-v2/credentials v1.16.16/go.mod h1:UHVZrdUsv63hPXFo1H7c5fEneoVo9UXiz36QG1GEPi0= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58 h1:/d7FUpAPU8Lf2KUdjniQvfNdlMID0Sd9pS23FJ3SS9Y= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58/go.mod h1:aVYW33Ow10CyMQGFgC0ptMRIqJWvJ4nxZb0sUiuQT/A= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 h1:c5I5iH+DZcH3xOIMlz3/tCKJDaHFwYEmxvlh2fAcFo8= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11/go.mod h1:cRrYDYAMUohBJUtUnOhydaMHtiK/1NZ0Otc9lIb6O0Y= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 h1:7lOW8NUwE9UZekS1DYoiPdVAqZ6A+LheHWb+mHbNOq8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27/go.mod h1:w1BASFIPOPUae7AgaH4SbjNbfdkxuggLyGfNFTn8ITY= github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 h1:vF+Zgd9s+H4vOXd5BMaPWykta2a6Ih0AKLq/X6NYKn4= github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10/go.mod h1:6BkRjejp/GR4411UGqkX8+wFMbFbqsUIimfK4XjOKR4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31 h1:lWm9ucLSRFiI4dQQafLrEOmEDGry3Swrz0BIRdiHJqQ= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.31/go.mod h1:Huu6GG0YTfbPphQkDSo4dEGmQRTKb9k9G7RdtyQWxuI= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 h1:nYPe006ktcqUji8S2mqXf9c/7NdiKriOwMvWQHgYztw= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10/go.mod h1:6UV4SZkVvmODfXKql4LCbaZUpF7HO2BX38FgBf9ZOLw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31 h1:ACxDklUKKXb48+eg5ROZXi1vDgfMyfIA/WyvqHcHI0o= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.31/go.mod h1:yadnfsDwqXeVaohbGc/RaD287PuyRw2wugkh5ZL2J6k= github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 h1:GrSw8s0Gs/5zZ0SX+gX4zQjRnRsMJDJ2sLur1gRBhEM= github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 h1:Pg9URiobXy85kgFev3og2CuOZ8JZUBENF+dcgWBaYNk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2 h1:D4oz8/CzT9bAEYtVhSBmFj2dNOtaHOtMKc2vHBwYizA= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.2/go.mod h1:Za3IHqTQ+yNcRHxu1OFucBh0ACZT4j4VQFF0BqpZcLY= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10 h1:DBYTXwIGQSGs9w4jKm60F5dmCQ3EEruxdc0MFh+3EY4= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10/go.mod h1:wohMUQiFdzo0NtxbBg0mSRGZ4vL3n0dKjLTINdcIino= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12 h1:O+8vD2rGjfihBewr5bT+QUfYUHIxCVgG61LHoT59shM= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.12/go.mod h1:usVdWJaosa66NMvmCrr08NcWDBRv4E6+YFG2pUdw1Lk= github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7 h1:a8HvP/+ew3tKwSXqL3BCSjiuicr+XTU2eFYeogV9GJE= github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7/go.mod h1:Q7XIWsMo0JcMpI/6TGD6XXcXcV1DbTj6e9BKNntIMIM= github.com/aws/aws-sdk-go-v2/service/sso v1.18.7 h1:eajuO3nykDPdYicLlP3AGgOyVN3MOlFmZv7WGTuJPow= github.com/aws/aws-sdk-go-v2/service/sso v1.18.7/go.mod h1:+mJNDdF+qiUlNKNC3fxn74WWNN+sOiGOEImje+3ScPM= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 h1:c5WJ3iHz7rLIgArznb3JCSQT3uUMiz9DLZhIX+1G8ok= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14/go.mod h1:+JJQTxB6N4niArC14YNtxcQtwEqzS3o9Z32n7q33Rfs= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7 h1:QPMJf+Jw8E1l7zqhZmMlFw6w1NmfkfiSK8mS4zOx3BA= github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7/go.mod h1:ykf3COxYI0UJmxcfcxcVuz7b6uADi1FkiUz6Eb7AgM8= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 h1:f1L/JtUkVODD+k1+IiSJUUv8A++2qVr+Xvb3xWXETMU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13/go.mod h1:tvqlFoja8/s0o+UruA1Nrezo/df0PzdunMDDurUfg6U= github.com/aws/aws-sdk-go-v2/service/sts v1.26.7 h1:NzO4Vrau795RkUdSHKEwiR01FaGzGOH1EETJ+5QHnm0= github.com/aws/aws-sdk-go-v2/service/sts v1.26.7/go.mod h1:6h2YuIoxaMSCFf5fi1EgZAwdfkGMgDY+DVfa61uLe4U= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.13 h1:3LXNnmtH3TURctC23hnC0p/39Q5gre3FI7BNOiDcVWc= +github.com/aws/aws-sdk-go-v2/service/sts v1.33.13/go.mod h1:7Yn+p66q/jt38qMoVfNvjbm3D89mGBnkwDcijgtih8w= github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= +github.com/aws/smithy-go v1.22.2 h1:6D9hW43xKFrRx/tXXfAlIZc4JI+yQe6snnWcQyxSyLQ= +github.com/aws/smithy-go v1.22.2/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 h1:8h5+bWd7R6AYUslN6c6iuZWTKsKxUFDlpnmilO6R2n0= github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= github.com/creack/pty v1.1.21 h1:1/QdRyBaHHJP61QkWMXlOIBfsgdDeeKfK8SYVUWJKf0= github.com/creack/pty v1.1.21/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/creack/pty v1.1.23 h1:4M6+isWdcStXEf15G/RbrMPOQj1dZ7HPZCGwE4kOeP0= +github.com/creack/pty v1.1.23/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e h1:vUmf0yezR0y7jJ5pceLHthLaYf4bA5T14B6q39S4q2Q= github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e/go.mod h1:YTIHhz/QFSYnu/EhlF2SpU2Uk+32abacUYA5ZPljz1A= github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c= github.com/djherbis/times v1.6.0/go.mod h1:gOHeRAz2h+VJNZ5Gmc/o7iD9k4wW7NMVqieYCY99oc0= github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA= github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= +github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv5E= +github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/gaissmai/bart v0.11.1 h1:5Uv5XwsaFBRo4E5VBcb9TzY8B7zxFf+U7isDxqOrRfc= github.com/gaissmai/bart v0.11.1/go.mod h1:KHeYECXQiBjTzQz/om2tqn3sZF1J7hw9m6z41ftj3fg= github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0 h1:ymLjT4f35nQbASLnvxEde4XOBL+Sn7rFuV+FOJqkljg= github.com/go-json-experiment/json v0.0.0-20231102232822-2e55bd4e08b0/go.mod h1:6daplAwHHGbUGib4990V3Il26O0OC4aRyvewaaAihaA= +github.com/go-json-experiment/json v0.0.0-20250103232110-6a9a0fde9288 h1:KbX3Z3CgiYlbaavUq3Cj9/MjpO+88S7/AGXzynVDv84= +github.com/go-json-experiment/json v0.0.0-20250103232110-6a9a0fde9288/go.mod h1:BWmvoE1Xia34f3l/ibJweyhrT+aROb/FQ6d+37F0e2s= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg= github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= @@ -58,12 +90,16 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/csrf v1.7.2 h1:oTUjx0vyf2T+wkrx09Trsev1TE+/EbDAeHtSTbtC2eI= github.com/gorilla/csrf v1.7.2/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= +github.com/gorilla/csrf v1.7.3-0.20250123201450-9dd6af1f6d30 h1:fiJdrgVBkjZ5B1HJ2WQwNOaXB+QyYcNXTA3t1XYLz0M= +github.com/gorilla/csrf v1.7.3-0.20250123201450-9dd6af1f6d30/go.mod h1:F1Fj3KG23WYHE6gozCmBAezKookxbIvUJT+121wTuLk= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/hdevalence/ed25519consensus v0.2.0 h1:37ICyZqdyj0lAZ8P4D1d1id3HqbbG1N3iBb1Tb4rdcU= github.com/hdevalence/ed25519consensus v0.2.0/go.mod h1:w3BHWjwJbFU29IRHL1Iqkw3sus+7FctEyM4RqDxYNzo= github.com/illarion/gonotify v1.0.1 h1:F1d+0Fgbq/sDWjj/r66ekjDG+IDeecQKUFH4wNwsoio= github.com/illarion/gonotify v1.0.1/go.mod h1:zt5pmDofZpU1f8aqlK0+95eQhoEAn/d4G4B/FjVW4jE= +github.com/illarion/gonotify/v2 v2.0.3 h1:B6+SKPo/0Sw8cRJh1aLzNEeNVFfzE3c6N+o+vyxM+9A= +github.com/illarion/gonotify/v2 v2.0.3/go.mod h1:38oIJTgFqupkEydkkClkbL6i5lXV/bxdH9do5TALPEE= github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2 h1:9K06NfxkBh25x56yVhWWlKFE8YpicaSfHwoV8SFbueA= github.com/insomniacslk/dhcp v0.0.0-20231206064809-8c70d406f6d2/go.mod h1:3A9PQ1cunSDF/1rbTq99Ts4pVnycWg+vlPkfeD2NLFI= github.com/jellydator/ttlcache/v3 v3.1.0 h1:0gPFG0IHHP6xyUyXq+JaD8fwkDCqgqwohXNJBcYE71g= @@ -78,6 +114,8 @@ github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNU github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a h1:+RR6SqnTkDLWyICxS1xpjCi/3dhyV+TgZwA6Ww3KncQ= github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a/go.mod h1:YTtCCM3ryyfiu4F7t8HQ1mxvp1UBdWM2r6Xa+nGWvDk= github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= @@ -90,6 +128,8 @@ github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy5 github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= github.com/mdlayher/sdnotify v1.0.0 h1:Ma9XeLVN/l0qpyx1tNeMSeTjCPH6NtuD6/N9XdTlQ3c= github.com/mdlayher/sdnotify v1.0.0/go.mod h1:HQUmpM4XgYkhDLtd+Uad8ZFK1T9D5+pNxnXQjCeJlGE= github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI= @@ -116,14 +156,22 @@ github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a h1:SJy1Pu0eH1C29X github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a/go.mod h1:DFSS3NAGHthKo1gTlmEcSBiZrRJXi28rLNd/1udP1c8= github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 h1:zrsUcqrG2uQSPhaUPjUQwozcRdDdSxxqhNgNZ3drZFk= github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= +github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 h1:uFsXVBE9Qr4ZoF094vE6iYTLDl0qCiKzYXlL6UeWObU= +github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4 h1:Gz0rz40FvFVLTBk/K8UNAenb36EbDSnh+q7Z9ldcC8w= github.com/tailscale/peercred v0.0.0-20240214030740-b535050b2aa4/go.mod h1:phI29ccmHQBc+wvroosENp1IF9195449VDnFDhJ4rJU= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc h1:24heQPtnFR+yfntqhI3oAu9i27nEojcQ4NuBQOo5ZFA= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc/go.mod h1:f93CXfllFsO9ZQVq+Zocb1Gp4G5Fz0b0rXHLOzt/Djc= github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1 h1:tdUdyPqJ0C97SJfjB9tW6EylTtreyee9C44de+UBG0g= github.com/tailscale/web-client-prebuilt v0.0.0-20240226180453-5db17b287bf1/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 h1:UBPHPtv8+nEAy2PD8RyAhOYvau1ek0HDJqLS/Pysi14= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= github.com/tailscale/wireguard-go v0.0.0-20240705152531-2f5d148bcfe1 h1:ycpNCSYwzZ7x4G4ioPNtKQmIY0G/3o4pVf8wCZq6blY= github.com/tailscale/wireguard-go v0.0.0-20240705152531-2f5d148bcfe1/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98 h1:RNpJrXfI5u6e+uzyIzvmnXbhmhdRkVf//90sMBH3lso= github.com/tailscale/wireguard-go v0.0.0-20240731203015-71393c576b98/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= +github.com/tailscale/wireguard-go v0.0.0-20250107165329-0b8b35511f19 h1:BcEJP2ewTIK2ZCsqgl6YGpuO6+oKqqag5HHb7ehljKw= +github.com/tailscale/wireguard-go v0.0.0-20250107165329-0b8b35511f19/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/xnet v0.0.0-20240117122442-62b9a7c569f9 h1:81P7rjnikHKTJ75EkjppvbwUfKHDHYk6LJpO5PZy8pA= github.com/tailscale/xnet v0.0.0-20240117122442-62b9a7c569f9/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek= @@ -136,6 +184,8 @@ github.com/u-root/u-root v0.12.0 h1:K0AuBFriwr0w/PGS3HawiAw89e3+MU7ks80GpghAsNs= github.com/u-root/u-root v0.12.0/go.mod h1:FYjTOh4IkIZHhjsd17lb8nYW6udgXdJhG1c0r6u0arI= github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e h1:BA9O3BmlTmpjbvajAwzWx4Wo2TRVdpPXZEeemGQcajw= github.com/u-root/uio v0.0.0-20240118234441-a3c409a6018e/go.mod h1:eLL9Nub3yfAho7qB0MzZizFhTU2QkLeoVsWdHtDW264= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8RaCKgVpHZnecvArXvPXcFkM= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA= github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= @@ -144,42 +194,66 @@ github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= go4.org/mem v0.0.0-20220726221520-4f986261bf13 h1:CbZeCBZ0aZj8EfVgnqQcYZgf0lpZ3H9rmp5nkDTAst8= go4.org/mem v0.0.0-20220726221520-4f986261bf13/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745 h1:Tl++JLUCe4sxGu8cTpDzRLd3tN7US4hOxG5YpKCzkek= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= golang.org/x/crypto v0.24.0 h1:mnl8DM0o513X8fdIkmyFE/5hTYxbwYOjDS/+rK6qpRI= golang.org/x/crypto v0.24.0/go.mod h1:Z1PMYSOR5nyMcyAVAIQSKCDwalqy85Aqn1x3Ws4L5DM= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.32.1-0.20250118192723-a8ea4be81f07 h1:Z+Zg+aXJYq6f4TK2E4H+vZkQ4dJAWnInXDR6hM9znxo= +golang.org/x/crypto v0.32.1-0.20250118192723-a8ea4be81f07/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a h1:Q8/wZp0KX97QFTc2ywcOE0YRjZPVIx+MXInMzdvQqcA= golang.org/x/exp v0.0.0-20240119083558-1b970713d09a/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08= +golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8 h1:yqrTHse8TCMW1M1ZCP+VAR/l0kKxwaAIqN/il7x4voA= +golang.org/x/exp v0.0.0-20250106191152-7588d65b2ba8/go.mod h1:tujkw807nyEEAamNbDrEGzRav+ilXA7PCRAd6xsmwiU= golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= golang.org/x/net v0.26.0/go.mod h1:5YKkiSynbBIh3p6iOc/vibscux0x38BZDkn8sCUPxHE= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/oauth2 v0.16.0 h1:aDkGMBSYxElaoP81NpoUoz2oo2R2wHdZpGToUxfyQrQ= golang.org/x/oauth2 v0.16.0/go.mod h1:hqZ+0LWXsiVoZpeld6jVt06P3adbS2Uu911W1SsJv2o= +golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70= +golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.21.0 h1:rF+pYz3DAGSQAxAu1CbC7catZg4ebC4UIeIhKxBZvws= golang.org/x/sys v0.21.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.29.1-0.20250107080300-1c14dcadc3ab h1:BMkEEWYOjkvOX7+YKOGbp6jCyQ5pR2j0Ah47p1Vdsx4= +golang.org/x/sys v0.29.1-0.20250107080300-1c14dcadc3ab/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.21.0 h1:WVXCp+/EBEHOj53Rvu+7KiT/iElMrO8ACK16SMZ3jaA= golang.org/x/term v0.21.0/go.mod h1:ooXLefLobQVslOqselCNF4SxFAaoS6KujMbsGzSDmX0= golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= +golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= +golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= +golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3 h1:/8/t5pz/mgdRXhYOIeqqYhFAQLE4DDGegc0Y4ZjyFJM= gvisor.dev/gvisor v0.0.0-20240306221502-ee1e1f6070e3/go.mod h1:NQHVAzMwvZ+Qe3ElSiHmq9RUm1MdNHpUZ52fiEqvn+0= gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987 h1:TU8z2Lh3Bbq77w0t1eG8yRlLcNHzZu3x6mhoH2Mk0c8= gvisor.dev/gvisor v0.0.0-20240722211153-64c016c92987/go.mod h1:sxc3Uvk/vHcd3tj7/DHVBoR5wvWT/MmRq2pj7HRJnwU= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 h1:2gap+Kh/3F47cO6hAu3idFvsJ0ue6TRcEi2IUkv/F8k= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM= k8s.io/client-go v0.30.1 h1:uC/Ir6A3R46wdkgCV3vbLyNOYyCJ8oZnjtJGKfytl/Q= k8s.io/client-go v0.30.1/go.mod h1:wrAqLNs2trwiCH/wxxmT/x3hKVH9PuV0GGW0oDoHVqc= k8s.io/client-go v0.30.3 h1:bHrJu3xQZNXIi8/MoxYtZBBWQQXwy16zqJwloXXfD3k= k8s.io/client-go v0.30.3/go.mod h1:8d4pf8vYu665/kUbsxWAQ/JDBNWqfFeZnvFiVdmx89U= +k8s.io/client-go v0.32.0 h1:DimtMcnN/JIKZcrSrstiwvvZvLjG0aSxy8PxN8IChp8= +k8s.io/client-go v0.32.0/go.mod h1:boDWvdM1Drk4NJj/VddSLnx59X3OPgwrOo0vGbtq9+8= nhooyr.io/websocket v1.8.10 h1:mv4p+MnGrLDcPlBoWsvPP7XCzTYMXP9F9eIGoKbgx7Q= nhooyr.io/websocket v1.8.10/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= sigs.k8s.io/yaml v1.4.0 h1:Mk1wCc2gy/F0THH0TAp1QYyJNzRm2KCLy3o5ASXVI5E= diff --git a/gokrazy/tsapp/config.json b/gokrazy/tsapp/config.json index 33dd98a962043..b88be53a456a8 100644 --- a/gokrazy/tsapp/config.json +++ b/gokrazy/tsapp/config.json @@ -27,6 +27,10 @@ } } }, + "Environment": [ + "GOOS=linux", + "GOARCH=amd64" + ], "KernelPackage": "github.com/tailscale/gokrazy-kernel", "FirmwarePackage": "github.com/tailscale/gokrazy-kernel", "InternalCompatibilityFlags": {} diff --git a/health/health.go b/health/health.go index 216535d17c484..1ec2bcc9b0dd1 100644 --- a/health/health.go +++ b/health/health.go @@ -22,6 +22,7 @@ import ( "tailscale.com/envknob" "tailscale.com/metrics" "tailscale.com/tailcfg" + "tailscale.com/tstime" "tailscale.com/types/opt" "tailscale.com/util/cibuild" "tailscale.com/util/mak" @@ -73,6 +74,8 @@ type Tracker struct { // mu should not be held during init. initOnce sync.Once + testClock tstime.Clock // nil means use time.Now / tstime.StdClock{} + // mu guards everything that follows. mu sync.Mutex @@ -80,13 +83,13 @@ type Tracker struct { warnableVal map[*Warnable]*warningState // pendingVisibleTimers contains timers for Warnables that are unhealthy, but are // not visible to the user yet, because they haven't been unhealthy for TimeToVisible - pendingVisibleTimers map[*Warnable]*time.Timer + pendingVisibleTimers map[*Warnable]tstime.TimerController // sysErr maps subsystems to their current error (or nil if the subsystem is healthy) // Deprecated: using Warnables should be preferred sysErr map[Subsystem]error watchers set.HandleSet[func(*Warnable, *UnhealthyState)] // opt func to run if error state changes - timer *time.Timer + timer tstime.TimerController latestVersion *tailcfg.ClientVersion // or nil checkForUpdates bool @@ -115,6 +118,20 @@ type Tracker struct { metricHealthMessage *metrics.MultiLabelMap[metricHealthMessageLabel] } +func (t *Tracker) now() time.Time { + if t.testClock != nil { + return t.testClock.Now() + } + return time.Now() +} + +func (t *Tracker) clock() tstime.Clock { + if t.testClock != nil { + return t.testClock + } + return tstime.StdClock{} +} + // Subsystem is the name of a subsystem whose health can be monitored. // // Deprecated: Registering a Warnable using Register() and updating its health state @@ -128,9 +145,6 @@ const ( // SysDNS is the name of the net/dns subsystem. SysDNS = Subsystem("dns") - // SysDNSOS is the name of the net/dns OSConfigurator subsystem. - SysDNSOS = Subsystem("dns-os") - // SysDNSManager is the name of the net/dns manager subsystem. SysDNSManager = Subsystem("dns-manager") @@ -141,7 +155,7 @@ const ( var subsystemsWarnables = map[Subsystem]*Warnable{} func init() { - for _, s := range []Subsystem{SysRouter, SysDNS, SysDNSOS, SysDNSManager, SysTKA} { + for _, s := range []Subsystem{SysRouter, SysDNS, SysDNSManager, SysTKA} { w := Register(&Warnable{ Code: WarnableCode(s), Severity: SeverityMedium, @@ -217,9 +231,11 @@ type Warnable struct { // TODO(angott): turn this into a SeverityFunc, which allows the Warnable to change its severity based on // the Args of the unhappy state, just like we do in the Text function. Severity Severity - // DependsOn is a set of Warnables that this Warnable depends, on and need to be healthy - // before this Warnable can also be healthy again. The GUI can use this information to ignore + // DependsOn is a set of Warnables that this Warnable depends on and need to be healthy + // before this Warnable is relevant. The GUI can use this information to ignore // this Warnable if one of its dependencies is unhealthy. + // That is, if any of these Warnables are unhealthy, then this Warnable is not relevant + // and should be considered healthy to bother the user about. DependsOn []*Warnable // MapDebugFlag is a MapRequest.DebugFlag that is sent to control when this Warnable is unhealthy @@ -312,11 +328,11 @@ func (ws *warningState) Equal(other *warningState) bool { // IsVisible returns whether the Warnable should be visible to the user, based on the TimeToVisible // field of the Warnable and the BrokenSince time when the Warnable became unhealthy. -func (w *Warnable) IsVisible(ws *warningState) bool { +func (w *Warnable) IsVisible(ws *warningState, clockNow func() time.Time) bool { if ws == nil || w.TimeToVisible == 0 { return true } - return time.Since(ws.BrokenSince) >= w.TimeToVisible + return clockNow().Sub(ws.BrokenSince) >= w.TimeToVisible } // SetMetricsRegistry sets up the metrics for the Tracker. It takes @@ -334,7 +350,7 @@ func (t *Tracker) SetMetricsRegistry(reg *usermetric.Registry) { ) t.metricHealthMessage.Set(metricHealthMessageLabel{ - Type: "warning", + Type: MetricLabelWarning, }, expvar.Func(func() any { if t.nil() { return 0 @@ -346,6 +362,18 @@ func (t *Tracker) SetMetricsRegistry(reg *usermetric.Registry) { })) } +// IsUnhealthy reports whether the current state is unhealthy because the given +// warnable is set. +func (t *Tracker) IsUnhealthy(w *Warnable) bool { + if t.nil() { + return false + } + t.mu.Lock() + defer t.mu.Unlock() + _, exists := t.warnableVal[w] + return exists +} + // SetUnhealthy sets a warningState for the given Warnable with the provided Args, and should be // called when a Warnable becomes unhealthy, or its unhealthy status needs to be updated. // SetUnhealthy takes ownership of args. The args can be nil if no additional information is @@ -366,7 +394,7 @@ func (t *Tracker) setUnhealthyLocked(w *Warnable, args Args) { // If we already have a warningState for this Warnable with an earlier BrokenSince time, keep that // BrokenSince time. - brokenSince := time.Now() + brokenSince := t.now() if existingWS := t.warnableVal[w]; existingWS != nil { brokenSince = existingWS.BrokenSince } @@ -385,24 +413,25 @@ func (t *Tracker) setUnhealthyLocked(w *Warnable, args Args) { // If the Warnable has been unhealthy for more than its TimeToVisible, the callback should be // executed immediately. Otherwise, the callback should be enqueued to run once the Warnable // becomes visible. - if w.IsVisible(ws) { - go cb(w, w.unhealthyState(ws)) + if w.IsVisible(ws, t.now) { + cb(w, w.unhealthyState(ws)) continue } // The time remaining until the Warnable will be visible to the user is the TimeToVisible // minus the time that has already passed since the Warnable became unhealthy. - visibleIn := w.TimeToVisible - time.Since(brokenSince) - mak.Set(&t.pendingVisibleTimers, w, time.AfterFunc(visibleIn, func() { + visibleIn := w.TimeToVisible - t.now().Sub(brokenSince) + var tc tstime.TimerController = t.clock().AfterFunc(visibleIn, func() { t.mu.Lock() defer t.mu.Unlock() // Check if the Warnable is still unhealthy, as it could have become healthy between the time // the timer was set for and the time it was executed. if t.warnableVal[w] != nil { - go cb(w, w.unhealthyState(ws)) + cb(w, w.unhealthyState(ws)) delete(t.pendingVisibleTimers, w) } - })) + }) + mak.Set(&t.pendingVisibleTimers, w, tc) } } } @@ -432,7 +461,7 @@ func (t *Tracker) setHealthyLocked(w *Warnable) { } for _, cb := range t.watchers { - go cb(w, nil) + cb(w, nil) } } @@ -466,6 +495,16 @@ func (t *Tracker) AppendWarnableDebugFlags(base []string) []string { // The provided callback function will be executed in its own goroutine. The returned function can be used // to unregister the callback. func (t *Tracker) RegisterWatcher(cb func(w *Warnable, r *UnhealthyState)) (unregister func()) { + return t.registerSyncWatcher(func(w *Warnable, r *UnhealthyState) { + go cb(w, r) + }) +} + +// registerSyncWatcher adds a function that will be called whenever the health +// state of any Warnable changes. The provided callback function will be +// executed synchronously. Call RegisterWatcher to register any callbacks that +// won't return from execution immediately. +func (t *Tracker) registerSyncWatcher(cb func(w *Warnable, r *UnhealthyState)) (unregister func()) { if t.nil() { return func() {} } @@ -477,7 +516,7 @@ func (t *Tracker) RegisterWatcher(cb func(w *Warnable, r *UnhealthyState)) (unre } handle := t.watchers.Add(cb) if t.timer == nil { - t.timer = time.AfterFunc(time.Minute, t.timerSelfCheck) + t.timer = t.clock().AfterFunc(time.Minute, t.timerSelfCheck) } return func() { t.mu.Lock() @@ -510,22 +549,12 @@ func (t *Tracker) SetDNSHealth(err error) { t.setErr(SysDNS, err) } // Deprecated: Warnables should be preferred over Subsystem errors. func (t *Tracker) DNSHealth() error { return t.get(SysDNS) } -// SetDNSOSHealth sets the state of the net/dns.OSConfigurator -// -// Deprecated: Warnables should be preferred over Subsystem errors. -func (t *Tracker) SetDNSOSHealth(err error) { t.setErr(SysDNSOS, err) } - // SetDNSManagerHealth sets the state of the Linux net/dns manager's // discovery of the /etc/resolv.conf situation. // // Deprecated: Warnables should be preferred over Subsystem errors. func (t *Tracker) SetDNSManagerHealth(err error) { t.setErr(SysDNSManager, err) } -// DNSOSHealth returns the net/dns.OSConfigurator error state. -// -// Deprecated: Warnables should be preferred over Subsystem errors. -func (t *Tracker) DNSOSHealth() error { return t.get(SysDNSOS) } - // SetTKAHealth sets the health of the tailnet key authority. // // Deprecated: Warnables should be preferred over Subsystem errors. @@ -651,10 +680,10 @@ func (t *Tracker) GotStreamedMapResponse() { } t.mu.Lock() defer t.mu.Unlock() - t.lastStreamedMapResponse = time.Now() + t.lastStreamedMapResponse = t.now() if !t.inMapPoll { t.inMapPoll = true - t.inMapPollSince = time.Now() + t.inMapPollSince = t.now() } t.selfCheckLocked() } @@ -671,7 +700,7 @@ func (t *Tracker) SetOutOfPollNetMap() { return } t.inMapPoll = false - t.lastMapPollEndedAt = time.Now() + t.lastMapPollEndedAt = t.now() t.selfCheckLocked() } @@ -713,7 +742,7 @@ func (t *Tracker) NoteMapRequestHeard(mr *tailcfg.MapRequest) { // against SetMagicSockDERPHome and // SetDERPRegionConnectedState - t.lastMapRequestHeard = time.Now() + t.lastMapRequestHeard = t.now() t.selfCheckLocked() } @@ -751,7 +780,7 @@ func (t *Tracker) NoteDERPRegionReceivedFrame(region int) { } t.mu.Lock() defer t.mu.Unlock() - mak.Set(&t.derpRegionLastFrame, region, time.Now()) + mak.Set(&t.derpRegionLastFrame, region, t.now()) t.selfCheckLocked() } @@ -810,9 +839,9 @@ func (t *Tracker) SetIPNState(state string, wantRunning bool) { // The first time we see wantRunning=true and it used to be false, it means the user requested // the backend to start. We store this timestamp and use it to silence some warnings that are // expected during startup. - t.ipnWantRunningLastTrue = time.Now() + t.ipnWantRunningLastTrue = t.now() t.setUnhealthyLocked(warmingUpWarnable, nil) - time.AfterFunc(warmingUpWarnableDuration, func() { + t.clock().AfterFunc(warmingUpWarnableDuration, func() { t.mu.Lock() t.updateWarmingUpWarnableLocked() t.mu.Unlock() @@ -949,10 +978,13 @@ func (t *Tracker) Strings() []string { func (t *Tracker) stringsLocked() []string { result := []string{} for w, ws := range t.warnableVal { - if !w.IsVisible(ws) { + if !w.IsVisible(ws, t.now) { // Do not append invisible warnings. continue } + if t.isEffectivelyHealthyLocked(w) { + continue + } if ws.Args == nil { result = append(result, w.Text(Args{})) } else { @@ -1018,7 +1050,7 @@ func (t *Tracker) updateBuiltinWarnablesLocked() { t.setHealthyLocked(localLogWarnable) } - now := time.Now() + now := t.now() // How long we assume we'll have heard a DERP frame or a MapResponse // KeepAlive by. @@ -1028,8 +1060,10 @@ func (t *Tracker) updateBuiltinWarnablesLocked() { recentlyOn := now.Sub(t.ipnWantRunningLastTrue) < 5*time.Second homeDERP := t.derpHomeRegion - if recentlyOn { + if recentlyOn || !t.inMapPoll { // If user just turned Tailscale on, don't warn for a bit. + // Also, if we're not in a map poll, that means we don't yet + // have a DERPMap or aren't in a state where we even want t.setHealthyLocked(noDERPHomeWarnable) t.setHealthyLocked(noDERPConnectionWarnable) t.setHealthyLocked(derpTimeoutWarnable) @@ -1051,11 +1085,15 @@ func (t *Tracker) updateBuiltinWarnablesLocked() { ArgDuration: d.Round(time.Second).String(), }) } - } else { + } else if homeDERP != 0 { t.setUnhealthyLocked(noDERPConnectionWarnable, Args{ ArgDERPRegionID: fmt.Sprint(homeDERP), ArgDERPRegionName: t.derpRegionNameLocked(homeDERP), }) + } else { + // No DERP home yet determined yet. There's probably some + // other problem or things are just starting up. + t.setHealthyLocked(noDERPConnectionWarnable) } if !t.ipnWantRunning { @@ -1174,7 +1212,7 @@ func (t *Tracker) updateBuiltinWarnablesLocked() { // updateWarmingUpWarnableLocked ensures the warmingUpWarnable is healthy if wantRunning has been set to true // for more than warmingUpWarnableDuration. func (t *Tracker) updateWarmingUpWarnableLocked() { - if !t.ipnWantRunningLastTrue.IsZero() && time.Now().After(t.ipnWantRunningLastTrue.Add(warmingUpWarnableDuration)) { + if !t.ipnWantRunningLastTrue.IsZero() && t.now().After(t.ipnWantRunningLastTrue.Add(warmingUpWarnableDuration)) { t.setHealthyLocked(warmingUpWarnable) } } @@ -1286,12 +1324,14 @@ func (t *Tracker) LastNoiseDialWasRecent() bool { t.mu.Lock() defer t.mu.Unlock() - now := time.Now() + now := t.now() dur := now.Sub(t.lastNoiseDial) t.lastNoiseDial = now return dur < 2*time.Minute } +const MetricLabelWarning = "warning" + type metricHealthMessageLabel struct { // TODO: break down by warnable.severity as well? Type string diff --git a/health/health_test.go b/health/health_test.go index 8107c1cf09db5..aa39045817ce2 100644 --- a/health/health_test.go +++ b/health/health_test.go @@ -7,11 +7,15 @@ import ( "fmt" "reflect" "slices" + "strconv" "testing" "time" "tailscale.com/tailcfg" + "tailscale.com/tstest" "tailscale.com/types/opt" + "tailscale.com/util/usermetric" + "tailscale.com/version" ) func TestAppendWarnableDebugFlags(t *testing.T) { @@ -254,9 +258,15 @@ func TestCheckDependsOnAppearsInUnhealthyState(t *testing.T) { } ht.SetUnhealthy(w2, Args{ArgError: "w2 is also unhealthy now"}) us2, ok := ht.CurrentState().Warnings[w2.Code] + if ok { + t.Fatalf("Saw w2 being unhealthy but it shouldn't be, as it depends on unhealthy w1") + } + ht.SetHealthy(w1) + us2, ok = ht.CurrentState().Warnings[w2.Code] if !ok { - t.Fatalf("Expected an UnhealthyState for w2, got nothing") + t.Fatalf("w2 wasn't unhealthy; want it to be unhealthy now that w1 is back healthy") } + wantDependsOn = slices.Concat([]WarnableCode{w1.Code}, wantDependsOn) if !reflect.DeepEqual(us2.DependsOn, wantDependsOn) { t.Fatalf("Expected DependsOn = %v in the unhealthy state, got: %v", wantDependsOn, us2.DependsOn) @@ -273,7 +283,7 @@ func TestShowUpdateWarnable(t *testing.T) { wantShow bool }{ { - desc: "nil CientVersion", + desc: "nil ClientVersion", check: true, cv: nil, wantWarnable: nil, @@ -348,3 +358,181 @@ func TestShowUpdateWarnable(t *testing.T) { }) } } + +func TestHealthMetric(t *testing.T) { + unstableBuildWarning := 0 + if version.IsUnstableBuild() { + unstableBuildWarning = 1 + } + + tests := []struct { + desc string + check bool + apply opt.Bool + cv *tailcfg.ClientVersion + wantMetricCount int + }{ + // When running in dev, and not initialising the client, there will be two warnings + // by default: + // - is-using-unstable-version (except on the release branch) + // - wantrunning-false + { + desc: "base-warnings", + check: true, + cv: nil, + wantMetricCount: unstableBuildWarning + 1, + }, + // with: update-available + { + desc: "update-warning", + check: true, + cv: &tailcfg.ClientVersion{RunningLatest: false, LatestVersion: "1.2.3"}, + wantMetricCount: unstableBuildWarning + 2, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + tr := &Tracker{ + checkForUpdates: tt.check, + applyUpdates: tt.apply, + latestVersion: tt.cv, + } + tr.SetMetricsRegistry(&usermetric.Registry{}) + if val := tr.metricHealthMessage.Get(metricHealthMessageLabel{Type: MetricLabelWarning}).String(); val != strconv.Itoa(tt.wantMetricCount) { + t.Fatalf("metric value: %q, want: %q", val, strconv.Itoa(tt.wantMetricCount)) + } + for _, w := range tr.CurrentState().Warnings { + t.Logf("warning: %v", w) + } + }) + } +} + +// TestNoDERPHomeWarnable checks that we don't +// complain about no DERP home if we're not in a +// map poll. +func TestNoDERPHomeWarnable(t *testing.T) { + t.Skip("TODO: fix https://github.com/tailscale/tailscale/issues/14798 to make this test not deadlock") + clock := tstest.NewClock(tstest.ClockOpts{ + Start: time.Unix(123, 0), + FollowRealTime: false, + }) + ht := &Tracker{ + testClock: clock, + } + ht.SetIPNState("NeedsLogin", true) + + // Advance 30 seconds to get past the "recentlyLoggedIn" check. + clock.Advance(30 * time.Second) + ht.updateBuiltinWarnablesLocked() + + // Advance to get past the the TimeToVisible delay. + clock.Advance(noDERPHomeWarnable.TimeToVisible * 2) + + ht.updateBuiltinWarnablesLocked() + if ws, ok := ht.CurrentState().Warnings[noDERPHomeWarnable.Code]; ok { + t.Fatalf("got unexpected noDERPHomeWarnable warnable: %v", ws) + } +} + +// TestNoDERPHomeWarnableManual is like TestNoDERPHomeWarnable +// but doesn't use tstest.Clock so avoids the deadlock +// I hit: https://github.com/tailscale/tailscale/issues/14798 +func TestNoDERPHomeWarnableManual(t *testing.T) { + ht := &Tracker{} + ht.SetIPNState("NeedsLogin", true) + + // Avoid wantRunning: + ht.ipnWantRunningLastTrue = ht.ipnWantRunningLastTrue.Add(-10 * time.Second) + ht.updateBuiltinWarnablesLocked() + + ws, ok := ht.warnableVal[noDERPHomeWarnable] + if ok { + t.Fatalf("got unexpected noDERPHomeWarnable warnable: %v", ws) + } +} + +func TestControlHealth(t *testing.T) { + ht := Tracker{} + ht.SetIPNState("NeedsLogin", true) + ht.GotStreamedMapResponse() + + ht.SetControlHealth([]string{"Test message"}) + state := ht.CurrentState() + warning, ok := state.Warnings["control-health"] + + if !ok { + t.Fatal("no warning found in current state with code 'control-health'") + } + if got, want := warning.Title, "Coordination server reports an issue"; got != want { + t.Errorf("warning.Title = %q, want %q", got, want) + } + if got, want := warning.Severity, SeverityMedium; got != want { + t.Errorf("warning.Severity = %s, want %s", got, want) + } + if got, want := warning.Text, "The coordination server is reporting an health issue: Test message"; got != want { + t.Errorf("warning.Text = %q, want %q", got, want) + } +} + +func TestControlHealthNotifiesOnChange(t *testing.T) { + ht := Tracker{} + ht.SetIPNState("NeedsLogin", true) + ht.GotStreamedMapResponse() + + gotNotified := false + ht.registerSyncWatcher(func(_ *Warnable, _ *UnhealthyState) { + gotNotified = true + }) + + ht.SetControlHealth([]string{"Test message"}) + + if !gotNotified { + t.Errorf("watcher did not get called, want it to be called") + } +} + +func TestControlHealthNoNotifyOnUnchanged(t *testing.T) { + ht := Tracker{} + ht.SetIPNState("NeedsLogin", true) + ht.GotStreamedMapResponse() + + // Set up an existing control health issue + ht.SetControlHealth([]string{"Test message"}) + + // Now register our watcher + gotNotified := false + ht.registerSyncWatcher(func(_ *Warnable, _ *UnhealthyState) { + gotNotified = true + }) + + // Send the same control health message again - should not notify + ht.SetControlHealth([]string{"Test message"}) + + if gotNotified { + t.Errorf("watcher got called, want it to not be called") + } +} + +func TestControlHealthIgnoredOutsideMapPoll(t *testing.T) { + ht := Tracker{} + ht.SetIPNState("NeedsLogin", true) + + gotNotified := false + ht.registerSyncWatcher(func(_ *Warnable, _ *UnhealthyState) { + gotNotified = true + }) + + ht.SetControlHealth([]string{"Test message"}) + + state := ht.CurrentState() + _, ok := state.Warnings["control-health"] + + if ok { + t.Error("got a warning with code 'control-health', want none") + } + + if gotNotified { + t.Error("watcher got called, want it to not be called") + } +} diff --git a/health/state.go b/health/state.go index 17a646794b252..c06f6ef59c8ed 100644 --- a/health/state.go +++ b/health/state.go @@ -86,10 +86,15 @@ func (t *Tracker) CurrentState() *State { wm := map[WarnableCode]UnhealthyState{} for w, ws := range t.warnableVal { - if !w.IsVisible(ws) { + if !w.IsVisible(ws, t.now) { // Skip invisible Warnables. continue } + if t.isEffectivelyHealthyLocked(w) { + // Skip Warnables that are unhealthy if they have dependencies + // that are unhealthy. + continue + } wm[w.Code] = *w.unhealthyState(ws) } @@ -97,3 +102,23 @@ func (t *Tracker) CurrentState() *State { Warnings: wm, } } + +// isEffectivelyHealthyLocked reports whether w is effectively healthy. +// That means it's either actually healthy or it has a dependency that +// that's unhealthy, so we should treat w as healthy to not spam users +// with multiple warnings when only the root cause is relevant. +func (t *Tracker) isEffectivelyHealthyLocked(w *Warnable) bool { + if _, ok := t.warnableVal[w]; !ok { + // Warnable not found in the tracker. So healthy. + return true + } + for _, d := range w.DependsOn { + if !t.isEffectivelyHealthyLocked(d) { + // If one of our deps is unhealthy, we're healthy. + return true + } + } + // If we have no unhealthy deps and had warnableVal set, + // we're unhealthy. + return false +} diff --git a/hostinfo/hostinfo.go b/hostinfo/hostinfo.go index 3233a422dd6c3..3e8f2f994791e 100644 --- a/hostinfo/hostinfo.go +++ b/hostinfo/hostinfo.go @@ -21,22 +21,31 @@ import ( "go4.org/mem" "tailscale.com/envknob" "tailscale.com/tailcfg" + "tailscale.com/types/lazy" "tailscale.com/types/opt" "tailscale.com/types/ptr" "tailscale.com/util/cloudenv" "tailscale.com/util/dnsname" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/version" "tailscale.com/version/distro" ) var started = time.Now() +var newHooks []func(*tailcfg.Hostinfo) + +// RegisterHostinfoNewHook registers a callback to be called on a non-nil +// [tailcfg.Hostinfo] before it is returned by [New]. +func RegisterHostinfoNewHook(f func(*tailcfg.Hostinfo)) { + newHooks = append(newHooks, f) +} + // New returns a partially populated Hostinfo for the current host. func New() *tailcfg.Hostinfo { - hostname, _ := os.Hostname() + hostname, _ := Hostname() hostname = dnsname.FirstLabel(hostname) - return &tailcfg.Hostinfo{ + hi := &tailcfg.Hostinfo{ IPNVersion: version.Long(), Hostname: hostname, App: appTypeCached(), @@ -57,8 +66,11 @@ func New() *tailcfg.Hostinfo { Cloud: string(cloudenv.Get()), NoLogsNoSupport: envknob.NoLogsNoSupport(), AllowsUpdate: envknob.AllowsRemoteUpdate(), - WoLMACs: getWoLMACs(), } + for _, f := range newHooks { + f(hi) + } + return hi } // non-nil on some platforms @@ -231,12 +243,11 @@ func desktop() (ret opt.Bool) { } seenDesktop := false - lineread.File("/proc/net/unix", func(line []byte) error { - seenDesktop = seenDesktop || mem.Contains(mem.B(line), mem.S(" @/tmp/dbus-")) + for lr := range lineiter.File("/proc/net/unix") { + line, _ := lr.Value() seenDesktop = seenDesktop || mem.Contains(mem.B(line), mem.S(".X11-unix")) seenDesktop = seenDesktop || mem.Contains(mem.B(line), mem.S("/wayland-1")) - return nil - }) + } ret.Set(seenDesktop) // Only cache after a minute - compositors might not have started yet. @@ -305,21 +316,21 @@ func inContainer() opt.Bool { ret.Set(true) return ret } - lineread.File("/proc/1/cgroup", func(line []byte) error { + for lr := range lineiter.File("/proc/1/cgroup") { + line, _ := lr.Value() if mem.Contains(mem.B(line), mem.S("/docker/")) || mem.Contains(mem.B(line), mem.S("/lxc/")) { ret.Set(true) - return io.EOF // arbitrary non-nil error to stop loop + break } - return nil - }) - lineread.File("/proc/mounts", func(line []byte) error { + } + for lr := range lineiter.File("/proc/mounts") { + line, _ := lr.Value() if mem.Contains(mem.B(line), mem.S("lxcfs /proc/cpuinfo fuse.lxcfs")) { ret.Set(true) - return io.EOF + break } - return nil - }) + } return ret } @@ -487,5 +498,32 @@ func IsNATLabGuestVM() bool { return false } -// NAT Lab VMs have a unique MAC address prefix. -// See +const copyV86DeviceModel = "copy-v86" + +var isV86Cache lazy.SyncValue[bool] + +// IsInVM86 reports whether we're running in the copy/v86 wasm emulator, +// https://github.com/copy/v86/. +func IsInVM86() bool { + return isV86Cache.Get(func() bool { + return New().DeviceModel == copyV86DeviceModel + }) +} + +type hostnameQuery func() (string, error) + +var hostnameFn atomic.Value // of func() (string, error) + +// SetHostNameFn sets a custom function for querying the system hostname. +func SetHostnameFn(fn hostnameQuery) { + hostnameFn.Store(fn) +} + +// Hostname returns the system hostname using the function +// set by SetHostNameFn. We will fallback to os.Hostname. +func Hostname() (string, error) { + if fn, ok := hostnameFn.Load().(hostnameQuery); ok && fn != nil { + return fn() + } + return os.Hostname() +} diff --git a/hostinfo/hostinfo_linux.go b/hostinfo/hostinfo_linux.go index 53d4187bc0c67..66484a3588027 100644 --- a/hostinfo/hostinfo_linux.go +++ b/hostinfo/hostinfo_linux.go @@ -12,7 +12,7 @@ import ( "golang.org/x/sys/unix" "tailscale.com/types/ptr" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/version/distro" ) @@ -106,15 +106,18 @@ func linuxVersionMeta() (meta versionMeta) { } m := map[string]string{} - lineread.File(propFile, func(line []byte) error { + for lr := range lineiter.File(propFile) { + line, err := lr.Value() + if err != nil { + break + } eq := bytes.IndexByte(line, '=') if eq == -1 { - return nil + continue } k, v := string(line[:eq]), strings.Trim(string(line[eq+1:]), `"'`) m[k] = v - return nil - }) + } if v := m["VERSION_CODENAME"]; v != "" { meta.DistroCodeName = v diff --git a/hostinfo/hostinfo_linux_test.go b/hostinfo/hostinfo_linux_test.go index c8bd2abbeb230..0286fadf329ab 100644 --- a/hostinfo/hostinfo_linux_test.go +++ b/hostinfo/hostinfo_linux_test.go @@ -35,8 +35,12 @@ remotes/origin/QTSFW_5.0.0` } } -func TestInContainer(t *testing.T) { - if got := inContainer(); !got.EqualBool(false) { - t.Errorf("inContainer = %v; want false due to absence of ts_package_container build tag", got) +func TestPackageTypeNotContainer(t *testing.T) { + var got string + if packageType != nil { + got = packageType() + } + if got == "container" { + t.Fatal("packageType = container; should only happen if build tag ts_package_container is set") } } diff --git a/hostinfo/hostinfo_plan9.go b/hostinfo/hostinfo_plan9.go new file mode 100644 index 0000000000000..f9aa30e51769f --- /dev/null +++ b/hostinfo/hostinfo_plan9.go @@ -0,0 +1,39 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package hostinfo + +import ( + "bytes" + "os" + "strings" + + "tailscale.com/tailcfg" + "tailscale.com/types/lazy" +) + +func init() { + RegisterHostinfoNewHook(func(hi *tailcfg.Hostinfo) { + if isPlan9V86() { + hi.DeviceModel = copyV86DeviceModel + } + }) +} + +var isPlan9V86Cache lazy.SyncValue[bool] + +// isPlan9V86 reports whether we're running in the wasm +// environment (https://github.com/copy/v86/). +func isPlan9V86() bool { + return isPlan9V86Cache.Get(func() bool { + v, _ := os.ReadFile("/dev/cputype") + s, _, _ := strings.Cut(string(v), " ") + if s != "PentiumIV/Xeon" { + return false + } + + v, _ = os.ReadFile("/dev/config") + v, _, _ = bytes.Cut(v, []byte{'\n'}) + return string(v) == "# pcvm - small kernel used to run in vm" + }) +} diff --git a/hostinfo/hostinfo_test.go b/hostinfo/hostinfo_test.go index 9fe32e0449be1..15b6971b6ccd0 100644 --- a/hostinfo/hostinfo_test.go +++ b/hostinfo/hostinfo_test.go @@ -5,6 +5,7 @@ package hostinfo import ( "encoding/json" + "os" "strings" "testing" ) @@ -49,3 +50,31 @@ func TestEtcAptSourceFileIsDisabled(t *testing.T) { }) } } + +func TestCustomHostnameFunc(t *testing.T) { + want := "custom-hostname" + SetHostnameFn(func() (string, error) { + return want, nil + }) + + got, err := Hostname() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got != want { + t.Errorf("got %q, want %q", got, want) + } + + SetHostnameFn(os.Hostname) + got, err = Hostname() + want, _ = os.Hostname() + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != want { + t.Errorf("got %q, want %q", got, want) + } + +} diff --git a/hostinfo/wol.go b/hostinfo/wol.go deleted file mode 100644 index 3a30af2fe3a37..0000000000000 --- a/hostinfo/wol.go +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package hostinfo - -import ( - "log" - "net" - "runtime" - "strings" - "unicode" - - "tailscale.com/envknob" -) - -// TODO(bradfitz): this is all too simplistic and static. It needs to run -// continuously in response to netmon events (USB ethernet adapaters might get -// plugged in) and look for the media type/status/etc. Right now on macOS it -// still detects a half dozen "up" en0, en1, en2, en3 etc interfaces that don't -// have any media. We should only report the one that's actually connected. -// But it works for now (2023-10-05) for fleshing out the rest. - -var wakeMAC = envknob.RegisterString("TS_WAKE_MAC") // mac address, "false" or "auto". for https://github.com/tailscale/tailscale/issues/306 - -// getWoLMACs returns up to 10 MAC address of the local machine to send -// wake-on-LAN packets to in order to wake it up. The returned MACs are in -// lowercase hex colon-separated form ("xx:xx:xx:xx:xx:xx"). -// -// If TS_WAKE_MAC=auto, it tries to automatically find the MACs based on the OS -// type and interface properties. (TODO(bradfitz): incomplete) If TS_WAKE_MAC is -// set to a MAC address, that sole MAC address is returned. -func getWoLMACs() (macs []string) { - switch runtime.GOOS { - case "ios", "android": - return nil - } - if s := wakeMAC(); s != "" { - switch s { - case "auto": - ifs, _ := net.Interfaces() - for _, iface := range ifs { - if iface.Flags&net.FlagLoopback != 0 { - continue - } - if iface.Flags&net.FlagBroadcast == 0 || - iface.Flags&net.FlagRunning == 0 || - iface.Flags&net.FlagUp == 0 { - continue - } - if keepMAC(iface.Name, iface.HardwareAddr) { - macs = append(macs, iface.HardwareAddr.String()) - } - if len(macs) == 10 { - break - } - } - return macs - case "false", "off": // fast path before ParseMAC error - return nil - } - mac, err := net.ParseMAC(s) - if err != nil { - log.Printf("invalid MAC %q", s) - return nil - } - return []string{mac.String()} - } - return nil -} - -var ignoreWakeOUI = map[[3]byte]bool{ - {0x00, 0x15, 0x5d}: true, // Hyper-V - {0x00, 0x50, 0x56}: true, // VMware - {0x00, 0x1c, 0x14}: true, // VMware - {0x00, 0x05, 0x69}: true, // VMware - {0x00, 0x0c, 0x29}: true, // VMware - {0x00, 0x1c, 0x42}: true, // Parallels - {0x08, 0x00, 0x27}: true, // VirtualBox - {0x00, 0x21, 0xf6}: true, // VirtualBox - {0x00, 0x14, 0x4f}: true, // VirtualBox - {0x00, 0x0f, 0x4b}: true, // VirtualBox - {0x52, 0x54, 0x00}: true, // VirtualBox/Vagrant -} - -func keepMAC(ifName string, mac []byte) bool { - if len(mac) != 6 { - return false - } - base := strings.TrimRightFunc(ifName, unicode.IsNumber) - switch runtime.GOOS { - case "darwin": - switch base { - case "llw", "awdl", "utun", "bridge", "lo", "gif", "stf", "anpi", "ap": - return false - } - } - if mac[0] == 0x02 && mac[1] == 0x42 { - // Docker container. - return false - } - oui := [3]byte{mac[0], mac[1], mac[2]} - if ignoreWakeOUI[oui] { - return false - } - return true -} diff --git a/internal/client/tailscale/tailscale.go b/internal/client/tailscale/tailscale.go new file mode 100644 index 0000000000000..cba7228bbc8b3 --- /dev/null +++ b/internal/client/tailscale/tailscale.go @@ -0,0 +1,83 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tailscale provides a minimal control plane API client for internal +// use. A full client for 3rd party use is available at +// tailscale.com/client/tailscale/v2. The internal client is provided to avoid +// having to import that whole package. +package tailscale + +import ( + "errors" + "io" + "net/http" + + tsclient "tailscale.com/client/tailscale" +) + +// maxSize is the maximum read size (10MB) of responses from the server. +const maxReadSize = 10 << 20 + +func init() { + tsclient.I_Acknowledge_This_API_Is_Unstable = true +} + +// AuthMethod is an alias to tailscale.com/client/tailscale. +type AuthMethod = tsclient.AuthMethod + +// Device is an alias to tailscale.com/client/tailscale. +type Device = tsclient.Device + +// DeviceFieldsOpts is an alias to tailscale.com/client/tailscale. +type DeviceFieldsOpts = tsclient.DeviceFieldsOpts + +// Key is an alias to tailscale.com/client/tailscale. +type Key = tsclient.Key + +// KeyCapabilities is an alias to tailscale.com/client/tailscale. +type KeyCapabilities = tsclient.KeyCapabilities + +// KeyDeviceCapabilities is an alias to tailscale.com/client/tailscale. +type KeyDeviceCapabilities = tsclient.KeyDeviceCapabilities + +// KeyDeviceCreateCapabilities is an alias to tailscale.com/client/tailscale. +type KeyDeviceCreateCapabilities = tsclient.KeyDeviceCreateCapabilities + +// ErrResponse is an alias to tailscale.com/client/tailscale. +type ErrResponse = tsclient.ErrResponse + +// NewClient is an alias to tailscale.com/client/tailscale. +func NewClient(tailnet string, auth AuthMethod) *Client { + return &Client{ + Client: tsclient.NewClient(tailnet, auth), + } +} + +// Client is a wrapper of tailscale.com/client/tailscale. +type Client struct { + *tsclient.Client +} + +// HandleErrorResponse is an alias to tailscale.com/client/tailscale. +func HandleErrorResponse(b []byte, resp *http.Response) error { + return tsclient.HandleErrorResponse(b, resp) +} + +// SendRequest add the authentication key to the request and sends it. It +// receives the response and reads up to 10MB of it. +func SendRequest(c *Client, req *http.Request) ([]byte, *http.Response, error) { + resp, err := c.Do(req) + if err != nil { + return nil, resp, err + } + defer resp.Body.Close() + + // Read response. Limit the response to 10MB. + // This limit is carried over from client/tailscale/tailscale.go. + body := io.LimitReader(resp.Body, maxReadSize+1) + b, err := io.ReadAll(body) + if len(b) > maxReadSize { + err = errors.New("API response too large") + } + return b, resp, err +} diff --git a/internal/client/tailscale/vip_service.go b/internal/client/tailscale/vip_service.go new file mode 100644 index 0000000000000..64fcfdf5e86d6 --- /dev/null +++ b/internal/client/tailscale/vip_service.go @@ -0,0 +1,105 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tailscale + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + + "tailscale.com/tailcfg" + "tailscale.com/util/httpm" +) + +// VIPService is a Tailscale VIPService with Tailscale API JSON representation. +type VIPService struct { + // Name is a VIPService name in form svc:. + Name tailcfg.ServiceName `json:"name,omitempty"` + // Addrs are the IP addresses of the VIP Service. There are two addresses: + // the first is IPv4 and the second is IPv6. + // When creating a new VIP Service, the IP addresses are optional: if no + // addresses are specified then they will be selected. If an IPv4 address is + // specified at index 0, then that address will attempt to be used. An IPv6 + // address can not be specified upon creation. + Addrs []string `json:"addrs,omitempty"` + // Comment is an optional text string for display in the admin panel. + Comment string `json:"comment,omitempty"` + // Annotations are optional key-value pairs that can be used to store arbitrary metadata. + Annotations map[string]string `json:"annotations,omitempty"` + // Ports are the ports of a VIPService that will be configured via Tailscale serve config. + // If set, any node wishing to advertise this VIPService must have this port configured via Tailscale serve. + Ports []string `json:"ports,omitempty"` + // Tags are optional ACL tags that will be applied to the VIPService. + Tags []string `json:"tags,omitempty"` +} + +// GetVIPService retrieves a VIPService by its name. It returns 404 if the VIPService is not found. +func (client *Client) GetVIPService(ctx context.Context, name tailcfg.ServiceName) (*VIPService, error) { + path := client.BuildTailnetURL("vip-services", name.String()) + req, err := http.NewRequestWithContext(ctx, httpm.GET, path, nil) + if err != nil { + return nil, fmt.Errorf("error creating new HTTP request: %w", err) + } + b, resp, err := SendRequest(client, req) + if err != nil { + return nil, fmt.Errorf("error making Tailsale API request: %w", err) + } + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return nil, HandleErrorResponse(b, resp) + } + svc := &VIPService{} + if err := json.Unmarshal(b, svc); err != nil { + return nil, err + } + return svc, nil +} + +// CreateOrUpdateVIPService creates or updates a VIPService by its name. Caller must ensure that, if the +// VIPService already exists, the VIPService is fetched first to ensure that any auto-allocated IP addresses are not +// lost during the update. If the VIPService was created without any IP addresses explicitly set (so that they were +// auto-allocated by Tailscale) any subsequent request to this function that does not set any IP addresses will error. +func (client *Client) CreateOrUpdateVIPService(ctx context.Context, svc *VIPService) error { + data, err := json.Marshal(svc) + if err != nil { + return err + } + path := client.BuildTailnetURL("vip-services", svc.Name.String()) + req, err := http.NewRequestWithContext(ctx, httpm.PUT, path, bytes.NewBuffer(data)) + if err != nil { + return fmt.Errorf("error creating new HTTP request: %w", err) + } + b, resp, err := SendRequest(client, req) + if err != nil { + return fmt.Errorf("error making Tailscale API request: %w", err) + } + // If status code was not successful, return the error. + // TODO: Change the check for the StatusCode to include other 2XX success codes. + if resp.StatusCode != http.StatusOK { + return HandleErrorResponse(b, resp) + } + return nil +} + +// DeleteVIPService deletes a VIPService by its name. It returns an error if the VIPService +// does not exist or if the deletion fails. +func (client *Client) DeleteVIPService(ctx context.Context, name tailcfg.ServiceName) error { + path := client.BuildTailnetURL("vip-services", name.String()) + req, err := http.NewRequestWithContext(ctx, httpm.DELETE, path, nil) + if err != nil { + return fmt.Errorf("error creating new HTTP request: %w", err) + } + b, resp, err := SendRequest(client, req) + if err != nil { + return fmt.Errorf("error making Tailscale API request: %w", err) + } + // If status code was not successful, return the error. + if resp.StatusCode != http.StatusOK { + return HandleErrorResponse(b, resp) + } + return nil +} diff --git a/ipn/auditlog/auditlog.go b/ipn/auditlog/auditlog.go new file mode 100644 index 0000000000000..0460bc4e2c655 --- /dev/null +++ b/ipn/auditlog/auditlog.go @@ -0,0 +1,468 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package auditlog provides a mechanism for logging audit events. +package auditlog + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sort" + "sync" + "time" + + "tailscale.com/ipn" + "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/util/rands" + "tailscale.com/util/set" +) + +// transaction represents an audit log that has not yet been sent to the control plane. +type transaction struct { + // EventID is the unique identifier for the event being logged. + // This is used on the client side only and is not sent to control. + EventID string `json:",omitempty"` + // Retries is the number of times the logger has attempted to send this log. + // This is used on the client side only and is not sent to control. + Retries int `json:",omitempty"` + + // Action is the action to be logged. It must correspond to a known action in the control plane. + Action tailcfg.ClientAuditAction `json:",omitempty"` + // Details is an opaque string specific to the action being logged. Empty strings may not + // be valid depending on the action being logged. + Details string `json:",omitempty"` + // TimeStamp is the time at which the audit log was generated on the node. + TimeStamp time.Time `json:",omitzero"` +} + +// Transport provides a means for a client to send audit logs to a consumer (typically the control plane). +type Transport interface { + // SendAuditLog sends an audit log to a consumer of audit logs. + // Errors should be checked with [IsRetryableError] for retryability. + SendAuditLog(context.Context, tailcfg.AuditLogRequest) error +} + +// LogStore provides a means for a [Logger] to persist logs to disk or memory. +type LogStore interface { + // Save saves the given data to a persistent store. Save will overwrite existing data + // for the given key. + save(key ipn.ProfileID, txns []*transaction) error + + // Load retrieves the data from a persistent store. Returns a nil slice and + // no error if no data exists for the given key. + load(key ipn.ProfileID) ([]*transaction, error) +} + +// Opts contains the configuration options for a [Logger]. +type Opts struct { + // RetryLimit is the maximum number of attempts the logger will make to send a log before giving up. + RetryLimit int + // Store is the persistent store used to save logs to disk. Must be non-nil. + Store LogStore + // Logf is the logger used to log messages from the audit logger. Must be non-nil. + Logf logger.Logf +} + +// IsRetryableError returns true if the given error is retryable +// See [controlclient.apiResponseError]. Potentially retryable errors implement the Retryable() method. +func IsRetryableError(err error) bool { + var retryable interface{ Retryable() bool } + return errors.As(err, &retryable) && retryable.Retryable() +} + +type backoffOpts struct { + min, max time.Duration + multiplier float64 +} + +// .5, 1, 2, 4, 8, 10, 10, 10, 10, 10... +var defaultBackoffOpts = backoffOpts{ + min: time.Millisecond * 500, + max: 10 * time.Second, + multiplier: 2, +} + +// Logger provides a queue-based mechanism for submitting audit logs to the control plane - or +// another suitable consumer. Logs are stored to disk and retried until they are successfully sent, +// or until they permanently fail. +// +// Each individual profile/controlclient tuple should construct and manage a unique [Logger] instance. +type Logger struct { + logf logger.Logf + retryLimit int // the maximum number of attempts to send a log before giving up. + flusher chan struct{} // channel used to signal a flush operation. + done chan struct{} // closed when the flush worker exits. + ctx context.Context // canceled when the logger is stopped. + ctxCancel context.CancelFunc // cancels ctx. + backoffOpts // backoff settings for retry operations. + + // mu protects the fields below. + mu sync.Mutex + store LogStore // persistent storage for unsent logs. + profileID ipn.ProfileID // empty if [Logger.SetProfileID] has not been called. + transport Transport // nil until [Logger.Start] is called. +} + +// NewLogger creates a new [Logger] with the given options. +func NewLogger(opts Opts) *Logger { + ctx, cancel := context.WithCancel(context.Background()) + + al := &Logger{ + retryLimit: opts.RetryLimit, + logf: opts.Logf, + store: opts.Store, + flusher: make(chan struct{}, 1), + done: make(chan struct{}), + ctx: ctx, + ctxCancel: cancel, + backoffOpts: defaultBackoffOpts, + } + al.logf("created") + return al +} + +// FlushAndStop synchronously flushes all pending logs and stops the audit logger. +// This will block until a final flush operation completes or context is done. +// If the logger is already stopped, this will return immediately. All unsent +// logs will be persisted to the store. +func (al *Logger) FlushAndStop(ctx context.Context) { + al.stop() + al.flush(ctx) +} + +// SetProfileID sets the profileID for the logger. This must be called before any logs can be enqueued. +// The profileID of a logger cannot be changed once set. +func (al *Logger) SetProfileID(profileID ipn.ProfileID) error { + al.mu.Lock() + defer al.mu.Unlock() + // It's not an error to call SetProfileID more than once + // with the same [ipn.ProfileID]. + if al.profileID != "" && al.profileID != profileID { + return errors.New("profileID cannot be changed once set") + } + + al.profileID = profileID + return nil +} + +// Start starts the audit logger with the given transport. +// It returns an error if the logger is already started. +func (al *Logger) Start(t Transport) error { + al.mu.Lock() + defer al.mu.Unlock() + + if al.transport != nil { + return errors.New("already started") + } + + al.transport = t + pending, err := al.storedCountLocked() + if err != nil { + al.logf("[unexpected] failed to restore logs: %v", err) + } + go al.flushWorker() + if pending > 0 { + al.flushAsync() + } + return nil +} + +// ErrAuditLogStorageFailure is returned when the logger fails to persist logs to the store. +var ErrAuditLogStorageFailure = errors.New("audit log storage failure") + +// Enqueue queues an audit log to be sent to the control plane (or another suitable consumer/transport). +// This will return an error if the underlying store fails to save the log or we fail to generate a unique +// eventID for the log. +func (al *Logger) Enqueue(action tailcfg.ClientAuditAction, details string) error { + txn := &transaction{ + Action: action, + Details: details, + TimeStamp: time.Now(), + } + // Generate a suitably random eventID for the transaction. + txn.EventID = fmt.Sprint(txn.TimeStamp, rands.HexString(16)) + return al.enqueue(txn) +} + +// flushAsync requests an asynchronous flush. +// It is a no-op if a flush is already pending. +func (al *Logger) flushAsync() { + select { + case al.flusher <- struct{}{}: + default: + } +} + +func (al *Logger) flushWorker() { + defer close(al.done) + + var retryDelay time.Duration + retry := time.NewTimer(0) + retry.Stop() + + for { + select { + case <-al.ctx.Done(): + return + case <-al.flusher: + err := al.flush(al.ctx) + switch { + case errors.Is(err, context.Canceled): + // The logger was stopped, no need to retry. + return + case err != nil: + retryDelay = max(al.backoffOpts.min, min(retryDelay*time.Duration(al.backoffOpts.multiplier), al.backoffOpts.max)) + al.logf("retrying after %v, %v", retryDelay, err) + retry.Reset(retryDelay) + default: + retryDelay = 0 + retry.Stop() + } + case <-retry.C: + al.flushAsync() + } + } +} + +// flush attempts to send all pending logs to the control plane. +// l.mu must not be held. +func (al *Logger) flush(ctx context.Context) error { + al.mu.Lock() + pending, err := al.store.load(al.profileID) + t := al.transport + al.mu.Unlock() + + if err != nil { + // This will catch nil profileIDs + return fmt.Errorf("failed to restore pending logs: %w", err) + } + if len(pending) == 0 { + return nil + } + if t == nil { + return errors.New("no transport") + } + + complete, unsent := al.sendToTransport(ctx, pending, t) + al.markTransactionsDone(complete) + + al.mu.Lock() + defer al.mu.Unlock() + if err = al.appendToStoreLocked(unsent); err != nil { + al.logf("[unexpected] failed to persist logs: %v", err) + } + + if len(unsent) != 0 { + return fmt.Errorf("failed to send %d logs", len(unsent)) + } + + if len(complete) != 0 { + al.logf("complete %d audit log transactions", len(complete)) + } + return nil +} + +// sendToTransport sends all pending logs to the control plane. Returns a pair of slices +// containing the logs that were successfully sent (or failed permanently) and those that were not. +// +// This may require multiple round trips to the control plane and can be a long running transaction. +func (al *Logger) sendToTransport(ctx context.Context, pending []*transaction, t Transport) (complete []*transaction, unsent []*transaction) { + for i, txn := range pending { + req := tailcfg.AuditLogRequest{ + Action: tailcfg.ClientAuditAction(txn.Action), + Details: txn.Details, + Timestamp: txn.TimeStamp, + } + + if err := t.SendAuditLog(ctx, req); err != nil { + switch { + case errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded): + // The contex is done. All further attempts will fail. + unsent = append(unsent, pending[i:]...) + return complete, unsent + case IsRetryableError(err) && txn.Retries+1 < al.retryLimit: + // We permit a maximum number of retries for each log. All retriable + // errors should be transient and we should be able to send the log eventually, but + // we don't want logs to be persisted indefinitely. + txn.Retries++ + unsent = append(unsent, txn) + default: + complete = append(complete, txn) + al.logf("failed permanently: %v", err) + } + } else { + // No error - we're done. + complete = append(complete, txn) + } + } + + return complete, unsent +} + +func (al *Logger) stop() { + al.mu.Lock() + t := al.transport + al.mu.Unlock() + + if t == nil { + // No transport means no worker goroutine and done will not be + // closed if we cancel the context. + return + } + + al.ctxCancel() + <-al.done + al.logf("stopped for profileID: %v", al.profileID) +} + +// appendToStoreLocked persists logs to the store. This will deduplicate +// logs so it is safe to call this with the same logs multiple time, to +// requeue failed transactions for example. +// +// l.mu must be held. +func (al *Logger) appendToStoreLocked(txns []*transaction) error { + if len(txns) == 0 { + return nil + } + + if al.profileID == "" { + return errors.New("no logId set") + } + + persisted, err := al.store.load(al.profileID) + if err != nil { + al.logf("[unexpected] append failed to restore logs: %v", err) + } + + // The order is important here. We want the latest transactions first, which will + // ensure when we dedup, the new transactions are seen and the older transactions + // are discarded. + txnsOut := append(txns, persisted...) + txnsOut = deduplicateAndSort(txnsOut) + + return al.store.save(al.profileID, txnsOut) +} + +// storedCountLocked returns the number of logs persisted to the store. +// al.mu must be held. +func (al *Logger) storedCountLocked() (int, error) { + persisted, err := al.store.load(al.profileID) + return len(persisted), err +} + +// markTransactionsDone removes logs from the store that are complete (sent or failed permanently). +// al.mu must not be held. +func (al *Logger) markTransactionsDone(sent []*transaction) { + al.mu.Lock() + defer al.mu.Unlock() + + ids := set.Set[string]{} + for _, txn := range sent { + ids.Add(txn.EventID) + } + + persisted, err := al.store.load(al.profileID) + if err != nil { + al.logf("[unexpected] markTransactionsDone failed to restore logs: %v", err) + } + var unsent []*transaction + for _, txn := range persisted { + if !ids.Contains(txn.EventID) { + unsent = append(unsent, txn) + } + } + al.store.save(al.profileID, unsent) +} + +// deduplicateAndSort removes duplicate logs from the given slice and sorts them by timestamp. +// The first log entry in the slice will be retained, subsequent logs with the same EventID will be discarded. +func deduplicateAndSort(txns []*transaction) []*transaction { + seen := set.Set[string]{} + deduped := make([]*transaction, 0, len(txns)) + for _, txn := range txns { + if !seen.Contains(txn.EventID) { + deduped = append(deduped, txn) + seen.Add(txn.EventID) + } + } + // Sort logs by timestamp - oldest to newest. This will put the oldest logs at + // the front of the queue. + sort.Slice(deduped, func(i, j int) bool { + return deduped[i].TimeStamp.Before(deduped[j].TimeStamp) + }) + return deduped +} + +func (al *Logger) enqueue(txn *transaction) error { + al.mu.Lock() + defer al.mu.Unlock() + + if err := al.appendToStoreLocked([]*transaction{txn}); err != nil { + return fmt.Errorf("%w: %w", ErrAuditLogStorageFailure, err) + } + + // If a.transport is nil if the logger is stopped. + if al.transport != nil { + al.flushAsync() + } + + return nil +} + +var _ LogStore = (*logStateStore)(nil) + +// logStateStore is a concrete implementation of [LogStore] +// using [ipn.StateStore] as the underlying storage. +type logStateStore struct { + store ipn.StateStore +} + +// NewLogStore creates a new LogStateStore with the given [ipn.StateStore]. +func NewLogStore(store ipn.StateStore) LogStore { + return &logStateStore{ + store: store, + } +} + +func (s *logStateStore) generateKey(key ipn.ProfileID) string { + return "auditlog-" + string(key) +} + +// Save saves the given logs to an [ipn.StateStore]. This overwrites +// any existing entries for the given key. +func (s *logStateStore) save(key ipn.ProfileID, txns []*transaction) error { + if key == "" { + return errors.New("empty key") + } + + data, err := json.Marshal(txns) + if err != nil { + return err + } + k := ipn.StateKey(s.generateKey(key)) + return s.store.WriteState(k, data) +} + +// Load retrieves the logs from an [ipn.StateStore]. +func (s *logStateStore) load(key ipn.ProfileID) ([]*transaction, error) { + if key == "" { + return nil, errors.New("empty key") + } + + k := ipn.StateKey(s.generateKey(key)) + data, err := s.store.ReadState(k) + + switch { + case errors.Is(err, ipn.ErrStateNotExist): + return nil, nil + case err != nil: + return nil, err + } + + var txns []*transaction + err = json.Unmarshal(data, &txns) + return txns, err +} diff --git a/ipn/auditlog/auditlog_test.go b/ipn/auditlog/auditlog_test.go new file mode 100644 index 0000000000000..041cab3546bd0 --- /dev/null +++ b/ipn/auditlog/auditlog_test.go @@ -0,0 +1,484 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package auditlog + +import ( + "context" + "errors" + "fmt" + "sync" + "testing" + "time" + + qt "github.com/frankban/quicktest" + "tailscale.com/ipn/store/mem" + "tailscale.com/tailcfg" + "tailscale.com/tstest" +) + +// loggerForTest creates an auditLogger for you and cleans it up +// (and ensures no goroutines are leaked) when the test is done. +func loggerForTest(t *testing.T, opts Opts) *Logger { + t.Helper() + tstest.ResourceCheck(t) + + if opts.Logf == nil { + opts.Logf = t.Logf + } + + if opts.Store == nil { + t.Fatalf("opts.Store must be set") + } + + a := NewLogger(opts) + + t.Cleanup(func() { + a.FlushAndStop(context.Background()) + }) + return a +} + +func TestNonRetryableErrors(t *testing.T) { + errorTests := []struct { + desc string + err error + want bool + }{ + {"DeadlineExceeded", context.DeadlineExceeded, false}, + {"Canceled", context.Canceled, false}, + {"Canceled wrapped", fmt.Errorf("%w: %w", context.Canceled, errors.New("ctx cancelled")), false}, + {"Random error", errors.New("random error"), false}, + } + + for _, tt := range errorTests { + t.Run(tt.desc, func(t *testing.T) { + if IsRetryableError(tt.err) != tt.want { + t.Fatalf("retriable: got %v, want %v", !tt.want, tt.want) + } + }) + } +} + +// TestEnqueueAndFlush enqueues n logs and flushes them. +// We expect all logs to be flushed and for no +// logs to remain in the store once FlushAndStop returns. +func TestEnqueueAndFlush(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(nil) + al := loggerForTest(t, Opts{ + RetryLimit: 200, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + c.Assert(al.SetProfileID("test"), qt.IsNil) + c.Assert(al.Start(mockTransport), qt.IsNil) + + wantSent := 10 + + for i := range wantSent { + err := al.Enqueue(tailcfg.AuditNodeDisconnect, fmt.Sprintf("log %d", i)) + c.Assert(err, qt.IsNil) + } + + al.FlushAndStop(context.Background()) + + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent := mockTransport.sentCount(); gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestEnqueueAndFlushWithFlushCancel calls FlushAndCancel with a cancelled +// context. We expect nothing to be sent and all logs to be stored. +func TestEnqueueAndFlushWithFlushCancel(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(&retriableError) + al := loggerForTest(t, Opts{ + RetryLimit: 200, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + c.Assert(al.SetProfileID("test"), qt.IsNil) + c.Assert(al.Start(mockTransport), qt.IsNil) + + for i := range 10 { + err := al.Enqueue(tailcfg.AuditNodeDisconnect, fmt.Sprintf("log %d", i)) + c.Assert(err, qt.IsNil) + } + + // Cancel the context before calling FlushAndStop - nothing should get sent. + // This mimics a timeout before flush() has a chance to execute. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + al.FlushAndStop(ctx) + + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 10; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent, wantSent := mockTransport.sentCount(), 0; gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestDeduplicateAndSort tests that the most recent log is kept when deduplicating logs +func TestDeduplicateAndSort(t *testing.T) { + c := qt.New(t) + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + c.Assert(al.SetProfileID("test"), qt.IsNil) + + logs := []*transaction{ + {EventID: "1", Details: "log 1", TimeStamp: time.Now().Add(-time.Minute * 1), Retries: 1}, + } + + al.mu.Lock() + defer al.mu.Unlock() + al.appendToStoreLocked(logs) + + // Update the transaction and re-append it + logs[0].Retries = 2 + al.appendToStoreLocked(logs) + + fromStore, err := al.store.load("test") + c.Assert(err, qt.IsNil) + + // We should see only one transaction + if wantStored, gotStored := len(logs), len(fromStore); gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + // We should see the latest transaction + if wantRetryCount, gotRetryCount := 2, fromStore[0].Retries; gotRetryCount != wantRetryCount { + t.Fatalf("reties: got %d, want %d", gotRetryCount, wantRetryCount) + } +} + +func TestChangeProfileId(t *testing.T) { + c := qt.New(t) + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + c.Assert(al.SetProfileID("test"), qt.IsNil) + + // Calling SetProfileID with the same profile ID must not fail. + c.Assert(al.SetProfileID("test"), qt.IsNil) + + // Changing a profile ID must fail. + c.Assert(al.SetProfileID("test2"), qt.IsNotNil) +} + +// TestSendOnRestore pushes a n logs to the persistent store, and ensures they +// are sent as soon as Start is called then checks to ensure the sent logs no +// longer exist in the store. +func TestSendOnRestore(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(nil) + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + al.SetProfileID("test") + + wantTotal := 10 + + for range 10 { + al.Enqueue(tailcfg.AuditNodeDisconnect, "log") + } + + c.Assert(al.Start(mockTransport), qt.IsNil) + + al.FlushAndStop(context.Background()) + + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent, wantSent := mockTransport.sentCount(), wantTotal; gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestFailureExhaustion enqueues n logs, with the transport in a failable state. +// We then set it to a non-failing state, call FlushAndStop and expect all logs to be sent. +func TestFailureExhaustion(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(&retriableError) + + al := loggerForTest(t, Opts{ + RetryLimit: 1, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + c.Assert(al.SetProfileID("test"), qt.IsNil) + c.Assert(al.Start(mockTransport), qt.IsNil) + + for range 10 { + err := al.Enqueue(tailcfg.AuditNodeDisconnect, "log") + c.Assert(err, qt.IsNil) + } + + al.FlushAndStop(context.Background()) + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent, wantSent := mockTransport.sentCount(), 0; gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestEnqueueAndFailNoRetry enqueues a set of logs, all of which will fail and are not +// retriable. We then call FlushAndStop and expect all to be unsent. +func TestEnqueueAndFailNoRetry(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(&nonRetriableError) + + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + c.Assert(al.SetProfileID("test"), qt.IsNil) + c.Assert(al.Start(mockTransport), qt.IsNil) + + for i := range 10 { + err := al.Enqueue(tailcfg.AuditNodeDisconnect, fmt.Sprintf("log %d", i)) + c.Assert(err, qt.IsNil) + } + + al.FlushAndStop(context.Background()) + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent, wantSent := mockTransport.sentCount(), 0; gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestEnqueueAndRetry enqueues a set of logs, all of which will fail and are retriable. +// Mid-test, we set the transport to not-fail and expect the queue to flush properly +// We set the backoff parameters to 0 seconds so retries are immediate. +func TestEnqueueAndRetry(t *testing.T) { + c := qt.New(t) + mockTransport := newMockTransport(&retriableError) + + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + al.backoffOpts = backoffOpts{ + min: 1 * time.Millisecond, + max: 4 * time.Millisecond, + multiplier: 2.0, + } + + c.Assert(al.SetProfileID("test"), qt.IsNil) + c.Assert(al.Start(mockTransport), qt.IsNil) + + err := al.Enqueue(tailcfg.AuditNodeDisconnect, fmt.Sprintf("log 1")) + c.Assert(err, qt.IsNil) + + // This will wait for at least 2 retries + gotRetried, wantRetried := mockTransport.waitForSendAttemptsToReach(3), true + if gotRetried != wantRetried { + t.Fatalf("retried: got %v, want %v", gotRetried, wantRetried) + } + + mockTransport.setErrorCondition(nil) + + al.FlushAndStop(context.Background()) + al.mu.Lock() + defer al.mu.Unlock() + + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } + + if gotSent, wantSent := mockTransport.sentCount(), 1; gotSent != wantSent { + t.Fatalf("sent: got %d, want %d", gotSent, wantSent) + } +} + +// TestEnqueueBeforeSetProfileID tests that logs enqueued before SetProfileId are not sent +func TestEnqueueBeforeSetProfileID(t *testing.T) { + c := qt.New(t) + al := loggerForTest(t, Opts{ + RetryLimit: 100, + Logf: t.Logf, + Store: NewLogStore(&mem.Store{}), + }) + + err := al.Enqueue(tailcfg.AuditNodeDisconnect, "log") + c.Assert(err, qt.IsNotNil) + al.FlushAndStop(context.Background()) + + al.mu.Lock() + defer al.mu.Unlock() + gotStored, err := al.storedCountLocked() + c.Assert(err, qt.IsNotNil) + + if wantStored := 0; gotStored != wantStored { + t.Fatalf("stored: got %d, want %d", gotStored, wantStored) + } +} + +// TestLogStoring tests that audit logs are persisted sorted by timestamp, oldest to newest +func TestLogSorting(t *testing.T) { + c := qt.New(t) + mockStore := NewLogStore(&mem.Store{}) + + logs := []*transaction{ + {EventID: "1", Details: "log 3", TimeStamp: time.Now().Add(-time.Minute * 1)}, + {EventID: "1", Details: "log 3", TimeStamp: time.Now().Add(-time.Minute * 2)}, + {EventID: "2", Details: "log 2", TimeStamp: time.Now().Add(-time.Minute * 3)}, + {EventID: "3", Details: "log 1", TimeStamp: time.Now().Add(-time.Minute * 4)}, + } + + wantLogs := []transaction{ + {Details: "log 1"}, + {Details: "log 2"}, + {Details: "log 3"}, + } + + mockStore.save("test", logs) + + gotLogs, err := mockStore.load("test") + c.Assert(err, qt.IsNil) + gotLogs = deduplicateAndSort(gotLogs) + + for i := range gotLogs { + if want, got := wantLogs[i].Details, gotLogs[i].Details; want != got { + t.Fatalf("Details: got %v, want %v", got, want) + } + } +} + +// mock implementations for testing + +// newMockTransport returns a mock transport for testing +// If err is no nil, SendAuditLog will return this error if the send is attempted +// before the context is cancelled. +func newMockTransport(err error) *mockAuditLogTransport { + return &mockAuditLogTransport{ + err: err, + attempts: make(chan int, 1), + } +} + +type mockAuditLogTransport struct { + attempts chan int // channel to notify of send attempts + + mu sync.Mutex + sendAttmpts int // number of attempts to send logs + sendCount int // number of logs sent by the transport + err error // error to return when sending logs +} + +// waitForSendAttemptsToReach blocks until the number of send attempts reaches n +// This should be use only in tests where the transport is expected to retry sending logs +func (t *mockAuditLogTransport) waitForSendAttemptsToReach(n int) bool { + for attempts := range t.attempts { + if attempts >= n { + return true + } + } + return false +} + +func (t *mockAuditLogTransport) setErrorCondition(err error) { + t.mu.Lock() + defer t.mu.Unlock() + t.err = err +} + +func (t *mockAuditLogTransport) sentCount() int { + t.mu.Lock() + defer t.mu.Unlock() + return t.sendCount +} + +func (t *mockAuditLogTransport) SendAuditLog(ctx context.Context, _ tailcfg.AuditLogRequest) (err error) { + t.mu.Lock() + t.sendAttmpts += 1 + defer func() { + a := t.sendAttmpts + t.mu.Unlock() + select { + case t.attempts <- a: + default: + } + }() + + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + if t.err != nil { + return t.err + } + t.sendCount += 1 + return nil +} + +var ( + retriableError = mockError{errors.New("retriable error")} + nonRetriableError = mockError{errors.New("permanent failure error")} +) + +type mockError struct { + error +} + +func (e mockError) Retryable() bool { + return e == retriableError +} diff --git a/ipn/auditlog/extension.go b/ipn/auditlog/extension.go new file mode 100644 index 0000000000000..f73681db073c1 --- /dev/null +++ b/ipn/auditlog/extension.go @@ -0,0 +1,189 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package auditlog + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "tailscale.com/control/controlclient" + "tailscale.com/feature" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnext" + "tailscale.com/tailcfg" + "tailscale.com/types/lazy" + "tailscale.com/types/logger" +) + +// featureName is the name of the feature implemented by this package. +// It is also the [extension] name and the log prefix. +const featureName = "auditlog" + +func init() { + feature.Register(featureName) + ipnext.RegisterExtension(featureName, newExtension) +} + +// extension is an [ipnext.Extension] managing audit logging +// on platforms that import this package. +// As of 2025-03-27, that's only Windows and macOS. +type extension struct { + logf logger.Logf + + // store is the log store shared by all loggers. + // It is created when the first logger is started. + store lazy.SyncValue[LogStore] + + // mu protects all following fields. + mu sync.Mutex + // logger is the current audit logger, or nil if it is not set up, + // such as before the first control client is created, or after + // a profile change and before the new control client is created. + // + // It queues, persists, and sends audit logs to the control client. + logger *Logger +} + +// newExtension is an [ipnext.NewExtensionFn] that creates a new audit log extension. +// It is registered with [ipnext.RegisterExtension] if the package is imported. +func newExtension(logf logger.Logf, _ ipnext.SafeBackend) (ipnext.Extension, error) { + return &extension{logf: logger.WithPrefix(logf, featureName+": ")}, nil +} + +// Name implements [ipnext.Extension]. +func (e *extension) Name() string { + return featureName +} + +// Init implements [ipnext.Extension] by registering callbacks and providers +// for the duration of the extension's lifetime. +func (e *extension) Init(h ipnext.Host) error { + h.Hooks().NewControlClient.Add(e.controlClientChanged) + h.Hooks().ProfileStateChange.Add(e.profileChanged) + h.Hooks().AuditLoggers.Add(e.getCurrentLogger) + return nil +} + +// [controlclient.Auto] implements [Transport]. +var _ Transport = (*controlclient.Auto)(nil) + +// startNewLogger creates and starts a new logger for the specified profile +// using the specified [controlclient.Client] as the transport. +// The profileID may be "" if the profile has not been persisted yet. +func (e *extension) startNewLogger(cc controlclient.Client, profileID ipn.ProfileID) (*Logger, error) { + transport, ok := cc.(Transport) + if !ok { + return nil, fmt.Errorf("%T cannot be used as transport", cc) + } + + // Create a new log store if this is the first logger. + // Otherwise, get the existing log store. + store, err := e.store.GetErr(func() (LogStore, error) { + return newDefaultLogStore(e.logf) + }) + if err != nil { + return nil, fmt.Errorf("failed to create audit log store: %w", err) + } + + logger := NewLogger(Opts{ + Logf: e.logf, + RetryLimit: 32, + Store: store, + }) + if err := logger.SetProfileID(profileID); err != nil { + return nil, fmt.Errorf("set profile failed: %w", err) + } + if err := logger.Start(transport); err != nil { + return nil, fmt.Errorf("start failed: %w", err) + } + return logger, nil +} + +func (e *extension) controlClientChanged(cc controlclient.Client, profile ipn.LoginProfileView) (cleanup func()) { + logger, err := e.startNewLogger(cc, profile.ID()) + e.mu.Lock() + e.logger = logger // nil on error + e.mu.Unlock() + if err != nil { + // If we fail to create or start the logger, log the error + // and return a nil cleanup function. There's nothing more + // we can do here. + // + // But [extension.getCurrentLogger] returns [noCurrentLogger] + // when the logger is nil. Since [noCurrentLogger] always + // fails with [errNoLogger], operations that must be audited + // but cannot will fail on platforms where the audit logger + // is enabled (i.e., the auditlog package is imported). + e.logf("[unexpected] %v", err) + return nil + } + return func() { + // Stop the logger when the control client shuts down. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + logger.FlushAndStop(ctx) + } +} + +func (e *extension) profileChanged(profile ipn.LoginProfileView, _ ipn.PrefsView, sameNode bool) { + e.mu.Lock() + defer e.mu.Unlock() + switch { + case e.logger == nil: + // No-op if we don't have an audit logger. + case sameNode: + // The profile info has changed, but it represents the same node. + // This includes the case where the login has just been completed + // and the profile's [ipn.ProfileID] has been set for the first time. + if err := e.logger.SetProfileID(profile.ID()); err != nil { + e.logf("[unexpected] failed to set profile ID: %v", err) + } + default: + // The profile info has changed, and it represents a different node. + // We won't have an audit logger for the new profile until the new + // control client is created. + // + // We don't expect any auditable actions to be attempted in this state. + // But if they are, they will fail with [errNoLogger]. + e.logger = nil + } +} + +// errNoLogger is an error returned by [noCurrentLogger]. It indicates that +// the logger was unavailable when [ipnlocal.LocalBackend] requested it, +// such as when an auditable action was attempted before [LocalBackend.Start] +// was called for the first time or immediately after a profile change +// and before the new control client was created. +// +// This error is unexpected and should not occur in normal operation. +var errNoLogger = errors.New("[unexpected] no audit logger") + +// noCurrentLogger is an [ipnauth.AuditLogFunc] returned by [extension.getCurrentLogger] +// when the logger is not available. It fails with [errNoLogger] on every call. +func noCurrentLogger(_ tailcfg.ClientAuditAction, _ string) error { + return errNoLogger +} + +// getCurrentLogger is an [ipnext.AuditLogProvider] registered with [ipnext.Host]. +// It is called when [ipnlocal.LocalBackend] or an extension needs to audit an action. +// +// It returns a function that enqueues the audit log for the current profile, +// or [noCurrentLogger] if the logger is unavailable. +func (e *extension) getCurrentLogger() ipnauth.AuditLogFunc { + e.mu.Lock() + defer e.mu.Unlock() + if e.logger == nil { + return noCurrentLogger + } + return e.logger.Enqueue +} + +// Shutdown implements [ipnlocal.Extension]. +func (e *extension) Shutdown() error { + return nil +} diff --git a/ipn/auditlog/store.go b/ipn/auditlog/store.go new file mode 100644 index 0000000000000..3b58ffa9318a2 --- /dev/null +++ b/ipn/auditlog/store.go @@ -0,0 +1,62 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package auditlog + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + + "tailscale.com/ipn/store" + "tailscale.com/types/lazy" + "tailscale.com/types/logger" + "tailscale.com/util/must" +) + +var storeFilePath lazy.SyncValue[string] + +// SetStoreFilePath sets the audit log store file path. +// It is optional on platforms with a default store path, +// but required on platforms without one (e.g., macOS). +// It panics if called more than once or after the store has been created. +func SetStoreFilePath(path string) { + if !storeFilePath.Set(path) { + panic("store file path already set or used") + } +} + +// DefaultStoreFilePath returns the default audit log store file path +// for the current platform, or an error if the platform does not have one. +func DefaultStoreFilePath() (string, error) { + switch runtime.GOOS { + case "windows": + return filepath.Join(os.Getenv("ProgramData"), "Tailscale", "audit-log.json"), nil + default: + // The auditlog package must either be omitted from the build, + // have the platform-specific store path set with [SetStoreFilePath] (e.g., on macOS), + // or have the default store path available on the current platform. + return "", fmt.Errorf("[unexpected] no default store path available on %s", runtime.GOOS) + } +} + +// newDefaultLogStore returns a new [LogStore] for the current platform. +func newDefaultLogStore(logf logger.Logf) (LogStore, error) { + path, err := storeFilePath.GetErr(DefaultStoreFilePath) + if err != nil { + // This indicates that the auditlog package was not omitted from the build + // on a platform without a default store path and that [SetStoreFilePath] + // was not called to set a platform-specific store path. + // + // This is not expected to happen, but if it does, let's log it + // and use an in-memory store as a fallback. + logf("[unexpected] failed to get audit log store path: %v", err) + return NewLogStore(must.Get(store.New(logf, "mem:auditlog"))), nil + } + fs, err := store.New(logf, path) + if err != nil { + return nil, fmt.Errorf("failed to create audit log store at %q: %w", path, err) + } + return NewLogStore(fs), nil +} diff --git a/ipn/backend.go b/ipn/backend.go index 76ad1910bf14c..3e956f4734f5f 100644 --- a/ipn/backend.go +++ b/ipn/backend.go @@ -58,21 +58,29 @@ type EngineStatus struct { // to subscribe to. type NotifyWatchOpt uint64 +// NotifyWatchOpt values. +// +// These aren't declared using Go's iota because they're not purely internal to +// the process and iota should not be used for values that are serialized to +// disk or network. In this case, these values come over the network via the +// LocalAPI, a mostly stable API. const ( // NotifyWatchEngineUpdates, if set, causes Engine updates to be sent to the // client either regularly or when they change, without having to ask for // each one via Engine.RequestStatus. - NotifyWatchEngineUpdates NotifyWatchOpt = 1 << iota + NotifyWatchEngineUpdates NotifyWatchOpt = 1 << 0 + + NotifyInitialState NotifyWatchOpt = 1 << 1 // if set, the first Notify message (sent immediately) will contain the current State + BrowseToURL + SessionID + NotifyInitialPrefs NotifyWatchOpt = 1 << 2 // if set, the first Notify message (sent immediately) will contain the current Prefs + NotifyInitialNetMap NotifyWatchOpt = 1 << 3 // if set, the first Notify message (sent immediately) will contain the current NetMap - NotifyInitialState // if set, the first Notify message (sent immediately) will contain the current State + BrowseToURL + SessionID - NotifyInitialPrefs // if set, the first Notify message (sent immediately) will contain the current Prefs - NotifyInitialNetMap // if set, the first Notify message (sent immediately) will contain the current NetMap + NotifyNoPrivateKeys NotifyWatchOpt = 1 << 4 // if set, private keys that would normally be sent in updates are zeroed out + NotifyInitialDriveShares NotifyWatchOpt = 1 << 5 // if set, the first Notify message (sent immediately) will contain the current Taildrive Shares + NotifyInitialOutgoingFiles NotifyWatchOpt = 1 << 6 // if set, the first Notify message (sent immediately) will contain the current Taildrop OutgoingFiles - NotifyNoPrivateKeys // if set, private keys that would normally be sent in updates are zeroed out - NotifyInitialDriveShares // if set, the first Notify message (sent immediately) will contain the current Taildrive Shares - NotifyInitialOutgoingFiles // if set, the first Notify message (sent immediately) will contain the current Taildrop OutgoingFiles + NotifyInitialHealthState NotifyWatchOpt = 1 << 7 // if set, the first Notify message (sent immediately) will contain the current health.State of the client - NotifyInitialHealthState // if set, the first Notify message (sent immediately) will contain the current health.State of the client + NotifyRateLimit NotifyWatchOpt = 1 << 8 // if set, rate limit spammy netmap updates to every few seconds ) // Notify is a communication from a backend (e.g. tailscaled) to a frontend @@ -100,7 +108,6 @@ type Notify struct { NetMap *netmap.NetworkMap // if non-nil, the new or current netmap Engine *EngineStatus // if non-nil, the new or current wireguard stats BrowseToURL *string // if non-nil, UI should open a browser right now - BackendLogID *string // if non-nil, the public logtail ID used by backend // FilesWaiting if non-nil means that files are buffered in // the Tailscale daemon and ready for local transfer to the @@ -146,7 +153,7 @@ type Notify struct { // any changes to the user in the UI. Health *health.State `json:",omitempty"` - // type is mirrored in xcode/Shared/IPN.swift + // type is mirrored in xcode/IPN/Core/LocalAPI/Model/LocalAPIModel.swift } func (n Notify) String() string { @@ -173,9 +180,6 @@ func (n Notify) String() string { if n.BrowseToURL != nil { sb.WriteString("URL=<...> ") } - if n.BackendLogID != nil { - sb.WriteString("BackendLogID ") - } if n.FilesWaiting != nil { sb.WriteString("FilesWaiting ") } diff --git a/ipn/conf.go b/ipn/conf.go index 6a67f40040c76..2c9fb2fd15f9e 100644 --- a/ipn/conf.go +++ b/ipn/conf.go @@ -32,6 +32,10 @@ type ConfigVAlpha struct { AdvertiseRoutes []netip.Prefix `json:",omitempty"` DisableSNAT opt.Bool `json:",omitempty"` + AdvertiseServices []string `json:",omitempty"` + + AppConnector *AppConnectorPrefs `json:",omitempty"` // advertise app connector; defaults to false (if nil or explicitly set to false) + NetfilterMode *string `json:",omitempty"` // "on", "off", "nodivert" NoStatefulFiltering opt.Bool `json:",omitempty"` @@ -137,5 +141,19 @@ func (c *ConfigVAlpha) ToPrefs() (MaskedPrefs, error) { mp.AutoUpdate = *c.AutoUpdate mp.AutoUpdateSet = AutoUpdatePrefsMask{ApplySet: true, CheckSet: true} } + if c.AppConnector != nil { + mp.AppConnector = *c.AppConnector + mp.AppConnectorSet = true + } + // Configfile should be the source of truth for whether this node + // advertises any services. We need to ensure that each reload updates + // currently advertised services as else the transition from 'some + // services are advertised' to 'advertised services are empty/unset in + // conffile' would have no effect (especially given that an empty + // service slice would be omitted from the JSON config). + mp.AdvertiseServicesSet = true + if c.AdvertiseServices != nil { + mp.AdvertiseServices = c.AdvertiseServices + } return mp, nil } diff --git a/ipn/desktop/doc.go b/ipn/desktop/doc.go new file mode 100644 index 0000000000000..64a332792a5a4 --- /dev/null +++ b/ipn/desktop/doc.go @@ -0,0 +1,6 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package desktop facilitates interaction with the desktop environment +// and user sessions. As of 2025-02-06, it is only implemented for Windows. +package desktop diff --git a/ipn/desktop/extension.go b/ipn/desktop/extension.go new file mode 100644 index 0000000000000..f204a90dee048 --- /dev/null +++ b/ipn/desktop/extension.go @@ -0,0 +1,189 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Both the desktop session manager and multi-user support +// are currently available only on Windows. +// This file does not need to be built for other platforms. + +//go:build windows && !ts_omit_desktop_sessions + +package desktop + +import ( + "cmp" + "fmt" + "sync" + + "tailscale.com/feature" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnext" + "tailscale.com/types/logger" + "tailscale.com/util/syspolicy" +) + +// featureName is the name of the feature implemented by this package. +// It is also the the [desktopSessionsExt] name and the log prefix. +const featureName = "desktop-sessions" + +func init() { + feature.Register(featureName) + ipnext.RegisterExtension(featureName, newDesktopSessionsExt) +} + +// [desktopSessionsExt] implements [ipnext.Extension]. +var _ ipnext.Extension = (*desktopSessionsExt)(nil) + +// desktopSessionsExt extends [LocalBackend] with desktop session management. +// It keeps Tailscale running in the background if Always-On mode is enabled, +// and switches to an appropriate profile when a user signs in or out, +// locks their screen, or disconnects a remote session. +type desktopSessionsExt struct { + logf logger.Logf + sm SessionManager + + host ipnext.Host // or nil, until Init is called + cleanup []func() // cleanup functions to call on shutdown + + // mu protects all following fields. + mu sync.Mutex + sessByID map[SessionID]*Session +} + +// newDesktopSessionsExt returns a new [desktopSessionsExt], +// or an error if a [SessionManager] cannot be created. +// It is registered with [ipnext.RegisterExtension] if the package is imported. +func newDesktopSessionsExt(logf logger.Logf, _ ipnext.SafeBackend) (ipnext.Extension, error) { + logf = logger.WithPrefix(logf, featureName+": ") + sm, err := NewSessionManager(logf) + if err != nil { + return nil, fmt.Errorf("%w: session manager is not available: %w", ipnext.SkipExtension, err) + } + return &desktopSessionsExt{ + logf: logf, + sm: sm, + sessByID: make(map[SessionID]*Session), + }, nil +} + +// Name implements [ipnext.Extension]. +func (e *desktopSessionsExt) Name() string { + return featureName +} + +// Init implements [ipnext.Extension]. +func (e *desktopSessionsExt) Init(host ipnext.Host) (err error) { + e.host = host + unregisterSessionCb, err := e.sm.RegisterStateCallback(e.updateDesktopSessionState) + if err != nil { + return fmt.Errorf("session callback registration failed: %w", err) + } + host.Hooks().BackgroundProfileResolvers.Add(e.getBackgroundProfile) + e.cleanup = []func(){unregisterSessionCb} + return nil +} + +// updateDesktopSessionState is a [SessionStateCallback] +// invoked by [SessionManager] once for each existing session +// and whenever the session state changes. It updates the session map +// and switches to the best profile if necessary. +func (e *desktopSessionsExt) updateDesktopSessionState(session *Session) { + e.mu.Lock() + if session.Status != ClosedSession { + e.sessByID[session.ID] = session + } else { + delete(e.sessByID, session.ID) + } + e.mu.Unlock() + + var action string + switch session.Status { + case ForegroundSession: + // The user has either signed in or unlocked their session. + // For remote sessions, this may also mean the user has connected. + // The distinction isn't important for our purposes, + // so let's always say "signed in". + action = "signed in to" + case BackgroundSession: + action = "locked" + case ClosedSession: + action = "signed out from" + default: + panic("unreachable") + } + maybeUsername, _ := session.User.Username() + userIdentifier := cmp.Or(maybeUsername, string(session.User.UserID()), "user") + reason := fmt.Sprintf("%s %s session %v", userIdentifier, action, session.ID) + + e.host.Profiles().SwitchToBestProfileAsync(reason) +} + +// getBackgroundProfile is a [ipnext.ProfileResolver] that works as follows: +// +// If Always-On mode is disabled, it returns no profile. +// +// If AlwaysOn mode is enabled, it returns the current profile unless: +// - The current profile's owner has signed out. +// - Another user has a foreground (i.e. active/unlocked) session. +// +// If the current profile owner's session runs in the background and no other user +// has a foreground session, it returns the current profile. This applies +// when a locally signed-in user locks their screen or when a remote user +// disconnects without signing out. +// +// In all other cases, it returns no profile. +func (e *desktopSessionsExt) getBackgroundProfile(profiles ipnext.ProfileStore) ipn.LoginProfileView { + e.mu.Lock() + defer e.mu.Unlock() + + if alwaysOn, _ := syspolicy.GetBoolean(syspolicy.AlwaysOn, false); !alwaysOn { + // If the Always-On mode is disabled, there's no background profile + // as far as the desktop session extension is concerned. + return ipn.LoginProfileView{} + } + + isCurrentProfileOwnerSignedIn := false + var foregroundUIDs []ipn.WindowsUserID + for _, s := range e.sessByID { + switch uid := s.User.UserID(); uid { + case profiles.CurrentProfile().LocalUserID(): + isCurrentProfileOwnerSignedIn = true + if s.Status == ForegroundSession { + // Keep the current profile if the user has a foreground session. + return profiles.CurrentProfile() + } + default: + if s.Status == ForegroundSession { + foregroundUIDs = append(foregroundUIDs, uid) + } + } + } + + // If the current profile is empty and not owned by anyone (e.g., tailscaled just started), + // or if the current profile's owner has no foreground session, switch to the default profile + // of the first user with a foreground session, if any. + for _, uid := range foregroundUIDs { + if profile := profiles.DefaultUserProfile(uid); profile.ID() != "" { + return profile + } + } + + // If no user has a foreground session but the current profile's owner is still signed in, + // keep the current profile even if the session is not in the foreground, + // such as when the screen is locked or a remote session is disconnected. + if len(foregroundUIDs) == 0 && isCurrentProfileOwnerSignedIn { + return profiles.CurrentProfile() + } + + // Otherwise, there's no background profile. + return ipn.LoginProfileView{} +} + +// Shutdown implements [ipnext.Extension]. +func (e *desktopSessionsExt) Shutdown() error { + for _, f := range e.cleanup { + f() + } + e.cleanup = nil + e.host = nil + return e.sm.Close() +} diff --git a/ipn/desktop/mksyscall.go b/ipn/desktop/mksyscall.go new file mode 100644 index 0000000000000..b7af12366b64e --- /dev/null +++ b/ipn/desktop/mksyscall.go @@ -0,0 +1,22 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package desktop + +//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go mksyscall.go +//go:generate go run golang.org/x/tools/cmd/goimports -w zsyscall_windows.go + +//sys setLastError(dwErrorCode uint32) = kernel32.SetLastError + +//sys registerClassEx(windowClass *_WNDCLASSEX) (atom uint16, err error) [atom==0] = user32.RegisterClassExW +//sys createWindowEx(dwExStyle uint32, lpClassName *uint16, lpWindowName *uint16, dwStyle uint32, x int32, y int32, nWidth int32, nHeight int32, hWndParent windows.HWND, hMenu windows.Handle, hInstance windows.Handle, lpParam unsafe.Pointer) (hWnd windows.HWND, err error) [hWnd==0] = user32.CreateWindowExW +//sys defWindowProc(hwnd windows.HWND, msg uint32, wparam uintptr, lparam uintptr) (res uintptr) = user32.DefWindowProcW +//sys sendMessage(hwnd windows.HWND, msg uint32, wparam uintptr, lparam uintptr) (res uintptr) = user32.SendMessageW +//sys getMessage(lpMsg *_MSG, hwnd windows.HWND, msgMin uint32, msgMax uint32) (ret int32) = user32.GetMessageW +//sys translateMessage(lpMsg *_MSG) (res bool) = user32.TranslateMessage +//sys dispatchMessage(lpMsg *_MSG) (res uintptr) = user32.DispatchMessageW +//sys destroyWindow(hwnd windows.HWND) (err error) [int32(failretval)==0] = user32.DestroyWindow +//sys postQuitMessage(exitCode int32) = user32.PostQuitMessage + +//sys registerSessionNotification(hServer windows.Handle, hwnd windows.HWND, flags uint32) (err error) [int32(failretval)==0] = wtsapi32.WTSRegisterSessionNotificationEx +//sys unregisterSessionNotification(hServer windows.Handle, hwnd windows.HWND) (err error) [int32(failretval)==0] = wtsapi32.WTSUnRegisterSessionNotificationEx diff --git a/ipn/desktop/session.go b/ipn/desktop/session.go new file mode 100644 index 0000000000000..c95378914321d --- /dev/null +++ b/ipn/desktop/session.go @@ -0,0 +1,58 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package desktop + +import ( + "fmt" + + "tailscale.com/ipn/ipnauth" +) + +// SessionID is a unique identifier of a desktop session. +type SessionID uint + +// SessionStatus is the status of a desktop session. +type SessionStatus int + +const ( + // ClosedSession is a session that does not exist, is not yet initialized by the OS, + // or has been terminated. + ClosedSession SessionStatus = iota + // ForegroundSession is a session that a user can interact with, + // such as when attached to a physical console or an active, + // unlocked RDP connection. + ForegroundSession + // BackgroundSession indicates that the session is locked, disconnected, + // or otherwise running without user presence or interaction. + BackgroundSession +) + +// String implements [fmt.Stringer]. +func (s SessionStatus) String() string { + switch s { + case ClosedSession: + return "Closed" + case ForegroundSession: + return "Foreground" + case BackgroundSession: + return "Background" + default: + panic("unreachable") + } +} + +// Session is a state of a desktop session at a given point in time. +type Session struct { + ID SessionID // Identifier of the session; can be reused after the session is closed. + Status SessionStatus // The status of the session, such as foreground or background. + User ipnauth.Actor // User logged into the session. +} + +// Description returns a human-readable description of the session. +func (s *Session) Description() string { + if maybeUsername, _ := s.User.Username(); maybeUsername != "" { // best effort + return fmt.Sprintf("Session %d - %q (%s)", s.ID, maybeUsername, s.Status) + } + return fmt.Sprintf("Session %d (%s)", s.ID, s.Status) +} diff --git a/ipn/desktop/sessions.go b/ipn/desktop/sessions.go new file mode 100644 index 0000000000000..8bf7a75e2dc3a --- /dev/null +++ b/ipn/desktop/sessions.go @@ -0,0 +1,60 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package desktop + +import ( + "errors" + "runtime" +) + +// ErrNotImplemented is returned by [NewSessionManager] when it is not +// implemented for the current GOOS. +var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) + +// SessionInitCallback is a function that is called once per [Session]. +// It returns an optional cleanup function that is called when the session +// is about to be destroyed, or nil if no cleanup is needed. +// It is not safe to call SessionManager methods from within the callback. +type SessionInitCallback func(session *Session) (cleanup func()) + +// SessionStateCallback is a function that reports the initial or updated +// state of a [Session], such as when it transitions between foreground and background. +// It is guaranteed to be called after all registered [SessionInitCallback] functions +// have completed, and before any cleanup functions are called for the same session. +// It is not safe to call SessionManager methods from within the callback. +type SessionStateCallback func(session *Session) + +// SessionManager is an interface that provides access to desktop sessions on the current platform. +// It is safe for concurrent use. +type SessionManager interface { + // Init explicitly initializes the receiver. + // Unless the receiver is explicitly initialized, it will be lazily initialized + // on the first call to any other method. + // It is safe to call Init multiple times. + Init() error + + // Sessions returns a session snapshot taken at the time of the call. + // Since sessions can be created or destroyed at any time, it may become + // outdated as soon as it is returned. + // + // It is primarily intended for logging and debugging. + // Prefer registering a [SessionInitCallback] or [SessionStateCallback] + // in contexts requiring stronger guarantees. + Sessions() (map[SessionID]*Session, error) + + // RegisterInitCallback registers a [SessionInitCallback] that is called for each existing session + // and for each new session that is created, until the returned unregister function is called. + // If the specified [SessionInitCallback] returns a cleanup function, it is called when the session + // is about to be destroyed. The callback function is guaranteed to be called once and only once + // for each existing and new session. + RegisterInitCallback(cb SessionInitCallback) (unregister func(), err error) + + // RegisterStateCallback registers a [SessionStateCallback] that is called for each existing session + // and every time the state of a session changes, until the returned unregister function is called. + RegisterStateCallback(cb SessionStateCallback) (unregister func(), err error) + + // Close waits for all registered callbacks to complete + // and releases resources associated with the receiver. + Close() error +} diff --git a/ipn/desktop/sessions_notwindows.go b/ipn/desktop/sessions_notwindows.go new file mode 100644 index 0000000000000..da3230a456480 --- /dev/null +++ b/ipn/desktop/sessions_notwindows.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !windows + +package desktop + +import "tailscale.com/types/logger" + +// NewSessionManager returns a new [SessionManager] for the current platform, +// [ErrNotImplemented] if the platform is not supported, or an error if the +// session manager could not be created. +func NewSessionManager(logger.Logf) (SessionManager, error) { + return nil, ErrNotImplemented +} diff --git a/ipn/desktop/sessions_windows.go b/ipn/desktop/sessions_windows.go new file mode 100644 index 0000000000000..83b884228c1f8 --- /dev/null +++ b/ipn/desktop/sessions_windows.go @@ -0,0 +1,707 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package desktop + +import ( + "context" + "errors" + "fmt" + "runtime" + "sync" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" + "tailscale.com/ipn/ipnauth" + "tailscale.com/types/logger" + "tailscale.com/util/must" + "tailscale.com/util/set" +) + +// wtsManager is a [SessionManager] implementation for Windows. +type wtsManager struct { + logf logger.Logf + ctx context.Context // cancelled when the manager is closed + ctxCancel context.CancelFunc + + initOnce func() error + watcher *sessionWatcher + + mu sync.Mutex + sessions map[SessionID]*wtsSession + initCbs set.HandleSet[SessionInitCallback] + stateCbs set.HandleSet[SessionStateCallback] +} + +// NewSessionManager returns a new [SessionManager] for the current platform, +func NewSessionManager(logf logger.Logf) (SessionManager, error) { + ctx, ctxCancel := context.WithCancel(context.Background()) + m := &wtsManager{ + logf: logf, + ctx: ctx, + ctxCancel: ctxCancel, + sessions: make(map[SessionID]*wtsSession), + } + m.watcher = newSessionWatcher(m.ctx, m.logf, m.sessionEventHandler) + + m.initOnce = sync.OnceValue(func() error { + if err := waitUntilWTSReady(m.ctx); err != nil { + return fmt.Errorf("WTS is not ready: %w", err) + } + + m.mu.Lock() + defer m.mu.Unlock() + if err := m.watcher.Start(); err != nil { + return fmt.Errorf("failed to start session watcher: %w", err) + } + + var err error + m.sessions, err = enumerateSessions() + return err // may be nil or non-nil + }) + return m, nil +} + +// Init implements [SessionManager]. +func (m *wtsManager) Init() error { + return m.initOnce() +} + +// Sessions implements [SessionManager]. +func (m *wtsManager) Sessions() (map[SessionID]*Session, error) { + if err := m.initOnce(); err != nil { + return nil, err + } + + m.mu.Lock() + defer m.mu.Unlock() + sessions := make(map[SessionID]*Session, len(m.sessions)) + for _, s := range m.sessions { + sessions[s.id] = s.AsSession() + } + return sessions, nil +} + +// RegisterInitCallback implements [SessionManager]. +func (m *wtsManager) RegisterInitCallback(cb SessionInitCallback) (unregister func(), err error) { + if err := m.initOnce(); err != nil { + return nil, err + } + if cb == nil { + return nil, errors.New("nil callback") + } + + m.mu.Lock() + defer m.mu.Unlock() + handle := m.initCbs.Add(cb) + + // TODO(nickkhyl): enqueue callbacks in a separate goroutine? + for _, s := range m.sessions { + if cleanup := cb(s.AsSession()); cleanup != nil { + s.cleanup = append(s.cleanup, cleanup) + } + } + + return func() { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.initCbs, handle) + }, nil +} + +// RegisterStateCallback implements [SessionManager]. +func (m *wtsManager) RegisterStateCallback(cb SessionStateCallback) (unregister func(), err error) { + if err := m.initOnce(); err != nil { + return nil, err + } + if cb == nil { + return nil, errors.New("nil callback") + } + + m.mu.Lock() + defer m.mu.Unlock() + handle := m.stateCbs.Add(cb) + + // TODO(nickkhyl): enqueue callbacks in a separate goroutine? + for _, s := range m.sessions { + cb(s.AsSession()) + } + + return func() { + m.mu.Lock() + defer m.mu.Unlock() + delete(m.stateCbs, handle) + }, nil +} + +func (m *wtsManager) sessionEventHandler(id SessionID, event uint32) { + m.mu.Lock() + defer m.mu.Unlock() + switch event { + case windows.WTS_SESSION_LOGON: + // The session may have been created after we started watching, + // but before the initial enumeration was performed. + // Do not create a new session if it already exists. + if _, _, err := m.getOrCreateSessionLocked(id); err != nil { + m.logf("[unexpected] getOrCreateSessionLocked(%d): %v", id, err) + } + case windows.WTS_SESSION_LOCK: + if err := m.setSessionStatusLocked(id, BackgroundSession); err != nil { + m.logf("[unexpected] setSessionStatusLocked(%d, BackgroundSession): %v", id, err) + } + case windows.WTS_SESSION_UNLOCK: + if err := m.setSessionStatusLocked(id, ForegroundSession); err != nil { + m.logf("[unexpected] setSessionStatusLocked(%d, ForegroundSession): %v", id, err) + } + case windows.WTS_SESSION_LOGOFF: + if err := m.deleteSessionLocked(id); err != nil { + m.logf("[unexpected] deleteSessionLocked(%d): %v", id, err) + } + } +} + +func (m *wtsManager) getOrCreateSessionLocked(id SessionID) (_ *wtsSession, created bool, err error) { + if s, ok := m.sessions[id]; ok { + return s, false, nil + } + + s, err := newWTSSession(id, ForegroundSession) + if err != nil { + return nil, false, err + } + m.sessions[id] = s + + session := s.AsSession() + // TODO(nickkhyl): enqueue callbacks in a separate goroutine? + for _, cb := range m.initCbs { + if cleanup := cb(session); cleanup != nil { + s.cleanup = append(s.cleanup, cleanup) + } + } + for _, cb := range m.stateCbs { + cb(session) + } + + return s, true, err +} + +func (m *wtsManager) setSessionStatusLocked(id SessionID, status SessionStatus) error { + s, _, err := m.getOrCreateSessionLocked(id) + if err != nil { + return err + } + if s.status == status { + return nil + } + + s.status = status + session := s.AsSession() + // TODO(nickkhyl): enqueue callbacks in a separate goroutine? + for _, cb := range m.stateCbs { + cb(session) + } + return nil +} + +func (m *wtsManager) deleteSessionLocked(id SessionID) error { + s, ok := m.sessions[id] + if !ok { + return nil + } + + s.status = ClosedSession + session := s.AsSession() + // TODO(nickkhyl): enqueue callbacks (and [wtsSession.close]!) in a separate goroutine? + for _, cb := range m.stateCbs { + cb(session) + } + + delete(m.sessions, id) + return s.close() +} + +func (m *wtsManager) Close() error { + m.ctxCancel() + + if m.watcher != nil { + err := m.watcher.Stop() + if err != nil { + return err + } + m.watcher = nil + } + + m.mu.Lock() + defer m.mu.Unlock() + m.initCbs = nil + m.stateCbs = nil + errs := make([]error, 0, len(m.sessions)) + for _, s := range m.sessions { + errs = append(errs, s.close()) + } + m.sessions = nil + return errors.Join(errs...) +} + +type wtsSession struct { + id SessionID + user *ipnauth.WindowsActor + + status SessionStatus + + cleanup []func() +} + +func newWTSSession(id SessionID, status SessionStatus) (*wtsSession, error) { + var token windows.Token + if err := windows.WTSQueryUserToken(uint32(id), &token); err != nil { + return nil, err + } + user, err := ipnauth.NewWindowsActorWithToken(token) + if err != nil { + return nil, err + } + return &wtsSession{id, user, status, nil}, nil +} + +// enumerateSessions returns a map of all active WTS sessions. +func enumerateSessions() (map[SessionID]*wtsSession, error) { + const reserved, version uint32 = 0, 1 + var numSessions uint32 + var sessionInfos *windows.WTS_SESSION_INFO + if err := windows.WTSEnumerateSessions(_WTS_CURRENT_SERVER_HANDLE, reserved, version, &sessionInfos, &numSessions); err != nil { + return nil, fmt.Errorf("WTSEnumerateSessions failed: %w", err) + } + defer windows.WTSFreeMemory(uintptr(unsafe.Pointer(sessionInfos))) + + sessions := make(map[SessionID]*wtsSession, numSessions) + for _, si := range unsafe.Slice(sessionInfos, numSessions) { + status := _WTS_CONNECTSTATE_CLASS(si.State).ToSessionStatus() + if status == ClosedSession { + // The session does not exist as far as we're concerned. + // It may be in the process of being created or destroyed, + // or be a special "listener" session, etc. + continue + } + id := SessionID(si.SessionID) + session, err := newWTSSession(id, status) + if err != nil { + continue + } + sessions[id] = session + } + return sessions, nil +} + +func (s *wtsSession) AsSession() *Session { + return &Session{ + ID: s.id, + Status: s.status, + // wtsSession owns the user; don't let the caller close it + User: ipnauth.WithoutClose(s.user), + } +} + +func (m *wtsSession) close() error { + for _, cleanup := range m.cleanup { + cleanup() + } + m.cleanup = nil + + if m.user != nil { + if err := m.user.Close(); err != nil { + return err + } + m.user = nil + } + return nil +} + +type sessionEventHandler func(id SessionID, event uint32) + +// TODO(nickkhyl): implement a sessionWatcher that does not use the message queue. +// One possible approach is to have the tailscaled service register a HandlerEx function +// and stream SERVICE_CONTROL_SESSIONCHANGE events to the tailscaled subprocess +// (the actual tailscaled backend), exposing these events via [sessionWatcher]/[wtsManager]. +// +// See tailscale/corp#26477 for details and tracking. +type sessionWatcher struct { + logf logger.Logf + ctx context.Context // canceled to stop the watcher + ctxCancel context.CancelFunc // cancels the watcher + hWnd windows.HWND // window handle for receiving session change notifications + handler sessionEventHandler // called on session events + + mu sync.Mutex + doneCh chan error // written to when the watcher exits; nil if not started +} + +func newSessionWatcher(ctx context.Context, logf logger.Logf, handler sessionEventHandler) *sessionWatcher { + ctx, cancel := context.WithCancel(ctx) + return &sessionWatcher{logf: logf, ctx: ctx, ctxCancel: cancel, handler: handler} +} + +func (sw *sessionWatcher) Start() error { + sw.mu.Lock() + defer sw.mu.Unlock() + + select { + case <-sw.ctx.Done(): + return fmt.Errorf("sessionWatcher already stopped: %w", sw.ctx.Err()) + default: + } + + if sw.doneCh != nil { + // Already started. + return nil + } + sw.doneCh = make(chan error, 1) + + startedCh := make(chan error, 1) + go sw.run(startedCh, sw.doneCh) + if err := <-startedCh; err != nil { + return err + } + + // Signal the window to unsubscribe from session notifications + // and shut down gracefully when the sessionWatcher is stopped. + context.AfterFunc(sw.ctx, func() { + sendMessage(sw.hWnd, _WM_CLOSE, 0, 0) + }) + return nil +} + +func (sw *sessionWatcher) run(started, done chan<- error) { + runtime.LockOSThread() + defer func() { + runtime.UnlockOSThread() + close(done) + }() + err := sw.createMessageWindow() + started <- err + if err != nil { + return + } + pumpThreadMessages() +} + +// Stop stops the session watcher and waits for it to exit. +func (sw *sessionWatcher) Stop() error { + sw.ctxCancel() + + sw.mu.Lock() + doneCh := sw.doneCh + sw.doneCh = nil + sw.mu.Unlock() + + if doneCh != nil { + return <-doneCh + } + return nil +} + +const watcherWindowClassName = "Tailscale-SessionManager" + +var watcherWindowClassName16 = sync.OnceValue(func() *uint16 { + return must.Get(syscall.UTF16PtrFromString(watcherWindowClassName)) +}) + +var registerSessionManagerWindowClass = sync.OnceValue(func() error { + var hInst windows.Handle + if err := windows.GetModuleHandleEx(0, nil, &hInst); err != nil { + return fmt.Errorf("GetModuleHandle: %w", err) + } + wc := _WNDCLASSEX{ + CbSize: uint32(unsafe.Sizeof(_WNDCLASSEX{})), + HInstance: hInst, + LpfnWndProc: syscall.NewCallback(sessionWatcherWndProc), + LpszClassName: watcherWindowClassName16(), + } + if _, err := registerClassEx(&wc); err != nil { + return fmt.Errorf("RegisterClassEx(%q): %w", watcherWindowClassName, err) + } + return nil +}) + +func (sw *sessionWatcher) createMessageWindow() error { + if err := registerSessionManagerWindowClass(); err != nil { + return err + } + _, err := createWindowEx( + 0, // dwExStyle + watcherWindowClassName16(), // lpClassName + nil, // lpWindowName + 0, // dwStyle + 0, // x + 0, // y + 0, // nWidth + 0, // nHeight + _HWND_MESSAGE, // hWndParent; message-only window + 0, // hMenu + 0, // hInstance + unsafe.Pointer(sw), // lpParam + ) + if err != nil { + return fmt.Errorf("CreateWindowEx: %w", err) + } + return nil +} + +func (sw *sessionWatcher) wndProc(hWnd windows.HWND, msg uint32, wParam, lParam uintptr) (result uintptr) { + switch msg { + case _WM_CREATE: + err := registerSessionNotification(_WTS_CURRENT_SERVER_HANDLE, hWnd, _NOTIFY_FOR_ALL_SESSIONS) + if err != nil { + sw.logf("[unexpected] failed to register for session notifications: %v", err) + return ^uintptr(0) + } + sw.logf("registered for session notifications") + case _WM_WTSSESSION_CHANGE: + sw.handler(SessionID(lParam), uint32(wParam)) + return 0 + case _WM_CLOSE: + if err := destroyWindow(hWnd); err != nil { + sw.logf("[unexpected] failed to destroy window: %v", err) + } + return 0 + case _WM_DESTROY: + err := unregisterSessionNotification(_WTS_CURRENT_SERVER_HANDLE, hWnd) + if err != nil { + sw.logf("[unexpected] failed to unregister session notifications callback: %v", err) + } + sw.logf("unregistered from session notifications") + return 0 + case _WM_NCDESTROY: + sw.hWnd = 0 + postQuitMessage(0) // quit the message loop for this thread + } + return defWindowProc(hWnd, msg, wParam, lParam) +} + +func (sw *sessionWatcher) setHandle(hwnd windows.HWND) error { + sw.hWnd = hwnd + setLastError(0) + _, err := setWindowLongPtr(sw.hWnd, _GWLP_USERDATA, uintptr(unsafe.Pointer(sw))) + return err // may be nil or non-nil +} + +func sessionWatcherByHandle(hwnd windows.HWND) *sessionWatcher { + val, _ := getWindowLongPtr(hwnd, _GWLP_USERDATA) + return (*sessionWatcher)(unsafe.Pointer(val)) +} + +func sessionWatcherWndProc(hWnd windows.HWND, msg uint32, wParam, lParam uintptr) (result uintptr) { + if msg == _WM_NCCREATE { + cs := (*_CREATESTRUCT)(unsafe.Pointer(lParam)) + sw := (*sessionWatcher)(unsafe.Pointer(cs.CreateParams)) + if sw == nil { + return 0 + } + if err := sw.setHandle(hWnd); err != nil { + return 0 + } + return defWindowProc(hWnd, msg, wParam, lParam) + } + if sw := sessionWatcherByHandle(hWnd); sw != nil { + return sw.wndProc(hWnd, msg, wParam, lParam) + } + return defWindowProc(hWnd, msg, wParam, lParam) +} + +func pumpThreadMessages() { + var msg _MSG + for getMessage(&msg, 0, 0, 0) != 0 { + translateMessage(&msg) + dispatchMessage(&msg) + } +} + +// waitUntilWTSReady waits until the Windows Terminal Services (WTS) is ready. +// This is necessary because the WTS API functions may fail if called before +// the WTS is ready. +// +// https://web.archive.org/web/20250207011738/https://learn.microsoft.com/en-us/windows/win32/api/wtsapi32/nf-wtsapi32-wtsregistersessionnotificationex +func waitUntilWTSReady(ctx context.Context) error { + eventName16, err := windows.UTF16PtrFromString(`Global\TermSrvReadyEvent`) + if err != nil { + return err + } + event, err := windows.OpenEvent(windows.SYNCHRONIZE, false, eventName16) + if err != nil { + return err + } + defer windows.CloseHandle(event) + return waitForContextOrHandle(ctx, event) +} + +// waitForContextOrHandle waits for either the context to be done or a handle to be signaled. +func waitForContextOrHandle(ctx context.Context, handle windows.Handle) error { + contextDoneEvent, cleanup, err := channelToEvent(ctx.Done()) + if err != nil { + return err + } + defer cleanup() + + handles := []windows.Handle{contextDoneEvent, handle} + waitCode, err := windows.WaitForMultipleObjects(handles, false, windows.INFINITE) + if err != nil { + return err + } + + waitCode -= windows.WAIT_OBJECT_0 + if waitCode == 0 { // contextDoneEvent + return ctx.Err() + } + return nil +} + +// channelToEvent returns an auto-reset event that is set when the channel +// becomes receivable, including when the channel is closed. +func channelToEvent[T any](c <-chan T) (evt windows.Handle, cleanup func(), err error) { + evt, err = windows.CreateEvent(nil, 0, 0, nil) + if err != nil { + return 0, nil, err + } + + cancel := make(chan struct{}) + + go func() { + select { + case <-cancel: + return + case <-c: + } + windows.SetEvent(evt) + }() + + cleanup = func() { + close(cancel) + windows.CloseHandle(evt) + } + + return evt, cleanup, nil +} + +type _WNDCLASSEX struct { + CbSize uint32 + Style uint32 + LpfnWndProc uintptr + CbClsExtra int32 + CbWndExtra int32 + HInstance windows.Handle + HIcon windows.Handle + HCursor windows.Handle + HbrBackground windows.Handle + LpszMenuName *uint16 + LpszClassName *uint16 + HIconSm windows.Handle +} + +type _CREATESTRUCT struct { + CreateParams uintptr + Instance windows.Handle + Menu windows.Handle + Parent windows.HWND + Cy int32 + Cx int32 + Y int32 + X int32 + Style int32 + Name *uint16 + ClassName *uint16 + ExStyle uint32 +} + +type _POINT struct { + X, Y int32 +} + +type _MSG struct { + HWnd windows.HWND + Message uint32 + WParam uintptr + LParam uintptr + Time uint32 + Pt _POINT +} + +const ( + _WM_CREATE = 1 + _WM_DESTROY = 2 + _WM_CLOSE = 16 + _WM_NCCREATE = 129 + _WM_QUIT = 18 + _WM_NCDESTROY = 130 + + // _WM_WTSSESSION_CHANGE is a message sent to windows that have registered + // for session change notifications, informing them of changes in session state. + // + // https://web.archive.org/web/20250207012421/https://learn.microsoft.com/en-us/windows/win32/termserv/wm-wtssession-change + _WM_WTSSESSION_CHANGE = 0x02B1 +) + +const _GWLP_USERDATA = -21 + +const _HWND_MESSAGE = ^windows.HWND(2) + +// _NOTIFY_FOR_ALL_SESSIONS indicates that the window should receive +// session change notifications for all sessions on the specified server. +const _NOTIFY_FOR_ALL_SESSIONS = 1 + +// _WTS_CURRENT_SERVER_HANDLE indicates that the window should receive +// session change notifications for the host itself rather than a remote server. +const _WTS_CURRENT_SERVER_HANDLE = windows.Handle(0) + +// _WTS_CONNECTSTATE_CLASS represents the connection state of a session. +// +// https://web.archive.org/web/20250206082427/https://learn.microsoft.com/en-us/windows/win32/api/wtsapi32/ne-wtsapi32-wts_connectstate_class +type _WTS_CONNECTSTATE_CLASS int32 + +// ToSessionStatus converts cs to a [SessionStatus]. +func (cs _WTS_CONNECTSTATE_CLASS) ToSessionStatus() SessionStatus { + switch cs { + case windows.WTSActive: + return ForegroundSession + case windows.WTSDisconnected: + return BackgroundSession + default: + // The session does not exist as far as we're concerned. + return ClosedSession + } +} + +var ( + procGetWindowLongPtrW *windows.LazyProc + procSetWindowLongPtrW *windows.LazyProc +) + +func init() { + // GetWindowLongPtrW and SetWindowLongPtrW are only available on 64-bit platforms. + // https://web.archive.org/web/20250414195520/https://learn.microsoft.com/en-us/windows/win32/api/winuser/nf-winuser-getwindowlongptrw + if runtime.GOARCH == "386" || runtime.GOARCH == "arm" { + procGetWindowLongPtrW = moduser32.NewProc("GetWindowLongW") + procSetWindowLongPtrW = moduser32.NewProc("SetWindowLongW") + } else { + procGetWindowLongPtrW = moduser32.NewProc("GetWindowLongPtrW") + procSetWindowLongPtrW = moduser32.NewProc("SetWindowLongPtrW") + } +} + +func getWindowLongPtr(hwnd windows.HWND, index int32) (res uintptr, err error) { + r0, _, e1 := syscall.Syscall(procGetWindowLongPtrW.Addr(), 2, uintptr(hwnd), uintptr(index), 0) + res = uintptr(r0) + if res == 0 && e1 != 0 { + err = errnoErr(e1) + } + return +} + +func setWindowLongPtr(hwnd windows.HWND, index int32, newLong uintptr) (res uintptr, err error) { + r0, _, e1 := syscall.Syscall(procSetWindowLongPtrW.Addr(), 3, uintptr(hwnd), uintptr(index), uintptr(newLong)) + res = uintptr(r0) + if res == 0 && e1 != 0 { + err = errnoErr(e1) + } + return +} diff --git a/ipn/desktop/zsyscall_windows.go b/ipn/desktop/zsyscall_windows.go new file mode 100644 index 0000000000000..535274016f9ca --- /dev/null +++ b/ipn/desktop/zsyscall_windows.go @@ -0,0 +1,139 @@ +// Code generated by 'go generate'; DO NOT EDIT. + +package desktop + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +var _ unsafe.Pointer + +// Do the interface allocations only once for common +// Errno values. +const ( + errnoERROR_IO_PENDING = 997 +) + +var ( + errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL +) + +// errnoErr returns common boxed Errno values, to prevent +// allocations at runtime. +func errnoErr(e syscall.Errno) error { + switch e { + case 0: + return errERROR_EINVAL + case errnoERROR_IO_PENDING: + return errERROR_IO_PENDING + } + // TODO: add more here, after collecting data on the common + // error values see on Windows. (perhaps when running + // all.bat?) + return e +} + +var ( + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + moduser32 = windows.NewLazySystemDLL("user32.dll") + modwtsapi32 = windows.NewLazySystemDLL("wtsapi32.dll") + + procSetLastError = modkernel32.NewProc("SetLastError") + procCreateWindowExW = moduser32.NewProc("CreateWindowExW") + procDefWindowProcW = moduser32.NewProc("DefWindowProcW") + procDestroyWindow = moduser32.NewProc("DestroyWindow") + procDispatchMessageW = moduser32.NewProc("DispatchMessageW") + procGetMessageW = moduser32.NewProc("GetMessageW") + procPostQuitMessage = moduser32.NewProc("PostQuitMessage") + procRegisterClassExW = moduser32.NewProc("RegisterClassExW") + procSendMessageW = moduser32.NewProc("SendMessageW") + procTranslateMessage = moduser32.NewProc("TranslateMessage") + procWTSRegisterSessionNotificationEx = modwtsapi32.NewProc("WTSRegisterSessionNotificationEx") + procWTSUnRegisterSessionNotificationEx = modwtsapi32.NewProc("WTSUnRegisterSessionNotificationEx") +) + +func setLastError(dwErrorCode uint32) { + syscall.Syscall(procSetLastError.Addr(), 1, uintptr(dwErrorCode), 0, 0) + return +} + +func createWindowEx(dwExStyle uint32, lpClassName *uint16, lpWindowName *uint16, dwStyle uint32, x int32, y int32, nWidth int32, nHeight int32, hWndParent windows.HWND, hMenu windows.Handle, hInstance windows.Handle, lpParam unsafe.Pointer) (hWnd windows.HWND, err error) { + r0, _, e1 := syscall.Syscall12(procCreateWindowExW.Addr(), 12, uintptr(dwExStyle), uintptr(unsafe.Pointer(lpClassName)), uintptr(unsafe.Pointer(lpWindowName)), uintptr(dwStyle), uintptr(x), uintptr(y), uintptr(nWidth), uintptr(nHeight), uintptr(hWndParent), uintptr(hMenu), uintptr(hInstance), uintptr(lpParam)) + hWnd = windows.HWND(r0) + if hWnd == 0 { + err = errnoErr(e1) + } + return +} + +func defWindowProc(hwnd windows.HWND, msg uint32, wparam uintptr, lparam uintptr) (res uintptr) { + r0, _, _ := syscall.Syscall6(procDefWindowProcW.Addr(), 4, uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam), 0, 0) + res = uintptr(r0) + return +} + +func destroyWindow(hwnd windows.HWND) (err error) { + r1, _, e1 := syscall.Syscall(procDestroyWindow.Addr(), 1, uintptr(hwnd), 0, 0) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} + +func dispatchMessage(lpMsg *_MSG) (res uintptr) { + r0, _, _ := syscall.Syscall(procDispatchMessageW.Addr(), 1, uintptr(unsafe.Pointer(lpMsg)), 0, 0) + res = uintptr(r0) + return +} + +func getMessage(lpMsg *_MSG, hwnd windows.HWND, msgMin uint32, msgMax uint32) (ret int32) { + r0, _, _ := syscall.Syscall6(procGetMessageW.Addr(), 4, uintptr(unsafe.Pointer(lpMsg)), uintptr(hwnd), uintptr(msgMin), uintptr(msgMax), 0, 0) + ret = int32(r0) + return +} + +func postQuitMessage(exitCode int32) { + syscall.Syscall(procPostQuitMessage.Addr(), 1, uintptr(exitCode), 0, 0) + return +} + +func registerClassEx(windowClass *_WNDCLASSEX) (atom uint16, err error) { + r0, _, e1 := syscall.Syscall(procRegisterClassExW.Addr(), 1, uintptr(unsafe.Pointer(windowClass)), 0, 0) + atom = uint16(r0) + if atom == 0 { + err = errnoErr(e1) + } + return +} + +func sendMessage(hwnd windows.HWND, msg uint32, wparam uintptr, lparam uintptr) (res uintptr) { + r0, _, _ := syscall.Syscall6(procSendMessageW.Addr(), 4, uintptr(hwnd), uintptr(msg), uintptr(wparam), uintptr(lparam), 0, 0) + res = uintptr(r0) + return +} + +func translateMessage(lpMsg *_MSG) (res bool) { + r0, _, _ := syscall.Syscall(procTranslateMessage.Addr(), 1, uintptr(unsafe.Pointer(lpMsg)), 0, 0) + res = r0 != 0 + return +} + +func registerSessionNotification(hServer windows.Handle, hwnd windows.HWND, flags uint32) (err error) { + r1, _, e1 := syscall.Syscall(procWTSRegisterSessionNotificationEx.Addr(), 3, uintptr(hServer), uintptr(hwnd), uintptr(flags)) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} + +func unregisterSessionNotification(hServer windows.Handle, hwnd windows.HWND) (err error) { + r1, _, e1 := syscall.Syscall(procWTSUnRegisterSessionNotificationEx.Addr(), 2, uintptr(hServer), uintptr(hwnd), 0) + if int32(r1) == 0 { + err = errnoErr(e1) + } + return +} diff --git a/ipn/doc.go b/ipn/doc.go index 4b3810be1f734..c98c7e8b3599f 100644 --- a/ipn/doc.go +++ b/ipn/doc.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:generate go run tailscale.com/cmd/viewer -type=Prefs,ServeConfig,TCPPortHandler,HTTPHandler,WebServerConfig +//go:generate go run tailscale.com/cmd/viewer -type=LoginProfile,Prefs,ServeConfig,ServiceConfig,TCPPortHandler,HTTPHandler,WebServerConfig // Package ipn implements the interactions between the Tailscale cloud // control plane and the local network stack. diff --git a/ipn/ipn_clone.go b/ipn/ipn_clone.go index de35b60a7927d..65438444e162f 100644 --- a/ipn/ipn_clone.go +++ b/ipn/ipn_clone.go @@ -17,6 +17,29 @@ import ( "tailscale.com/types/ptr" ) +// Clone makes a deep copy of LoginProfile. +// The result aliases no memory with the original. +func (src *LoginProfile) Clone() *LoginProfile { + if src == nil { + return nil + } + dst := new(LoginProfile) + *dst = *src + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _LoginProfileCloneNeedsRegeneration = LoginProfile(struct { + ID ProfileID + Name string + NetworkProfile NetworkProfile + Key StateKey + UserProfile tailcfg.UserProfile + NodeID tailcfg.StableNodeID + LocalUserID WindowsUserID + ControlURL string +}{}) + // Clone makes a deep copy of Prefs. // The result aliases no memory with the original. func (src *Prefs) Clone() *Prefs { @@ -27,6 +50,7 @@ func (src *Prefs) Clone() *Prefs { *dst = *src dst.AdvertiseTags = append(src.AdvertiseTags[:0:0], src.AdvertiseTags...) dst.AdvertiseRoutes = append(src.AdvertiseRoutes[:0:0], src.AdvertiseRoutes...) + dst.AdvertiseServices = append(src.AdvertiseServices[:0:0], src.AdvertiseServices...) if src.DriveShares != nil { dst.DriveShares = make([]*drive.Share, len(src.DriveShares)) for i := range dst.DriveShares { @@ -37,6 +61,9 @@ func (src *Prefs) Clone() *Prefs { } } } + if dst.RelayServerPort != nil { + dst.RelayServerPort = ptr.To(*src.RelayServerPort) + } dst.Persist = src.Persist.Clone() return dst } @@ -61,6 +88,7 @@ var _PrefsCloneNeedsRegeneration = Prefs(struct { ForceDaemon bool Egg bool AdvertiseRoutes []netip.Prefix + AdvertiseServices []string NoSNAT bool NoStatefulFiltering opt.Bool NetfilterMode preftype.NetfilterMode @@ -71,6 +99,7 @@ var _PrefsCloneNeedsRegeneration = Prefs(struct { PostureChecking bool NetfilterKind string DriveShares []*drive.Share + RelayServerPort *int AllowSingleHosts marshalAsTrueInJSON Persist *persist.Persist }{}) @@ -103,6 +132,16 @@ func (src *ServeConfig) Clone() *ServeConfig { } } } + if dst.Services != nil { + dst.Services = map[tailcfg.ServiceName]*ServiceConfig{} + for k, v := range src.Services { + if v == nil { + dst.Services[k] = nil + } else { + dst.Services[k] = v.Clone() + } + } + } dst.AllowFunnel = maps.Clone(src.AllowFunnel) if dst.Foreground != nil { dst.Foreground = map[string]*ServeConfig{} @@ -121,11 +160,50 @@ func (src *ServeConfig) Clone() *ServeConfig { var _ServeConfigCloneNeedsRegeneration = ServeConfig(struct { TCP map[uint16]*TCPPortHandler Web map[HostPort]*WebServerConfig + Services map[tailcfg.ServiceName]*ServiceConfig AllowFunnel map[HostPort]bool Foreground map[string]*ServeConfig ETag string }{}) +// Clone makes a deep copy of ServiceConfig. +// The result aliases no memory with the original. +func (src *ServiceConfig) Clone() *ServiceConfig { + if src == nil { + return nil + } + dst := new(ServiceConfig) + *dst = *src + if dst.TCP != nil { + dst.TCP = map[uint16]*TCPPortHandler{} + for k, v := range src.TCP { + if v == nil { + dst.TCP[k] = nil + } else { + dst.TCP[k] = ptr.To(*v) + } + } + } + if dst.Web != nil { + dst.Web = map[HostPort]*WebServerConfig{} + for k, v := range src.Web { + if v == nil { + dst.Web[k] = nil + } else { + dst.Web[k] = v.Clone() + } + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _ServiceConfigCloneNeedsRegeneration = ServiceConfig(struct { + TCP map[uint16]*TCPPortHandler + Web map[HostPort]*WebServerConfig + Tun bool +}{}) + // Clone makes a deep copy of TCPPortHandler. // The result aliases no memory with the original. func (src *TCPPortHandler) Clone() *TCPPortHandler { diff --git a/ipn/ipn_view.go b/ipn/ipn_view.go index ff48b9c8975f9..871270b8564f1 100644 --- a/ipn/ipn_view.go +++ b/ipn/ipn_view.go @@ -18,9 +18,75 @@ import ( "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=Prefs,ServeConfig,TCPPortHandler,HTTPHandler,WebServerConfig +//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=LoginProfile,Prefs,ServeConfig,ServiceConfig,TCPPortHandler,HTTPHandler,WebServerConfig -// View returns a readonly view of Prefs. +// View returns a read-only view of LoginProfile. +func (p *LoginProfile) View() LoginProfileView { + return LoginProfileView{Đļ: p} +} + +// LoginProfileView provides a read-only view over LoginProfile. +// +// Its methods should only be called if `Valid()` returns true. +type LoginProfileView struct { + // Đļ is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + Đļ *LoginProfile +} + +// Valid reports whether v's underlying value is non-nil. +func (v LoginProfileView) Valid() bool { return v.Đļ != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v LoginProfileView) AsStruct() *LoginProfile { + if v.Đļ == nil { + return nil + } + return v.Đļ.Clone() +} + +func (v LoginProfileView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } + +func (v *LoginProfileView) UnmarshalJSON(b []byte) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x LoginProfile + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +func (v LoginProfileView) ID() ProfileID { return v.Đļ.ID } +func (v LoginProfileView) Name() string { return v.Đļ.Name } +func (v LoginProfileView) NetworkProfile() NetworkProfile { return v.Đļ.NetworkProfile } +func (v LoginProfileView) Key() StateKey { return v.Đļ.Key } +func (v LoginProfileView) UserProfile() tailcfg.UserProfile { return v.Đļ.UserProfile } +func (v LoginProfileView) NodeID() tailcfg.StableNodeID { return v.Đļ.NodeID } +func (v LoginProfileView) LocalUserID() WindowsUserID { return v.Đļ.LocalUserID } +func (v LoginProfileView) ControlURL() string { return v.Đļ.ControlURL } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _LoginProfileViewNeedsRegeneration = LoginProfile(struct { + ID ProfileID + Name string + NetworkProfile NetworkProfile + Key StateKey + UserProfile tailcfg.UserProfile + NodeID tailcfg.StableNodeID + LocalUserID WindowsUserID + ControlURL string +}{}) + +// View returns a read-only view of Prefs. func (p *Prefs) View() PrefsView { return PrefsView{Đļ: p} } @@ -36,7 +102,7 @@ type PrefsView struct { Đļ *Prefs } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v PrefsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -85,6 +151,9 @@ func (v PrefsView) Egg() bool { return v.Đļ.Eg func (v PrefsView) AdvertiseRoutes() views.Slice[netip.Prefix] { return views.SliceOf(v.Đļ.AdvertiseRoutes) } +func (v PrefsView) AdvertiseServices() views.Slice[string] { + return views.SliceOf(v.Đļ.AdvertiseServices) +} func (v PrefsView) NoSNAT() bool { return v.Đļ.NoSNAT } func (v PrefsView) NoStatefulFiltering() opt.Bool { return v.Đļ.NoStatefulFiltering } func (v PrefsView) NetfilterMode() preftype.NetfilterMode { return v.Đļ.NetfilterMode } @@ -97,6 +166,10 @@ func (v PrefsView) NetfilterKind() string { return v.Đļ.Netfilte func (v PrefsView) DriveShares() views.SliceView[*drive.Share, drive.ShareView] { return views.SliceOfViews[*drive.Share, drive.ShareView](v.Đļ.DriveShares) } +func (v PrefsView) RelayServerPort() views.ValuePointer[int] { + return views.ValuePointerOf(v.Đļ.RelayServerPort) +} + func (v PrefsView) AllowSingleHosts() marshalAsTrueInJSON { return v.Đļ.AllowSingleHosts } func (v PrefsView) Persist() persist.PersistView { return v.Đļ.Persist.View() } @@ -120,6 +193,7 @@ var _PrefsViewNeedsRegeneration = Prefs(struct { ForceDaemon bool Egg bool AdvertiseRoutes []netip.Prefix + AdvertiseServices []string NoSNAT bool NoStatefulFiltering opt.Bool NetfilterMode preftype.NetfilterMode @@ -130,11 +204,12 @@ var _PrefsViewNeedsRegeneration = Prefs(struct { PostureChecking bool NetfilterKind string DriveShares []*drive.Share + RelayServerPort *int AllowSingleHosts marshalAsTrueInJSON Persist *persist.Persist }{}) -// View returns a readonly view of ServeConfig. +// View returns a read-only view of ServeConfig. func (p *ServeConfig) View() ServeConfigView { return ServeConfigView{Đļ: p} } @@ -150,7 +225,7 @@ type ServeConfigView struct { Đļ *ServeConfig } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v ServeConfigView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -191,6 +266,12 @@ func (v ServeConfigView) Web() views.MapFn[HostPort, *WebServerConfig, WebServer }) } +func (v ServeConfigView) Services() views.MapFn[tailcfg.ServiceName, *ServiceConfig, ServiceConfigView] { + return views.MapFnOf(v.Đļ.Services, func(t *ServiceConfig) ServiceConfigView { + return t.View() + }) +} + func (v ServeConfigView) AllowFunnel() views.Map[HostPort, bool] { return views.MapOf(v.Đļ.AllowFunnel) } @@ -206,12 +287,78 @@ func (v ServeConfigView) ETag() string { return v.Đļ.ETag } var _ServeConfigViewNeedsRegeneration = ServeConfig(struct { TCP map[uint16]*TCPPortHandler Web map[HostPort]*WebServerConfig + Services map[tailcfg.ServiceName]*ServiceConfig AllowFunnel map[HostPort]bool Foreground map[string]*ServeConfig ETag string }{}) -// View returns a readonly view of TCPPortHandler. +// View returns a read-only view of ServiceConfig. +func (p *ServiceConfig) View() ServiceConfigView { + return ServiceConfigView{Đļ: p} +} + +// ServiceConfigView provides a read-only view over ServiceConfig. +// +// Its methods should only be called if `Valid()` returns true. +type ServiceConfigView struct { + // Đļ is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + Đļ *ServiceConfig +} + +// Valid reports whether v's underlying value is non-nil. +func (v ServiceConfigView) Valid() bool { return v.Đļ != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v ServiceConfigView) AsStruct() *ServiceConfig { + if v.Đļ == nil { + return nil + } + return v.Đļ.Clone() +} + +func (v ServiceConfigView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } + +func (v *ServiceConfigView) UnmarshalJSON(b []byte) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x ServiceConfig + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +func (v ServiceConfigView) TCP() views.MapFn[uint16, *TCPPortHandler, TCPPortHandlerView] { + return views.MapFnOf(v.Đļ.TCP, func(t *TCPPortHandler) TCPPortHandlerView { + return t.View() + }) +} + +func (v ServiceConfigView) Web() views.MapFn[HostPort, *WebServerConfig, WebServerConfigView] { + return views.MapFnOf(v.Đļ.Web, func(t *WebServerConfig) WebServerConfigView { + return t.View() + }) +} +func (v ServiceConfigView) Tun() bool { return v.Đļ.Tun } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _ServiceConfigViewNeedsRegeneration = ServiceConfig(struct { + TCP map[uint16]*TCPPortHandler + Web map[HostPort]*WebServerConfig + Tun bool +}{}) + +// View returns a read-only view of TCPPortHandler. func (p *TCPPortHandler) View() TCPPortHandlerView { return TCPPortHandlerView{Đļ: p} } @@ -227,7 +374,7 @@ type TCPPortHandlerView struct { Đļ *TCPPortHandler } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v TCPPortHandlerView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -269,7 +416,7 @@ var _TCPPortHandlerViewNeedsRegeneration = TCPPortHandler(struct { TerminateTLS string }{}) -// View returns a readonly view of HTTPHandler. +// View returns a read-only view of HTTPHandler. func (p *HTTPHandler) View() HTTPHandlerView { return HTTPHandlerView{Đļ: p} } @@ -285,7 +432,7 @@ type HTTPHandlerView struct { Đļ *HTTPHandler } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v HTTPHandlerView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -325,7 +472,7 @@ var _HTTPHandlerViewNeedsRegeneration = HTTPHandler(struct { Text string }{}) -// View returns a readonly view of WebServerConfig. +// View returns a read-only view of WebServerConfig. func (p *WebServerConfig) View() WebServerConfigView { return WebServerConfigView{Đļ: p} } @@ -341,7 +488,7 @@ type WebServerConfigView struct { Đļ *WebServerConfig } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v WebServerConfigView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with diff --git a/ipn/ipnauth/access.go b/ipn/ipnauth/access.go new file mode 100644 index 0000000000000..74c66392221b2 --- /dev/null +++ b/ipn/ipnauth/access.go @@ -0,0 +1,17 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnauth + +// ProfileAccess is a bitmask representing the requested, required, or granted +// access rights to an [ipn.LoginProfile]. +// +// It is not to be written to disk or transmitted over the network in its integer form, +// but rather serialized to a string or other format if ever needed. +type ProfileAccess uint + +// Define access rights that might be granted or denied on a per-profile basis. +const ( + // Disconnect is required to disconnect (or switch from) a Tailscale profile. + Disconnect = ProfileAccess(1 << iota) +) diff --git a/ipn/ipnauth/actor.go b/ipn/ipnauth/actor.go index db3192c9100ad..108bdd341ae6a 100644 --- a/ipn/ipnauth/actor.go +++ b/ipn/ipnauth/actor.go @@ -4,9 +4,18 @@ package ipnauth import ( + "context" + "encoding/json" + "fmt" + + "tailscale.com/client/tailscale/apitype" "tailscale.com/ipn" + "tailscale.com/tailcfg" ) +// AuditLogFunc is any function that can be used to log audit actions performed by an [Actor]. +type AuditLogFunc func(action tailcfg.ClientAuditAction, details string) error + // Actor is any actor using the [ipnlocal.LocalBackend]. // // It typically represents a specific OS user, indicating that an operation @@ -20,6 +29,22 @@ type Actor interface { // Username returns the user name associated with the receiver, // or "" if the actor does not represent a specific user. Username() (string, error) + // ClientID returns a non-zero ClientID and true if the actor represents + // a connected LocalAPI client. Otherwise, it returns a zero value and false. + ClientID() (_ ClientID, ok bool) + + // Context returns the context associated with the actor. + // It carries additional information about the actor + // and is canceled when the actor is done. + Context() context.Context + + // CheckProfileAccess checks whether the actor has the necessary access rights + // to perform a given action on the specified Tailscale profile. + // It returns an error if access is denied. + // + // If the auditLogger is non-nil, it is used to write details about the action + // to the audit log when required by the policy. + CheckProfileAccess(profile ipn.LoginProfileView, requestedAccess ProfileAccess, auditLogFn AuditLogFunc) error // IsLocalSystem reports whether the actor is the Windows' Local System account. // @@ -45,3 +70,65 @@ type ActorCloser interface { // Close releases resources associated with the receiver. Close() error } + +// ClientID is an opaque, comparable value used to identify a connected LocalAPI +// client, such as a connected Tailscale GUI or CLI. It does not necessarily +// correspond to the same [net.Conn] or any physical session. +// +// Its zero value is valid, but does not represent a specific connected client. +type ClientID struct { + v any +} + +// NoClientID is the zero value of [ClientID]. +var NoClientID ClientID + +// ClientIDFrom returns a new [ClientID] derived from the specified value. +// ClientIDs derived from equal values are equal. +func ClientIDFrom[T comparable](v T) ClientID { + return ClientID{v} +} + +// String implements [fmt.Stringer]. +func (id ClientID) String() string { + if id.v == nil { + return "(none)" + } + return fmt.Sprint(id.v) +} + +// MarshalJSON implements [json.Marshaler]. +// It is primarily used for testing. +func (id ClientID) MarshalJSON() ([]byte, error) { + return json.Marshal(id.v) +} + +// UnmarshalJSON implements [json.Unmarshaler]. +// It is primarily used for testing. +func (id *ClientID) UnmarshalJSON(b []byte) error { + return json.Unmarshal(b, &id.v) +} + +type actorWithRequestReason struct { + Actor + ctx context.Context +} + +// WithRequestReason returns an [Actor] that wraps the given actor and +// carries the specified request reason in its context. +func WithRequestReason(actor Actor, requestReason string) Actor { + ctx := apitype.RequestReasonKey.WithValue(actor.Context(), requestReason) + return &actorWithRequestReason{Actor: actor, ctx: ctx} +} + +// Context implements [Actor]. +func (a *actorWithRequestReason) Context() context.Context { return a.ctx } + +type withoutCloseActor struct{ Actor } + +// WithoutClose returns an [Actor] that does not expose the [ActorCloser] interface. +// In other words, _, ok := WithoutClose(actor).(ActorCloser) will always be false, +// even if the original actor implements [ActorCloser]. +func WithoutClose(actor Actor) Actor { + return withoutCloseActor{actor} +} diff --git a/ipn/ipnauth/actor_windows.go b/ipn/ipnauth/actor_windows.go new file mode 100644 index 0000000000000..90d3bdd362bbf --- /dev/null +++ b/ipn/ipnauth/actor_windows.go @@ -0,0 +1,102 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnauth + +import ( + "context" + "errors" + + "golang.org/x/sys/windows" + "tailscale.com/ipn" + "tailscale.com/types/lazy" +) + +// WindowsActor implements [Actor]. +var _ Actor = (*WindowsActor)(nil) + +// WindowsActor represents a logged in Windows user. +type WindowsActor struct { + ctx context.Context + cancelCtx context.CancelFunc + token WindowsToken + uid ipn.WindowsUserID + username lazy.SyncValue[string] +} + +// NewWindowsActorWithToken returns a new [WindowsActor] for the user +// represented by the given [windows.Token]. +// It takes ownership of the token. +func NewWindowsActorWithToken(t windows.Token) (_ *WindowsActor, err error) { + tok := newToken(t) + uid, err := tok.UID() + if err != nil { + t.Close() + return nil, err + } + ctx, cancelCtx := context.WithCancel(context.Background()) + return &WindowsActor{ctx: ctx, cancelCtx: cancelCtx, token: tok, uid: uid}, nil +} + +// UserID implements [Actor]. +func (a *WindowsActor) UserID() ipn.WindowsUserID { + return a.uid +} + +// Username implements [Actor]. +func (a *WindowsActor) Username() (string, error) { + return a.username.GetErr(a.token.Username) +} + +// ClientID implements [Actor]. +func (a *WindowsActor) ClientID() (_ ClientID, ok bool) { + // TODO(nickkhyl): assign and return a client ID when the actor + // represents a connected LocalAPI client. + return NoClientID, false +} + +// Context implements [Actor]. +func (a *WindowsActor) Context() context.Context { + return a.ctx +} + +// CheckProfileAccess implements [Actor]. +func (a *WindowsActor) CheckProfileAccess(profile ipn.LoginProfileView, _ ProfileAccess, _ AuditLogFunc) error { + if profile.LocalUserID() != a.UserID() { + // TODO(nickkhyl): return errors of more specific types and have them + // translated to the appropriate HTTP status codes in the API handler. + return errors.New("the target profile does not belong to the user") + } + return nil +} + +// IsLocalSystem implements [Actor]. +// +// Deprecated: this method exists for compatibility with the current (as of 2025-02-06) +// permission model and will be removed as we progress on tailscale/corp#18342. +func (a *WindowsActor) IsLocalSystem() bool { + // https://web.archive.org/web/2024/https://learn.microsoft.com/en-us/windows-server/identity/ad-ds/manage/understand-security-identifiers + const systemUID = ipn.WindowsUserID("S-1-5-18") + return a.uid == systemUID +} + +// IsLocalAdmin implements [Actor]. +// +// Deprecated: this method exists for compatibility with the current (as of 2025-02-06) +// permission model and will be removed as we progress on tailscale/corp#18342. +func (a *WindowsActor) IsLocalAdmin(operatorUID string) bool { + return a.token.IsElevated() +} + +// Close releases resources associated with the actor +// and cancels its context. +func (a *WindowsActor) Close() error { + if a.token != nil { + if err := a.token.Close(); err != nil { + return err + } + a.token = nil + } + a.cancelCtx() + return nil +} diff --git a/ipn/ipnauth/ipnauth_notwindows.go b/ipn/ipnauth/ipnauth_notwindows.go index 3dad8233a2198..d9d11bd0a17a1 100644 --- a/ipn/ipnauth/ipnauth_notwindows.go +++ b/ipn/ipnauth/ipnauth_notwindows.go @@ -18,7 +18,9 @@ import ( func GetConnIdentity(_ logger.Logf, c net.Conn) (ci *ConnIdentity, err error) { ci = &ConnIdentity{conn: c, notWindows: true} _, ci.isUnixSock = c.(*net.UnixConn) - ci.creds, _ = peercred.Get(c) + if ci.creds, _ = peercred.Get(c); ci.creds != nil { + ci.pid, _ = ci.creds.PID() + } return ci, nil } diff --git a/ipn/ipnauth/ipnauth_windows.go b/ipn/ipnauth/ipnauth_windows.go index 9abd04cd19408..1138bc23d20fa 100644 --- a/ipn/ipnauth/ipnauth_windows.go +++ b/ipn/ipnauth/ipnauth_windows.go @@ -36,6 +36,12 @@ type token struct { t windows.Token } +func newToken(t windows.Token) *token { + tok := &token{t: t} + runtime.SetFinalizer(tok, func(t *token) { t.Close() }) + return tok +} + func (t *token) UID() (ipn.WindowsUserID, error) { sid, err := t.uid() if err != nil { @@ -184,7 +190,5 @@ func (ci *ConnIdentity) WindowsToken() (WindowsToken, error) { return nil, err } - result := &token{t: windows.Token(h)} - runtime.SetFinalizer(result, func(t *token) { t.Close() }) - return result, nil + return newToken(windows.Token(h)), nil } diff --git a/ipn/ipnauth/policy.go b/ipn/ipnauth/policy.go new file mode 100644 index 0000000000000..aa4ec4100ff93 --- /dev/null +++ b/ipn/ipnauth/policy.go @@ -0,0 +1,74 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnauth + +import ( + "errors" + "fmt" + + "tailscale.com/client/tailscale/apitype" + "tailscale.com/ipn" + "tailscale.com/tailcfg" + "tailscale.com/util/syspolicy" +) + +type actorWithPolicyChecks struct{ Actor } + +// WithPolicyChecks returns an [Actor] that wraps the given actor and +// performs additional policy checks on top of the access checks +// implemented by the wrapped actor. +func WithPolicyChecks(actor Actor) Actor { + // TODO(nickkhyl): We should probably exclude the Windows Local System + // account from policy checks as well. + switch actor.(type) { + case unrestricted: + return actor + default: + return &actorWithPolicyChecks{Actor: actor} + } +} + +// CheckProfileAccess implements [Actor]. +func (a actorWithPolicyChecks) CheckProfileAccess(profile ipn.LoginProfileView, requestedAccess ProfileAccess, auditLogger AuditLogFunc) error { + if err := a.Actor.CheckProfileAccess(profile, requestedAccess, auditLogger); err != nil { + return err + } + requestReason := apitype.RequestReasonKey.Value(a.Context()) + return CheckDisconnectPolicy(a.Actor, profile, requestReason, auditLogger) +} + +// CheckDisconnectPolicy checks if the policy allows the specified actor to disconnect +// Tailscale with the given optional reason. It returns nil if the operation is allowed, +// or an error if it is not. If auditLogger is non-nil, it is called to log the action +// when required by the policy. +// +// Note: this function only checks the policy and does not check whether the actor has +// the necessary access rights to the device or profile. It is intended to be used by +// [Actor] implementations on platforms where [syspolicy] is supported. +// +// TODO(nickkhyl): unexport it when we move [ipn.Actor] implementations from [ipnserver] +// and corp to this package. +func CheckDisconnectPolicy(actor Actor, profile ipn.LoginProfileView, reason string, auditFn AuditLogFunc) error { + if alwaysOn, _ := syspolicy.GetBoolean(syspolicy.AlwaysOn, false); !alwaysOn { + return nil + } + if allowWithReason, _ := syspolicy.GetBoolean(syspolicy.AlwaysOnOverrideWithReason, false); !allowWithReason { + return errors.New("disconnect not allowed: always-on mode is enabled") + } + if reason == "" { + return errors.New("disconnect not allowed: reason required") + } + if auditFn != nil { + var details string + if username, _ := actor.Username(); username != "" { // best-effort; we don't have it on all platforms + details = fmt.Sprintf("%q is being disconnected by %q: %v", profile.Name(), username, reason) + } else { + details = fmt.Sprintf("%q is being disconnected: %v", profile.Name(), reason) + } + if err := auditFn(tailcfg.AuditNodeDisconnect, details); err != nil { + return err + } + } + return nil +} diff --git a/ipn/ipnauth/self.go b/ipn/ipnauth/self.go new file mode 100644 index 0000000000000..9b430dc6d915e --- /dev/null +++ b/ipn/ipnauth/self.go @@ -0,0 +1,51 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnauth + +import ( + "context" + + "tailscale.com/ipn" +) + +// Self is a caller identity that represents the tailscaled itself and therefore +// has unlimited access. +var Self Actor = unrestricted{} + +// unrestricted is an [Actor] that has unlimited access to the currently running +// tailscaled instance. It's typically used for operations performed by tailscaled +// on its own, or upon a request from the control plane, rather on behalf of a user. +type unrestricted struct{} + +// UserID implements [Actor]. +func (unrestricted) UserID() ipn.WindowsUserID { return "" } + +// Username implements [Actor]. +func (unrestricted) Username() (string, error) { return "", nil } + +// Context implements [Actor]. +func (unrestricted) Context() context.Context { return context.Background() } + +// ClientID implements [Actor]. +// It always returns (NoClientID, false) because the tailscaled itself +// is not a connected LocalAPI client. +func (unrestricted) ClientID() (_ ClientID, ok bool) { return NoClientID, false } + +// CheckProfileAccess implements [Actor]. +func (unrestricted) CheckProfileAccess(_ ipn.LoginProfileView, _ ProfileAccess, _ AuditLogFunc) error { + // Unrestricted access to all profiles. + return nil +} + +// IsLocalSystem implements [Actor]. +// +// Deprecated: this method exists for compatibility with the current (as of 2025-01-28) +// permission model and will be removed as we progress on tailscale/corp#18342. +func (unrestricted) IsLocalSystem() bool { return false } + +// IsLocalAdmin implements [Actor]. +// +// Deprecated: this method exists for compatibility with the current (as of 2025-01-28) +// permission model and will be removed as we progress on tailscale/corp#18342. +func (unrestricted) IsLocalAdmin(operatorUID string) bool { return false } diff --git a/ipn/ipnauth/test_actor.go b/ipn/ipnauth/test_actor.go new file mode 100644 index 0000000000000..80c5fcc8a6328 --- /dev/null +++ b/ipn/ipnauth/test_actor.go @@ -0,0 +1,48 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnauth + +import ( + "cmp" + "context" + "errors" + + "tailscale.com/ipn" +) + +var _ Actor = (*TestActor)(nil) + +// TestActor is an [Actor] used exclusively for testing purposes. +type TestActor struct { + UID ipn.WindowsUserID // OS-specific UID of the user, if the actor represents a local Windows user + Name string // username associated with the actor, or "" + NameErr error // error to be returned by [TestActor.Username] + CID ClientID // non-zero if the actor represents a connected LocalAPI client + Ctx context.Context // context associated with the actor + LocalSystem bool // whether the actor represents the special Local System account on Windows + LocalAdmin bool // whether the actor has local admin access +} + +// UserID implements [Actor]. +func (a *TestActor) UserID() ipn.WindowsUserID { return a.UID } + +// Username implements [Actor]. +func (a *TestActor) Username() (string, error) { return a.Name, a.NameErr } + +// ClientID implements [Actor]. +func (a *TestActor) ClientID() (_ ClientID, ok bool) { return a.CID, a.CID != NoClientID } + +// Context implements [Actor]. +func (a *TestActor) Context() context.Context { return cmp.Or(a.Ctx, context.Background()) } + +// CheckProfileAccess implements [Actor]. +func (a *TestActor) CheckProfileAccess(profile ipn.LoginProfileView, _ ProfileAccess, _ AuditLogFunc) error { + return errors.New("profile access denied") +} + +// IsLocalSystem implements [Actor]. +func (a *TestActor) IsLocalSystem() bool { return a.LocalSystem } + +// IsLocalAdmin implements [Actor]. +func (a *TestActor) IsLocalAdmin(operatorUID string) bool { return a.LocalAdmin } diff --git a/ipn/ipnext/ipnext.go b/ipn/ipnext/ipnext.go new file mode 100644 index 0000000000000..066763ba4d2fa --- /dev/null +++ b/ipn/ipnext/ipnext.go @@ -0,0 +1,401 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ipnext defines types and interfaces used for extending the core LocalBackend +// functionality with additional features and services. +package ipnext + +import ( + "errors" + "fmt" + "iter" + "net/netip" + + "tailscale.com/control/controlclient" + "tailscale.com/feature" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" + "tailscale.com/tsd" + "tailscale.com/tstime" + "tailscale.com/types/logger" + "tailscale.com/types/mapx" +) + +// Extension augments LocalBackend with additional functionality. +// +// An extension uses the provided [Host] to register callbacks +// and interact with the backend in a controlled, well-defined +// and thread-safe manner. +// +// Extensions are registered using [RegisterExtension]. +// +// They must be safe for concurrent use. +type Extension interface { + // Name is a unique name of the extension. + // It must be the same as the name used to register the extension. + Name() string + + // Init is called to initialize the extension when LocalBackend's + // Start method is called. Extensions are created but not initialized + // unless LocalBackend is started. + // + // If the extension cannot be initialized, it must return an error, + // and its Shutdown method will not be called on the host's shutdown. + // Returned errors are not fatal; they are used for logging. + // A [SkipExtension] error indicates an intentional decision rather than a failure. + Init(Host) error + + // Shutdown is called when LocalBackend is shutting down, + // provided the extension was initialized. For multiple extensions, + // Shutdown is called in the reverse order of Init. + // Returned errors are not fatal; they are used for logging. + // After a call to Shutdown, the extension will not be called again. + Shutdown() error +} + +// NewExtensionFn is a function that instantiates an [Extension]. +// If a registered extension cannot be instantiated, the function must return an error. +// If the extension should be skipped at runtime, it must return either [SkipExtension] +// or a wrapped [SkipExtension]. Any other error returned is fatal and will prevent +// the LocalBackend from starting. +type NewExtensionFn func(logger.Logf, SafeBackend) (Extension, error) + +// SkipExtension is an error returned by [NewExtensionFn] to indicate that the extension +// should be skipped rather than prevent the LocalBackend from starting. +// +// Skipping an extension should be reserved for cases where the extension is not supported +// on the current platform or configuration, or depends on a feature that is not available, +// or otherwise should be disabled permanently rather than temporarily. +// +// Specifically, it must not be returned if the extension is not required right now +// based on user preferences, policy settings, the current tailnet, or other factors +// that may change throughout the LocalBackend's lifetime. +var SkipExtension = errors.New("skipping extension") + +// Definition describes a registered [Extension]. +type Definition struct { + name string // name under which the extension is registered + newFn NewExtensionFn // function that creates a new instance of the extension +} + +// Name returns the name of the extension. +func (d *Definition) Name() string { + return d.name +} + +// MakeExtension instantiates the extension. +func (d *Definition) MakeExtension(logf logger.Logf, sb SafeBackend) (Extension, error) { + ext, err := d.newFn(logf, sb) + if err != nil { + return nil, err + } + if ext.Name() != d.name { + return nil, fmt.Errorf("extension name mismatch: registered %q; actual %q", d.name, ext.Name()) + } + return ext, nil +} + +// extensions is a map of registered extensions, +// where the key is the name of the extension. +var extensions mapx.OrderedMap[string, *Definition] + +// RegisterExtension registers a function that instantiates an [Extension]. +// The name must be the same as returned by the extension's [Extension.Name]. +// +// It must be called on the main goroutine before LocalBackend is created, +// such as from an init function of the package implementing the extension. +// +// It panics if newExt is nil or if an extension with the same name +// has already been registered. +func RegisterExtension(name string, newExt NewExtensionFn) { + if newExt == nil { + panic(fmt.Sprintf("ipnext: newExt is nil: %q", name)) + } + if extensions.Contains(name) { + panic(fmt.Sprintf("ipnext: duplicate extension name %q", name)) + } + extensions.Set(name, &Definition{name, newExt}) +} + +// Extensions iterates over the extensions in the order they were registered +// via [RegisterExtension]. +func Extensions() iter.Seq[*Definition] { + return extensions.Values() +} + +// DefinitionForTest returns a [Definition] for the specified [Extension]. +// It is primarily used for testing where the test code needs to instantiate +// and use an extension without registering it. +func DefinitionForTest(ext Extension) *Definition { + return &Definition{ + name: ext.Name(), + newFn: func(logger.Logf, SafeBackend) (Extension, error) { return ext, nil }, + } +} + +// DefinitionWithErrForTest returns a [Definition] with the specified extension name +// whose [Definition.MakeExtension] method returns the specified error. +// It is used for testing. +func DefinitionWithErrForTest(name string, err error) *Definition { + return &Definition{ + name: name, + newFn: func(logger.Logf, SafeBackend) (Extension, error) { return nil, err }, + } +} + +// Host is the API surface used by [Extension]s to interact with LocalBackend +// in a controlled manner. +// +// Extensions can register callbacks, request information, or perform actions +// via the [Host] interface. +// +// Typically, the host invokes registered callbacks when one of the following occurs: +// - LocalBackend notifies it of an event or state change that may be +// of interest to extensions, such as when switching [ipn.LoginProfile]. +// - LocalBackend needs to consult extensions for information, for example, +// determining the most appropriate profile for the current state of the system. +// - LocalBackend performs an extensible action, such as logging an auditable event, +// and delegates its execution to the extension. +// +// The callbacks are invoked synchronously, and the LocalBackend's state +// remains unchanged while callbacks execute. +// +// In contrast, actions initiated by extensions are generally asynchronous, +// as indicated by the "Async" suffix in their names. +// Performing actions may result in callbacks being invoked as described above. +// +// To prevent conflicts between extensions competing for shared state, +// such as the current profile or prefs, the host must not expose methods +// that directly modify that state. For example, instead of allowing extensions +// to switch profiles at-will, the host's [ProfileServices] provides a method +// to switch to the "best" profile. The host can then consult extensions +// to determine the appropriate profile to use and resolve any conflicts +// in a controlled manner. +// +// A host must be safe for concurrent use. +type Host interface { + // Extensions returns the host's [ExtensionServices]. + Extensions() ExtensionServices + + // Profiles returns the host's [ProfileServices]. + Profiles() ProfileServices + + // AuditLogger returns a function that calls all currently registered audit loggers. + // The function fails if any logger returns an error, indicating that the action + // cannot be logged and must not be performed. + // + // The returned function captures the current state (e.g., the current profile) at + // the time of the call and must not be persisted. + AuditLogger() ipnauth.AuditLogFunc + + // Hooks returns a non-nil pointer to a [Hooks] struct. + // Hooks must not be modified concurrently or after Tailscale has started. + Hooks() *Hooks + + // SendNotifyAsync sends a notification to the IPN bus, + // typically to the GUI client. + SendNotifyAsync(ipn.Notify) + + // NodeBackend returns the [NodeBackend] for the currently active node + // (which is approximately the same as the current profile). + NodeBackend() NodeBackend +} + +// SafeBackend is a subset of the [ipnlocal.LocalBackend] type's methods that +// are safe to call from extension hooks at any time (even hooks called while +// LocalBackend's internal mutex is held). +type SafeBackend interface { + Sys() *tsd.System + Clock() tstime.Clock + TailscaleVarRoot() string +} + +// ExtensionServices provides access to the [Host]'s extension management services, +// such as fetching active extensions. +type ExtensionServices interface { + // FindExtensionByName returns an active extension with the given name, + // or nil if no such extension exists. + FindExtensionByName(name string) any + + // FindMatchingExtension finds the first active extension that matches target, + // and if one is found, sets target to that extension and returns true. + // Otherwise, it returns false. + // + // It panics if target is not a non-nil pointer to either a type + // that implements [ipnext.Extension], or to any interface type. + FindMatchingExtension(target any) bool +} + +// ProfileServices provides access to the [Host]'s profile management services, +// such as switching profiles and registering profile change callbacks. +type ProfileServices interface { + // CurrentProfileState returns read-only views of the current profile + // and its preferences. The returned views are always valid, + // but the profile's [ipn.LoginProfileView.ID] returns "" + // if the profile is new and has not been persisted yet. + // + // The returned views are immutable snapshots of the current profile + // and prefs at the time of the call. The actual state is only guaranteed + // to remain unchanged and match these views for the duration + // of a callback invoked by the host, if used within that callback. + // + // Extensions that need the current profile or prefs at other times + // should typically subscribe to [ProfileStateChangeCallback] + // to be notified if the profile or prefs change after retrieval. + // CurrentProfileState returns both the profile and prefs + // to guarantee that they are consistent with each other. + CurrentProfileState() (ipn.LoginProfileView, ipn.PrefsView) + + // CurrentPrefs is like [CurrentProfileState] but only returns prefs. + CurrentPrefs() ipn.PrefsView + + // SwitchToBestProfileAsync asynchronously selects the best profile to use + // and switches to it, unless it is already the current profile. + // + // If an extension needs to know when a profile switch occurs, + // it must use [ProfileServices.RegisterProfileStateChangeCallback] + // to register a [ProfileStateChangeCallback]. + // + // The reason indicates why the profile is being switched, such as due + // to a client connecting or disconnecting or a change in the desktop + // session state. It is used for logging. + SwitchToBestProfileAsync(reason string) +} + +// ProfileStore provides read-only access to available login profiles and their preferences. +// It is not safe for concurrent use and can only be used from the callback it is passed to. +type ProfileStore interface { + // CurrentUserID returns the current user ID. It is only non-empty on + // Windows where we have a multi-user system. + // + // Deprecated: this method exists for compatibility with the current (as of 2024-08-27) + // permission model and will be removed as we progress on tailscale/corp#18342. + CurrentUserID() ipn.WindowsUserID + + // CurrentProfile returns a read-only [ipn.LoginProfileView] of the current profile. + // The returned view is always valid, but the profile's [ipn.LoginProfileView.ID] + // returns "" if the profile is new and has not been persisted yet. + CurrentProfile() ipn.LoginProfileView + + // CurrentPrefs returns a read-only view of the current prefs. + // The returned view is always valid. + CurrentPrefs() ipn.PrefsView + + // DefaultUserProfile returns a read-only view of the default (last used) profile for the specified user. + // It returns a read-only view of a new, non-persisted profile if the specified user does not have a default profile. + DefaultUserProfile(uid ipn.WindowsUserID) ipn.LoginProfileView +} + +// AuditLogProvider is a function that returns an [ipnauth.AuditLogFunc] for +// logging auditable actions. +type AuditLogProvider func() ipnauth.AuditLogFunc + +// ProfileResolver is a function that returns a read-only view of a login profile. +// An invalid view indicates no profile. A valid profile view with an empty [ipn.ProfileID] +// indicates that the profile is new and has not been persisted yet. +// The provided [ProfileStore] can only be used for the duration of the callback. +type ProfileResolver func(ProfileStore) ipn.LoginProfileView + +// ProfileStateChangeCallback is a function to be called when the current login profile +// or its preferences change. +// +// The sameNode parameter indicates whether the profile represents the same node as before, +// which is true when: +// - Only the profile's [ipn.Prefs] or metadata (e.g., [tailcfg.UserProfile]) have changed, +// but the node ID and [ipn.ProfileID] remain the same. +// - The profile has been persisted and assigned an [ipn.ProfileID] for the first time, +// so while its node ID and [ipn.ProfileID] have changed, it is still the same profile. +// +// It can be used to decide whether to reset state bound to the current profile or node identity. +// +// The profile and prefs are always valid, but the profile's [ipn.LoginProfileView.ID] +// returns "" if the profile is new and has not been persisted yet. +type ProfileStateChangeCallback func(_ ipn.LoginProfileView, _ ipn.PrefsView, sameNode bool) + +// NewControlClientCallback is a function to be called when a new [controlclient.Client] +// is created and before it is first used. The specified profile represents the node +// for which the cc is created and is always valid. Its [ipn.LoginProfileView.ID] +// returns "" if it is a new node whose profile has never been persisted. +// +// If the [controlclient.Client] is created due to a profile switch, any registered +// [ProfileStateChangeCallback]s are called first. +// +// It returns a function to be called when the cc is being shut down, +// or nil if no cleanup is needed. +type NewControlClientCallback func(controlclient.Client, ipn.LoginProfileView) (cleanup func()) + +// Hooks is a collection of hooks that extensions can add to (non-concurrently) +// during program initialization and can be called by LocalBackend and others at +// runtime. +// +// Each hook has its own rules about when it's called and what environment it +// has access to and what it's allowed to do. +type Hooks struct { + // BackendStateChange is called when the backend state changes. + BackendStateChange feature.Hooks[func(ipn.State)] + + // ProfileStateChange contains callbacks that are invoked when the current login profile + // or its [ipn.Prefs] change, after those changes have been made. The current login profile + // may be changed either because of a profile switch, or because the profile information + // was updated by [LocalBackend.SetControlClientStatus], including when the profile + // is first populated and persisted. + ProfileStateChange feature.Hooks[ProfileStateChangeCallback] + + // BackgroundProfileResolvers are registered background profile resolvers. + // They're used to determine the profile to use when no GUI/CLI client is connected. + // + // TODO(nickkhyl): allow specifying some kind of priority/altitude for the resolver. + // TODO(nickkhyl): make it a "profile resolver" instead of a "background profile resolver". + // The concepts of the "current user", "foreground profile" and "background profile" + // only exist on Windows, and we're moving away from them anyway. + BackgroundProfileResolvers feature.Hooks[ProfileResolver] + + // AuditLoggers are registered [AuditLogProvider]s. + // Each provider is called to get an [ipnauth.AuditLogFunc] when an auditable action + // is about to be performed. If an audit logger returns an error, the action is denied. + AuditLoggers feature.Hooks[AuditLogProvider] + + // NewControlClient are the functions to be called when a new control client + // is created. It is called with the LocalBackend locked. + NewControlClient feature.Hooks[NewControlClientCallback] + + // OnSelfChange is called (with LocalBackend.mu held) when the self node + // changes, including changing to nothing (an invalid view). + OnSelfChange feature.Hooks[func(tailcfg.NodeView)] + + // MutateNotifyLocked is called to optionally mutate the provided Notify + // before sending it to the IPN bus. It is called with LocalBackend.mu held. + MutateNotifyLocked feature.Hooks[func(*ipn.Notify)] + + // SetPeerStatus is called to mutate PeerStatus. + // Callers must only use NodeBackend to read data. + SetPeerStatus feature.Hooks[func(*ipnstate.PeerStatus, tailcfg.NodeView, NodeBackend)] +} + +// NodeBackend is an interface to query the current node and its peers. +// +// It is not a snapshot in time but is locked to a particular node. +type NodeBackend interface { + // AppendMatchingPeers appends all peers that match the predicate + // to the base slice and returns it. + AppendMatchingPeers(base []tailcfg.NodeView, pred func(tailcfg.NodeView) bool) []tailcfg.NodeView + + // PeerCaps returns the capabilities that src has to this node. + PeerCaps(src netip.Addr) tailcfg.PeerCapMap + + // PeerHasCap reports whether the peer has the specified peer capability. + PeerHasCap(peer tailcfg.NodeView, cap tailcfg.PeerCapability) bool + + // PeerAPIBase returns the "http://ip:port" URL base to reach peer's + // PeerAPI, or the empty string if the peer is invalid or doesn't support + // PeerAPI. + PeerAPIBase(tailcfg.NodeView) string + + // PeerHasPeerAPI whether the provided peer supports PeerAPI. + // + // It effectively just reports whether PeerAPIBase(node) is non-empty, but + // potentially more efficiently. + PeerHasPeerAPI(tailcfg.NodeView) bool +} diff --git a/ipn/ipnlocal/bus.go b/ipn/ipnlocal/bus.go new file mode 100644 index 0000000000000..111a877d849d8 --- /dev/null +++ b/ipn/ipnlocal/bus.go @@ -0,0 +1,160 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "context" + "time" + + "tailscale.com/ipn" + "tailscale.com/tstime" +) + +type rateLimitingBusSender struct { + fn func(*ipn.Notify) (keepGoing bool) + lastFlush time.Time // last call to fn, or zero value if none + interval time.Duration // 0 to flush immediately; non-zero to rate limit sends + clock tstime.DefaultClock // non-nil for testing + didSendTestHook func() // non-nil for testing + + // pending, if non-nil, is the pending notification that we + // haven't sent yet. We own this memory to mutate. + pending *ipn.Notify + + // flushTimer is non-nil if the timer is armed. + flushTimer tstime.TimerController // effectively a *time.Timer + flushTimerC <-chan time.Time // ... said ~Timer's C chan +} + +func (s *rateLimitingBusSender) close() { + if s.flushTimer != nil { + s.flushTimer.Stop() + } +} + +func (s *rateLimitingBusSender) flushChan() <-chan time.Time { + return s.flushTimerC +} + +func (s *rateLimitingBusSender) flush() (keepGoing bool) { + if n := s.pending; n != nil { + s.pending = nil + return s.flushNotify(n) + } + return true +} + +func (s *rateLimitingBusSender) flushNotify(n *ipn.Notify) (keepGoing bool) { + s.lastFlush = s.clock.Now() + return s.fn(n) +} + +// send conditionally sends n to the underlying fn, possibly rate +// limiting it, depending on whether s.interval is set, and whether +// n is a notable notification that the client (typically a GUI) would +// want to act on (render) immediately. +// +// It returns whether the caller should keep looping. +// +// The passed-in memory 'n' is owned by the caller and should +// not be mutated. +func (s *rateLimitingBusSender) send(n *ipn.Notify) (keepGoing bool) { + if s.interval <= 0 { + // No rate limiting case. + return s.fn(n) + } + if isNotableNotify(n) { + // Notable notifications are always sent immediately. + // But first send any boring one that was pending. + // TODO(bradfitz): there might be a boring one pending + // with a NetMap or Engine field that is redundant + // with the new one (n) with NetMap or Engine populated. + // We should clear the pending one's NetMap/Engine in + // that case. Or really, merge the two, but mergeBoringNotifies + // only handles the case of both sides being boring. + // So for now, flush both. + if !s.flush() { + return false + } + return s.flushNotify(n) + } + s.pending = mergeBoringNotifies(s.pending, n) + d := s.clock.Now().Sub(s.lastFlush) + if d > s.interval { + return s.flush() + } + nextFlushIn := s.interval - d + if s.flushTimer == nil { + s.flushTimer, s.flushTimerC = s.clock.NewTimer(nextFlushIn) + } else { + s.flushTimer.Reset(nextFlushIn) + } + return true +} + +func (s *rateLimitingBusSender) Run(ctx context.Context, ch <-chan *ipn.Notify) { + for { + select { + case <-ctx.Done(): + return + case n, ok := <-ch: + if !ok { + return + } + if !s.send(n) { + return + } + if f := s.didSendTestHook; f != nil { + f() + } + case <-s.flushChan(): + if !s.flush() { + return + } + } + } +} + +// mergeBoringNotify merges new notify 'src' into possibly-nil 'dst', +// either mutating 'dst' or allocating a new one if 'dst' is nil, +// returning the merged result. +// +// dst and src must both be "boring" (i.e. not notable per isNotifiableNotify). +func mergeBoringNotifies(dst, src *ipn.Notify) *ipn.Notify { + if dst == nil { + dst = &ipn.Notify{Version: src.Version} + } + if src.NetMap != nil { + dst.NetMap = src.NetMap + } + if src.Engine != nil { + dst.Engine = src.Engine + } + return dst +} + +// isNotableNotify reports whether n is a "notable" notification that +// should be sent on the IPN bus immediately (e.g. to GUIs) without +// rate limiting it for a few seconds. +// +// It effectively reports whether n contains any field set that's +// not NetMap or Engine. +func isNotableNotify(n *ipn.Notify) bool { + if n == nil { + return false + } + return n.State != nil || + n.SessionID != "" || + n.BrowseToURL != nil || + n.LocalTCPPort != nil || + n.ClientVersion != nil || + n.Prefs != nil || + n.ErrMessage != nil || + n.LoginFinished != nil || + !n.DriveShares.IsNil() || + n.Health != nil || + len(n.IncomingFiles) > 0 || + len(n.OutgoingFiles) > 0 || + n.FilesWaiting != nil +} diff --git a/ipn/ipnlocal/bus_test.go b/ipn/ipnlocal/bus_test.go new file mode 100644 index 0000000000000..5c75ac54d688d --- /dev/null +++ b/ipn/ipnlocal/bus_test.go @@ -0,0 +1,220 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "context" + "reflect" + "slices" + "testing" + "time" + + "tailscale.com/drive" + "tailscale.com/ipn" + "tailscale.com/tstest" + "tailscale.com/tstime" + "tailscale.com/types/logger" + "tailscale.com/types/netmap" + "tailscale.com/types/views" +) + +func TestIsNotableNotify(t *testing.T) { + tests := []struct { + name string + notify *ipn.Notify + want bool + }{ + {"nil", nil, false}, + {"empty", &ipn.Notify{}, false}, + {"version", &ipn.Notify{Version: "foo"}, false}, + {"netmap", &ipn.Notify{NetMap: new(netmap.NetworkMap)}, false}, + {"engine", &ipn.Notify{Engine: new(ipn.EngineStatus)}, false}, + } + + // Then for all other fields, assume they're notable. + // We use reflect to catch fields that might be added in the future without + // remembering to update the [isNotableNotify] function. + rt := reflect.TypeFor[ipn.Notify]() + for i := range rt.NumField() { + n := &ipn.Notify{} + sf := rt.Field(i) + switch sf.Name { + case "_", "NetMap", "Engine", "Version": + // Already covered above or not applicable. + continue + case "DriveShares": + n.DriveShares = views.SliceOfViews[*drive.Share, drive.ShareView](make([]*drive.Share, 1)) + default: + rf := reflect.ValueOf(n).Elem().Field(i) + switch rf.Kind() { + case reflect.Pointer: + rf.Set(reflect.New(rf.Type().Elem())) + case reflect.String: + rf.SetString("foo") + case reflect.Slice: + rf.Set(reflect.MakeSlice(rf.Type(), 1, 1)) + default: + t.Errorf("unhandled field kind %v for %q", rf.Kind(), sf.Name) + } + } + + tests = append(tests, struct { + name string + notify *ipn.Notify + want bool + }{ + name: "field-" + rt.Field(i).Name, + notify: n, + want: true, + }) + } + + for _, tt := range tests { + if got := isNotableNotify(tt.notify); got != tt.want { + t.Errorf("%v: got %v; want %v", tt.name, got, tt.want) + } + } +} + +type rateLimitingBusSenderTester struct { + tb testing.TB + got []*ipn.Notify + clock *tstest.Clock + s *rateLimitingBusSender +} + +func (st *rateLimitingBusSenderTester) init() { + if st.s != nil { + return + } + st.clock = tstest.NewClock(tstest.ClockOpts{ + Start: time.Unix(1731777537, 0), // time I wrote this test :) + }) + st.s = &rateLimitingBusSender{ + clock: tstime.DefaultClock{Clock: st.clock}, + fn: func(n *ipn.Notify) bool { + st.got = append(st.got, n) + return true + }, + } +} + +func (st *rateLimitingBusSenderTester) send(n *ipn.Notify) { + st.tb.Helper() + st.init() + if !st.s.send(n) { + st.tb.Fatal("unexpected send failed") + } +} + +func (st *rateLimitingBusSenderTester) advance(d time.Duration) { + st.tb.Helper() + st.clock.Advance(d) + select { + case <-st.s.flushChan(): + if !st.s.flush() { + st.tb.Fatal("unexpected flush failed") + } + default: + } +} + +func TestRateLimitingBusSender(t *testing.T) { + nm1 := &ipn.Notify{NetMap: new(netmap.NetworkMap)} + nm2 := &ipn.Notify{NetMap: new(netmap.NetworkMap)} + eng1 := &ipn.Notify{Engine: new(ipn.EngineStatus)} + eng2 := &ipn.Notify{Engine: new(ipn.EngineStatus)} + + t.Run("unbuffered", func(t *testing.T) { + st := &rateLimitingBusSenderTester{tb: t} + st.send(nm1) + st.send(nm2) + st.send(eng1) + st.send(eng2) + if !slices.Equal(st.got, []*ipn.Notify{nm1, nm2, eng1, eng2}) { + t.Errorf("got %d items; want 4 specific ones, unmodified", len(st.got)) + } + }) + + t.Run("buffered", func(t *testing.T) { + st := &rateLimitingBusSenderTester{tb: t} + st.init() + st.s.interval = 1 * time.Second + st.send(&ipn.Notify{Version: "initial"}) + if len(st.got) != 1 { + t.Fatalf("got %d items; expected 1 (first to flush immediately)", len(st.got)) + } + st.send(nm1) + st.send(nm2) + st.send(eng1) + st.send(eng2) + if len(st.got) != 1 { + if len(st.got) != 1 { + t.Fatalf("got %d items; expected still just that first 1", len(st.got)) + } + } + + // But moving the clock should flush the rest, collasced into one new one. + st.advance(5 * time.Second) + if len(st.got) != 2 { + t.Fatalf("got %d items; want 2", len(st.got)) + } + gotn := st.got[1] + if gotn.NetMap != nm2.NetMap { + t.Errorf("got wrong NetMap; got %p", gotn.NetMap) + } + if gotn.Engine != eng2.Engine { + t.Errorf("got wrong Engine; got %p", gotn.Engine) + } + if t.Failed() { + t.Logf("failed Notify was: %v", logger.AsJSON(gotn)) + } + }) + + // Test the Run method + t.Run("run", func(t *testing.T) { + st := &rateLimitingBusSenderTester{tb: t} + st.init() + st.s.interval = 1 * time.Second + st.s.lastFlush = st.clock.Now() // pretend we just flushed + + flushc := make(chan *ipn.Notify, 1) + st.s.fn = func(n *ipn.Notify) bool { + flushc <- n + return true + } + didSend := make(chan bool, 2) + st.s.didSendTestHook = func() { didSend <- true } + waitSend := func() { + select { + case <-didSend: + case <-time.After(5 * time.Second): + t.Error("timeout waiting for call to send") + } + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + incoming := make(chan *ipn.Notify, 2) + go func() { + incoming <- nm1 + waitSend() + incoming <- nm2 + waitSend() + st.advance(5 * time.Second) + select { + case n := <-flushc: + if n.NetMap != nm2.NetMap { + t.Errorf("got wrong NetMap; got %p", n.NetMap) + } + case <-time.After(10 * time.Second): + t.Error("timeout") + } + cancel() + }() + + st.s.Run(ctx, incoming) + }) +} diff --git a/ipn/ipnlocal/c2n.go b/ipn/ipnlocal/c2n.go index de6ca2321a741..b3379475164ae 100644 --- a/ipn/ipnlocal/c2n.go +++ b/ipn/ipnlocal/c2n.go @@ -10,19 +10,16 @@ import ( "errors" "fmt" "io" - "net" "net/http" "os" "os/exec" "path" "path/filepath" "runtime" - "sort" "strconv" "strings" "time" - "github.com/kortschak/wol" "tailscale.com/clientupdate" "tailscale.com/envknob" "tailscale.com/ipn" @@ -66,9 +63,6 @@ var c2nHandlers = map[methodAndPath]c2nHandler{ req("GET /update"): handleC2NUpdateGet, req("POST /update"): handleC2NUpdatePost, - // Wake-on-LAN. - req("POST /wol"): handleC2NWoL, - // Device posture. req("GET /posture/identity"): handleC2NPostureIdentityGet, @@ -77,6 +71,21 @@ var c2nHandlers = map[methodAndPath]c2nHandler{ // Linux netfilter. req("POST /netfilter-kind"): handleC2NSetNetfilterKind, + + // VIP services. + req("GET /vip-services"): handleC2NVIPServicesGet, +} + +// RegisterC2N registers a new c2n handler for the given pattern. +// +// A pattern is like "GET /foo" (specific to an HTTP method) or "/foo" (all +// methods). It panics if the pattern is already registered. +func RegisterC2N(pattern string, h func(*LocalBackend, http.ResponseWriter, *http.Request)) { + k := req(pattern) + if _, ok := c2nHandlers[k]; ok { + panic(fmt.Sprintf("c2n: duplicate handler for %q", pattern)) + } + c2nHandlers[k] = h } type c2nHandler func(*LocalBackend, http.ResponseWriter, *http.Request) @@ -269,6 +278,16 @@ func handleC2NSetNetfilterKind(b *LocalBackend, w http.ResponseWriter, r *http.R w.WriteHeader(http.StatusNoContent) } +func handleC2NVIPServicesGet(b *LocalBackend, w http.ResponseWriter, r *http.Request) { + b.logf("c2n: GET /vip-services received") + var res tailcfg.C2NVIPServicesResponse + res.VIPServices = b.VIPServices() + res.ServicesHash = b.vipServiceHash(res.VIPServices) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(res) +} + func handleC2NUpdateGet(b *LocalBackend, w http.ResponseWriter, r *http.Request) { b.logf("c2n: GET /update received") @@ -332,18 +351,16 @@ func handleC2NPostureIdentityGet(b *LocalBackend, w http.ResponseWriter, r *http } if choice.ShouldEnable(b.Prefs().PostureChecking()) { - sns, err := posture.GetSerialNumbers(b.logf) + res.SerialNumbers, err = posture.GetSerialNumbers(b.logf) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + b.logf("c2n: GetSerialNumbers returned error: %v", err) } - res.SerialNumbers = sns // TODO(tailscale/corp#21371, 2024-07-10): once this has landed in a stable release // and looks good in client metrics, remove this parameter and always report MAC // addresses. if r.FormValue("hwaddrs") == "true" { - res.IfaceHardwareAddrs, err = posture.GetHardwareAddrs() + res.IfaceHardwareAddrs, err = b.getHardwareAddrs() if err != nil { b.logf("c2n: GetHardwareAddrs returned error: %v", err) } @@ -352,6 +369,8 @@ func handleC2NPostureIdentityGet(b *LocalBackend, w http.ResponseWriter, r *http res.PostureDisabled = true } + b.logf("c2n: posture identity disabled=%v reported %d serials %d hwaddrs", res.PostureDisabled, len(res.SerialNumbers), len(res.IfaceHardwareAddrs)) + w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(res) } @@ -490,55 +509,6 @@ func regularFileExists(path string) bool { return err == nil && fi.Mode().IsRegular() } -func handleC2NWoL(b *LocalBackend, w http.ResponseWriter, r *http.Request) { - r.ParseForm() - var macs []net.HardwareAddr - for _, macStr := range r.Form["mac"] { - mac, err := net.ParseMAC(macStr) - if err != nil { - http.Error(w, "bad 'mac' param", http.StatusBadRequest) - return - } - macs = append(macs, mac) - } - var res struct { - SentTo []string - Errors []string - } - st := b.sys.NetMon.Get().InterfaceState() - if st == nil { - res.Errors = append(res.Errors, "no interface state") - writeJSON(w, &res) - return - } - var password []byte // TODO(bradfitz): support? does anything use WoL passwords? - for _, mac := range macs { - for ifName, ips := range st.InterfaceIPs { - for _, ip := range ips { - if ip.Addr().IsLoopback() || ip.Addr().Is6() { - continue - } - local := &net.UDPAddr{ - IP: ip.Addr().AsSlice(), - Port: 0, - } - remote := &net.UDPAddr{ - IP: net.IPv4bcast, - Port: 0, - } - if err := wol.Wake(mac, password, local, remote); err != nil { - res.Errors = append(res.Errors, err.Error()) - } else { - res.SentTo = append(res.SentTo, ifName) - } - break // one per interface is enough - } - } - } - sort.Strings(res.SentTo) - writeJSON(w, &res) -} - // handleC2NTLSCertStatus returns info about the last TLS certificate issued for the // provided domain. This can be called by the controlplane to clean up DNS TXT // records when they're no longer needed by LetsEncrypt. diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index d87374bbbcd61..86052eb8d5861 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -32,7 +32,6 @@ import ( "sync" "time" - "github.com/tailscale/golang-x-crypto/acme" "tailscale.com/atomicfile" "tailscale.com/envknob" "tailscale.com/hostinfo" @@ -40,6 +39,8 @@ import ( "tailscale.com/ipn/ipnstate" "tailscale.com/ipn/store" "tailscale.com/ipn/store/mem" + "tailscale.com/net/bakedroots" + "tailscale.com/tempfork/acme" "tailscale.com/types/logger" "tailscale.com/util/testenv" "tailscale.com/version" @@ -118,6 +119,9 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string } if pair, err := getCertPEMCached(cs, domain, now); err == nil { + if envknob.IsCertShareReadOnlyMode() { + return pair, nil + } // If we got here, we have a valid unexpired cert. // Check whether we should start an async renewal. shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, pair, minValidity) @@ -133,7 +137,7 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string if minValidity == 0 { logf("starting async renewal") // Start renewal in the background, return current valid cert. - go b.getCertPEM(context.Background(), cs, logf, traceACME, domain, now, minValidity) + b.goTracker.Go(func() { getCertPEM(context.Background(), b, cs, logf, traceACME, domain, now, minValidity) }) return pair, nil } // If the caller requested a specific validity duration, fall through @@ -141,7 +145,11 @@ func (b *LocalBackend) GetCertPEMWithValidity(ctx context.Context, domain string logf("starting sync renewal") } - pair, err := b.getCertPEM(ctx, cs, logf, traceACME, domain, now, minValidity) + if envknob.IsCertShareReadOnlyMode() { + return nil, fmt.Errorf("retrieving cached TLS certificate failed and cert store is configured in read-only mode, not attempting to issue a new certificate: %w", err) + } + + pair, err := getCertPEM(ctx, b, cs, logf, traceACME, domain, now, minValidity) if err != nil { logf("getCertPEM: %v", err) return nil, err @@ -249,15 +257,13 @@ type certStore interface { // for now. If they're expired, it returns errCertExpired. // If they don't exist, it returns ipn.ErrStateNotExist. Read(domain string, now time.Time) (*TLSCertKeyPair, error) - // WriteCert writes the cert for domain. - WriteCert(domain string, cert []byte) error - // WriteKey writes the key for domain. - WriteKey(domain string, key []byte) error // ACMEKey returns the value previously stored via WriteACMEKey. // It is a PEM encoded ECDSA key. ACMEKey() ([]byte, error) // WriteACMEKey stores the provided PEM encoded ECDSA key. WriteACMEKey([]byte) error + // WriteTLSCertAndKey writes the cert and key for domain. + WriteTLSCertAndKey(domain string, cert, key []byte) error } var errCertExpired = errors.New("cert expired") @@ -343,6 +349,13 @@ func (f certFileStore) WriteKey(domain string, key []byte) error { return atomicfile.WriteFile(keyFile(f.dir, domain), key, 0600) } +func (f certFileStore) WriteTLSCertAndKey(domain string, cert, key []byte) error { + if err := f.WriteKey(domain, key); err != nil { + return err + } + return f.WriteCert(domain, cert) +} + // certStateStore implements certStore by storing the cert & key files in an ipn.StateStore. type certStateStore struct { ipn.StateStore @@ -352,7 +365,29 @@ type certStateStore struct { testRoots *x509.CertPool } +// TLSCertKeyReader is an interface implemented by state stores where it makes +// sense to read the TLS cert and key in a single operation that can be +// distinguished from generic state value reads. Currently this is only implemented +// by the kubestore.Store, which, in some cases, need to read cert and key from a +// non-cached TLS Secret. +type TLSCertKeyReader interface { + ReadTLSCertAndKey(domain string) ([]byte, []byte, error) +} + func (s certStateStore) Read(domain string, now time.Time) (*TLSCertKeyPair, error) { + // If we're using a store that supports atomic reads, use that + if kr, ok := s.StateStore.(TLSCertKeyReader); ok { + cert, key, err := kr.ReadTLSCertAndKey(domain) + if err != nil { + return nil, err + } + if !validCertPEM(domain, key, cert, s.testRoots, now) { + return nil, errCertExpired + } + return &TLSCertKeyPair{CertPEM: cert, KeyPEM: key, Cached: true}, nil + } + + // Otherwise fall back to separate reads certPEM, err := s.ReadState(ipn.StateKey(domain + ".crt")) if err != nil { return nil, err @@ -383,6 +418,27 @@ func (s certStateStore) WriteACMEKey(key []byte) error { return ipn.WriteState(s.StateStore, ipn.StateKey(acmePEMName), key) } +// TLSCertKeyWriter is an interface implemented by state stores that can write the TLS +// cert and key in a single atomic operation. Currently this is only implemented +// by the kubestore.StoreKube. +type TLSCertKeyWriter interface { + WriteTLSCertAndKey(domain string, cert, key []byte) error +} + +// WriteTLSCertAndKey writes the TLS cert and key for domain to the current +// LocalBackend's StateStore. +func (s certStateStore) WriteTLSCertAndKey(domain string, cert, key []byte) error { + // If we're using a store that supports atomic writes, use that. + if aw, ok := s.StateStore.(TLSCertKeyWriter); ok { + return aw.WriteTLSCertAndKey(domain, cert, key) + } + // Otherwise fall back to separate writes for cert and key. + if err := s.WriteKey(domain, key); err != nil { + return err + } + return s.WriteCert(domain, cert) +} + // TLSCertKeyPair is a TLS public and private key, and whether they were obtained // from cache or freshly obtained. type TLSCertKeyPair struct { @@ -419,21 +475,24 @@ func getCertPEMCached(cs certStore, domain string, now time.Time) (p *TLSCertKey return cs.Read(domain, now) } -func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { +// getCertPem checks if a cert needs to be renewed and if so, renews it. +// It can be overridden in tests. +var getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { acmeMu.Lock() defer acmeMu.Unlock() // In case this method was triggered multiple times in parallel (when // serving incoming requests), check whether one of the other goroutines // already renewed the cert before us. - if p, err := getCertPEMCached(cs, domain, now); err == nil { + previous, err := getCertPEMCached(cs, domain, now) + if err == nil { // shouldStartDomainRenewal caches its result so it's OK to call this // frequently. - shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, p, minValidity) + shouldRenew, err := b.shouldStartDomainRenewal(cs, domain, now, previous, minValidity) if err != nil { logf("error checking for certificate renewal: %v", err) } else if !shouldRenew { - return p, nil + return previous, nil } } else if !errors.Is(err, ipn.ErrStateNotExist) && !errors.Is(err, errCertExpired) { return nil, err @@ -444,6 +503,10 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger return nil, err } + if !isDefaultDirectoryURL(ac.DirectoryURL) { + logf("acme: using Directory URL %q", ac.DirectoryURL) + } + a, err := ac.GetReg(ctx, "" /* pre-RFC param */) switch { case err == nil: @@ -474,7 +537,17 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger return nil, err } - order, err := ac.AuthorizeOrder(ctx, []acme.AuthzID{{Type: "dns", Value: domain}}) + // If we have a previous cert, include it in the order. Assuming we're + // within the ARI renewal window this should exclude us from LE rate + // limits. + var opts []acme.OrderOption + if previous != nil { + prevCrt, err := previous.parseCertificate() + if err == nil { + opts = append(opts, acme.WithOrderReplacesCert(prevCrt)) + } + } + order, err := ac.AuthorizeOrder(ctx, []acme.AuthzID{{Type: "dns", Value: domain}}, opts...) if err != nil { return nil, err } @@ -545,9 +618,6 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger if err := encodeECDSAKey(&privPEM, certPrivKey); err != nil { return nil, err } - if err := cs.WriteKey(domain, privPEM.Bytes()); err != nil { - return nil, err - } csr, err := certRequest(certPrivKey, domain, nil) if err != nil { @@ -555,6 +625,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger } logf("requesting cert...") + traceACME(csr) der, _, err := ac.CreateOrderCert(ctx, order.FinalizeURL, csr, true) if err != nil { return nil, fmt.Errorf("CreateOrder: %v", err) @@ -568,7 +639,7 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger return nil, err } } - if err := cs.WriteCert(domain, certPEM.Bytes()); err != nil { + if err := cs.WriteTLSCertAndKey(domain, certPEM.Bytes(), privPEM.Bytes()); err != nil { return nil, err } b.domainRenewed(domain) @@ -577,10 +648,10 @@ func (b *LocalBackend) getCertPEM(ctx context.Context, cs certStore, logf logger } // certRequest generates a CSR for the given common name cn and optional SANs. -func certRequest(key crypto.Signer, cn string, ext []pkix.Extension, san ...string) ([]byte, error) { +func certRequest(key crypto.Signer, name string, ext []pkix.Extension) ([]byte, error) { req := &x509.CertificateRequest{ - Subject: pkix.Name{CommonName: cn}, - DNSNames: san, + Subject: pkix.Name{CommonName: name}, + DNSNames: []string{name}, ExtraExtensions: ext, } return x509.CreateCertificateRequest(rand.Reader, req, key) @@ -657,15 +728,16 @@ func acmeClient(cs certStore) (*acme.Client, error) { // LetsEncrypt), we should make sure that they support ARI extension (see // shouldStartDomainRenewalARI). return &acme.Client{ - Key: key, - UserAgent: "tailscaled/" + version.Long(), + Key: key, + UserAgent: "tailscaled/" + version.Long(), + DirectoryURL: envknob.String("TS_DEBUG_ACME_DIRECTORY_URL"), }, nil } // validCertPEM reports whether the given certificate is valid for domain at now. // // If roots != nil, it is used instead of the system root pool. This is meant -// to support testing, and production code should pass roots == nil. +// to support testing; production code should pass roots == nil. func validCertPEM(domain string, keyPEM, certPEM []byte, roots *x509.CertPool, now time.Time) bool { if len(keyPEM) == 0 || len(certPEM) == 0 { return false @@ -688,16 +760,51 @@ func validCertPEM(domain string, keyPEM, certPEM []byte, roots *x509.CertPool, n intermediates.AddCert(cert) } } + return validateLeaf(leaf, intermediates, domain, now, roots) +} + +// validateLeaf is a helper for [validCertPEM]. +// +// If called with roots == nil, it will use the system root pool as well as the +// baked-in roots. If non-nil, only those roots are used. +func validateLeaf(leaf *x509.Certificate, intermediates *x509.CertPool, domain string, now time.Time, roots *x509.CertPool) bool { if leaf == nil { return false } - _, err = leaf.Verify(x509.VerifyOptions{ + _, err := leaf.Verify(x509.VerifyOptions{ DNSName: domain, CurrentTime: now, Roots: roots, Intermediates: intermediates, }) - return err == nil + if err != nil && roots == nil { + // If validation failed and they specified nil for roots (meaning to use + // the system roots), then give it another chance to validate using the + // binary's baked-in roots (LetsEncrypt). See tailscale/tailscale#14690. + return validateLeaf(leaf, intermediates, domain, now, bakedroots.Get()) + } + + if err == nil { + return true + } + + // When pointed at a non-prod ACME server, we don't expect to have the CA + // in our system or baked-in roots. Verify only throws UnknownAuthorityError + // after first checking the leaf cert's expiry, hostnames etc, so we know + // that the only reason for an error is to do with constructing a full chain. + // Allow this error so that cert caching still works in testing environments. + if errors.As(err, &x509.UnknownAuthorityError{}) { + acmeURL := envknob.String("TS_DEBUG_ACME_DIRECTORY_URL") + if !isDefaultDirectoryURL(acmeURL) { + return true + } + } + + return false +} + +func isDefaultDirectoryURL(u string) bool { + return u == "" || u == acme.LetsEncryptURL } // validLookingCertDomain reports whether name looks like a valid domain name that diff --git a/ipn/ipnlocal/cert_test.go b/ipn/ipnlocal/cert_test.go index 3ae7870e3174f..e2398f670b5ad 100644 --- a/ipn/ipnlocal/cert_test.go +++ b/ipn/ipnlocal/cert_test.go @@ -6,6 +6,7 @@ package ipnlocal import ( + "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" @@ -14,11 +15,17 @@ import ( "embed" "encoding/pem" "math/big" + "os" + "path/filepath" "testing" "time" "github.com/google/go-cmp/cmp" + "tailscale.com/envknob" "tailscale.com/ipn/store/mem" + "tailscale.com/tstest" + "tailscale.com/types/logger" + "tailscale.com/util/must" ) func TestValidLookingCertDomain(t *testing.T) { @@ -47,10 +54,10 @@ var certTestFS embed.FS func TestCertStoreRoundTrip(t *testing.T) { const testDomain = "example.com" - // Use a fixed verification timestamp so validity doesn't fall off when the - // cert expires. If you update the test data below, this may also need to be - // updated. + // Use fixed verification timestamps so validity doesn't change over time. + // If you update the test data below, these may also need to be updated. testNow := time.Date(2023, time.February, 10, 0, 0, 0, 0, time.UTC) + testExpired := time.Date(2026, time.February, 10, 0, 0, 0, 0, time.UTC) // To re-generate a root certificate and domain certificate for testing, // use: @@ -78,21 +85,23 @@ func TestCertStoreRoundTrip(t *testing.T) { } tests := []struct { - name string - store certStore + name string + store certStore + debugACMEURL bool }{ - {"FileStore", certFileStore{dir: t.TempDir(), testRoots: roots}}, - {"StateStore", certStateStore{StateStore: new(mem.Store), testRoots: roots}}, + {"FileStore", certFileStore{dir: t.TempDir(), testRoots: roots}, false}, + {"FileStore_UnknownCA", certFileStore{dir: t.TempDir()}, true}, + {"StateStore", certStateStore{StateStore: new(mem.Store), testRoots: roots}, false}, + {"StateStore_UnknownCA", certStateStore{StateStore: new(mem.Store)}, true}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - if err := test.store.WriteCert(testDomain, testCert); err != nil { - t.Fatalf("WriteCert: unexpected error: %v", err) + if test.debugACMEURL { + t.Setenv("TS_DEBUG_ACME_DIRECTORY_URL", "https://acme-staging-v02.api.letsencrypt.org/directory") } - if err := test.store.WriteKey(testDomain, testKey); err != nil { - t.Fatalf("WriteKey: unexpected error: %v", err) + if err := test.store.WriteTLSCertAndKey(testDomain, testCert, testKey); err != nil { + t.Fatalf("WriteTLSCertAndKey: unexpected error: %v", err) } - kp, err := test.store.Read(testDomain, testNow) if err != nil { t.Fatalf("Read: unexpected error: %v", err) @@ -103,6 +112,10 @@ func TestCertStoreRoundTrip(t *testing.T) { if diff := cmp.Diff(kp.KeyPEM, testKey); diff != "" { t.Errorf("Key (-got, +want):\n%s", diff) } + unexpected, err := test.store.Read(testDomain, testExpired) + if err != errCertExpired { + t.Fatalf("Read: expected expiry error: %v", string(unexpected.CertPEM)) + } }) } } @@ -199,3 +212,167 @@ func TestShouldStartDomainRenewal(t *testing.T) { }) } } + +func TestDebugACMEDirectoryURL(t *testing.T) { + for _, tc := range []string{"", "https://acme-staging-v02.api.letsencrypt.org/directory"} { + const setting = "TS_DEBUG_ACME_DIRECTORY_URL" + t.Run(tc, func(t *testing.T) { + t.Setenv(setting, tc) + ac, err := acmeClient(certStateStore{StateStore: new(mem.Store)}) + if err != nil { + t.Fatalf("acmeClient creation err: %v", err) + } + if ac.DirectoryURL != tc { + t.Fatalf("acmeClient.DirectoryURL = %q, want %q", ac.DirectoryURL, tc) + } + }) + } +} + +func TestGetCertPEMWithValidity(t *testing.T) { + const testDomain = "example.com" + b := &LocalBackend{ + store: &mem.Store{}, + varRoot: t.TempDir(), + ctx: context.Background(), + logf: t.Logf, + } + certDir, err := b.certDir() + if err != nil { + t.Fatalf("certDir error: %v", err) + } + if _, err := b.getCertStore(); err != nil { + t.Fatalf("getCertStore error: %v", err) + } + testRoot, err := certTestFS.ReadFile("testdata/rootCA.pem") + if err != nil { + t.Fatal(err) + } + roots := x509.NewCertPool() + if !roots.AppendCertsFromPEM(testRoot) { + t.Fatal("Unable to add test CA to the cert pool") + } + testX509Roots = roots + defer func() { testX509Roots = nil }() + tests := []struct { + name string + now time.Time + // storeCerts is true if the test cert and key should be written to store. + storeCerts bool + readOnlyMode bool // TS_READ_ONLY_CERTS env var + wantAsyncRenewal bool // async issuance should be started + wantIssuance bool // sync issuance should be started + wantErr bool + }{ + { + name: "valid_no_renewal", + now: time.Date(2023, time.February, 20, 0, 0, 0, 0, time.UTC), + storeCerts: true, + wantAsyncRenewal: false, + wantIssuance: false, + wantErr: false, + }, + { + name: "issuance_needed", + now: time.Date(2023, time.February, 20, 0, 0, 0, 0, time.UTC), + storeCerts: false, + wantAsyncRenewal: false, + wantIssuance: true, + wantErr: false, + }, + { + name: "renewal_needed", + now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC), + storeCerts: true, + wantAsyncRenewal: true, + wantIssuance: false, + wantErr: false, + }, + { + name: "renewal_needed_read_only_mode", + now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC), + storeCerts: true, + readOnlyMode: true, + wantAsyncRenewal: false, + wantIssuance: false, + wantErr: false, + }, + { + name: "no_certs_read_only_mode", + now: time.Date(2025, time.May, 1, 0, 0, 0, 0, time.UTC), + storeCerts: false, + readOnlyMode: true, + wantAsyncRenewal: false, + wantIssuance: false, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + if tt.readOnlyMode { + envknob.Setenv("TS_CERT_SHARE_MODE", "ro") + } + + os.RemoveAll(certDir) + if tt.storeCerts { + os.MkdirAll(certDir, 0755) + if err := os.WriteFile(filepath.Join(certDir, "example.com.crt"), + must.Get(os.ReadFile("testdata/example.com.pem")), 0644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(certDir, "example.com.key"), + must.Get(os.ReadFile("testdata/example.com-key.pem")), 0644); err != nil { + t.Fatal(err) + } + } + + b.clock = tstest.NewClock(tstest.ClockOpts{Start: tt.now}) + + allDone := make(chan bool, 1) + defer b.goTracker.AddDoneCallback(func() { + b.mu.Lock() + defer b.mu.Unlock() + if b.goTracker.RunningGoroutines() > 0 { + return + } + select { + case allDone <- true: + default: + } + })() + + // Set to true if get getCertPEM is called. GetCertPEM can be called in a goroutine for async + // renewal or in the main goroutine if issuance is required to obtain valid TLS credentials. + getCertPemWasCalled := false + getCertPEM = func(ctx context.Context, b *LocalBackend, cs certStore, logf logger.Logf, traceACME func(any), domain string, now time.Time, minValidity time.Duration) (*TLSCertKeyPair, error) { + getCertPemWasCalled = true + return nil, nil + } + prevGoRoutines := b.goTracker.StartedGoroutines() + _, err = b.GetCertPEMWithValidity(context.Background(), testDomain, 0) + if (err != nil) != tt.wantErr { + t.Errorf("b.GetCertPemWithValidity got err %v, wants error: '%v'", err, tt.wantErr) + } + // GetCertPEMWithValidity calls getCertPEM in a goroutine if async renewal is needed. That's the + // only goroutine it starts, so this can be used to test if async renewal was started. + gotAsyncRenewal := b.goTracker.StartedGoroutines()-prevGoRoutines != 0 + if gotAsyncRenewal { + select { + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for goroutines to finish") + case <-allDone: + } + } + // Verify that async renewal was triggered if expected. + if tt.wantAsyncRenewal != gotAsyncRenewal { + t.Fatalf("wants getCertPem to be called async: %v, got called %v", tt.wantAsyncRenewal, gotAsyncRenewal) + } + // Verify that (non-async) issuance was started if expected. + gotIssuance := getCertPemWasCalled && !gotAsyncRenewal + if tt.wantIssuance != gotIssuance { + t.Errorf("wants getCertPem to be called: %v, got called %v", tt.wantIssuance, gotIssuance) + } + }) + } +} diff --git a/ipn/ipnlocal/dnsconfig_test.go b/ipn/ipnlocal/dnsconfig_test.go index 19d8e8b86b5ee..c0f5b25f38b11 100644 --- a/ipn/ipnlocal/dnsconfig_test.go +++ b/ipn/ipnlocal/dnsconfig_test.go @@ -382,14 +382,14 @@ func TestAllowExitNodeDNSProxyToServeName(t *testing.T) { t.Fatal("unexpected true on backend with nil NetMap") } - b.netMap = &netmap.NetworkMap{ + b.currentNode().SetNetMap(&netmap.NetworkMap{ DNS: tailcfg.DNSConfig{ ExitNodeFilteredSet: []string{ ".ts.net", "some.exact.bad", }, }, - } + }) tests := []struct { name string want bool diff --git a/ipn/ipnlocal/drive.go b/ipn/ipnlocal/drive.go index 98d563d8746b1..a06ea5e8c41ba 100644 --- a/ipn/ipnlocal/drive.go +++ b/ipn/ipnlocal/drive.go @@ -4,7 +4,6 @@ package ipnlocal import ( - "cmp" "fmt" "os" "slices" @@ -26,26 +25,14 @@ const ( // enabled. This is currently based on checking for the drive:share node // attribute. func (b *LocalBackend) DriveSharingEnabled() bool { - b.mu.Lock() - defer b.mu.Unlock() - return b.driveSharingEnabledLocked() -} - -func (b *LocalBackend) driveSharingEnabledLocked() bool { - return b.netMap != nil && b.netMap.SelfNode.HasCap(tailcfg.NodeAttrsTaildriveShare) + return b.currentNode().SelfHasCap(tailcfg.NodeAttrsTaildriveShare) } // DriveAccessEnabled reports whether accessing Taildrive shares on remote nodes // is enabled. This is currently based on checking for the drive:access node // attribute. func (b *LocalBackend) DriveAccessEnabled() bool { - b.mu.Lock() - defer b.mu.Unlock() - return b.driveAccessEnabledLocked() -} - -func (b *LocalBackend) driveAccessEnabledLocked() bool { - return b.netMap != nil && b.netMap.SelfNode.HasCap(tailcfg.NodeAttrsTaildriveAccess) + return b.currentNode().SelfHasCap(tailcfg.NodeAttrsTaildriveAccess) } // DriveSetServerAddr tells Taildrive to use the given address for connecting @@ -266,7 +253,7 @@ func (b *LocalBackend) driveNotifyShares(shares views.SliceView[*drive.Share, dr // shares has changed since the last notification. func (b *LocalBackend) driveNotifyCurrentSharesLocked() { var shares views.SliceView[*drive.Share, drive.ShareView] - if b.driveSharingEnabledLocked() { + if b.DriveSharingEnabled() { // Only populate shares if sharing is enabled. shares = b.pm.prefs.DriveShares() } @@ -310,12 +297,12 @@ func (b *LocalBackend) updateDrivePeersLocked(nm *netmap.NetworkMap) { } var driveRemotes []*drive.Remote - if b.driveAccessEnabledLocked() { + if b.DriveAccessEnabled() { // Only populate peers if access is enabled, otherwise leave blank. driveRemotes = b.driveRemotesFromPeers(nm) } - fs.SetRemotes(b.netMap.Domain, driveRemotes, b.newDriveTransport()) + fs.SetRemotes(nm.Domain, driveRemotes, b.newDriveTransport()) } func (b *LocalBackend) driveRemotesFromPeers(nm *netmap.NetworkMap) []*drive.Remote { @@ -330,36 +317,27 @@ func (b *LocalBackend) driveRemotesFromPeers(nm *netmap.NetworkMap) []*drive.Rem // Peers are available to Taildrive if: // - They are online // - They are allowed to share at least one folder with us - b.mu.Lock() - latestNetMap := b.netMap - b.mu.Unlock() - - idx, found := slices.BinarySearchFunc(latestNetMap.Peers, peerID, func(candidate tailcfg.NodeView, id tailcfg.NodeID) int { - return cmp.Compare(candidate.ID(), id) - }) - if !found { + cn := b.currentNode() + peer, ok := cn.PeerByID(peerID) + if !ok { return false } - peer := latestNetMap.Peers[idx] - // Exclude offline peers. // TODO(oxtoacart): for some reason, this correctly // catches when a node goes from offline to online, // but not the other way around... - online := peer.Online() - if online == nil || !*online { + // TODO(oxtoacart,nickkhyl): the reason was probably + // that we were using netmap.Peers instead of b.peers. + // The netmap.Peers slice is not updated in all cases. + // It should be fixed now that we use PeerByIDOk. + if !peer.Online().Get() { return false } // Check that the peer is allowed to share with us. - addresses := peer.Addresses() - for i := range addresses.Len() { - addr := addresses.At(i) - capsMap := b.PeerCaps(addr.Addr()) - if capsMap.HasCapability(tailcfg.PeerCapabilityTaildriveSharer) { - return true - } + if cn.PeerHasCap(peer, tailcfg.PeerCapabilityTaildriveSharer) { + return true } return false diff --git a/ipn/ipnlocal/expiry.go b/ipn/ipnlocal/expiry.go index 04c10226d50a0..d1119981594da 100644 --- a/ipn/ipnlocal/expiry.go +++ b/ipn/ipnlocal/expiry.go @@ -116,7 +116,7 @@ func (em *expiryManager) flagExpiredPeers(netmap *netmap.NetworkMap, localNow ti // since we discover endpoints via DERP, and due to DERP return // path optimization. mut.Endpoints = nil - mut.DERP = "" + mut.HomeDERP = 0 // Defense-in-depth: break the node's public key as well, in // case something tries to communicate. diff --git a/ipn/ipnlocal/expiry_test.go b/ipn/ipnlocal/expiry_test.go index af1aa337bbe0c..a2b10fe325b8a 100644 --- a/ipn/ipnlocal/expiry_test.go +++ b/ipn/ipnlocal/expiry_test.go @@ -283,11 +283,11 @@ func formatNodes(nodes []tailcfg.NodeView) string { } fmt.Fprintf(&sb, "(%d, %q", n.ID(), n.Name()) - if n.Online() != nil { - fmt.Fprintf(&sb, ", online=%v", *n.Online()) + if online, ok := n.Online().GetOk(); ok { + fmt.Fprintf(&sb, ", online=%v", online) } - if n.LastSeen() != nil { - fmt.Fprintf(&sb, ", lastSeen=%v", n.LastSeen().Unix()) + if lastSeen, ok := n.LastSeen().GetOk(); ok { + fmt.Fprintf(&sb, ", lastSeen=%v", lastSeen.Unix()) } if n.Key() != (key.NodePublic{}) { fmt.Fprintf(&sb, ", key=%v", n.Key().String()) diff --git a/ipn/ipnlocal/extension_host.go b/ipn/ipnlocal/extension_host.go new file mode 100644 index 0000000000000..ca802ab89f747 --- /dev/null +++ b/ipn/ipnlocal/extension_host.go @@ -0,0 +1,621 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "context" + "errors" + "fmt" + "maps" + "reflect" + "slices" + "strings" + "sync" + "sync/atomic" + "time" + + "tailscale.com/control/controlclient" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnext" + "tailscale.com/tailcfg" + "tailscale.com/types/logger" + "tailscale.com/util/execqueue" + "tailscale.com/util/mak" + "tailscale.com/util/testenv" +) + +// ExtensionHost is a bridge between the [LocalBackend] and the registered [ipnext.Extension]s. +// It implements [ipnext.Host] and is safe for concurrent use. +// +// A nil pointer to [ExtensionHost] is a valid, no-op extension host which is primarily used in tests +// that instantiate [LocalBackend] directly without using [NewExtensionHost]. +// +// The [LocalBackend] is not required to hold its mutex when calling the host's methods, +// but it typically does so either to prevent changes to its state (for example, the current profile) +// while callbacks are executing, or because it calls the host's methods as part of a larger operation +// that requires the mutex to be held. +// +// Extensions might invoke the host's methods either from callbacks triggered by the [LocalBackend], +// or in a response to external events. Some methods can be called by both the extensions and the backend. +// +// As a general rule, the host cannot assume anything about the current state of the [LocalBackend]'s +// internal mutex on entry to its methods, and therefore cannot safely call [LocalBackend] methods directly. +// +// The following are typical and supported patterns: +// - LocalBackend notifies the host about an event, such as a change in the current profile. +// The host invokes callbacks registered by Extensions, forwarding the event arguments to them. +// If necessary, the host can also update its own state for future use. +// - LocalBackend requests information from the host, such as the effective [ipnauth.AuditLogFunc] +// or the [ipn.LoginProfile] to use when no GUI/CLI client is connected. Typically, [LocalBackend] +// provides the required context to the host, and the host returns the result to [LocalBackend] +// after forwarding the request to the extensions. +// - Extension invokes the host's method to perform an action, such as switching to the "best" profile +// in response to a change in the device's state. Since the host does not know whether the [LocalBackend]'s +// internal mutex is held, it cannot invoke any methods on the [LocalBackend] directly and must instead +// do so asynchronously, such as by using [ExtensionHost.enqueueBackendOperation]. +// - Extension requests information from the host, such as the effective [ipnauth.AuditLogFunc] +// or the current [ipn.LoginProfile]. Since the host cannot invoke any methods on the [LocalBackend] directly, +// it should maintain its own view of the current state, updating it when the [LocalBackend] notifies it +// about a change or event. +// +// To safeguard against adopting incorrect or risky patterns, the host does not store [LocalBackend] in its fields +// and instead provides [ExtensionHost.enqueueBackendOperation]. Additionally, to make it easier to test extensions +// and to further reduce the risk of accessing unexported methods or fields of [LocalBackend], the host interacts +// with it via the [Backend] interface. +type ExtensionHost struct { + b Backend + hooks ipnext.Hooks + logf logger.Logf // prefixed with "ipnext:" + + // allExtensions holds the extensions in the order they were registered, + // including those that have not yet attempted initialization or have failed to initialize. + allExtensions []ipnext.Extension + + // initOnce is used to ensure that the extensions are initialized only once, + // even if [extensionHost.Init] is called multiple times. + initOnce sync.Once + initDone atomic.Bool + // shutdownOnce is like initOnce, but for [ExtensionHost.Shutdown]. + shutdownOnce sync.Once + + // workQueue maintains execution order for asynchronous operations requested by extensions. + // It is always an [execqueue.ExecQueue] except in some tests. + workQueue execQueue + // doEnqueueBackendOperation adds an asynchronous [LocalBackend] operation to the workQueue. + doEnqueueBackendOperation func(func(Backend)) + + shuttingDown atomic.Bool + + extByType sync.Map // reflect.Type -> ipnext.Extension + + // mu protects the following fields. + // It must not be held when calling [LocalBackend] methods + // or when invoking callbacks registered by extensions. + mu sync.Mutex + // initialized is whether the host and extensions have been fully initialized. + initialized atomic.Bool + // activeExtensions is a subset of allExtensions that have been initialized and are ready to use. + activeExtensions []ipnext.Extension + // extensionsByName are the extensions indexed by their names. + // They are not necessarily initialized (in activeExtensions) yet. + extensionsByName map[string]ipnext.Extension + // postInitWorkQueue is a queue of functions to be executed + // by the workQueue after all extensions have been initialized. + postInitWorkQueue []func(Backend) + + // currentProfile is a read-only view of the currently used profile. + // The view is always Valid, but might be of an empty, non-persisted profile. + currentProfile ipn.LoginProfileView + // currentPrefs is a read-only view of the current profile's [ipn.Prefs] + // with any private keys stripped. It is always Valid. + currentPrefs ipn.PrefsView +} + +// Backend is a subset of [LocalBackend] methods that are used by [ExtensionHost]. +// It is primarily used for testing. +type Backend interface { + // SwitchToBestProfile switches to the best profile for the current state of the system. + // The reason indicates why the profile is being switched. + SwitchToBestProfile(reason string) + + SendNotify(ipn.Notify) + + NodeBackend() ipnext.NodeBackend + + ipnext.SafeBackend +} + +// NewExtensionHost returns a new [ExtensionHost] which manages registered extensions for the given backend. +// The extensions are instantiated, but are not initialized until [ExtensionHost.Init] is called. +// It returns an error if instantiating any extension fails. +func NewExtensionHost(logf logger.Logf, b Backend) (*ExtensionHost, error) { + return newExtensionHost(logf, b) +} + +func NewExtensionHostForTest(logf logger.Logf, b Backend, overrideExts ...*ipnext.Definition) (*ExtensionHost, error) { + if !testenv.InTest() { + panic("use outside of test") + } + return newExtensionHost(logf, b, overrideExts...) +} + +// newExtensionHost is the shared implementation of [NewExtensionHost] and +// [NewExtensionHostForTest]. +// +// If overrideExts is non-nil, the registered extensions are ignored and the +// provided extensions are used instead. Overriding extensions is primarily used +// for testing. +func newExtensionHost(logf logger.Logf, b Backend, overrideExts ...*ipnext.Definition) (_ *ExtensionHost, err error) { + host := &ExtensionHost{ + b: b, + logf: logger.WithPrefix(logf, "ipnext: "), + workQueue: &execqueue.ExecQueue{}, + // The host starts with an empty profile and default prefs. + // We'll update them once [profileManager] notifies us of the initial profile. + currentProfile: zeroProfile, + currentPrefs: defaultPrefs, + } + + // All operations on the backend must be executed asynchronously by the work queue. + // DO NOT retain a direct reference to the backend in the host. + // See the docstring for [ExtensionHost] for more details. + host.doEnqueueBackendOperation = func(f func(Backend)) { + if f == nil { + panic("nil backend operation") + } + host.workQueue.Add(func() { f(b) }) + } + + // Use registered extensions. + extDef := ipnext.Extensions() + if overrideExts != nil { + // Use the provided, potentially empty, overrideExts + // instead of the registered ones. + extDef = slices.Values(overrideExts) + } + + for d := range extDef { + ext, err := d.MakeExtension(logf, b) + if errors.Is(err, ipnext.SkipExtension) { + // The extension wants to be skipped. + host.logf("%q: %v", d.Name(), err) + continue + } else if err != nil { + return nil, fmt.Errorf("failed to create %q extension: %v", d.Name(), err) + } + host.allExtensions = append(host.allExtensions, ext) + + if d.Name() != ext.Name() { + return nil, fmt.Errorf("extension name %q does not match the registered name %q", ext.Name(), d.Name()) + } + + if _, ok := host.extensionsByName[ext.Name()]; ok { + return nil, fmt.Errorf("duplicate extension name %q", ext.Name()) + } else { + mak.Set(&host.extensionsByName, ext.Name(), ext) + } + + typ := reflect.TypeOf(ext) + if _, ok := host.extByType.Load(typ); ok { + if _, ok := ext.(interface{ PermitDoubleRegister() }); !ok { + return nil, fmt.Errorf("duplicate extension type %T", ext) + } + } + host.extByType.Store(typ, ext) + } + return host, nil +} + +func (h *ExtensionHost) NodeBackend() ipnext.NodeBackend { + if h == nil { + return nil + } + return h.b.NodeBackend() +} + +// Init initializes the host and the extensions it manages. +func (h *ExtensionHost) Init() { + if h != nil { + h.initOnce.Do(h.init) + } +} + +var zeroHooks ipnext.Hooks + +func (h *ExtensionHost) Hooks() *ipnext.Hooks { + if h == nil { + return &zeroHooks + } + return &h.hooks +} + +func (h *ExtensionHost) init() { + defer h.initDone.Store(true) + + // Initialize the extensions in the order they were registered. + for _, ext := range h.allExtensions { + // Do not hold the lock while calling [ipnext.Extension.Init]. + // Extensions call back into the host to register their callbacks, + // and that would cause a deadlock if the h.mu is already held. + if err := ext.Init(h); err != nil { + // As per the [ipnext.Extension] interface, failures to initialize + // an extension are never fatal. The extension is simply skipped. + // + // But we handle [ipnext.SkipExtension] differently for nicer logging + // if the extension wants to be skipped and not actually failing. + if errors.Is(err, ipnext.SkipExtension) { + h.logf("%q: %v", ext.Name(), err) + } else { + h.logf("%q init failed: %v", ext.Name(), err) + } + continue + } + // Update the initialized extensions lists as soon as the extension is initialized. + // We'd like to make them visible to other extensions that are initialized later. + h.mu.Lock() + h.activeExtensions = append(h.activeExtensions, ext) + h.mu.Unlock() + } + + // Report active extensions to the log. + // TODO(nickkhyl): update client metrics to include the active/failed/skipped extensions. + h.mu.Lock() + extensionNames := slices.Collect(maps.Keys(h.extensionsByName)) + h.mu.Unlock() + h.logf("active extensions: %v", strings.Join(extensionNames, ", ")) + + // Additional init steps that need to be performed after all extensions have been initialized. + h.mu.Lock() + wq := h.postInitWorkQueue + h.postInitWorkQueue = nil + h.initialized.Store(true) + h.mu.Unlock() + + // Enqueue work that was requested and deferred during initialization. + h.doEnqueueBackendOperation(func(b Backend) { + for _, f := range wq { + f(b) + } + }) +} + +// Extensions implements [ipnext.Host]. +func (h *ExtensionHost) Extensions() ipnext.ExtensionServices { + // Currently, [ExtensionHost] implements [ExtensionServices] directly. + // We might want to extract it to a separate type in the future. + return h +} + +// FindExtensionByName implements [ipnext.ExtensionServices] +// and is also used by the [LocalBackend]. +// It returns nil if the extension is not found. +func (h *ExtensionHost) FindExtensionByName(name string) any { + if h == nil { + return nil + } + h.mu.Lock() + defer h.mu.Unlock() + return h.extensionsByName[name] +} + +// extensionIfaceType is the runtime type of the [ipnext.Extension] interface. +var extensionIfaceType = reflect.TypeFor[ipnext.Extension]() + +// GetExt returns the extension of type T registered with lb. +// If lb is nil or the extension is not found, it returns zero, false. +func GetExt[T ipnext.Extension](lb *LocalBackend) (_ T, ok bool) { + var zero T + if lb == nil { + return zero, false + } + if ext, ok := lb.extHost.extensionOfType(reflect.TypeFor[T]()); ok { + return ext.(T), true + } + return zero, false +} + +func (h *ExtensionHost) extensionOfType(t reflect.Type) (_ ipnext.Extension, ok bool) { + if h == nil { + return nil, false + } + if v, ok := h.extByType.Load(t); ok { + return v.(ipnext.Extension), true + } + return nil, false +} + +// FindMatchingExtension implements [ipnext.ExtensionServices] +// and is also used by the [LocalBackend]. +func (h *ExtensionHost) FindMatchingExtension(target any) bool { + if h == nil { + return false + } + + if target == nil { + panic("ipnext: target cannot be nil") + } + + val := reflect.ValueOf(target) + typ := val.Type() + if typ.Kind() != reflect.Ptr || val.IsNil() { + panic("ipnext: target must be a non-nil pointer") + } + targetType := typ.Elem() + if targetType.Kind() != reflect.Interface && !targetType.Implements(extensionIfaceType) { + panic("ipnext: *target must be interface or implement ipnext.Extension") + } + + h.mu.Lock() + defer h.mu.Unlock() + for _, ext := range h.activeExtensions { + if reflect.TypeOf(ext).AssignableTo(targetType) { + val.Elem().Set(reflect.ValueOf(ext)) + return true + } + } + return false +} + +// Profiles implements [ipnext.Host]. +func (h *ExtensionHost) Profiles() ipnext.ProfileServices { + // Currently, [ExtensionHost] implements [ipnext.ProfileServices] directly. + // We might want to extract it to a separate type in the future. + return h +} + +// CurrentProfileState implements [ipnext.ProfileServices]. +func (h *ExtensionHost) CurrentProfileState() (ipn.LoginProfileView, ipn.PrefsView) { + if h == nil { + return zeroProfile, defaultPrefs + } + h.mu.Lock() + defer h.mu.Unlock() + return h.currentProfile, h.currentPrefs +} + +// CurrentPrefs implements [ipnext.ProfileServices]. +func (h *ExtensionHost) CurrentPrefs() ipn.PrefsView { + _, prefs := h.CurrentProfileState() + return prefs +} + +// SwitchToBestProfileAsync implements [ipnext.ProfileServices]. +func (h *ExtensionHost) SwitchToBestProfileAsync(reason string) { + if h == nil { + return + } + h.enqueueBackendOperation(func(b Backend) { + b.SwitchToBestProfile(reason) + }) +} + +// SendNotifyAsync implements [ipnext.Host]. +func (h *ExtensionHost) SendNotifyAsync(n ipn.Notify) { + if h == nil { + return + } + h.enqueueBackendOperation(func(b Backend) { + b.SendNotify(n) + }) +} + +// NotifyProfileChange invokes registered profile state change callbacks +// and updates the current profile and prefs in the host. +// It strips private keys from the [ipn.Prefs] before preserving +// or passing them to the callbacks. +func (h *ExtensionHost) NotifyProfileChange(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + if !h.active() { + return + } + h.mu.Lock() + // Strip private keys from the prefs before preserving or passing them to the callbacks. + // Extensions should not need them (unless proven otherwise in the future), + // and this is a good way to ensure that they won't accidentally leak them. + prefs = stripKeysFromPrefs(prefs) + // Update the current profile and prefs in the host, + // so we can provide them to the extensions later if they ask. + h.currentPrefs = prefs + h.currentProfile = profile + h.mu.Unlock() + + for _, cb := range h.hooks.ProfileStateChange { + cb(profile, prefs, sameNode) + } +} + +// NotifyProfilePrefsChanged invokes registered profile state change callbacks, +// and updates the current profile and prefs in the host. +// It strips private keys from the [ipn.Prefs] before preserving or using them. +func (h *ExtensionHost) NotifyProfilePrefsChanged(profile ipn.LoginProfileView, oldPrefs, newPrefs ipn.PrefsView) { + if !h.active() { + return + } + h.mu.Lock() + // Strip private keys from the prefs before preserving or passing them to the callbacks. + // Extensions should not need them (unless proven otherwise in the future), + // and this is a good way to ensure that they won't accidentally leak them. + newPrefs = stripKeysFromPrefs(newPrefs) + // Update the current profile and prefs in the host, + // so we can provide them to the extensions later if they ask. + h.currentPrefs = newPrefs + h.currentProfile = profile + // Get the callbacks to be invoked. + h.mu.Unlock() + + for _, cb := range h.hooks.ProfileStateChange { + cb(profile, newPrefs, true) + } +} + +func (h *ExtensionHost) active() bool { + return h != nil && !h.shuttingDown.Load() +} + +// DetermineBackgroundProfile returns a read-only view of the profile +// used when no GUI/CLI client is connected, using background profile +// resolvers registered by extensions. +// +// It returns an invalid view if Tailscale should not run in the background +// and instead disconnect until a GUI/CLI client connects. +// +// As of 2025-02-07, this is only used on Windows. +func (h *ExtensionHost) DetermineBackgroundProfile(profiles ipnext.ProfileStore) ipn.LoginProfileView { + if !h.active() { + return ipn.LoginProfileView{} + } + // TODO(nickkhyl): check if the returned profile is allowed on the device, + // such as when [syspolicy.Tailnet] policy setting requires a specific Tailnet. + // See tailscale/corp#26249. + + // Attempt to resolve the background profile using the registered + // background profile resolvers (e.g., [ipn/desktop.desktopSessionsExt] on Windows). + for _, resolver := range h.hooks.BackgroundProfileResolvers { + if profile := resolver(profiles); profile.Valid() { + return profile + } + } + + // Otherwise, switch to an empty profile and disconnect Tailscale + // until a GUI or CLI client connects. + return ipn.LoginProfileView{} +} + +// NotifyNewControlClient invokes all registered control client callbacks. +// It returns callbacks to be executed when the control client shuts down. +func (h *ExtensionHost) NotifyNewControlClient(cc controlclient.Client, profile ipn.LoginProfileView) (ccShutdownCbs []func()) { + if !h.active() { + return nil + } + for _, cb := range h.hooks.NewControlClient { + if shutdown := cb(cc, profile); shutdown != nil { + ccShutdownCbs = append(ccShutdownCbs, shutdown) + } + } + return ccShutdownCbs +} + +// AuditLogger returns a function that reports an auditable action +// to all registered audit loggers. It fails if any of them returns an error, +// indicating that the action cannot be logged and must not be performed. +// +// It implements [ipnext.Host], but is also used by the [LocalBackend]. +// +// The returned function closes over the current state of the host and extensions, +// which typically includes the current profile and the audit loggers registered by extensions. +// It must not be persisted outside of the auditable action context. +func (h *ExtensionHost) AuditLogger() ipnauth.AuditLogFunc { + if !h.active() { + return func(tailcfg.ClientAuditAction, string) error { return nil } + } + loggers := make([]ipnauth.AuditLogFunc, 0, len(h.hooks.AuditLoggers)) + for _, provider := range h.hooks.AuditLoggers { + loggers = append(loggers, provider()) + } + return func(action tailcfg.ClientAuditAction, details string) error { + // Log auditable actions to the host's log regardless of whether + // the audit loggers are available or not. + h.logf("auditlog: %v: %v", action, details) + + // Invoke all registered audit loggers and collect errors. + // If any of them returns an error, the action is denied. + var errs []error + for _, logger := range loggers { + if err := logger(action, details); err != nil { + errs = append(errs, err) + } + } + return errors.Join(errs...) + } +} + +// Shutdown shuts down the extension host and all initialized extensions. +func (h *ExtensionHost) Shutdown() { + if h == nil { + return + } + // Ensure that the init function has completed before shutting down, + // or prevent any further init calls from happening. + h.initOnce.Do(func() {}) + h.shutdownOnce.Do(h.shutdown) +} + +func (h *ExtensionHost) shutdown() { + h.shuttingDown.Store(true) + // Prevent any queued but not yet started operations from running, + // block new operations from being enqueued, and wait for the + // currently executing operation (if any) to finish. + h.shutdownWorkQueue() + // Invoke shutdown callbacks registered by extensions. + h.shutdownExtensions() +} + +func (h *ExtensionHost) shutdownWorkQueue() { + h.workQueue.Shutdown() + var ctx context.Context + if testenv.InTest() { + // In tests, we'd like to wait indefinitely for the current operation to finish, + // mostly to help avoid flaky tests. Test runners can be pretty slow. + ctx = context.Background() + } else { + // In prod, however, we want to avoid blocking indefinitely. + // The 5s timeout is somewhat arbitrary; LocalBackend operations + // should not take that long. + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + } + // Since callbacks are invoked synchronously, this will also wait + // for in-flight callbacks associated with those operations to finish. + if err := h.workQueue.Wait(ctx); err != nil { + h.logf("work queue shutdown failed: %v", err) + } +} + +func (h *ExtensionHost) shutdownExtensions() { + h.mu.Lock() + extensions := h.activeExtensions + h.mu.Unlock() + + // h.mu must not be held while shutting down extensions. + // Extensions might call back into the host and that would cause + // a deadlock if the h.mu is already held. + // + // Shutdown is called in the reverse order of Init. + for _, ext := range slices.Backward(extensions) { + if err := ext.Shutdown(); err != nil { + // Extension shutdown errors are never fatal, but we log them for debugging purposes. + h.logf("%q: shutdown callback failed: %v", ext.Name(), err) + } + } +} + +// enqueueBackendOperation enqueues a function to perform an operation on the [Backend]. +// If the host has not yet been initialized (e.g., when called from an extension's Init method), +// the operation is deferred until after the host and all extensions have completed initialization. +// It panics if the f is nil. +func (h *ExtensionHost) enqueueBackendOperation(f func(Backend)) { + if h == nil { + return + } + if f == nil { + panic("nil backend operation") + } + h.mu.Lock() // protects h.initialized and h.postInitWorkQueue + defer h.mu.Unlock() + if h.initialized.Load() { + h.doEnqueueBackendOperation(f) + } else { + h.postInitWorkQueue = append(h.postInitWorkQueue, f) + } +} + +// execQueue is an ordered asynchronous queue for executing functions. +// It is implemented by [execqueue.ExecQueue]. The interface is used +// to allow testing with a mock implementation. +type execQueue interface { + Add(func()) + Shutdown() + Wait(context.Context) error +} diff --git a/ipn/ipnlocal/extension_host_test.go b/ipn/ipnlocal/extension_host_test.go new file mode 100644 index 0000000000000..f655c477fcb36 --- /dev/null +++ b/ipn/ipnlocal/extension_host_test.go @@ -0,0 +1,1411 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "cmp" + "context" + "errors" + "net/netip" + "reflect" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + + deepcmp "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + + "tailscale.com/health" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnext" + "tailscale.com/ipn/store/mem" + "tailscale.com/tailcfg" + "tailscale.com/tsd" + "tailscale.com/tstest" + "tailscale.com/tstime" + "tailscale.com/types/key" + "tailscale.com/types/lazy" + "tailscale.com/types/logger" + "tailscale.com/types/persist" + "tailscale.com/util/must" +) + +// defaultCmpOpts are the default options used for deepcmp comparisons in tests. +var defaultCmpOpts = []deepcmp.Option{ + cmpopts.EquateComparable(key.NodePublic{}, netip.Addr{}, netip.Prefix{}), +} + +// TestExtensionInitShutdown tests that [ExtensionHost] correctly initializes +// and shuts down extensions. +func TestExtensionInitShutdown(t *testing.T) { + t.Parallel() + + // As of 2025-04-08, [ipn.Host.Init] and [ipn.Host.Shutdown] do not return errors + // as extension initialization and shutdown errors are not fatal. + // If these methods are updated to return errors, this test should also be updated. + // The conversions below will fail to compile if their signatures change, reminding us to update the test. + _ = (func(*ExtensionHost))((*ExtensionHost).Init) + _ = (func(*ExtensionHost))((*ExtensionHost).Shutdown) + + tests := []struct { + name string + nilHost bool + exts []*testExtension + wantInit []string + wantShutdown []string + skipInit bool + }{ + { + name: "nil-host", + nilHost: true, + exts: []*testExtension{}, + wantInit: []string{}, + wantShutdown: []string{}, + }, + { + name: "empty-extensions", + exts: []*testExtension{}, + wantInit: []string{}, + wantShutdown: []string{}, + }, + { + name: "single-extension", + exts: []*testExtension{{name: "A"}}, + wantInit: []string{"A"}, + wantShutdown: []string{"A"}, + }, + { + name: "multiple-extensions/all-ok", + exts: []*testExtension{{name: "A"}, {name: "B"}, {name: "C"}}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"C", "B", "A"}, + }, + { + name: "multiple-extensions/no-init-no-shutdown", + exts: []*testExtension{{name: "A"}, {name: "B"}, {name: "C"}}, + wantInit: []string{}, + wantShutdown: []string{}, + skipInit: true, + }, + { + name: "multiple-extensions/init-failed/first", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }, { + name: "B", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "C", + InitHook: func(*testExtension) error { return nil }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"C", "B"}, + }, + { + name: "multiple-extensions/init-failed/second", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "B", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }, { + name: "C", + InitHook: func(*testExtension) error { return nil }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"C", "A"}, + }, + { + name: "multiple-extensions/init-failed/third", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "B", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "C", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"B", "A"}, + }, + { + name: "multiple-extensions/init-failed/all", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }, { + name: "B", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }, { + name: "C", + InitHook: func(*testExtension) error { return errors.New("init failed") }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{}, + }, + { + name: "multiple-extensions/init-skipped", + exts: []*testExtension{{ + name: "A", + InitHook: func(*testExtension) error { return nil }, + }, { + name: "B", + InitHook: func(*testExtension) error { return ipnext.SkipExtension }, + }, { + name: "C", + InitHook: func(*testExtension) error { return nil }, + }}, + wantInit: []string{"A", "B", "C"}, + wantShutdown: []string{"C", "A"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Configure all extensions to append their names + // to the gotInit and gotShutdown slices + // during initialization and shutdown, + // so we can check that they are called in the right order + // and that shutdown is not unless init succeeded. + var gotInit, gotShutdown []string + for _, ext := range tt.exts { + oldInitHook := ext.InitHook + ext.InitHook = func(e *testExtension) error { + gotInit = append(gotInit, e.name) + if oldInitHook == nil { + return nil + } + return oldInitHook(e) + } + ext.ShutdownHook = func(e *testExtension) error { + gotShutdown = append(gotShutdown, e.name) + return nil + } + } + + var h *ExtensionHost + if !tt.nilHost { + h = newExtensionHostForTest(t, &testBackend{}, false, tt.exts...) + } + + if !tt.skipInit { + h.Init() + } + + // Check that the extensions were initialized in the right order. + if !slices.Equal(gotInit, tt.wantInit) { + t.Errorf("Init extensions: got %v; want %v", gotInit, tt.wantInit) + } + + // Calling Init again on the host should be a no-op. + // The [testExtension.Init] method fails the test if called more than once, + // regardless of which test is running, so we don't need to check it here. + // Similarly, calling Shutdown again on the host should be a no-op as well. + // It is verified by the [testExtension.Shutdown] method itself. + if !tt.skipInit { + h.Init() + } + + // Extensions should not be shut down before the host is shut down, + // even if they are not initialized successfully. + for _, ext := range tt.exts { + if gotShutdown := ext.ShutdownCalled(); gotShutdown { + t.Errorf("%q: Extension shutdown called before host shutdown", ext.name) + } + } + + h.Shutdown() + // Check that the extensions were shut down in the right order, + // and that they were not shut down if they were not initialized successfully. + if !slices.Equal(gotShutdown, tt.wantShutdown) { + t.Errorf("Shutdown extensions: got %v; want %v", gotShutdown, tt.wantShutdown) + } + + }) + } +} + +// TestNewExtensionHost tests that [NewExtensionHost] correctly creates +// an [ExtensionHost], instantiates the extensions and handles errors +// if an extension cannot be created. +func TestNewExtensionHost(t *testing.T) { + t.Parallel() + tests := []struct { + name string + defs []*ipnext.Definition + wantErr bool + wantExts []string + }{ + { + name: "no-exts", + defs: []*ipnext.Definition{}, + wantErr: false, + wantExts: []string{}, + }, + { + name: "exts-ok", + defs: []*ipnext.Definition{ + ipnext.DefinitionForTest(&testExtension{name: "A"}), + ipnext.DefinitionForTest(&testExtension{name: "B"}), + ipnext.DefinitionForTest(&testExtension{name: "C"}), + }, + wantErr: false, + wantExts: []string{"A", "B", "C"}, + }, + { + name: "exts-skipped", + defs: []*ipnext.Definition{ + ipnext.DefinitionForTest(&testExtension{name: "A"}), + ipnext.DefinitionWithErrForTest("B", ipnext.SkipExtension), + ipnext.DefinitionForTest(&testExtension{name: "C"}), + }, + wantErr: false, // extension B is skipped, that's ok + wantExts: []string{"A", "C"}, + }, + { + name: "exts-fail", + defs: []*ipnext.Definition{ + ipnext.DefinitionForTest(&testExtension{name: "A"}), + ipnext.DefinitionWithErrForTest("B", errors.New("failed creating Ext-2")), + ipnext.DefinitionForTest(&testExtension{name: "C"}), + }, + wantErr: true, // extension B failed to create, that's not ok + wantExts: []string{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + logf := tstest.WhileTestRunningLogger(t) + h, err := NewExtensionHostForTest(logf, &testBackend{}, tt.defs...) + if gotErr := err != nil; gotErr != tt.wantErr { + t.Errorf("NewExtensionHost: gotErr %v(%v); wantErr %v", gotErr, err, tt.wantErr) + } + if err != nil { + return + } + + var gotExts []string + for _, ext := range h.allExtensions { + gotExts = append(gotExts, ext.Name()) + } + + if !slices.Equal(gotExts, tt.wantExts) { + t.Errorf("Shutdown extensions: got %v; want %v", gotExts, tt.wantExts) + } + }) + } +} + +// TestFindMatchingExtension tests that [ExtensionHost.FindMatchingExtension] correctly +// finds extensions by their type or interface. +func TestFindMatchingExtension(t *testing.T) { + t.Parallel() + + // Define test extension types and a couple of interfaces + type ( + extensionA struct { + testExtension + } + extensionB struct { + testExtension + } + extensionC struct { + testExtension + } + supportedIface interface { + Name() string + } + unsupportedIface interface { + Unsupported() + } + ) + + // Register extensions A and B, but not C. + extA := &extensionA{testExtension: testExtension{name: "A"}} + extB := &extensionB{testExtension: testExtension{name: "B"}} + h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, true, extA, extB) + + var gotA *extensionA + if !h.FindMatchingExtension(&gotA) { + t.Errorf("LookupExtension(%T): not found", gotA) + } else if gotA != extA { + t.Errorf("LookupExtension(%T): got %v; want %v", gotA, gotA, extA) + } + + var gotB *extensionB + if !h.FindMatchingExtension(&gotB) { + t.Errorf("LookupExtension(%T): extension B not found", gotB) + } else if gotB != extB { + t.Errorf("LookupExtension(%T): got %v; want %v", gotB, gotB, extB) + } + + var gotC *extensionC + if h.FindMatchingExtension(&gotC) { + t.Errorf("LookupExtension(%T): found, but it should not exist", gotC) + } + + // All extensions implement the supportedIface interface, + // but LookupExtension should only return the first one found, + // which is extA. + var gotSupportedIface supportedIface + if !h.FindMatchingExtension(&gotSupportedIface) { + t.Errorf("LookupExtension(%T): not found", gotSupportedIface) + } else if gotName, wantName := gotSupportedIface.Name(), extA.Name(); gotName != wantName { + t.Errorf("LookupExtension(%T): name: got %v; want %v", gotSupportedIface, gotName, wantName) + } else if gotSupportedIface != extA { + t.Errorf("LookupExtension(%T): got %v; want %v", gotSupportedIface, gotSupportedIface, extA) + } + + var gotUnsupportedIface unsupportedIface + if h.FindMatchingExtension(&gotUnsupportedIface) { + t.Errorf("LookupExtension(%T): found, but it should not exist", gotUnsupportedIface) + } +} + +// TestFindExtensionByName tests that [ExtensionHost.FindExtensionByName] correctly +// finds extensions by their name. +func TestFindExtensionByName(t *testing.T) { + // Register extensions A and B, but not C. + extA := &testExtension{name: "A"} + extB := &testExtension{name: "B"} + h := newExtensionHostForTest(t, &testBackend{}, true, extA, extB) + + gotA, ok := h.FindExtensionByName(extA.Name()).(*testExtension) + if !ok { + t.Errorf("FindExtensionByName(%q): not found", extA.Name()) + } else if gotA != extA { + t.Errorf(`FindExtensionByName(%q): got %v; want %v`, extA.Name(), gotA, extA) + } + + gotB, ok := h.FindExtensionByName(extB.Name()).(*testExtension) + if !ok { + t.Errorf("FindExtensionByName(%q): not found", extB.Name()) + } else if gotB != extB { + t.Errorf(`FindExtensionByName(%q): got %v; want %v`, extB.Name(), gotB, extB) + } + + gotC, ok := h.FindExtensionByName("C").(*testExtension) + if ok { + t.Errorf(`FindExtensionByName("C"): found, but it should not exist: %v`, gotC) + } +} + +// TestExtensionHostEnqueueBackendOperation verifies that [ExtensionHost] enqueues +// backend operations and executes them asynchronously in the order they were received. +// It also checks that operations requested before the host and all extensions are initialized +// are not executed immediately but rather after the host and extensions are initialized. +func TestExtensionHostEnqueueBackendOperation(t *testing.T) { + t.Parallel() + tests := []struct { + name string + preInitCalls []string // before host init + extInitCalls []string // from [Extension.Init]; "" means no call + wantInitCalls []string // what we expect to be called after host init + postInitCalls []string // after host init + }{ + { + name: "no-calls", + preInitCalls: []string{}, + extInitCalls: []string{}, + wantInitCalls: []string{}, + postInitCalls: []string{}, + }, + { + name: "pre-init-calls", + preInitCalls: []string{"pre-init-1", "pre-init-2"}, + extInitCalls: []string{}, + wantInitCalls: []string{"pre-init-1", "pre-init-2"}, + postInitCalls: []string{}, + }, + { + name: "init-calls", + preInitCalls: []string{}, + extInitCalls: []string{"init-1", "init-2"}, + wantInitCalls: []string{"init-1", "init-2"}, + postInitCalls: []string{}, + }, + { + name: "post-init-calls", + preInitCalls: []string{}, + extInitCalls: []string{}, + wantInitCalls: []string{}, + postInitCalls: []string{"post-init-1", "post-init-2"}, + }, + { + name: "mixed-calls", + preInitCalls: []string{"pre-init-1", "pre-init-2"}, + extInitCalls: []string{"init-1", "", "init-2"}, + wantInitCalls: []string{"pre-init-1", "pre-init-2", "init-1", "init-2"}, + postInitCalls: []string{"post-init-1", "post-init-2"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var gotCalls []string + var h *ExtensionHost + b := &testBackend{ + switchToBestProfileHook: func(reason string) { + gotCalls = append(gotCalls, reason) + }, + } + + exts := make([]*testExtension, len(tt.extInitCalls)) + for i, reason := range tt.extInitCalls { + exts[i] = &testExtension{} + if reason != "" { + exts[i].InitHook = func(e *testExtension) error { + e.host.Profiles().SwitchToBestProfileAsync(reason) + return nil + } + } + } + + h = newExtensionHostForTest(t, b, false, exts...) + wq := h.SetWorkQueueForTest(t) // use a test queue instead of [execqueue.ExecQueue]. + + // Issue some pre-init calls. They should be deferred and not + // added to the queue until the host is initialized. + for _, call := range tt.preInitCalls { + h.Profiles().SwitchToBestProfileAsync(call) + } + + // The queue should be empty before the host is initialized. + wq.Drain() + if len(gotCalls) != 0 { + t.Errorf("Pre-init calls: got %v; want (none)", gotCalls) + } + gotCalls = nil + + // Initialize the host and all extensions. + // The extensions will make their calls during initialization. + h.Init() + + // Calls made before or during initialization should now be enqueued and running. + wq.Drain() + if diff := deepcmp.Diff(tt.wantInitCalls, gotCalls, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("Init calls: (+got -want): %v", diff) + } + gotCalls = nil + + // Let's make some more calls, as if extensions were making them in a response + // to external events. + for _, call := range tt.postInitCalls { + h.Profiles().SwitchToBestProfileAsync(call) + } + + // Any calls made after initialization should be enqueued and running. + wq.Drain() + if diff := deepcmp.Diff(tt.postInitCalls, gotCalls, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("Init calls: (+got -want): %v", diff) + } + gotCalls = nil + }) + } +} + +// TestExtensionHostProfileStateChangeCallback verifies that [ExtensionHost] correctly handles the registration, +// invocation, and unregistration of profile state change callbacks. This includes callbacks triggered by profile changes +// and by changes to the profile's [ipn.Prefs]. It also checks that the callbacks are called with the correct arguments +// and that any private keys are stripped from [ipn.Prefs] before being passed to the callback. +func TestExtensionHostProfileStateChangeCallback(t *testing.T) { + t.Parallel() + + type stateChange struct { + Profile *ipn.LoginProfile + Prefs *ipn.Prefs + SameNode bool + } + type prefsChange struct { + Profile *ipn.LoginProfile + Old, New *ipn.Prefs + } + + // newStateChange creates a new [stateChange] with deep copies of the profile and prefs. + newStateChange := func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) stateChange { + return stateChange{ + Profile: profile.AsStruct(), + Prefs: prefs.AsStruct(), + SameNode: sameNode, + } + } + // makeStateChangeAppender returns a callback that appends profile state changes to the extension's state. + makeStateChangeAppender := func(e *testExtension) ipnext.ProfileStateChangeCallback { + return func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + UpdateExtState(e, "changes", func(changes []stateChange) []stateChange { + return append(changes, newStateChange(profile, prefs, sameNode)) + }) + } + } + // getStateChanges returns the profile state changes stored in the extension's state. + getStateChanges := func(e *testExtension) []stateChange { + changes, _ := GetExtStateOk[[]stateChange](e, "changes") + return changes + } + + tests := []struct { + name string + ext *testExtension + stateCalls []stateChange + prefsCalls []prefsChange + wantChanges []stateChange + }{ + { + // Register the callback for the lifetime of the extension. + name: "Register/Lifetime", + ext: &testExtension{}, + stateCalls: []stateChange{ + {Profile: &ipn.LoginProfile{ID: "profile-1"}}, + {Profile: &ipn.LoginProfile{ID: "profile-2"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true}, + }, + wantChanges: []stateChange{ // all calls are received by the callback + {Profile: &ipn.LoginProfile{ID: "profile-1"}}, + {Profile: &ipn.LoginProfile{ID: "profile-2"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}}, + {Profile: &ipn.LoginProfile{ID: "profile-3"}, SameNode: true}, + }, + }, + { + // Ensure that ipn.Prefs are passed to the callback. + name: "CheckPrefs", + ext: &testExtension{}, + stateCalls: []stateChange{{ + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + WantRunning: true, + LoggedOut: false, + AdvertiseRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/24"), + }, + }, + }}, + wantChanges: []stateChange{{ + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + WantRunning: true, + LoggedOut: false, + AdvertiseRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("192.168.2.0/24"), + }, + }, + }}, + }, + { + // Ensure that private keys are stripped from persist.Persist shared with extensions. + name: "StripPrivateKeys", + ext: &testExtension{}, + stateCalls: []stateChange{{ + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + Persist: &persist.Persist{ + NodeID: "12345", + PrivateNodeKey: key.NewNode(), + OldPrivateNodeKey: key.NewNode(), + NetworkLockKey: key.NewNLPrivate(), + UserProfile: tailcfg.UserProfile{ + ID: 12345, + LoginName: "test@example.com", + DisplayName: "Test User", + ProfilePicURL: "https://example.com/profile.png", + }, + }, + }, + }}, + wantChanges: []stateChange{{ + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + Persist: &persist.Persist{ + NodeID: "12345", + PrivateNodeKey: key.NodePrivate{}, // stripped + OldPrivateNodeKey: key.NodePrivate{}, // stripped + NetworkLockKey: key.NLPrivate{}, // stripped + UserProfile: tailcfg.UserProfile{ + ID: 12345, + LoginName: "test@example.com", + DisplayName: "Test User", + ProfilePicURL: "https://example.com/profile.png", + }, + }, + }, + }}, + }, + { + // Ensure that profile state callbacks are also invoked when prefs (rather than profile) change. + name: "PrefsChange", + ext: &testExtension{}, + prefsCalls: []prefsChange{ + { + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Old: &ipn.Prefs{WantRunning: false, LoggedOut: true}, + New: &ipn.Prefs{WantRunning: true, LoggedOut: false}, + }, + { + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Old: &ipn.Prefs{AdvertiseRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")}}, + New: &ipn.Prefs{AdvertiseRoutes: []netip.Prefix{netip.MustParsePrefix("10.10.10.0/24")}}, + }, + }, + wantChanges: []stateChange{ + { + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{WantRunning: true, LoggedOut: false}, + SameNode: true, // must be true for prefs changes + }, + { + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{AdvertiseRoutes: []netip.Prefix{netip.MustParsePrefix("10.10.10.0/24")}}, + SameNode: true, // must be true for prefs changes + }, + }, + }, + { + // Ensure that private keys are stripped from prefs when state change callback + // is invoked by prefs change. + name: "PrefsChange/StripPrivateKeys", + ext: &testExtension{}, + prefsCalls: []prefsChange{ + { + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Old: &ipn.Prefs{ + WantRunning: false, + LoggedOut: true, + Persist: &persist.Persist{ + NodeID: "12345", + PrivateNodeKey: key.NewNode(), + OldPrivateNodeKey: key.NewNode(), + NetworkLockKey: key.NewNLPrivate(), + UserProfile: tailcfg.UserProfile{ + ID: 12345, + LoginName: "test@example.com", + DisplayName: "Test User", + ProfilePicURL: "https://example.com/profile.png", + }, + }, + }, + New: &ipn.Prefs{ + WantRunning: true, + LoggedOut: false, + Persist: &persist.Persist{ + NodeID: "12345", + PrivateNodeKey: key.NewNode(), + OldPrivateNodeKey: key.NewNode(), + NetworkLockKey: key.NewNLPrivate(), + UserProfile: tailcfg.UserProfile{ + ID: 12345, + LoginName: "test@example.com", + DisplayName: "Test User", + ProfilePicURL: "https://example.com/profile.png", + }, + }, + }, + }, + }, + wantChanges: []stateChange{ + { + Profile: &ipn.LoginProfile{ID: "profile-1"}, + Prefs: &ipn.Prefs{ + WantRunning: true, + LoggedOut: false, + Persist: &persist.Persist{ + NodeID: "12345", + PrivateNodeKey: key.NodePrivate{}, // stripped + OldPrivateNodeKey: key.NodePrivate{}, // stripped + NetworkLockKey: key.NLPrivate{}, // stripped + UserProfile: tailcfg.UserProfile{ + ID: 12345, + LoginName: "test@example.com", + DisplayName: "Test User", + ProfilePicURL: "https://example.com/profile.png", + }, + }, + }, + SameNode: true, // must be true for prefs changes + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Use the default InitHook if not provided by the test. + if tt.ext.InitHook == nil { + tt.ext.InitHook = func(e *testExtension) error { + // Create and register the callback on init. + handler := makeStateChangeAppender(e) + e.host.Hooks().ProfileStateChange.Add(handler) + return nil + } + } + + h := newExtensionHostForTest(t, &testBackend{}, true, tt.ext) + for _, call := range tt.stateCalls { + h.NotifyProfileChange(call.Profile.View(), call.Prefs.View(), call.SameNode) + } + for _, call := range tt.prefsCalls { + h.NotifyProfilePrefsChanged(call.Profile.View(), call.Old.View(), call.New.View()) + } + if diff := deepcmp.Diff(tt.wantChanges, getStateChanges(tt.ext), defaultCmpOpts...); diff != "" { + t.Errorf("StateChange callbacks: (-want +got): %v", diff) + } + }) + } +} + +// TestCurrentProfileState tests that the current profile and prefs are correctly +// initialized and updated when the host is notified of changes. +func TestCurrentProfileState(t *testing.T) { + h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, false) + + // The initial profile and prefs should be valid and set to the default values. + gotProfile, gotPrefs := h.Profiles().CurrentProfileState() + checkViewsEqual(t, "Initial profile (from state)", gotProfile, zeroProfile) + checkViewsEqual(t, "Initial prefs (from state)", gotPrefs, defaultPrefs) + gotPrefs = h.Profiles().CurrentPrefs() // same when we only ask for prefs + checkViewsEqual(t, "Initial prefs (direct)", gotPrefs, defaultPrefs) + + // Create a new profile and prefs, and notify the host of the change. + profile := &ipn.LoginProfile{ID: "profile-A"} + prefsV1 := &ipn.Prefs{ProfileName: "Prefs V1", WantRunning: true} + h.NotifyProfileChange(profile.View(), prefsV1.View(), false) + // The current profile and prefs should be updated. + gotProfile, gotPrefs = h.Profiles().CurrentProfileState() + checkViewsEqual(t, "Changed profile (from state)", gotProfile, profile.View()) + checkViewsEqual(t, "New prefs (from state)", gotPrefs, prefsV1.View()) + gotPrefs = h.Profiles().CurrentPrefs() + checkViewsEqual(t, "New prefs (direct)", gotPrefs, prefsV1.View()) + + // Notify the host of a change to the profile's prefs. + prefsV2 := &ipn.Prefs{ProfileName: "Prefs V2", WantRunning: false} + h.NotifyProfilePrefsChanged(profile.View(), prefsV1.View(), prefsV2.View()) + // The current prefs should be updated. + gotProfile, gotPrefs = h.Profiles().CurrentProfileState() + checkViewsEqual(t, "Unchanged profile (from state)", gotProfile, profile.View()) + checkViewsEqual(t, "Changed (from state)", gotPrefs, prefsV2.View()) + gotPrefs = h.Profiles().CurrentPrefs() + checkViewsEqual(t, "Changed prefs (direct)", gotPrefs, prefsV2.View()) +} + +// TestBackgroundProfileResolver tests that the background profile resolvers +// are correctly registered, unregistered and invoked by the [ExtensionHost]. +func TestBackgroundProfileResolver(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + profiles []ipn.LoginProfile // the first one is the current profile + resolvers []ipnext.ProfileResolver + wantProfile *ipn.LoginProfile + }{ + { + name: "No-Profiles/No-Resolvers", + profiles: nil, + resolvers: nil, + wantProfile: nil, + }, + { + // TODO(nickkhyl): update this test as we change "background profile resolvers" + // to just "profile resolvers". The wantProfile should be the current profile by default. + name: "Has-Profiles/No-Resolvers", + profiles: []ipn.LoginProfile{{ID: "profile-1"}}, + resolvers: nil, + wantProfile: nil, + }, + { + name: "Has-Profiles/Single-Resolver", + profiles: []ipn.LoginProfile{{ID: "profile-1"}}, + resolvers: []ipnext.ProfileResolver{ + func(ps ipnext.ProfileStore) ipn.LoginProfileView { + return ps.CurrentProfile() + }, + }, + wantProfile: &ipn.LoginProfile{ID: "profile-1"}, + }, + // TODO(nickkhyl): add more tests for multiple resolvers and different profiles + // once we change "background profile resolvers" to just "profile resolvers" + // and add proper conflict resolution logic. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + // Create a new profile manager and add the profiles to it. + // We expose the profile manager to the extensions via the read-only [ipnext.ProfileStore] interface. + pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) + for i, p := range tt.profiles { + // Generate a unique ID and key for each profile, + // unless the profile already has them set + // or is an empty, unnamed profile. + if p.Name != "" { + if p.ID == "" { + p.ID = ipn.ProfileID("profile-" + strconv.Itoa(i)) + } + if p.Key == "" { + p.Key = "key-" + ipn.StateKey(p.ID) + } + } + pv := p.View() + pm.knownProfiles[p.ID] = pv + if i == 0 { + // Set the first profile as the current one. + // A profileManager starts with an empty profile, + // so it's okay if the list of profiles is empty. + pm.SwitchToProfile(pv) + } + } + + h := newExtensionHostForTest[ipnext.Extension](t, &testBackend{}, false) + + // Register the resolvers with the host. + // This is typically done by the extensions themselves, + // but we do it here for testing purposes. + for _, r := range tt.resolvers { + h.Hooks().BackgroundProfileResolvers.Add(r) + } + h.Init() + + // Call the resolver to get the profile. + gotProfile := h.DetermineBackgroundProfile(pm) + if !gotProfile.Equals(tt.wantProfile.View()) { + t.Errorf("Resolved profile: got %v; want %v", gotProfile, tt.wantProfile) + } + }) + } +} + +// TestAuditLogProviders tests that the [ExtensionHost] correctly handles +// the registration and invocation of audit log providers. It verifies that +// the audit loggers are called with the correct actions and details, +// and that any errors returned by the providers are properly propagated. +func TestAuditLogProviders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + auditLoggers []ipnauth.AuditLogFunc // each represents an extension + actions []tailcfg.ClientAuditAction + wantErr bool + }{ + { + name: "No-Providers", + auditLoggers: nil, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: false, + }, + { + name: "Single-Provider/Ok", + auditLoggers: []ipnauth.AuditLogFunc{ + func(tailcfg.ClientAuditAction, string) error { return nil }, + }, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: false, + }, + { + name: "Single-Provider/Err", + auditLoggers: []ipnauth.AuditLogFunc{ + func(tailcfg.ClientAuditAction, string) error { + return errors.New("failed to log") + }, + }, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: true, + }, + { + name: "Many-Providers/Ok", + auditLoggers: []ipnauth.AuditLogFunc{ + func(tailcfg.ClientAuditAction, string) error { return nil }, + func(tailcfg.ClientAuditAction, string) error { return nil }, + }, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: false, + }, + { + name: "Many-Providers/Err", + auditLoggers: []ipnauth.AuditLogFunc{ + func(tailcfg.ClientAuditAction, string) error { + return errors.New("failed to log") + }, + func(tailcfg.ClientAuditAction, string) error { + return nil // all good + }, + func(tailcfg.ClientAuditAction, string) error { + return errors.New("also failed to log") + }, + }, + actions: []tailcfg.ClientAuditAction{"TestAction-1", "TestAction-2"}, + wantErr: true, // some providers failed to log, so that's an error + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create extensions that register the audit log providers. + // Each extension/provider will append auditable actions to its state, + // then call the test's auditLogger function. + var exts []*testExtension + for _, auditLogger := range tt.auditLoggers { + ext := &testExtension{} + provider := func() ipnauth.AuditLogFunc { + return func(action tailcfg.ClientAuditAction, details string) error { + UpdateExtState(ext, "actions", func(actions []tailcfg.ClientAuditAction) []tailcfg.ClientAuditAction { + return append(actions, action) + }) + return auditLogger(action, details) + } + } + ext.InitHook = func(e *testExtension) error { + e.host.Hooks().AuditLoggers.Add(provider) + return nil + } + exts = append(exts, ext) + } + + // Initialize the host and the extensions. + h := newExtensionHostForTest(t, &testBackend{}, true, exts...) + + // Use [ExtensionHost.AuditLogger] to log actions. + for _, action := range tt.actions { + err := h.AuditLogger()(action, "Test details") + if gotErr := err != nil; gotErr != tt.wantErr { + t.Errorf("AuditLogger: gotErr %v (%v); wantErr %v", gotErr, err, tt.wantErr) + } + } + + // Check that the actions were logged correctly by each provider. + for _, ext := range exts { + gotActions := GetExtState[[]tailcfg.ClientAuditAction](ext, "actions") + if !slices.Equal(gotActions, tt.actions) { + t.Errorf("Actions: got %v; want %v", gotActions, tt.actions) + } + } + }) + } +} + +// TestNilExtensionHostMethodCall tests that calling exported methods +// on a nil [ExtensionHost] does not panic. We should treat it as a valid +// value since it's used in various tests that instantiate [LocalBackend] +// manually without calling [NewLocalBackend]. It also verifies that if +// a method returns a single func value (e.g., a cleanup function), +// it should not be nil. This is a basic sanity check to ensure that +// typical method calls on a nil receiver work as expected. +// It does not replace the need for more thorough testing of specific methods. +func TestNilExtensionHostMethodCall(t *testing.T) { + t.Parallel() + + var h *ExtensionHost + typ := reflect.TypeOf(h) + for i := range typ.NumMethod() { + m := typ.Method(i) + if strings.HasSuffix(m.Name, "ForTest") { + // Skip methods that are only for testing. + continue + } + + t.Run(m.Name, func(t *testing.T) { + t.Parallel() + // Calling the method on the nil receiver should not panic. + ret := checkMethodCallWithZeroArgs(t, m, h) + if len(ret) == 1 && ret[0].Kind() == reflect.Func { + // If the method returns a single func, such as a cleanup function, + // it should not be nil. + fn := ret[0] + if fn.IsNil() { + t.Fatalf("(%T).%s returned a nil func", h, m.Name) + } + // We expect it to be a no-op and calling it should not panic. + args := makeZeroArgsFor(fn) + func() { + defer func() { + if e := recover(); e != nil { + t.Fatalf("panic calling the func returned by (%T).%s: %v", e, m.Name, e) + } + }() + fn.Call(args) + }() + } + }) + } +} + +// extBeforeStartExtension is a test extension used by TestGetExtBeforeStart. +// It is registered with the [ipnext.RegisterExtension]. +type extBeforeStartExtension struct{} + +func init() { + ipnext.RegisterExtension("ext-before-start", mkExtBeforeStartExtension) +} + +func mkExtBeforeStartExtension(logger.Logf, ipnext.SafeBackend) (ipnext.Extension, error) { + return extBeforeStartExtension{}, nil +} + +func (extBeforeStartExtension) Name() string { return "ext-before-start" } +func (extBeforeStartExtension) Init(ipnext.Host) error { + return nil +} +func (extBeforeStartExtension) Shutdown() error { + return nil +} + +// TestGetExtBeforeStart verifies that an extension registered via +// RegisterExtension can be retrieved with GetExt before the host is started +// (via LocalBackend.Start) +func TestGetExtBeforeStart(t *testing.T) { + lb := newTestBackend(t) + // Now call GetExt without calling Start on the LocalBackend. + _, ok := GetExt[extBeforeStartExtension](lb) + if !ok { + t.Fatal("didn't find extension") + } +} + +// checkMethodCallWithZeroArgs calls the method m on the receiver r +// with zero values for all its arguments, except the receiver itself. +// It returns the result of the method call, or fails the test if the call panics. +func checkMethodCallWithZeroArgs[T any](t *testing.T, m reflect.Method, r T) []reflect.Value { + t.Helper() + args := makeZeroArgsFor(m.Func) + // The first arg is the receiver. + args[0] = reflect.ValueOf(r) + // Calling the method should not panic. + defer func() { + if e := recover(); e != nil { + t.Fatalf("panic calling (%T).%s: %v", r, m.Name, e) + } + }() + return m.Func.Call(args) +} + +func makeZeroArgsFor(fn reflect.Value) []reflect.Value { + args := make([]reflect.Value, fn.Type().NumIn()) + for i := range args { + args[i] = reflect.Zero(fn.Type().In(i)) + } + return args +} + +// newExtensionHostForTest creates an [ExtensionHost] with the given backend and extensions. +// It associates each extension that either is or embeds a [testExtension] with the test +// and assigns a name if one isn’t already set. +// +// If the host cannot be created, it fails the test. +// +// The host is initialized if the initialize parameter is true. +// It is shut down automatically when the test ends. +func newExtensionHostForTest[T ipnext.Extension](t *testing.T, b Backend, initialize bool, exts ...T) *ExtensionHost { + t.Helper() + + // testExtensionIface is a subset of the methods implemented by [testExtension] that are used here. + // We use testExtensionIface in type assertions instead of using the [testExtension] type directly, + // which supports scenarios where an extension type embeds a [testExtension]. + type testExtensionIface interface { + Name() string + setName(string) + setT(*testing.T) + checkShutdown() + } + + logf := tstest.WhileTestRunningLogger(t) + defs := make([]*ipnext.Definition, len(exts)) + for i, ext := range exts { + if ext, ok := any(ext).(testExtensionIface); ok { + ext.setName(cmp.Or(ext.Name(), "Ext-"+strconv.Itoa(i))) + ext.setT(t) + } + defs[i] = ipnext.DefinitionForTest(ext) + } + h, err := NewExtensionHostForTest(logf, b, defs...) + if err != nil { + t.Fatalf("NewExtensionHost: %v", err) + } + // Replace doEnqueueBackendOperation with the one that's marked as a helper, + // so that we'll have better output if [testExecQueue.Add] fails a test. + h.doEnqueueBackendOperation = func(f func(Backend)) { + t.Helper() + h.workQueue.Add(func() { f(b) }) + } + for _, ext := range exts { + if ext, ok := any(ext).(testExtensionIface); ok { + t.Cleanup(ext.checkShutdown) + } + } + t.Cleanup(h.Shutdown) + if initialize { + h.Init() + } + return h +} + +// testExtension is an [ipnext.Extension] that: +// - Calls the provided init and shutdown callbacks +// when [Init] and [Shutdown] are called. +// - Ensures that [Init] and [Shutdown] are called at most once, +// that [Shutdown] is called after [Init], but is not called if [Init] fails +// and is called before the test ends if [Init] succeeds. +// +// Typically, [testExtension]s are created and passed to [newExtensionHostForTest] +// when creating an [ExtensionHost] for testing. +type testExtension struct { + t *testing.T // test that created the extension + name string // name of the extension, used for logging + + host ipnext.Host // or nil if not initialized + + // InitHook and ShutdownHook are optional hooks that can be set by tests. + InitHook, ShutdownHook func(*testExtension) error + + // initCnt, initOkCnt and shutdownCnt are used to verify that Init and Shutdown + // are called at most once and in the correct order. + initCnt, initOkCnt, shutdownCnt atomic.Int32 + + // mu protects the following fields. + mu sync.Mutex + // state is the optional state used by tests. + // It can be accessed by tests using [setTestExtensionState], + // [getTestExtensionStateOk] and [getTestExtensionState]. + state map[string]any +} + +var _ ipnext.Extension = (*testExtension)(nil) + +// PermitDoubleRegister is a sentinel method whose existence tells the +// ExtensionHost to permit it to be registered multiple times. +func (*testExtension) PermitDoubleRegister() {} + +func (e *testExtension) setT(t *testing.T) { + e.t = t +} + +func (e *testExtension) setName(name string) { + e.name = name +} + +// Name implements [ipnext.Extension]. +func (e *testExtension) Name() string { + return e.name +} + +// Init implements [ipnext.Extension]. +func (e *testExtension) Init(host ipnext.Host) (err error) { + e.t.Helper() + e.host = host + if e.initCnt.Add(1) == 1 { + e.mu.Lock() + e.state = make(map[string]any) + e.mu.Unlock() + } else { + e.t.Errorf("%q: Init called more than once", e.name) + } + if e.InitHook != nil { + err = e.InitHook(e) + } + if err == nil { + e.initOkCnt.Add(1) + } + return err // may be nil or non-nil +} + +// InitCalled reports whether the Init method was called on the receiver. +func (e *testExtension) InitCalled() bool { + return e.initCnt.Load() != 0 +} + +// Shutdown implements [ipnext.Extension]. +func (e *testExtension) Shutdown() (err error) { + e.t.Helper() + e.mu.Lock() + e.mu.Unlock() + if e.ShutdownHook != nil { + err = e.ShutdownHook(e) + } + if e.shutdownCnt.Add(1) != 1 { + e.t.Errorf("%q: Shutdown called more than once", e.name) + } + if e.initCnt.Load() == 0 { + e.t.Errorf("%q: Shutdown called without Init", e.name) + } else if e.initOkCnt.Load() == 0 { + e.t.Errorf("%q: Shutdown called despite failed Init", e.name) + } + e.host = nil + return err // may be nil or non-nil +} + +func (e *testExtension) checkShutdown() { + e.t.Helper() + if e.initOkCnt.Load() != 0 && e.shutdownCnt.Load() == 0 { + e.t.Errorf("%q: Shutdown has not been called before test end", e.name) + } +} + +// ShutdownCalled reports whether the Shutdown method was called on the receiver. +func (e *testExtension) ShutdownCalled() bool { + return e.shutdownCnt.Load() != 0 +} + +// SetExtState sets a keyed state on [testExtension] to the given value. +// Tests use it to propagate test-specific state throughout the extension lifecycle +// (e.g., between [testExtension.Init], [testExtension.Shutdown], and registered callbacks) +func SetExtState[T any](e *testExtension, key string, value T) { + e.mu.Lock() + defer e.mu.Unlock() + e.state[key] = value +} + +// UpdateExtState updates a keyed state of the extension using the provided update function. +func UpdateExtState[T any](e *testExtension, key string, update func(T) T) { + e.mu.Lock() + defer e.mu.Unlock() + old, _ := e.state[key].(T) + new := update(old) + e.state[key] = new +} + +// GetExtState returns the value of the keyed state of the extension. +// It returns a zero value of T if the state is not set or is of a different type. +func GetExtState[T any](e *testExtension, key string) T { + v, _ := GetExtStateOk[T](e, key) + return v +} + +// GetExtStateOk is like [getExtState], but also reports whether the state +// with the given key exists and is of the expected type. +func GetExtStateOk[T any](e *testExtension, key string) (_ T, ok bool) { + e.mu.Lock() + defer e.mu.Unlock() + v, ok := e.state[key].(T) + return v, ok +} + +// testExecQueue is a test implementation of [execQueue] +// that defers execution of the enqueued funcs until +// [testExecQueue.Drain] is called, and fails the test if +// if [execQueue.Add] is called before the host is initialized. +// +// It is typically used by calling [ExtensionHost.SetWorkQueueForTest]. +type testExecQueue struct { + t *testing.T // test that created the queue + h *ExtensionHost // host to own the queue + + mu sync.Mutex + queue []func() +} + +var _ execQueue = (*testExecQueue)(nil) + +// SetWorkQueueForTest is a helper function that creates a new [testExecQueue] +// and sets it as the work queue for the specified [ExtensionHost], +// returning the new queue. +// +// It fails the test if the host is already initialized. +func (h *ExtensionHost) SetWorkQueueForTest(t *testing.T) *testExecQueue { + t.Helper() + if h.initialized.Load() { + t.Fatalf("UseTestWorkQueue: host is already initialized") + return nil + } + q := &testExecQueue{t: t, h: h} + h.workQueue = q + return q +} + +// Add implements [execQueue]. +func (q *testExecQueue) Add(f func()) { + q.t.Helper() + + if !q.h.initialized.Load() { + q.t.Fatal("ExecQueue.Add must not be called until the host is initialized") + return + } + + q.mu.Lock() + q.queue = append(q.queue, f) + q.mu.Unlock() +} + +// Drain executes all queued functions in the order they were added. +func (q *testExecQueue) Drain() { + q.mu.Lock() + queue := q.queue + q.queue = nil + q.mu.Unlock() + + for _, f := range queue { + f() + } +} + +// Shutdown implements [execQueue]. +func (q *testExecQueue) Shutdown() {} + +// Wait implements [execQueue]. +func (q *testExecQueue) Wait(context.Context) error { return nil } + +// testBackend implements [ipnext.Backend] for testing purposes +// by calling the provided hooks when its methods are called. +type testBackend struct { + lazySys lazy.SyncValue[*tsd.System] + switchToBestProfileHook func(reason string) + + // mu protects the backend state. + // It is acquired on entry to the exported methods of the backend + // and released on exit, mimicking the behavior of the [LocalBackend]. + mu sync.Mutex +} + +func (b *testBackend) Clock() tstime.Clock { return tstime.StdClock{} } +func (b *testBackend) Sys() *tsd.System { + return b.lazySys.Get(tsd.NewSystem) +} +func (b *testBackend) SendNotify(ipn.Notify) { panic("not implemented") } +func (b *testBackend) NodeBackend() ipnext.NodeBackend { panic("not implemented") } +func (b *testBackend) TailscaleVarRoot() string { panic("not implemented") } + +func (b *testBackend) SwitchToBestProfile(reason string) { + b.mu.Lock() + defer b.mu.Unlock() + if b.switchToBestProfileHook != nil { + b.switchToBestProfileHook(reason) + } +} + +// equatableView is an interface implemented by views +// that can be compared for equality. +type equatableView[T any] interface { + Valid() bool + Equals(other T) bool +} + +// checkViewsEqual checks that the two views are equal +// and fails the test if they are not. The prefix is used +// to format the error message. +func checkViewsEqual[T equatableView[T]](t *testing.T, prefix string, got, want T) { + t.Helper() + switch { + case got.Equals(want): + return + case got.Valid() && want.Valid(): + t.Errorf("%s: got %v; want %v", prefix, got, want) + case got.Valid() && !want.Valid(): + t.Errorf("%s: got %v; want invalid", prefix, got) + case !got.Valid() && want.Valid(): + t.Errorf("%s: got invalid; want %v", prefix, want) + default: + panic("unreachable") + } +} diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 06dd84831254c..468fd72eb59cf 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -9,7 +9,9 @@ import ( "bytes" "cmp" "context" + "crypto/sha256" "encoding/base64" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -24,11 +26,9 @@ import ( "net/url" "os" "os/exec" - "path/filepath" "reflect" "runtime" "slices" - "sort" "strconv" "strings" "sync" @@ -37,7 +37,6 @@ import ( "go4.org/mem" "go4.org/netipx" - xmaps "golang.org/x/exp/maps" "golang.org/x/net/dns/dnsmessage" "gvisor.dev/gvisor/pkg/tcpip" "tailscale.com/appc" @@ -51,12 +50,14 @@ import ( "tailscale.com/doctor/routetable" "tailscale.com/drive" "tailscale.com/envknob" + "tailscale.com/envknob/featureknob" "tailscale.com/health" "tailscale.com/health/healthmsg" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/conffile" "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnext" "tailscale.com/ipn/ipnstate" "tailscale.com/ipn/policy" "tailscale.com/log/sockstatlog" @@ -71,13 +72,14 @@ import ( "tailscale.com/net/netmon" "tailscale.com/net/netns" "tailscale.com/net/netutil" + "tailscale.com/net/packet" "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" "tailscale.com/paths" "tailscale.com/portlist" + "tailscale.com/posture" "tailscale.com/syncs" "tailscale.com/tailcfg" - "tailscale.com/taildrop" "tailscale.com/tka" "tailscale.com/tsd" "tailscale.com/tstime" @@ -85,7 +87,6 @@ import ( "tailscale.com/types/dnstype" "tailscale.com/types/empty" "tailscale.com/types/key" - "tailscale.com/types/lazy" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" @@ -94,24 +95,25 @@ import ( "tailscale.com/types/preftype" "tailscale.com/types/ptr" "tailscale.com/types/views" + "tailscale.com/util/clientmetric" "tailscale.com/util/deephash" "tailscale.com/util/dnsname" + "tailscale.com/util/goroutines" "tailscale.com/util/httpm" "tailscale.com/util/mak" "tailscale.com/util/multierr" - "tailscale.com/util/osshare" "tailscale.com/util/osuser" "tailscale.com/util/rands" "tailscale.com/util/set" + "tailscale.com/util/slicesx" "tailscale.com/util/syspolicy" + "tailscale.com/util/syspolicy/rsop" "tailscale.com/util/systemd" "tailscale.com/util/testenv" - "tailscale.com/util/uniq" "tailscale.com/util/usermetric" "tailscale.com/version" "tailscale.com/version/distro" "tailscale.com/wgengine" - "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/magicsock" "tailscale.com/wgengine/router" @@ -154,14 +156,18 @@ func RegisterNewSSHServer(fn newSSHServerFunc) { newSSHServer = fn } -// watchSession represents a WatchNotifications channel +// watchSession represents a WatchNotifications channel, +// an [ipnauth.Actor] that owns it (e.g., a connected GUI/CLI), // and sessionID as required to close targeted buses. type watchSession struct { ch chan *ipn.Notify + owner ipnauth.Actor // or nil sessionID string - cancel func() // call to signal that the session must be terminated + cancel context.CancelFunc // to shut down the session } +var metricCaptivePortalDetected = clientmetric.NewCounter("captiveportal_detected") + // LocalBackend is the glue between the major pieces of the Tailscale // network software: the cloud control plane (via controlclient), the // network data plane (via wgengine), and the user-facing UIs and CLIs @@ -174,35 +180,36 @@ type watchSession struct { // state machine generates events back out to zero or more components. type LocalBackend struct { // Elements that are thread-safe or constant after construction. - ctx context.Context // canceled by Close - ctxCancel context.CancelFunc // cancels ctx - logf logger.Logf // general logging - keyLogf logger.Logf // for printing list of peers on change - statsLogf logger.Logf // for printing peers stats on change - sys *tsd.System - health *health.Tracker // always non-nil - metrics metrics - e wgengine.Engine // non-nil; TODO(bradfitz): remove; use sys - store ipn.StateStore // non-nil; TODO(bradfitz): remove; use sys - dialer *tsdial.Dialer // non-nil; TODO(bradfitz): remove; use sys - pushDeviceToken syncs.AtomicValue[string] - backendLogID logid.PublicID - unregisterNetMon func() - unregisterHealthWatch func() - portpoll *portlist.Poller // may be nil - portpollOnce sync.Once // guards starting readPoller - varRoot string // or empty if SetVarRoot never called - logFlushFunc func() // or nil if SetLogFlusher wasn't called - em *expiryManager // non-nil - sshAtomicBool atomic.Bool + ctx context.Context // canceled by [LocalBackend.Shutdown] + ctxCancel context.CancelFunc // cancels ctx + logf logger.Logf // general logging + keyLogf logger.Logf // for printing list of peers on change + statsLogf logger.Logf // for printing peers stats on change + sys *tsd.System + health *health.Tracker // always non-nil + metrics metrics + e wgengine.Engine // non-nil; TODO(bradfitz): remove; use sys + store ipn.StateStore // non-nil; TODO(bradfitz): remove; use sys + dialer *tsdial.Dialer // non-nil; TODO(bradfitz): remove; use sys + pushDeviceToken syncs.AtomicValue[string] + backendLogID logid.PublicID + unregisterNetMon func() + unregisterHealthWatch func() + unregisterSysPolicyWatch func() + portpoll *portlist.Poller // may be nil + portpollOnce sync.Once // guards starting readPoller + varRoot string // or empty if SetVarRoot never called + logFlushFunc func() // or nil if SetLogFlusher wasn't called + em *expiryManager // non-nil; TODO(nickkhyl): move to nodeContext + sshAtomicBool atomic.Bool // TODO(nickkhyl): move to nodeContext // webClientAtomicBool controls whether the web client is running. This should // be true unless the disable-web-client node attribute has been set. - webClientAtomicBool atomic.Bool + webClientAtomicBool atomic.Bool // TODO(nickkhyl): move to nodeContext // exposeRemoteWebClientAtomicBool controls whether the web client is exposed over // Tailscale on port 5252. - exposeRemoteWebClientAtomicBool atomic.Bool - shutdownCalled bool // if Shutdown has been called - debugSink *capture.Sink + exposeRemoteWebClientAtomicBool atomic.Bool // TODO(nickkhyl): move to nodeContext + shutdownCalled bool // if Shutdown has been called + debugSink packet.CaptureSink sockstatLogger *sockstatlog.Logger // getTCPHandlerForFunnelFlow returns a handler for an incoming TCP flow for @@ -221,82 +228,84 @@ type LocalBackend struct { // is never called. getTCPHandlerForFunnelFlow func(srcAddr netip.AddrPort, dstPort uint16) (handler func(net.Conn)) - filterAtomic atomic.Pointer[filter.Filter] - containsViaIPFuncAtomic syncs.AtomicValue[func(netip.Addr) bool] - shouldInterceptTCPPortAtomic syncs.AtomicValue[func(uint16) bool] - numClientStatusCalls atomic.Uint32 + containsViaIPFuncAtomic syncs.AtomicValue[func(netip.Addr) bool] // TODO(nickkhyl): move to nodeContext + shouldInterceptTCPPortAtomic syncs.AtomicValue[func(uint16) bool] // TODO(nickkhyl): move to nodeContext + shouldInterceptVIPServicesTCPPortAtomic syncs.AtomicValue[func(netip.AddrPort) bool] // TODO(nickkhyl): move to nodeContext + numClientStatusCalls atomic.Uint32 // TODO(nickkhyl): move to nodeContext + + // goTracker accounts for all goroutines started by LocalBacked, primarily + // for testing and graceful shutdown purposes. + goTracker goroutines.Tracker + + startOnce sync.Once // protects the one‑time initialization in [LocalBackend.Start] + + // extHost is the bridge between [LocalBackend] and the registered [ipnext.Extension]s. + // It may be nil in tests that use direct composite literal initialization of [LocalBackend] + // instead of calling [NewLocalBackend]. A nil pointer is a valid, no-op host. + // It can be used with or without b.mu held, but is typically used with it held + // to prevent state changes while invoking callbacks. + extHost *ExtensionHost // The mutex protects the following elements. - mu sync.Mutex - conf *conffile.Config // latest parsed config, or nil if not in declarative mode - pm *profileManager // mu guards access - filterHash deephash.Sum + mu sync.Mutex + + // currentNodeAtomic is the current node context. It is always non-nil. + // It must be re-created when [LocalBackend] switches to a different profile/node + // (see tailscale/corp#28014 for a bug), but can be mutated in place (via its methods) + // while [LocalBackend] represents the same node. + // + // It is safe for reading with or without holding b.mu, but mutating it in place + // or creating a new one must be done with b.mu held. If both mutexes must be held, + // the LocalBackend's mutex must be acquired first before acquiring the nodeContext's mutex. + // + // We intend to relax this in the future and only require holding b.mu when replacing it, + // but that requires a better (strictly ordered?) state machine and better management + // of [LocalBackend]'s own state that is not tied to the node context. + currentNodeAtomic atomic.Pointer[nodeBackend] + + conf *conffile.Config // latest parsed config, or nil if not in declarative mode + pm *profileManager // mu guards access + filterHash deephash.Sum // TODO(nickkhyl): move to nodeContext httpTestClient *http.Client // for controlclient. nil by default, used by tests. ccGen clientGen // function for producing controlclient; lazily populated sshServer SSHServer // or nil, initialized lazily. appConnector *appc.AppConnector // or nil, initialized when configured. // notifyCancel cancels notifications to the current SetNotifyCallback. notifyCancel context.CancelFunc - cc controlclient.Client - ccAuto *controlclient.Auto // if cc is of type *controlclient.Auto + cc controlclient.Client // TODO(nickkhyl): move to nodeContext + ccAuto *controlclient.Auto // if cc is of type *controlclient.Auto; TODO(nickkhyl): move to nodeContext machinePrivKey key.MachinePrivate - tka *tkaState - state ipn.State - capFileSharing bool // whether netMap contains the file sharing capability - capTailnetLock bool // whether netMap contains the tailnet lock capability + tka *tkaState // TODO(nickkhyl): move to nodeContext + state ipn.State // TODO(nickkhyl): move to nodeContext + capTailnetLock bool // whether netMap contains the tailnet lock capability // hostinfo is mutated in-place while mu is held. - hostinfo *tailcfg.Hostinfo - // netMap is the most recently set full netmap from the controlclient. - // It can't be mutated in place once set. Because it can't be mutated in place, - // delta updates from the control server don't apply to it. Instead, use - // the peers map to get up-to-date information on the state of peers. - // In general, avoid using the netMap.Peers slice. We'd like it to go away - // as of 2023-09-17. - netMap *netmap.NetworkMap - // peers is the set of current peers and their current values after applying - // delta node mutations as they come in (with mu held). The map values can - // be given out to callers, but the map itself must not escape the LocalBackend. - peers map[tailcfg.NodeID]tailcfg.NodeView - nodeByAddr map[netip.Addr]tailcfg.NodeID // by Node.Addresses only (not subnet routes) - nmExpiryTimer tstime.TimerController // for updating netMap on node expiry; can be nil - activeLogin string // last logged LoginName from netMap - engineStatus ipn.EngineStatus - endpoints []tailcfg.Endpoint - blocked bool - keyExpired bool - authURL string // non-empty if not Running - authURLTime time.Time // when the authURL was received from the control server - interact bool // indicates whether a user requested interactive login - egg bool - prevIfState *netmon.State - peerAPIServer *peerAPIServer // or nil - peerAPIListeners []*peerAPIListener - loginFlags controlclient.LoginFlags - fileWaiters set.HandleSet[context.CancelFunc] // of wake-up funcs - notifyWatchers map[string]*watchSession // by session ID - lastStatusTime time.Time // status.AsOf value of the last processed status update - // directFileRoot, if non-empty, means to write received files - // directly to this directory, without staging them in an - // intermediate buffered directory for "pick-up" later. If - // empty, the files are received in a daemon-owned location - // and the localapi is used to enumerate, download, and delete - // them. This is used on macOS where the GUI lifetime is the - // same as the Network Extension lifetime and we can thus avoid - // double-copying files by writing them to the right location - // immediately. - // It's also used on several NAS platforms (Synology, TrueNAS, etc) - // but in that case DoFinalRename is also set true, which moves the - // *.partial file to its final name on completion. - directFileRoot string + hostinfo *tailcfg.Hostinfo // TODO(nickkhyl): move to nodeContext + nmExpiryTimer tstime.TimerController // for updating netMap on node expiry; can be nil; TODO(nickkhyl): move to nodeContext + activeLogin string // last logged LoginName from netMap; TODO(nickkhyl): move to nodeContext (or remove? it's in [ipn.LoginProfile]). + engineStatus ipn.EngineStatus + endpoints []tailcfg.Endpoint + blocked bool + keyExpired bool // TODO(nickkhyl): move to nodeContext + authURL string // non-empty if not Running; TODO(nickkhyl): move to nodeContext + authURLTime time.Time // when the authURL was received from the control server; TODO(nickkhyl): move to nodeContext + authActor ipnauth.Actor // an actor who called [LocalBackend.StartLoginInteractive] last, or nil; TODO(nickkhyl): move to nodeContext + egg bool + prevIfState *netmon.State + peerAPIServer *peerAPIServer // or nil + peerAPIListeners []*peerAPIListener + loginFlags controlclient.LoginFlags + notifyWatchers map[string]*watchSession // by session ID + lastStatusTime time.Time // status.AsOf value of the last processed status update componentLogUntil map[string]componentLogState // c2nUpdateStatus is the status of c2n-triggered client update. - c2nUpdateStatus updateStatus - currentUser ipnauth.Actor + c2nUpdateStatus updateStatus + currentUser ipnauth.Actor + selfUpdateProgress []ipnstate.UpdateProgress lastSelfUpdateState ipnstate.SelfUpdateStatus // capForcedNetfilter is the netfilter that control instructs Linux clients // to use, unless overridden locally. - capForcedNetfilter string + capForcedNetfilter string // TODO(nickkhyl): move to nodeContext // offlineAutoUpdateCancel stops offline auto-updates when called. It // should be used via stopOfflineAutoUpdate and // maybeStartOfflineAutoUpdate. It is nil when offline auto-updates are @@ -306,8 +315,9 @@ type LocalBackend struct { offlineAutoUpdateCancel func() // ServeConfig fields. (also guarded by mu) - lastServeConfJSON mem.RO // last JSON that was parsed into serveConfig - serveConfig ipn.ServeConfigView // or !Valid if none + lastServeConfJSON mem.RO // last JSON that was parsed into serveConfig + serveConfig ipn.ServeConfigView // or !Valid if none + ipVIPServiceMap netmap.IPServiceMappings // map of VIPService IPs to their corresponding service names; TODO(nickkhyl): move to nodeContext webClient webClient webClientListeners map[netip.AddrPort]*localListener // listeners for local web client traffic @@ -322,7 +332,7 @@ type LocalBackend struct { // dialPlan is any dial plan that we've received from the control // server during a previous connection; it is cleared on logout. - dialPlan atomic.Pointer[tailcfg.ControlDialPlan] + dialPlan atomic.Pointer[tailcfg.ControlDialPlan] // TODO(nickkhyl): maybe move to nodeContext? // tkaSyncLock is used to make tkaSyncIfNeeded an exclusive // section. This is needed to stop two map-responses in quick succession @@ -343,15 +353,24 @@ type LocalBackend struct { // notified about. lastNotifiedDriveShares *views.SliceView[*drive.Share, drive.ShareView] - // outgoingFiles keeps track of Taildrop outgoing files keyed to their OutgoingFile.ID - outgoingFiles map[string]*ipn.OutgoingFile + // lastKnownHardwareAddrs is a list of the previous known hardware addrs. + // Previously known hwaddrs are kept to work around an issue on Windows + // where all addresses might disappear. + // http://go/corp/25168 + lastKnownHardwareAddrs syncs.AtomicValue[[]string] // lastSuggestedExitNode stores the last suggested exit node suggestion to // avoid unnecessary churn between multiple equally-good options. lastSuggestedExitNode tailcfg.StableNodeID + // allowedSuggestedExitNodes is a set of exit nodes permitted by the most recent + // [syspolicy.AllowedSuggestedExitNodes] value. The allowedSuggestedExitNodesMu + // mutex guards access to this set. + allowedSuggestedExitNodesMu sync.Mutex + allowedSuggestedExitNodes set.Set[tailcfg.StableNodeID] + // refreshAutoExitNode indicates if the exit node should be recomputed when the next netcheck report is available. - refreshAutoExitNode bool + refreshAutoExitNode bool // guarded by mu // captiveCtx and captiveCancel are used to control captive portal // detection. They are protected by 'mu' and can be changed during the @@ -367,6 +386,18 @@ type LocalBackend struct { // backend is healthy and captive portal detection is not required // (sending false). needsCaptiveDetection chan bool + + // overrideAlwaysOn is whether [syspolicy.AlwaysOn] is overridden by the user + // and should have no impact on the WantRunning state until the policy changes, + // or the user re-connects manually, switches to a different profile, etc. + // Notably, this is true when [syspolicy.AlwaysOnOverrideWithReason] is enabled, + // and the user has disconnected with a reason. + // See tailscale/corp#26146. + overrideAlwaysOn bool + + // reconnectTimer is used to schedule a reconnect by setting [ipn.Prefs.WantRunning] + // to true after a delay, or nil if no reconnect is scheduled. + reconnectTimer tstime.TimerController } // HealthTracker returns the health tracker for the backend. @@ -396,11 +427,6 @@ type metrics struct { // approvedRoutes is a metric that reports the number of network routes served by the local node and approved // by the control server. approvedRoutes *usermetric.Gauge - - // primaryRoutes is a metric that reports the number of primary network routes served by the local node. - // A route being a primary route implies that the route is currently served by this node, and not by another - // subnet router in a high availability configuration. - primaryRoutes *usermetric.Gauge } // clientGen is a func that creates a control plane client. @@ -411,7 +437,7 @@ type clientGen func(controlclient.Options) (controlclient.Client, error) // but is not actually running. // // If dialer is nil, a new one is made. -func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, loginFlags controlclient.LoginFlags) (*LocalBackend, error) { +func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, loginFlags controlclient.LoginFlags) (_ *LocalBackend, err error) { e := sys.Engine.Get() store := sys.StateStore.Get() dialer := sys.Dialer.Get() @@ -436,7 +462,6 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo } envknob.LogCurrent(logf) - osshare.SetFileSharingEnabled(false, logf) ctx, cancel := context.WithCancel(context.Background()) clock := tstime.StdClock{} @@ -451,8 +476,6 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo "tailscaled_advertised_routes", "Number of advertised network routes (e.g. by a subnet router)"), approvedRoutes: sys.UserMetricsRegistry().NewGauge( "tailscaled_approved_routes", "Number of approved network routes (e.g. by a subnet router)"), - primaryRoutes: sys.UserMetricsRegistry().NewGauge( - "tailscaled_primary_routes", "Number of network routes for which this node is a primary router (in high availability configuration)"), } b := &LocalBackend{ @@ -480,14 +503,29 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo captiveCancel: nil, // so that we start checkCaptivePortalLoop when Running needsCaptiveDetection: make(chan bool), } + b.currentNodeAtomic.Store(newNodeBackend()) mConn.SetNetInfoCallback(b.setNetInfo) if sys.InitialConfig != nil { - if err := b.setConfigLocked(sys.InitialConfig); err != nil { + if err := b.initPrefsFromConfig(sys.InitialConfig); err != nil { return nil, err } } + if b.extHost, err = NewExtensionHost(logf, b); err != nil { + return nil, fmt.Errorf("failed to create extension host: %w", err) + } + b.pm.SetExtensionHost(b.extHost) + + if b.unregisterSysPolicyWatch, err = b.registerSysPolicyWatch(); err != nil { + return nil, err + } + defer func() { + if err != nil { + b.unregisterSysPolicyWatch() + } + }() + netMon := sys.NetMon.Get() b.sockstatLogger, err = sockstatlog.NewLogger(logpolicy.LogsDir(logf), logf, logID, netMon, sys.HealthTracker()) if err != nil { @@ -504,6 +542,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo b.e.SetJailedFilter(noneFilter) b.setTCPPortsIntercepted(nil) + b.setVIPServicesTCPPortsIntercepted(nil) b.statusChanged = sync.NewCond(&b.statusLock) b.e.SetStatusCallback(b.setWgengineStatus) @@ -531,21 +570,41 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo } } } + return b, nil +} - // initialize Taildrive shares from saved state - fs, ok := b.sys.DriveForRemote.GetOK() - if ok { - currentShares := b.pm.prefs.DriveShares() - if currentShares.Len() > 0 { - var shares []*drive.Share - for _, share := range currentShares.All() { - shares = append(shares, share.AsStruct()) - } - fs.SetShares(shares) - } +func (b *LocalBackend) Clock() tstime.Clock { return b.clock } +func (b *LocalBackend) Sys() *tsd.System { return b.sys } + +// NodeBackend returns the current node's NodeBackend interface. +func (b *LocalBackend) NodeBackend() ipnext.NodeBackend { + return b.currentNode() +} + +func (b *LocalBackend) currentNode() *nodeBackend { + if v := b.currentNodeAtomic.Load(); v != nil || !testenv.InTest() { + return v } + // Auto-init one in tests for LocalBackend created without the NewLocalBackend constructor... + v := newNodeBackend() + b.currentNodeAtomic.CompareAndSwap(nil, v) + return b.currentNodeAtomic.Load() +} - return b, nil +// FindExtensionByName returns an active extension with the given name, +// or nil if no such extension exists. +func (b *LocalBackend) FindExtensionByName(name string) any { + return b.extHost.Extensions().FindExtensionByName(name) +} + +// FindMatchingExtension finds the first active extension that matches target, +// and if one is found, sets target to that extension and returns true. +// Otherwise, it returns false. +// +// It panics if target is not a non-nil pointer to either a type +// that implements [ipnext.Extension], or to any interface type. +func (b *LocalBackend) FindMatchingExtension(target any) bool { + return b.extHost.Extensions().FindMatchingExtension(target) } type componentLogState struct { @@ -700,24 +759,13 @@ func (b *LocalBackend) Dialer() *tsdial.Dialer { return b.dialer } -// SetDirectFileRoot sets the directory to download files to directly, -// without buffering them through an intermediate daemon-owned -// tailcfg.UserID-specific directory. -// -// This must be called before the LocalBackend starts being used. -func (b *LocalBackend) SetDirectFileRoot(dir string) { - b.mu.Lock() - defer b.mu.Unlock() - b.directFileRoot = dir -} - // ReloadConfig reloads the backend's config from disk. // // It returns (false, nil) if not running in declarative mode, (true, nil) on // success, or (false, error) on failure. func (b *LocalBackend) ReloadConfig() (ok bool, err error) { - b.mu.Lock() - defer b.mu.Unlock() + unlock := b.lockAndGetUnlock() + defer unlock() if b.conf == nil { return false, nil } @@ -725,18 +773,21 @@ func (b *LocalBackend) ReloadConfig() (ok bool, err error) { if err != nil { return false, err } - if err := b.setConfigLocked(conf); err != nil { + if err := b.setConfigLockedOnEntry(conf, unlock); err != nil { return false, fmt.Errorf("error setting config: %w", err) } return true, nil } -func (b *LocalBackend) setConfigLocked(conf *conffile.Config) error { - - // TODO(irbekrm): notify the relevant components to consume any prefs - // updates. Currently only initial configfile settings are applied - // immediately. +// initPrefsFromConfig initializes the backend's prefs from the provided config. +// This should only be called once, at startup. For updates at runtime, use +// [LocalBackend.setConfigLocked]. +func (b *LocalBackend) initPrefsFromConfig(conf *conffile.Config) error { + // TODO(maisem,bradfitz): combine this with setConfigLocked. This is called + // before anything is running, so there's no need to lock and we don't + // update any subsystems. At runtime, we both need to lock and update + // subsystems with the new prefs. p := b.pm.CurrentPrefs().AsStruct() mp, err := conf.Parsed.ToPrefs() if err != nil { @@ -746,13 +797,14 @@ func (b *LocalBackend) setConfigLocked(conf *conffile.Config) error { if err := b.pm.SetPrefs(p.View(), ipn.NetworkProfile{}); err != nil { return err } + b.setStaticEndpointsFromConfigLocked(conf) + b.conf = conf + return nil +} - defer func() { - b.conf = conf - }() - +func (b *LocalBackend) setStaticEndpointsFromConfigLocked(conf *conffile.Config) { if conf.Parsed.StaticEndpoints == nil && (b.conf == nil || b.conf.Parsed.StaticEndpoints == nil) { - return nil + return } // Ensure that magicsock conn has the up to date static wireguard @@ -766,6 +818,32 @@ func (b *LocalBackend) setConfigLocked(conf *conffile.Config) error { ms.SetStaticEndpoints(views.SliceOf(conf.Parsed.StaticEndpoints)) } } +} + +func (b *LocalBackend) setStateLocked(state ipn.State) { + if b.state == state { + return + } + b.state = state + for _, f := range b.extHost.Hooks().BackendStateChange { + f(state) + } +} + +// setConfigLockedOnEntry uses the provided config to update the backend's prefs +// and other state. +func (b *LocalBackend) setConfigLockedOnEntry(conf *conffile.Config, unlock unlockOnce) error { + defer unlock() + p := b.pm.CurrentPrefs().AsStruct() + mp, err := conf.Parsed.ToPrefs() + if err != nil { + return fmt.Errorf("error parsing config to prefs: %w", err) + } + p.ApplyEdits(&mp) + b.setStaticEndpointsFromConfigLocked(conf) + b.setPrefsLockedOnEntry(p, unlock) + + b.conf = conf return nil } @@ -781,7 +859,20 @@ func (b *LocalBackend) pauseOrResumeControlClientLocked() { return } networkUp := b.prevIfState.AnyInterfaceUp() - b.cc.SetPaused((b.state == ipn.Stopped && b.netMap != nil) || (!networkUp && !testenv.InTest() && !assumeNetworkUpdateForTest())) + b.cc.SetPaused((b.state == ipn.Stopped && b.NetMap() != nil) || (!networkUp && !testenv.InTest() && !assumeNetworkUpdateForTest())) +} + +// DisconnectControl shuts down control client. This can be run before node shutdown to force control to consider this ndoe +// inactive. This can be used to ensure that nodes that are HA subnet router or app connector replicas are shutting +// down, clients switch over to other replicas whilst the existing connections are kept alive for some period of time. +func (b *LocalBackend) DisconnectControl() { + b.mu.Lock() + defer b.mu.Unlock() + cc := b.resetControlClientLocked() + if cc == nil { + return + } + cc.Shutdown() } // captivePortalDetectionInterval is the duration to wait in an unhealthy state with connectivity broken @@ -820,20 +911,24 @@ func (b *LocalBackend) linkChange(delta *netmon.ChangeDelta) { // TODO(raggi,tailscale/corp#22574): authReconfig should be refactored such that we can call the // necessary operations here and avoid the need for asynchronous behavior that is racy and hard // to test here, and do less extra work in these conditions. - go b.authReconfig() + b.goTracker.Go(b.authReconfig) } } // If the local network configuration has changed, our filter may // need updating to tweak default routes. - b.updateFilterLocked(b.netMap, b.pm.CurrentPrefs()) + b.updateFilterLocked(b.pm.CurrentPrefs()) updateExitNodeUsageWarning(b.pm.CurrentPrefs(), delta.New, b.health) - if peerAPIListenAsync && b.netMap != nil && b.state == ipn.Running { - want := b.netMap.GetAddresses().Len() - if len(b.peerAPIListeners) < want { + cn := b.currentNode() + nm := cn.NetMap() + if peerAPIListenAsync && nm != nil && b.state == ipn.Running { + want := nm.GetAddresses().Len() + have := len(b.peerAPIListeners) + b.logf("[v1] linkChange: have %d peerAPIListeners, want %d", have, want) + if have < want { b.logf("linkChange: peerAPIListeners too low; trying again") - go b.initPeerAPIListener() + b.goTracker.Go(b.initPeerAPIListener) } } } @@ -893,6 +988,40 @@ func (b *LocalBackend) onHealthChange(w *health.Warnable, us *health.UnhealthySt } } +// GetOrSetCaptureSink returns the current packet capture sink, creating it +// with the provided newSink function if it does not already exist. +func (b *LocalBackend) GetOrSetCaptureSink(newSink func() packet.CaptureSink) packet.CaptureSink { + b.mu.Lock() + defer b.mu.Unlock() + + if b.debugSink != nil { + return b.debugSink + } + s := newSink() + b.debugSink = s + b.e.InstallCaptureHook(s.CaptureCallback()) + return s +} + +func (b *LocalBackend) ClearCaptureSink() { + // Shut down & uninstall the sink if there are no longer + // any outputs on it. + b.mu.Lock() + defer b.mu.Unlock() + + select { + case <-b.ctx.Done(): + return + default: + } + if b.debugSink != nil && b.debugSink.NumOutputs() == 0 { + s := b.debugSink + b.e.InstallCaptureHook(nil) + b.debugSink = nil + s.Close() + } +} + // Shutdown halts the backend and all its sub-components. The backend // can no longer be used after Shutdown returns. func (b *LocalBackend) Shutdown() { @@ -908,6 +1037,8 @@ func (b *LocalBackend) Shutdown() { b.captiveCancel() } + b.stopReconnectTimerLocked() + if b.loginFlags&controlclient.LoginEphemeral != 0 { b.mu.Unlock() ctx, cancel := context.WithTimeout(b.ctx, 5*time.Second) @@ -936,6 +1067,8 @@ func (b *LocalBackend) Shutdown() { if b.notifyCancel != nil { b.notifyCancel() } + extHost := b.extHost + b.extHost = nil b.mu.Unlock() b.webClientShutdown() @@ -944,19 +1077,45 @@ func (b *LocalBackend) Shutdown() { defer cancel() b.sockstatLogger.Shutdown(ctx) } - if b.peerAPIServer != nil { - b.peerAPIServer.taildrop.Shutdown() - } b.stopOfflineAutoUpdate() b.unregisterNetMon() b.unregisterHealthWatch() + b.unregisterSysPolicyWatch() if cc != nil { cc.Shutdown() } + extHost.Shutdown() b.ctxCancel() b.e.Close() <-b.e.Done() + b.awaitNoGoroutinesInTest() +} + +func (b *LocalBackend) awaitNoGoroutinesInTest() { + if !testenv.InTest() { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second) + defer cancel() + + ch := make(chan bool, 1) + defer b.goTracker.AddDoneCallback(func() { ch <- true })() + + for { + n := b.goTracker.RunningGoroutines() + if n == 0 { + return + } + select { + case <-ctx.Done(): + // TODO(bradfitz): pass down some TB-like failer interface from + // tests, without depending on testing from here? + // But this is fine in tests too: + panic(fmt.Sprintf("timeout waiting for %d goroutines to stop", n)) + case <-ch: + } + } } func stripKeysFromPrefs(p ipn.PrefsView) ipn.PrefsView { @@ -965,7 +1124,6 @@ func stripKeysFromPrefs(p ipn.PrefsView) ipn.PrefsView { } p2 := p.AsStruct() - p2.Persist.LegacyFrontendPrivateMachineKey = key.MachinePrivate{} p2.Persist.PrivateNodeKey = key.NodePrivate{} p2.Persist.OldPrivateNodeKey = key.NodePrivate{} p2.Persist.NetworkLockKey = key.NLPrivate{} @@ -1006,6 +1164,8 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { b.mu.Lock() defer b.mu.Unlock() + cn := b.currentNode() + nm := cn.NetMap() sb.MutateStatus(func(s *ipnstate.Status) { s.Version = version.Long() s.TUN = !b.sys.IsNetstack() @@ -1022,28 +1182,24 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { if m := b.sshOnButUnusableHealthCheckMessageLocked(); m != "" { s.Health = append(s.Health, m) } - if b.netMap != nil { - s.CertDomains = append([]string(nil), b.netMap.DNS.CertDomains...) - s.MagicDNSSuffix = b.netMap.MagicDNSSuffix() + if nm != nil { + s.CertDomains = append([]string(nil), nm.DNS.CertDomains...) + s.MagicDNSSuffix = nm.MagicDNSSuffix() if s.CurrentTailnet == nil { s.CurrentTailnet = &ipnstate.TailnetStatus{} } - s.CurrentTailnet.MagicDNSSuffix = b.netMap.MagicDNSSuffix() - s.CurrentTailnet.MagicDNSEnabled = b.netMap.DNS.Proxied - s.CurrentTailnet.Name = b.netMap.Domain + s.CurrentTailnet.MagicDNSSuffix = nm.MagicDNSSuffix() + s.CurrentTailnet.MagicDNSEnabled = nm.DNS.Proxied + s.CurrentTailnet.Name = nm.Domain if prefs := b.pm.CurrentPrefs(); prefs.Valid() { - if !prefs.RouteAll() && b.netMap.AnyPeersAdvertiseRoutes() { + if !prefs.RouteAll() && nm.AnyPeersAdvertiseRoutes() { s.Health = append(s.Health, healthmsg.WarnAcceptRoutesOff) } if !prefs.ExitNodeID().IsZero() { - if exitPeer, ok := b.netMap.PeerWithStableID(prefs.ExitNodeID()); ok { - online := false - if v := exitPeer.Online(); v != nil { - online = *v - } + if exitPeer, ok := nm.PeerWithStableID(prefs.ExitNodeID()); ok { s.ExitNodeStatus = &ipnstate.ExitNodeStatus{ ID: prefs.ExitNodeID(), - Online: online, + Online: exitPeer.Online().Get(), TailscaleIPs: exitPeer.Addresses().AsSlice(), } } @@ -1053,8 +1209,8 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { }) var tailscaleIPs []netip.Addr - if b.netMap != nil { - addrs := b.netMap.GetAddresses() + if nm != nil { + addrs := nm.GetAddresses() for i := range addrs.Len() { if addr := addrs.At(i); addr.IsSingleIP() { sb.AddTailscaleIP(addr.Addr()) @@ -1066,24 +1222,23 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { sb.MutateSelfStatus(func(ss *ipnstate.PeerStatus) { ss.OS = version.OS() ss.Online = b.health.GetInPollNetMap() - if b.netMap != nil { + if nm != nil { ss.InNetworkMap = true - if hi := b.netMap.SelfNode.Hostinfo(); hi.Valid() { + if hi := nm.SelfNode.Hostinfo(); hi.Valid() { ss.HostName = hi.Hostname() } - ss.DNSName = b.netMap.Name - ss.UserID = b.netMap.User() - if sn := b.netMap.SelfNode; sn.Valid() { + ss.DNSName = nm.Name + ss.UserID = nm.User() + if sn := nm.SelfNode; sn.Valid() { peerStatusFromNode(ss, sn) if cm := sn.CapMap(); cm.Len() > 0 { ss.Capabilities = make([]tailcfg.NodeCapability, 1, cm.Len()+1) ss.Capabilities[0] = "HTTPS://TAILSCALE.COM/s/DEPRECATED-NODE-CAPS#see-https://github.com/tailscale/tailscale/issues/11508" ss.CapMap = make(tailcfg.NodeCapMap, sn.CapMap().Len()) - cm.Range(func(k tailcfg.NodeCapability, v views.Slice[tailcfg.RawMessage]) bool { + for k, v := range cm.All() { ss.CapMap[k] = v.AsSlice() ss.Capabilities = append(ss.Capabilities, k) - return true - }) + } slices.Sort(ss.Capabilities[1:]) } } @@ -1092,7 +1247,7 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { } } else { - ss.HostName, _ = os.Hostname() + ss.HostName, _ = hostinfo.Hostname() } for _, pln := range b.peerAPIListeners { ss.PeerAPIURL = append(ss.PeerAPIURL, pln.urlStr) @@ -1107,18 +1262,16 @@ func (b *LocalBackend) UpdateStatus(sb *ipnstate.StatusBuilder) { } func (b *LocalBackend) populatePeerStatusLocked(sb *ipnstate.StatusBuilder) { - if b.netMap == nil { + cn := b.currentNode() + nm := cn.NetMap() + if nm == nil { return } - for id, up := range b.netMap.UserProfiles { + for id, up := range nm.UserProfiles { sb.AddUser(id, up) } exitNodeID := b.pm.CurrentPrefs().ExitNodeID() - for _, p := range b.peers { - var lastSeen time.Time - if p.LastSeen() != nil { - lastSeen = *p.LastSeen() - } + for _, p := range cn.Peers() { tailscaleIPs := make([]netip.Addr, 0, p.Addresses().Len()) for i := range p.Addresses().Len() { addr := p.Addresses().At(i) @@ -1126,7 +1279,6 @@ func (b *LocalBackend) populatePeerStatusLocked(sb *ipnstate.StatusBuilder) { tailscaleIPs = append(tailscaleIPs, addr.Addr()) } } - online := p.Online() ps := &ipnstate.PeerStatus{ InNetworkMap: true, UserID: p.User(), @@ -1135,20 +1287,22 @@ func (b *LocalBackend) populatePeerStatusLocked(sb *ipnstate.StatusBuilder) { HostName: p.Hostinfo().Hostname(), DNSName: p.Name(), OS: p.Hostinfo().OS(), - LastSeen: lastSeen, - Online: online != nil && *online, + LastSeen: p.LastSeen().Get(), + Online: p.Online().Get(), ShareeNode: p.Hostinfo().ShareeNode(), ExitNode: p.StableID() != "" && p.StableID() == exitNodeID, SSH_HostKeys: p.Hostinfo().SSH_HostKeys().AsSlice(), - Location: p.Hostinfo().Location(), + Location: p.Hostinfo().Location().AsStruct(), Capabilities: p.Capabilities().AsSlice(), } + for _, f := range b.extHost.Hooks().SetPeerStatus { + f(ps, p, cn) + } if cm := p.CapMap(); cm.Len() > 0 { ps.CapMap = make(tailcfg.NodeCapMap, cm.Len()) - cm.Range(func(k tailcfg.NodeCapability, v views.Slice[tailcfg.RawMessage]) bool { + for k, v := range cm.All() { ps.CapMap[k] = v.AsSlice() - return true - }) + } } peerStatusFromNode(ps, p) @@ -1192,20 +1346,25 @@ func peerStatusFromNode(ps *ipnstate.PeerStatus, n tailcfg.NodeView) { } } +func profileFromView(v tailcfg.UserProfileView) tailcfg.UserProfile { + if v.Valid() { + return tailcfg.UserProfile{ + ID: v.ID(), + LoginName: v.LoginName(), + DisplayName: v.DisplayName(), + ProfilePicURL: v.ProfilePicURL(), + } + } + return tailcfg.UserProfile{} +} + // WhoIsNodeKey returns the peer info of given public key, if it exists. func (b *LocalBackend) WhoIsNodeKey(k key.NodePublic) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { - b.mu.Lock() - defer b.mu.Unlock() - // TODO(bradfitz): add nodeByKey like nodeByAddr instead of walking peers. - if b.netMap == nil { - return n, u, false - } - if self := b.netMap.SelfNode; self.Valid() && self.Key() == k { - return self, b.netMap.UserProfiles[self.User()], true - } - for _, n := range b.peers { - if n.Key() == k { - u, ok = b.netMap.UserProfiles[n.User()] + cn := b.currentNode() + if nid, ok := cn.NodeByKey(k); ok { + if n, ok := cn.PeerByID(nid); ok { + up, ok := cn.NetMap().UserProfiles[n.User()] + u = profileFromView(up) return n, u, ok } } @@ -1237,7 +1396,8 @@ func (b *LocalBackend) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeVi return zero, u, false } - nid, ok := b.nodeByAddr[ipp.Addr()] + cn := b.currentNode() + nid, ok := cn.NodeByAddr(ipp.Addr()) if !ok { var ip netip.Addr if ipp.Port() != 0 { @@ -1259,57 +1419,42 @@ func (b *LocalBackend) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeVi if !ok { return failf("no IP found in ProxyMapper for %v", ipp) } - nid, ok = b.nodeByAddr[ip] + nid, ok = cn.NodeByAddr(ip) if !ok { return failf("no node for proxymapped IP %v", ip) } } - if b.netMap == nil { + nm := cn.NetMap() + if nm == nil { return failf("no netmap") } - n, ok = b.peers[nid] + n, ok = cn.PeerByID(nid) if !ok { // Check if this the self-node, which would not appear in peers. - if !b.netMap.SelfNode.Valid() || nid != b.netMap.SelfNode.ID() { + if !nm.SelfNode.Valid() || nid != nm.SelfNode.ID() { return zero, u, false } - n = b.netMap.SelfNode + n = nm.SelfNode } - u, ok = b.netMap.UserProfiles[n.User()] + up, ok := cn.UserByID(n.User()) if !ok { return failf("no userprofile for node %v", n.Key()) } - return n, u, true + return n, profileFromView(up), true } // PeerCaps returns the capabilities that remote src IP has to // ths current node. func (b *LocalBackend) PeerCaps(src netip.Addr) tailcfg.PeerCapMap { - b.mu.Lock() - defer b.mu.Unlock() - return b.peerCapsLocked(src) + return b.currentNode().PeerCaps(src) } -func (b *LocalBackend) peerCapsLocked(src netip.Addr) tailcfg.PeerCapMap { - if b.netMap == nil { - return nil - } - filt := b.filterAtomic.Load() - if filt == nil { - return nil +func (b *LocalBackend) GetFilterForTest() *filter.Filter { + if !testenv.InTest() { + panic("GetFilterForTest called outside of test") } - addrs := b.netMap.GetAddresses() - for i := range addrs.Len() { - a := addrs.At(i) - if !a.IsSingleIP() { - continue - } - dst := a.Addr() - if dst.BitLen() == src.BitLen() { // match on family - return filt.CapsWithValues(src, dst) - } - } - return nil + nb := b.currentNode() + return nb.filterAtomic.Load() } // SetControlClientStatus is the callback invoked by the control client whenever it posts a new status. @@ -1412,8 +1557,9 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control b.mu.Lock() prefsChanged := false + cn := b.currentNode() prefs := b.pm.CurrentPrefs().AsStruct() - oldNetMap := b.netMap + oldNetMap := cn.NetMap() curNetMap := st.NetMap if curNetMap == nil { // The status didn't include a netmap update, so the old one is still @@ -1459,10 +1605,10 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control b.logf("SetControlClientStatus failed to select auto exit node: %v", err) } } - if setExitNodeID(prefs, curNetMap, b.lastSuggestedExitNode) { + if applySysPolicy(prefs, b.lastSuggestedExitNode, b.overrideAlwaysOn) { prefsChanged = true } - if applySysPolicy(prefs) { + if setExitNodeID(prefs, curNetMap) { prefsChanged = true } @@ -1482,6 +1628,7 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control b.logf("Failed to save new controlclient state: %v", err) } } + // initTKALocked is dependent on CurrentProfile.ID, which is initialized // (for new profiles) on the first call to b.pm.SetPrefs. if err := b.initTKALocked(); err != nil { @@ -1517,7 +1664,7 @@ func (b *LocalBackend) SetControlClientStatus(c controlclient.Client, st control b.tkaFilterNetmapLocked(st.NetMap) } b.setNetMapLocked(st.NetMap) - b.updateFilterLocked(st.NetMap, prefs.View()) + b.updateFilterLocked(prefs.View()) } b.mu.Unlock() @@ -1628,12 +1775,73 @@ var preferencePolicies = []preferencePolicyInfo{ // applySysPolicy overwrites configured preferences with policies that may be // configured by the system administrator in an OS-specific way. -func applySysPolicy(prefs *ipn.Prefs) (anyChange bool) { +func applySysPolicy(prefs *ipn.Prefs, lastSuggestedExitNode tailcfg.StableNodeID, overrideAlwaysOn bool) (anyChange bool) { if controlURL, err := syspolicy.GetString(syspolicy.ControlURL, prefs.ControlURL); err == nil && prefs.ControlURL != controlURL { prefs.ControlURL = controlURL anyChange = true } + const sentinel = "HostnameDefaultValue" + hostnameFromPolicy, _ := syspolicy.GetString(syspolicy.Hostname, sentinel) + switch hostnameFromPolicy { + case sentinel: + // An empty string for this policy value means that the admin wants to delete + // the hostname stored in the ipn.Prefs. To make that work, we need to + // distinguish between an empty string and a policy that was not set. + // We cannot do that with the current implementation of syspolicy.GetString. + // It currently does not return an error if a policy was not configured. + // Instead, it returns the default value provided as the second argument. + // This behavior makes it impossible to distinguish between a policy that + // was not set and a policy that was set to an empty default value. + // Checking for sentinel here is a workaround to distinguish between + // the two cases. If we get it, we do nothing because the policy was not set. + // + // TODO(angott,nickkhyl): clean up this behavior once syspolicy.GetString starts + // properly returning errors. + case "": + // The policy was set to an empty string, which means the admin intends + // to clear the hostname stored in preferences. + prefs.Hostname = "" + anyChange = true + default: + // The policy was set to a non-empty string, which means the admin wants + // to override the hostname stored in preferences. + if prefs.Hostname != hostnameFromPolicy { + prefs.Hostname = hostnameFromPolicy + anyChange = true + } + } + + if exitNodeIDStr, _ := syspolicy.GetString(syspolicy.ExitNodeID, ""); exitNodeIDStr != "" { + exitNodeID := tailcfg.StableNodeID(exitNodeIDStr) + if shouldAutoExitNode() && lastSuggestedExitNode != "" { + exitNodeID = lastSuggestedExitNode + } + // Note: when exitNodeIDStr == "auto" && lastSuggestedExitNode == "", + // then exitNodeID is now "auto" which will never match a peer's node ID. + // When there is no a peer matching the node ID, traffic will blackhole, + // preventing accidental non-exit-node usage when a policy is in effect that requires an exit node. + if prefs.ExitNodeID != exitNodeID || prefs.ExitNodeIP.IsValid() { + anyChange = true + } + prefs.ExitNodeID = exitNodeID + prefs.ExitNodeIP = netip.Addr{} + } else if exitNodeIPStr, _ := syspolicy.GetString(syspolicy.ExitNodeIP, ""); exitNodeIPStr != "" { + exitNodeIP, err := netip.ParseAddr(exitNodeIPStr) + if exitNodeIP.IsValid() && err == nil { + if prefs.ExitNodeID != "" || prefs.ExitNodeIP != exitNodeIP { + anyChange = true + } + prefs.ExitNodeID = "" + prefs.ExitNodeIP = exitNodeIP + } + } + + if alwaysOn, _ := syspolicy.GetBoolean(syspolicy.AlwaysOn, false); alwaysOn && !overrideAlwaysOn && !prefs.WantRunning { + prefs.WantRunning = true + anyChange = true + } + for _, opt := range preferencePolicies { if po, err := syspolicy.GetPreferenceOption(opt.key); err == nil { curVal := opt.get(prefs.View()) @@ -1648,6 +1856,63 @@ func applySysPolicy(prefs *ipn.Prefs) (anyChange bool) { return anyChange } +// registerSysPolicyWatch subscribes to syspolicy change notifications +// and immediately applies the effective syspolicy settings to the current profile. +func (b *LocalBackend) registerSysPolicyWatch() (unregister func(), err error) { + if unregister, err = syspolicy.RegisterChangeCallback(b.sysPolicyChanged); err != nil { + return nil, fmt.Errorf("syspolicy: LocalBacked failed to register policy change callback: %v", err) + } + if prefs, anyChange := b.applySysPolicy(); anyChange { + b.logf("syspolicy: changed initial profile prefs: %v", prefs.Pretty()) + } + b.refreshAllowedSuggestions() + return unregister, nil +} + +// applySysPolicy overwrites the current profile's preferences with policies +// that may be configured by the system administrator in an OS-specific way. +// +// b.mu must not be held. +func (b *LocalBackend) applySysPolicy() (_ ipn.PrefsView, anyChange bool) { + unlock := b.lockAndGetUnlock() + prefs := b.pm.CurrentPrefs().AsStruct() + if !applySysPolicy(prefs, b.lastSuggestedExitNode, b.overrideAlwaysOn) { + unlock.UnlockEarly() + return prefs.View(), false + } + return b.setPrefsLockedOnEntry(prefs, unlock), true +} + +// sysPolicyChanged is a callback triggered by syspolicy when it detects +// a change in one or more syspolicy settings. +func (b *LocalBackend) sysPolicyChanged(policy *rsop.PolicyChange) { + if policy.HasChanged(syspolicy.AlwaysOn) || policy.HasChanged(syspolicy.AlwaysOnOverrideWithReason) { + // If the AlwaysOn or the AlwaysOnOverrideWithReason policy has changed, + // we should reset the overrideAlwaysOn flag, as the override might + // no longer be valid. + b.mu.Lock() + b.overrideAlwaysOn = false + b.mu.Unlock() + } + + if policy.HasChanged(syspolicy.AllowedSuggestedExitNodes) { + b.refreshAllowedSuggestions() + // Re-evaluate exit node suggestion now that the policy setting has changed. + b.mu.Lock() + _, err := b.suggestExitNodeLocked(nil) + b.mu.Unlock() + if err != nil && !errors.Is(err, ErrNoPreferredDERP) { + b.logf("failed to select auto exit node: %v", err) + } + // If [syspolicy.ExitNodeID] is set to `auto:any`, the suggested exit node ID + // will be used when [applySysPolicy] updates the current profile's prefs. + } + + if prefs, anyChange := b.applySysPolicy(); anyChange { + b.logf("syspolicy: changed profile prefs: %v", prefs.Pretty()) + } +} + var _ controlclient.NetmapDeltaUpdater = (*LocalBackend)(nil) // UpdateNetmapDelta implements controlclient.NetmapDeltaUpdater. @@ -1662,27 +1927,33 @@ func (b *LocalBackend) UpdateNetmapDelta(muts []netmap.NodeMutation) (handled bo b.send(*notify) } }() - unlock := b.lockAndGetUnlock() - defer unlock() - if !b.updateNetmapDeltaLocked(muts) { - return false - } + b.mu.Lock() + defer b.mu.Unlock() - if b.netMap != nil && mutationsAreWorthyOfTellingIPNBus(muts) { - nm := ptr.To(*b.netMap) // shallow clone - nm.Peers = make([]tailcfg.NodeView, 0, len(b.peers)) - shouldAutoExitNode := shouldAutoExitNode() - for _, p := range b.peers { - nm.Peers = append(nm.Peers, p) - // If the auto exit node currently set goes offline, find another auto exit node. - if shouldAutoExitNode && b.pm.prefs.ExitNodeID() == p.StableID() && p.Online() != nil && !*p.Online() { - b.setAutoExitNodeIDLockedOnEntry(unlock) - return false + cn := b.currentNode() + cn.UpdateNetmapDelta(muts) + + // If auto exit nodes are enabled and our exit node went offline, + // we need to schedule picking a new one. + // TODO(nickkhyl): move the auto exit node logic to a feature package. + if shouldAutoExitNode() { + exitNodeID := b.pm.prefs.ExitNodeID() + for _, m := range muts { + mo, ok := m.(netmap.NodeMutationOnline) + if !ok || mo.Online { + continue } + n, ok := cn.PeerByID(m.NodeIDBeingMutated()) + if !ok || n.StableID() != exitNodeID { + continue + } + b.goTracker.Go(b.pickNewAutoExitNode) + break } - slices.SortFunc(nm.Peers, func(a, b tailcfg.NodeView) int { - return cmp.Compare(a.ID(), b.ID()) - }) + } + + if cn.NetMap() != nil && mutationsAreWorthyOfTellingIPNBus(muts) { + nm := cn.netMapWithPeers() notify = &ipn.Notify{NetMap: nm} } else if testenv.InTest() { // In tests, send an empty Notify as a wake-up so end-to-end @@ -1709,61 +1980,23 @@ func mutationsAreWorthyOfTellingIPNBus(muts []netmap.NodeMutation) bool { return false } -func (b *LocalBackend) updateNetmapDeltaLocked(muts []netmap.NodeMutation) (handled bool) { - if b.netMap == nil || len(b.peers) == 0 { - return false - } - - // Locally cloned mutable nodes, to avoid calling AsStruct (clone) - // multiple times on a node if it's mutated multiple times in this - // call (e.g. its endpoints + online status both change) - var mutableNodes map[tailcfg.NodeID]*tailcfg.Node +// pickNewAutoExitNode picks a new automatic exit node if needed. +func (b *LocalBackend) pickNewAutoExitNode() { + unlock := b.lockAndGetUnlock() + defer unlock() - for _, m := range muts { - n, ok := mutableNodes[m.NodeIDBeingMutated()] - if !ok { - nv, ok := b.peers[m.NodeIDBeingMutated()] - if !ok { - // TODO(bradfitz): unexpected metric? - return false - } - n = nv.AsStruct() - mak.Set(&mutableNodes, nv.ID(), n) - } - m.Apply(n) - } - for nid, n := range mutableNodes { - b.peers[nid] = n.View() + newPrefs := b.setAutoExitNodeIDLockedOnEntry(unlock) + if !newPrefs.Valid() { + // Unchanged. + return } - return true + + b.send(ipn.Notify{Prefs: &newPrefs}) } // setExitNodeID updates prefs to reference an exit node by ID, rather // than by IP. It returns whether prefs was mutated. -func setExitNodeID(prefs *ipn.Prefs, nm *netmap.NetworkMap, lastSuggestedExitNode tailcfg.StableNodeID) (prefsChanged bool) { - if exitNodeIDStr, _ := syspolicy.GetString(syspolicy.ExitNodeID, ""); exitNodeIDStr != "" { - exitNodeID := tailcfg.StableNodeID(exitNodeIDStr) - if shouldAutoExitNode() && lastSuggestedExitNode != "" { - exitNodeID = lastSuggestedExitNode - } - // Note: when exitNodeIDStr == "auto" && lastSuggestedExitNode == "", then exitNodeID is now "auto" which will never match a peer's node ID. - // When there is no a peer matching the node ID, traffic will blackhole, preventing accidental non-exit-node usage when a policy is in effect that requires an exit node. - changed := prefs.ExitNodeID != exitNodeID || prefs.ExitNodeIP.IsValid() - prefs.ExitNodeID = exitNodeID - prefs.ExitNodeIP = netip.Addr{} - return changed - } - - oldExitNodeID := prefs.ExitNodeID - if exitNodeIPStr, _ := syspolicy.GetString(syspolicy.ExitNodeIP, ""); exitNodeIPStr != "" { - exitNodeIP, err := netip.ParseAddr(exitNodeIPStr) - if exitNodeIP.IsValid() && err == nil { - prefsChanged = prefs.ExitNodeID != "" || prefs.ExitNodeIP != exitNodeIP - prefs.ExitNodeID = "" - prefs.ExitNodeIP = exitNodeIP - } - } - +func setExitNodeID(prefs *ipn.Prefs, nm *netmap.NetworkMap) (prefsChanged bool) { if nm == nil { // No netmap, can't resolve anything. return false @@ -1781,9 +2014,9 @@ func setExitNodeID(prefs *ipn.Prefs, nm *netmap.NetworkMap, lastSuggestedExitNod prefsChanged = true } + oldExitNodeID := prefs.ExitNodeID for _, peer := range nm.Peers { - for i := range peer.Addresses().Len() { - addr := peer.Addresses().At(i) + for _, addr := range peer.Addresses().All() { if !addr.IsSingleIP() || addr.Addr() != prefs.ExitNodeIP { continue } @@ -1791,7 +2024,7 @@ func setExitNodeID(prefs *ipn.Prefs, nm *netmap.NetworkMap, lastSuggestedExitNod // reference it directly for next time. prefs.ExitNodeID = peer.StableID() prefs.ExitNodeIP = netip.Addr{} - return oldExitNodeID != prefs.ExitNodeID + return prefsChanged || oldExitNodeID != prefs.ExitNodeID } } @@ -1903,15 +2136,6 @@ func (b *LocalBackend) SetControlClientGetterForTesting(newControlClient func(co b.ccGen = newControlClient } -// NodeViewByIDForTest returns the state of the node with the given ID -// for integration tests in another repo. -func (b *LocalBackend) NodeViewByIDForTest(id tailcfg.NodeID) (_ tailcfg.NodeView, ok bool) { - b.mu.Lock() - defer b.mu.Unlock() - n, ok := b.peers[id] - return n, ok -} - // DisablePortMapperForTest disables the portmapper for tests. // It must be called before Start. func (b *LocalBackend) DisablePortMapperForTest() { @@ -1923,13 +2147,7 @@ func (b *LocalBackend) DisablePortMapperForTest() { // PeersForTest returns all the current peers, sorted by Node.ID, // for integration tests in another repo. func (b *LocalBackend) PeersForTest() []tailcfg.NodeView { - b.mu.Lock() - defer b.mu.Unlock() - ret := xmaps.Values(b.peers) - slices.SortFunc(ret, func(a, b tailcfg.NodeView) int { - return cmp.Compare(a.ID(), b.ID()) - }) - return ret + return b.currentNode().PeersForTest() } func (b *LocalBackend) getNewControlClientFuncLocked() clientGen { @@ -1944,6 +2162,11 @@ func (b *LocalBackend) getNewControlClientFuncLocked() clientGen { return b.ccGen } +// initOnce is called on the first call to [LocalBackend.Start]. +func (b *LocalBackend) initOnce() { + b.extHost.Init() +} + // Start applies the configuration specified in opts, and starts the // state machine. // @@ -1957,6 +2180,8 @@ func (b *LocalBackend) getNewControlClientFuncLocked() clientGen { func (b *LocalBackend) Start(opts ipn.Options) error { b.logf("Start") + b.startOnce.Do(b.initOnce) + var clientToShutdown controlclient.Client defer func() { if clientToShutdown != nil { @@ -2014,25 +2239,31 @@ func (b *LocalBackend) Start(opts ipn.Options) error { hostinfo.Services = b.hostinfo.Services // keep any previous services } b.hostinfo = hostinfo - b.state = ipn.NoState + b.setStateLocked(ipn.NoState) + cn := b.currentNode() if opts.UpdatePrefs != nil { oldPrefs := b.pm.CurrentPrefs() newPrefs := opts.UpdatePrefs.Clone() newPrefs.Persist = oldPrefs.Persist().AsStruct() pv := newPrefs.View() - if err := b.pm.SetPrefs(pv, ipn.NetworkProfile{ - MagicDNSName: b.netMap.MagicDNSSuffix(), - DomainName: b.netMap.DomainName(), - }); err != nil { + if err := b.pm.SetPrefs(pv, cn.NetworkProfile()); err != nil { b.logf("failed to save UpdatePrefs state: %v", err) } - b.setAtomicValuesFromPrefsLocked(pv) - } else { - b.setAtomicValuesFromPrefsLocked(b.pm.CurrentPrefs()) } + // Reset the always-on override whenever Start is called. + b.resetAlwaysOnOverrideLocked() + // And also apply syspolicy settings to the current profile. + // This is important in two cases: when opts.UpdatePrefs is not nil, + // and when Always Mode is enabled and we need to set WantRunning to true. + if newp := b.pm.CurrentPrefs().AsStruct(); applySysPolicy(newp, b.lastSuggestedExitNode, b.overrideAlwaysOn) { + setExitNodeID(newp, cn.NetMap()) + b.pm.setPrefsNoPermCheck(newp.View()) + } prefs := b.pm.CurrentPrefs() + b.setAtomicValuesFromPrefsLocked(prefs) + wantRunning := prefs.WantRunning() if wantRunning { if err := b.initMachineKeyLocked(); err != nil { @@ -2048,16 +2279,14 @@ func (b *LocalBackend) Start(opts ipn.Options) error { } b.applyPrefsToHostinfoLocked(hostinfo, prefs) - b.setNetMapLocked(nil) persistv := prefs.Persist().AsStruct() if persistv == nil { persistv = new(persist.Persist) } - b.updateFilterLocked(nil, ipn.PrefsView{}) if b.portpoll != nil { b.portpollOnce.Do(func() { - go b.readPoller() + b.goTracker.Go(b.readPoller) }) } @@ -2071,6 +2300,12 @@ func (b *LocalBackend) Start(opts ipn.Options) error { debugFlags = append([]string{"netstack"}, debugFlags...) } + var ccShutdownCbs []func() + ccShutdown := func() { + for _, cb := range ccShutdownCbs { + cb() + } + } // TODO(apenwarr): The only way to change the ServerURL is to // re-run b.Start, because this is the only place we create a // new controlclient. EditPrefs allows you to overwrite ServerURL, @@ -2096,6 +2331,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { C2NHandler: http.HandlerFunc(b.handleC2N), DialPlan: &b.dialPlan, // pointer because it can't be copied ControlKnobs: b.sys.ControlKnobs(), + Shutdown: ccShutdown, // Don't warn about broken Linux IP forwarding when // netstack is being used. @@ -2104,6 +2340,7 @@ func (b *LocalBackend) Start(opts ipn.Options) error { if err != nil { return err } + ccShutdownCbs = b.extHost.NotifyNewControlClient(cc, b.pm.CurrentProfile()) b.setControlClientLocked(cc) endpoints := b.endpoints @@ -2128,10 +2365,17 @@ func (b *LocalBackend) Start(opts ipn.Options) error { blid := b.backendLogID.String() b.logf("Backend: logs: be:%v fe:%v", blid, opts.FrontendLogID) - b.sendLocked(ipn.Notify{ - BackendLogID: &blid, - Prefs: &prefs, - }) + b.sendToLocked(ipn.Notify{Prefs: &prefs}, allClients) + + // initialize Taildrive shares from saved state + if fs, ok := b.sys.DriveForRemote.GetOK(); ok { + currentShares := b.pm.CurrentPrefs().DriveShares() + var shares []*drive.Share + for _, share := range currentShares.All() { + shares = append(shares, share.AsStruct()) + } + fs.SetShares(shares) + } if !loggedOut && (b.hasNodeKeyLocked() || confWantRunning) { // If we know that we're either logged in or meant to be @@ -2148,6 +2392,29 @@ func (b *LocalBackend) Start(opts ipn.Options) error { return nil } +// addServiceIPs adds the IP addresses of any VIP Services sent from the +// coordination server to the list of addresses that we expect to handle. +func addServiceIPs(localNetsB *netipx.IPSetBuilder, selfNode tailcfg.NodeView) error { + if !selfNode.Valid() { + return nil + } + + serviceMap, err := tailcfg.UnmarshalNodeCapViewJSON[tailcfg.ServiceIPMappings](selfNode.CapMap(), tailcfg.NodeAttrServiceHost) + if err != nil { + return err + } + + for _, sm := range serviceMap { // typically there will be exactly one of these + for _, serviceAddrs := range sm { + for _, addr := range serviceAddrs { // typically there will be exactly two of these + localNetsB.Add(addr) + } + } + } + + return nil +} + // invalidPacketFilterWarnable is a Warnable to warn the user that the control server sent an invalid packet filter. var invalidPacketFilterWarnable = health.Register(&health.Warnable{ Code: "invalid-packet-filter", @@ -2160,13 +2427,24 @@ var invalidPacketFilterWarnable = health.Register(&health.Warnable{ // given netMap and user preferences. // // b.mu must be held. -func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.PrefsView) { +func (b *LocalBackend) updateFilterLocked(prefs ipn.PrefsView) { + // TODO(nickkhyl) split this into two functions: + // - (*nodeBackend).RebuildFilters() (normalFilter, jailedFilter *filter.Filter, changed bool), + // which would return packet filters for the current state and whether they changed since the last call. + // - (*LocalBackend).updateFilters(), which would use the above to update the engine with the new filters, + // notify b.sshServer, etc. + // + // For this, we would need to plumb a few more things into the [nodeBackend]. Most importantly, + // the current [ipn.PrefsView]), but also maybe also a b.logf and a b.health? + // // NOTE(danderson): keep change detection as the first thing in // this function. Don't try to optimize by returning early, more // likely than not you'll just end up breaking the change // detection and end up with the wrong filter installed. This is // quite hard to debug, so save yourself the trouble. var ( + cn = b.currentNode() + netMap = cn.NetMap() haveNetmap = netMap != nil addrs views.Slice[netip.Prefix] packetFilter []filter.Match @@ -2185,12 +2463,16 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P } packetFilter = netMap.PacketFilter - if packetFilterPermitsUnlockedNodes(b.peers, packetFilter) { + if cn.unlockedNodesPermitted(packetFilter) { b.health.SetUnhealthy(invalidPacketFilterWarnable, nil) packetFilter = nil } else { b.health.SetHealthy(invalidPacketFilterWarnable) } + + if err := addServiceIPs(&localNetsB, netMap.SelfNode); err != nil { + b.logf("addServiceIPs: %v", err) + } } if prefs.Valid() { for _, r := range prefs.AdvertiseRoutes().All() { @@ -2274,7 +2556,7 @@ func (b *LocalBackend) updateFilterLocked(netMap *netmap.NetworkMap, prefs ipn.P b.e.SetJailedFilter(filter.NewShieldsUpFilter(localNets, logNets, oldJailedFilter, b.logf)) if b.sshServer != nil { - go b.sshServer.OnPolicyChange() + b.goTracker.Go(b.sshServer.OnPolicyChange) } } @@ -2362,11 +2644,9 @@ func (b *LocalBackend) performCaptiveDetection() { } d := captivedetection.NewDetector(b.logf) - var dm *tailcfg.DERPMap - b.mu.Lock() - if b.netMap != nil { - dm = b.netMap.DERPMap - } + b.mu.Lock() // for b.hostinfo + cn := b.currentNode() + dm := cn.DERPMap() preferredDERP := 0 if b.hostinfo != nil { if b.hostinfo.NetInfo != nil { @@ -2378,6 +2658,9 @@ func (b *LocalBackend) performCaptiveDetection() { b.mu.Unlock() found := d.Detect(ctx, netMon, dm, preferredDERP) if found { + if !b.health.IsUnhealthy(captivePortalWarnable) { + metricCaptivePortalDetected.Add(1) + } b.health.SetUnhealthy(captivePortalWarnable, health.Args{}) } else { b.health.SetHealthy(captivePortalWarnable) @@ -2433,8 +2716,10 @@ func packetFilterPermitsUnlockedNodes(peers map[tailcfg.NodeID]tailcfg.NodeView, return false } +// TODO(nickkhyl): this should be non-existent with a proper [LocalBackend.updateFilterLocked]. +// See the comment in that function for more details. func (b *LocalBackend) setFilter(f *filter.Filter) { - b.filterAtomic.Store(f) + b.currentNode().setFilter(f) b.e.SetFilter(f) } @@ -2656,10 +2941,15 @@ func applyConfigToHostinfo(hi *tailcfg.Hostinfo, c *conffile.Config) { // notifications. There is currently (2022-11-22) no mechanism provided to // detect when a message has been dropped. func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWatchOpt, onWatchAdded func(), fn func(roNotify *ipn.Notify) (keepGoing bool)) { - ch := make(chan *ipn.Notify, 128) + b.WatchNotificationsAs(ctx, nil, mask, onWatchAdded, fn) +} +// WatchNotificationsAs is like WatchNotifications but takes an [ipnauth.Actor] +// as an additional parameter. If non-nil, the specified callback is invoked +// only for notifications relevant to this actor. +func (b *LocalBackend) WatchNotificationsAs(ctx context.Context, actor ipnauth.Actor, mask ipn.NotifyWatchOpt, onWatchAdded func(), fn func(roNotify *ipn.Notify) (keepGoing bool)) { + ch := make(chan *ipn.Notify, 128) sessionID := rands.HexString(16) - origFn := fn if mask&ipn.NotifyNoPrivateKeys != 0 { fn = func(n *ipn.Notify) bool { @@ -2684,6 +2974,7 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa const initialBits = ipn.NotifyInitialState | ipn.NotifyInitialPrefs | ipn.NotifyInitialNetMap | ipn.NotifyInitialDriveShares if mask&initialBits != 0 { + cn := b.currentNode() ini = &ipn.Notify{Version: version.Long()} if mask&ipn.NotifyInitialState != 0 { ini.SessionID = sessionID @@ -2696,9 +2987,9 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa ini.Prefs = ptr.To(b.sanitizedPrefsLocked()) } if mask&ipn.NotifyInitialNetMap != 0 { - ini.NetMap = b.netMap + ini.NetMap = cn.NetMap() } - if mask&ipn.NotifyInitialDriveShares != 0 && b.driveSharingEnabledLocked() { + if mask&ipn.NotifyInitialDriveShares != 0 && b.DriveSharingEnabled() { ini.DriveShares = b.pm.prefs.DriveShares() } if mask&ipn.NotifyInitialHealthState != 0 { @@ -2711,12 +3002,16 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa session := &watchSession{ ch: ch, + owner: actor, sessionID: sessionID, cancel: cancel, } mak.Set(&b.notifyWatchers, sessionID, session) b.mu.Unlock() + metricCurrentWatchIPNBus.Add(1) + defer metricCurrentWatchIPNBus.Add(-1) + defer func() { b.mu.Lock() delete(b.notifyWatchers, sessionID) @@ -2745,23 +3040,20 @@ func (b *LocalBackend) WatchNotifications(ctx context.Context, mask ipn.NotifyWa // request every 2 seconds. // TODO(bradfitz): plumb this further and only send a Notify on change. if mask&ipn.NotifyWatchEngineUpdates != 0 { - go b.pollRequestEngineStatus(ctx) + b.goTracker.Go(func() { b.pollRequestEngineStatus(ctx) }) } - // TODO(marwan-at-work): check err // TODO(marwan-at-work): streaming background logs? defer b.DeleteForegroundSession(sessionID) - for { - select { - case <-ctx.Done(): - return - case n := <-ch: - if !fn(n) { - return - } - } + sender := &rateLimitingBusSender{fn: fn} + defer sender.close() + + if mask&ipn.NotifyRateLimit != 0 { + sender.interval = 3 * time.Second } + + sender.Run(ctx, ch) } // pollRequestEngineStatus calls b.e.RequestStatus every 2 seconds until ctx @@ -2791,11 +3083,7 @@ func (b *LocalBackend) DebugNotify(n ipn.Notify) { // // It should only be used via the LocalAPI's debug handler. func (b *LocalBackend) DebugNotifyLastNetMap() { - b.mu.Lock() - nm := b.netMap - b.mu.Unlock() - - if nm != nil { + if nm := b.currentNode().NetMap(); nm != nil { b.send(ipn.Notify{NetMap: nm}) } } @@ -2809,7 +3097,8 @@ func (b *LocalBackend) DebugNotifyLastNetMap() { func (b *LocalBackend) DebugForceNetmapUpdate() { b.mu.Lock() defer b.mu.Unlock() - nm := b.netMap + // TODO(nickkhyl): this all should be done in [LocalBackend.setNetMapLocked]. + nm := b.currentNode().NetMap() b.e.SetNetworkMap(nm) if nm != nil { b.MagicConn().SetDERPMap(nm.DERPMap) @@ -2823,6 +3112,12 @@ func (b *LocalBackend) DebugPickNewDERP() error { return b.sys.MagicSock.Get().DebugPickNewDERP() } +// DebugForcePreferDERP forwards to netcheck.DebugForcePreferDERP. +// See its docs. +func (b *LocalBackend) DebugForcePreferDERP(n int) { + b.sys.MagicSock.Get().DebugForcePreferDERP(n) +} + // send delivers n to the connected frontend and any API watchers from // LocalBackend.WatchNotifications (via the LocalAPI). // @@ -2833,13 +3128,77 @@ func (b *LocalBackend) DebugPickNewDERP() error { // // b.mu must not be held. func (b *LocalBackend) send(n ipn.Notify) { + b.sendTo(n, allClients) +} + +// SendNotify sends a notification to the IPN bus, +// typically to the GUI client. +func (b *LocalBackend) SendNotify(n ipn.Notify) { + b.send(n) +} + +// notificationTarget describes a notification recipient. +// A zero value is valid and indicate that the notification +// should be broadcast to all active [watchSession]s. +type notificationTarget struct { + // userID is the OS-specific UID of the target user. + // If empty, the notification is not user-specific and + // will be broadcast to all connected users. + // TODO(nickkhyl): make this field cross-platform rather + // than Windows-specific. + userID ipn.WindowsUserID + // clientID identifies a client that should be the exclusive recipient + // of the notification. A zero value indicates that notification should + // be sent to all sessions of the specified user. + clientID ipnauth.ClientID +} + +var allClients = notificationTarget{} // broadcast to all connected clients + +// toNotificationTarget returns a [notificationTarget] that matches only actors +// representing the same user as the specified actor. If the actor represents +// a specific connected client, the [ipnauth.ClientID] must also match. +// If the actor is nil, the [notificationTarget] matches all actors. +func toNotificationTarget(actor ipnauth.Actor) notificationTarget { + t := notificationTarget{} + if actor != nil { + t.userID = actor.UserID() + t.clientID, _ = actor.ClientID() + } + return t +} + +// match reports whether the specified actor should receive notifications +// targeting t. If the actor is nil, it should only receive notifications +// intended for all users. +func (t notificationTarget) match(actor ipnauth.Actor) bool { + if t == allClients { + return true + } + if actor == nil { + return false + } + if t.userID != "" && t.userID != actor.UserID() { + return false + } + if t.clientID != ipnauth.NoClientID { + clientID, ok := actor.ClientID() + if !ok || clientID != t.clientID { + return false + } + } + return true +} + +// sendTo is like [LocalBackend.send] but allows specifying a recipient. +func (b *LocalBackend) sendTo(n ipn.Notify, recipient notificationTarget) { b.mu.Lock() defer b.mu.Unlock() - b.sendLocked(n) + b.sendToLocked(n, recipient) } -// sendLocked is like send, but assumes b.mu is already held. -func (b *LocalBackend) sendLocked(n ipn.Notify) { +// sendToLocked is like [LocalBackend.sendTo], but assumes b.mu is already held. +func (b *LocalBackend) sendToLocked(n ipn.Notify, recipient notificationTarget) { if n.Prefs != nil { n.Prefs = ptr.To(stripKeysFromPrefs(*n.Prefs)) } @@ -2847,59 +3206,37 @@ func (b *LocalBackend) sendLocked(n ipn.Notify) { n.Version = version.Long() } - apiSrv := b.peerAPIServer - if mayDeref(apiSrv).taildrop.HasFilesWaiting() { - n.FilesWaiting = &empty.Message{} + for _, f := range b.extHost.Hooks().MutateNotifyLocked { + f(&n) } for _, sess := range b.notifyWatchers { - select { - case sess.ch <- &n: - default: - // Drop the notification if the channel is full. + if recipient.match(sess.owner) { + select { + case sess.ch <- &n: + default: + // Drop the notification if the channel is full. + } } } } -func (b *LocalBackend) sendFileNotify() { - var n ipn.Notify - - b.mu.Lock() - for _, wakeWaiter := range b.fileWaiters { - wakeWaiter() - } - apiSrv := b.peerAPIServer - if apiSrv == nil { - b.mu.Unlock() - return - } - - // Make sure we always set n.IncomingFiles non-nil so it gets encoded - // in JSON to clients. They distinguish between empty and non-nil - // to know whether a Notify should be able about files. - n.IncomingFiles = apiSrv.taildrop.IncomingFiles() - b.mu.Unlock() - - sort.Slice(n.IncomingFiles, func(i, j int) bool { - return n.IncomingFiles[i].Started.Before(n.IncomingFiles[j].Started) - }) - - b.send(n) -} - // setAuthURL sets the authURL and triggers [LocalBackend.popBrowserAuthNow] if the URL has changed. // This method is called when a new authURL is received from the control plane, meaning that either a user // has started a new interactive login (e.g., by running `tailscale login` or clicking Login in the GUI), // or the control plane was unable to authenticate this node non-interactively (e.g., due to key expiration). -// b.interact indicates whether an interactive login is in progress. +// A non-nil b.authActor indicates that an interactive login is in progress and was initiated by the specified actor. // If url is "", it is equivalent to calling [LocalBackend.resetAuthURLLocked] with b.mu held. func (b *LocalBackend) setAuthURL(url string) { var popBrowser, keyExpired bool + var recipient ipnauth.Actor b.mu.Lock() switch { case url == "": b.resetAuthURLLocked() + b.mu.Unlock() + return case b.authURL != url: b.authURL = url b.authURLTime = b.clock.Now() @@ -2908,26 +3245,27 @@ func (b *LocalBackend) setAuthURL(url string) { popBrowser = true default: // Otherwise, only open it if the user explicitly requests interactive login. - popBrowser = b.interact + popBrowser = b.authActor != nil } keyExpired = b.keyExpired + recipient = b.authActor // or nil // Consume the StartLoginInteractive call, if any, that caused the control // plane to send us this URL. - b.interact = false + b.authActor = nil b.mu.Unlock() if popBrowser { - b.popBrowserAuthNow(url, keyExpired) + b.popBrowserAuthNow(url, keyExpired, recipient) } } -// popBrowserAuthNow shuts down the data plane and sends an auth URL -// to the connected frontend, if any. +// popBrowserAuthNow shuts down the data plane and sends the URL to the recipient's +// [watchSession]s if the recipient is non-nil; otherwise, it sends the URL to all watchSessions. // keyExpired is the value of b.keyExpired upon entry and indicates // whether the node's key has expired. // It must not be called with b.mu held. -func (b *LocalBackend) popBrowserAuthNow(url string, keyExpired bool) { - b.logf("popBrowserAuthNow: url=%v, key-expired=%v, seamless-key-renewal=%v", url != "", keyExpired, b.seamlessRenewalEnabled()) +func (b *LocalBackend) popBrowserAuthNow(url string, keyExpired bool, recipient ipnauth.Actor) { + b.logf("popBrowserAuthNow(%q): url=%v, key-expired=%v, seamless-key-renewal=%v", maybeUsernameOf(recipient), url != "", keyExpired, b.seamlessRenewalEnabled()) // Deconfigure the local network data plane if: // - seamless key renewal is not enabled; @@ -2936,7 +3274,7 @@ func (b *LocalBackend) popBrowserAuthNow(url string, keyExpired bool) { b.blockEngineUpdates(true) b.stopEngineAndWait() } - b.tellClientToBrowseToURL(url) + b.tellRecipientToBrowseToURL(url, toNotificationTarget(recipient)) if b.State() == ipn.Running { b.enterState(ipn.Starting) } @@ -2977,8 +3315,13 @@ func (b *LocalBackend) validPopBrowserURL(urlStr string) bool { } func (b *LocalBackend) tellClientToBrowseToURL(url string) { + b.tellRecipientToBrowseToURL(url, allClients) +} + +// tellRecipientToBrowseToURL is like tellClientToBrowseToURL but allows specifying a recipient. +func (b *LocalBackend) tellRecipientToBrowseToURL(url string, recipient notificationTarget) { if b.validPopBrowserURL(url) { - b.send(ipn.Notify{BrowseToURL: &url}) + b.sendTo(ipn.Notify{BrowseToURL: &url}, recipient) } } @@ -3012,18 +3355,20 @@ func (b *LocalBackend) onTailnetDefaultAutoUpdate(au bool) { // can still manually enable auto-updates on this node. return } - b.logf("using tailnet default auto-update setting: %v", au) - prefsClone := prefs.AsStruct() - prefsClone.AutoUpdate.Apply = opt.NewBool(au) - _, err := b.editPrefsLockedOnEntry(&ipn.MaskedPrefs{ - Prefs: *prefsClone, - AutoUpdateSet: ipn.AutoUpdatePrefsMask{ - ApplySet: true, - }, - }, unlock) - if err != nil { - b.logf("failed to apply tailnet-wide default for auto-updates (%v): %v", au, err) - return + if clientupdate.CanAutoUpdate() { + b.logf("using tailnet default auto-update setting: %v", au) + prefsClone := prefs.AsStruct() + prefsClone.AutoUpdate.Apply = opt.NewBool(au) + _, err := b.editPrefsLockedOnEntry(&ipn.MaskedPrefs{ + Prefs: *prefsClone, + AutoUpdateSet: ipn.AutoUpdatePrefsMask{ + ApplySet: true, + }, + }, unlock) + if err != nil { + b.logf("failed to apply tailnet-wide default for auto-updates (%v): %v", au, err) + return + } } } @@ -3063,11 +3408,6 @@ func (b *LocalBackend) initMachineKeyLocked() (err error) { return nil } - var legacyMachineKey key.MachinePrivate - if p := b.pm.CurrentPrefs().Persist(); p.Valid() { - legacyMachineKey = p.LegacyFrontendPrivateMachineKey() - } - keyText, err := b.store.ReadState(ipn.MachineKeyStateKey) if err == nil { if err := b.machinePrivKey.UnmarshalText(keyText); err != nil { @@ -3076,9 +3416,6 @@ func (b *LocalBackend) initMachineKeyLocked() (err error) { if b.machinePrivKey.IsZero() { return fmt.Errorf("invalid zero key stored in %v key of %v", ipn.MachineKeyStateKey, b.store) } - if !legacyMachineKey.IsZero() && !legacyMachineKey.Equal(b.machinePrivKey) { - b.logf("frontend-provided legacy machine key ignored; used value from server state") - } return nil } if err != ipn.ErrStateNotExist { @@ -3088,12 +3425,8 @@ func (b *LocalBackend) initMachineKeyLocked() (err error) { // If we didn't find one already on disk and the prefs already // have a legacy machine key, use that. Otherwise generate a // new one. - if !legacyMachineKey.IsZero() { - b.machinePrivKey = legacyMachineKey - } else { - b.logf("generating new machine key") - b.machinePrivKey = key.NewMachine() - } + b.logf("generating new machine key") + b.machinePrivKey = key.NewMachine() keyText, _ = b.machinePrivKey.MarshalText() if err := ipn.WriteState(b.store, ipn.MachineKeyStateKey, keyText); err != nil { @@ -3119,12 +3452,9 @@ func (b *LocalBackend) clearMachineKeyLocked() error { return nil } -// setTCPPortsIntercepted populates b.shouldInterceptTCPPortAtomic with an -// efficient func for ShouldInterceptTCPPort to use, which is called on every -// incoming packet. -func (b *LocalBackend) setTCPPortsIntercepted(ports []uint16) { +func generateInterceptTCPPortFunc(ports []uint16) func(uint16) bool { slices.Sort(ports) - uniq.ModifySlice(&ports) + ports = slices.Compact(ports) var f func(uint16) bool switch len(ports) { case 0: @@ -3153,7 +3483,63 @@ func (b *LocalBackend) setTCPPortsIntercepted(ports []uint16) { } } } - b.shouldInterceptTCPPortAtomic.Store(f) + return f +} + +// setTCPPortsIntercepted populates b.shouldInterceptTCPPortAtomic with an +// efficient func for ShouldInterceptTCPPort to use, which is called on every +// incoming packet. +func (b *LocalBackend) setTCPPortsIntercepted(ports []uint16) { + b.shouldInterceptTCPPortAtomic.Store(generateInterceptTCPPortFunc(ports)) +} + +func generateInterceptVIPServicesTCPPortFunc(svcAddrPorts map[netip.Addr]func(uint16) bool) func(netip.AddrPort) bool { + return func(ap netip.AddrPort) bool { + if f, ok := svcAddrPorts[ap.Addr()]; ok { + return f(ap.Port()) + } + return false + } +} + +// setVIPServicesTCPPortsIntercepted populates b.shouldInterceptVIPServicesTCPPortAtomic with an +// efficient func for ShouldInterceptTCPPort to use, which is called on every incoming packet. +func (b *LocalBackend) setVIPServicesTCPPortsIntercepted(svcPorts map[tailcfg.ServiceName][]uint16) { + b.mu.Lock() + defer b.mu.Unlock() + b.setVIPServicesTCPPortsInterceptedLocked(svcPorts) +} + +func (b *LocalBackend) setVIPServicesTCPPortsInterceptedLocked(svcPorts map[tailcfg.ServiceName][]uint16) { + if len(svcPorts) == 0 { + b.shouldInterceptVIPServicesTCPPortAtomic.Store(func(netip.AddrPort) bool { return false }) + return + } + nm := b.currentNode().NetMap() + if nm == nil { + b.logf("can't set intercept function for Service TCP Ports, netMap is nil") + return + } + vipServiceIPMap := nm.GetVIPServiceIPMap() + if len(vipServiceIPMap) == 0 { + // No approved VIP Services + return + } + + svcAddrPorts := make(map[netip.Addr]func(uint16) bool) + // Only set the intercept function if the service has been assigned a VIP. + for svcName, ports := range svcPorts { + addrs, ok := vipServiceIPMap[svcName] + if !ok { + continue + } + interceptFn := generateInterceptTCPPortFunc(ports) + for _, addr := range addrs { + svcAddrPorts[addr] = interceptFn + } + } + + b.shouldInterceptVIPServicesTCPPortAtomic.Store(generateInterceptVIPServicesTCPPortFunc(svcAddrPorts)) } // setAtomicValuesFromPrefsLocked populates sshAtomicBool, containsViaIPFuncAtomic, @@ -3166,6 +3552,7 @@ func (b *LocalBackend) setAtomicValuesFromPrefsLocked(p ipn.PrefsView) { if !p.Valid() { b.containsViaIPFuncAtomic.Store(ipset.FalseContainsIPFunc()) b.setTCPPortsIntercepted(nil) + b.setVIPServicesTCPPortsInterceptedLocked(nil) b.lastServeConfJSON = mem.B(nil) b.serveConfig = ipn.ServeConfigView{} } else { @@ -3183,23 +3570,6 @@ func (b *LocalBackend) State() ipn.State { return b.state } -// InServerMode reports whether the Tailscale backend is explicitly running in -// "server mode" where it continues to run despite whatever the platform's -// default is. In practice, this is only used on Windows, where the default -// tailscaled behavior is to shut down whenever the GUI disconnects. -// -// On non-Windows platforms, this usually returns false (because people don't -// set unattended mode on other platforms) and also isn't checked on other -// platforms. -// -// TODO(bradfitz): rename to InWindowsUnattendedMode or something? Or make this -// return true on Linux etc and always be called? It's kinda messy now. -func (b *LocalBackend) InServerMode() bool { - b.mu.Lock() - defer b.mu.Unlock() - return b.pm.CurrentPrefs().ForceDaemon() -} - // CheckIPNConnectionAllowed returns an error if the specified actor should not // be allowed to connect or make requests to the LocalAPI currently. // @@ -3209,16 +3579,10 @@ func (b *LocalBackend) InServerMode() bool { func (b *LocalBackend) CheckIPNConnectionAllowed(actor ipnauth.Actor) error { b.mu.Lock() defer b.mu.Unlock() - serverModeUid := b.pm.CurrentUserID() - if serverModeUid == "" { - // Either this platform isn't a "multi-user" platform or we're not yet - // running as one. + if b.pm.CurrentUserID() == "" { + // There's no "current user" yet; allow the connection. return nil } - if !b.pm.CurrentPrefs().ForceDaemon() { - return nil - } - // Always allow Windows SYSTEM user to connect, // even if Tailscale is currently being used by another user. if actor.IsLocalSystem() { @@ -3229,10 +3593,21 @@ func (b *LocalBackend) CheckIPNConnectionAllowed(actor ipnauth.Actor) error { if uid == "" { return errors.New("empty user uid in connection identity") } - if uid != serverModeUid { - return fmt.Errorf("Tailscale running in server mode (%q); connection from %q not allowed", b.tryLookupUserName(string(serverModeUid)), b.tryLookupUserName(string(uid))) + if uid == b.pm.CurrentUserID() { + // The connection is from the current user; allow it. + return nil + } + + // The connection is from a different user; block it. + var reason string + if b.pm.CurrentPrefs().ForceDaemon() { + reason = "running in server mode" + } else { + reason = "already in use" } - return nil + return fmt.Errorf("Tailscale %s (%q); connection from %q not allowed", + reason, b.tryLookupUserName(string(b.pm.CurrentUserID())), + b.tryLookupUserName(string(uid))) } // tryLookupUserName tries to look up the username for the uid. @@ -3250,6 +3625,15 @@ func (b *LocalBackend) tryLookupUserName(uid string) string { // StartLoginInteractive attempts to pick up the in-progress flow where it left // off. func (b *LocalBackend) StartLoginInteractive(ctx context.Context) error { + return b.StartLoginInteractiveAs(ctx, nil) +} + +// StartLoginInteractiveAs is like StartLoginInteractive but takes an [ipnauth.Actor] +// as an additional parameter. If non-nil, the specified user is expected to complete +// the interactive login, and therefore will receive the BrowseToURL notification once +// the control plane sends us one. Otherwise, the notification will be delivered to all +// active [watchSession]s. +func (b *LocalBackend) StartLoginInteractiveAs(ctx context.Context, user ipnauth.Actor) error { b.mu.Lock() if b.cc == nil { panic("LocalBackend.assertClient: b.cc == nil") @@ -3263,17 +3647,17 @@ func (b *LocalBackend) StartLoginInteractive(ctx context.Context) error { hasValidURL := url != "" && timeSinceAuthURLCreated < ((7*24*time.Hour)-(1*time.Hour)) if !hasValidURL { // A user wants to log in interactively, but we don't have a valid authURL. - // Set a flag to indicate that interactive login is in progress, forcing - // a BrowseToURL notification once the authURL becomes available. - b.interact = true + // Remember the user who initiated the login, so that we can notify them + // once the authURL is available. + b.authActor = user } cc := b.cc b.mu.Unlock() - b.logf("StartLoginInteractive: url=%v", hasValidURL) + b.logf("StartLoginInteractiveAs(%q): url=%v", maybeUsernameOf(user), hasValidURL) if hasValidURL { - b.popBrowserAuthNow(url, keyExpired) + b.popBrowserAuthNow(url, keyExpired, user) } else { cc.Login(b.loginFlags | controlclient.LoginInteractive) } @@ -3387,19 +3771,21 @@ func (b *LocalBackend) parseWgStatusLocked(s *wgengine.Status) (ret ipn.EngineSt // in Hostinfo. When the user preferences currently request "shields up" // mode, all inbound connections are refused, so services are not reported. // Otherwise, shouldUploadServices respects NetMap.CollectServices. +// TODO(nickkhyl): move this into [nodeBackend]? func (b *LocalBackend) shouldUploadServices() bool { b.mu.Lock() defer b.mu.Unlock() p := b.pm.CurrentPrefs() - if !p.Valid() || b.netMap == nil { + nm := b.currentNode().NetMap() + if !p.Valid() || nm == nil { return false // default to safest setting } - return !p.ShieldsUp() && b.netMap.CollectServices + return !p.ShieldsUp() && nm.CollectServices } // SetCurrentUser is used to implement support for multi-user systems (only -// Windows 2022-11-25). On such systems, the uid is used to determine which +// Windows 2022-11-25). On such systems, the actor is used to determine which // user's state should be used. The current user is maintained by active // connections open to the backend. // @@ -3413,29 +3799,152 @@ func (b *LocalBackend) shouldUploadServices() bool { // unattended mode. The user must disable unattended mode before the user can be // changed. // -// On non-multi-user systems, the user should be set to nil. -// -// SetCurrentUser returns the ipn.WindowsUserID associated with the user -// when successful. -func (b *LocalBackend) SetCurrentUser(actor ipnauth.Actor) (ipn.WindowsUserID, error) { - var uid ipn.WindowsUserID - if actor != nil { - uid = actor.UserID() +// On non-multi-user systems, the actor should be set to nil. +func (b *LocalBackend) SetCurrentUser(actor ipnauth.Actor) { + unlock := b.lockAndGetUnlock() + defer unlock() + + var userIdentifier string + if user := cmp.Or(actor, b.currentUser); user != nil { + maybeUsername, _ := user.Username() + userIdentifier = cmp.Or(maybeUsername, string(user.UserID())) } - unlock := b.lockAndGetUnlock() + if actor != b.currentUser { + if c, ok := b.currentUser.(ipnauth.ActorCloser); ok { + c.Close() + } + b.currentUser = actor + } + + var action string + if actor == nil { + action = "disconnected" + } else { + action = "connected" + } + reason := fmt.Sprintf("client %s (%s)", action, userIdentifier) + b.switchToBestProfileLockedOnEntry(reason, unlock) +} + +// SwitchToBestProfile selects the best profile to use, +// as reported by [LocalBackend.resolveBestProfileLocked], and switches +// to it, unless it's already the current profile. The reason indicates +// why the profile is being switched, such as due to a client connecting +// or disconnecting, or a change in the desktop session state, and is used +// for logging. +func (b *LocalBackend) SwitchToBestProfile(reason string) { + b.switchToBestProfileLockedOnEntry(reason, b.lockAndGetUnlock()) +} + +// switchToBestProfileLockedOnEntry is like [LocalBackend.SwitchToBestProfile], +// but b.mu must held on entry. It is released on exit. +func (b *LocalBackend) switchToBestProfileLockedOnEntry(reason string, unlock unlockOnce) { defer unlock() + oldControlURL := b.pm.CurrentPrefs().ControlURLOrDefault() + profile, background := b.resolveBestProfileLocked() + cp, switched, err := b.pm.SwitchToProfile(profile) + switch { + case !switched && cp.ID() == "": + if err != nil { + b.logf("%s: an error occurred; staying on empty profile: %v", reason, err) + } else { + b.logf("%s: staying on empty profile", reason) + } + case !switched: + if err != nil { + b.logf("%s: an error occurred; staying on profile %q (%s): %v", reason, cp.UserProfile().LoginName, cp.ID(), err) + } else { + b.logf("%s: staying on profile %q (%s)", reason, cp.UserProfile().LoginName, cp.ID()) + } + case cp.ID() == "": + b.logf("%s: disconnecting Tailscale", reason) + case background: + b.logf("%s: switching to background profile %q (%s)", reason, cp.UserProfile().LoginName, cp.ID()) + default: + b.logf("%s: switching to profile %q (%s)", reason, cp.UserProfile().LoginName, cp.ID()) + } + if !switched { + return + } + // As an optimization, only reset the dialPlan if the control URL changed. + if newControlURL := b.pm.CurrentPrefs().ControlURLOrDefault(); oldControlURL != newControlURL { + b.resetDialPlan() + } + if err := b.resetForProfileChangeLockedOnEntry(unlock); err != nil { + // TODO(nickkhyl): The actual reset cannot fail. However, + // the TKA initialization or [LocalBackend.Start] can fail. + // These errors are not critical as far as we're concerned. + // But maybe we should post a notification to the API watchers? + b.logf("failed switching profile to %q: %v", profile.ID(), err) + } +} - if b.pm.CurrentUserID() == uid { - return uid, nil +// resolveBestProfileLocked returns the best profile to use based on the current +// state of the backend, such as whether a GUI/CLI client is connected, whether +// the unattended mode is enabled, the current state of the desktop sessions, +// and other factors. +// +// It returns a read-only view of the profile and whether it is considered +// a background profile. A background profile is used when no OS user is actively +// using Tailscale, such as when no GUI/CLI client is connected and Unattended Mode +// is enabled (see also [LocalBackend.getBackgroundProfileLocked]). +// +// An invalid view indicates no profile, meaning Tailscale should disconnect +// and remain idle until a GUI or CLI client connects. +// A valid profile view with an empty [ipn.ProfileID] indicates a new profile that +// has not been persisted yet. +// +// b.mu must be held. +func (b *LocalBackend) resolveBestProfileLocked() (_ ipn.LoginProfileView, isBackground bool) { + // TODO(nickkhyl): delegate all of this to the extensions and remove the distinction + // between "foreground" and "background" profiles as we migrate away from the concept + // of a single "current user" on Windows. See tailscale/corp#18342. + // + // If a GUI/CLI client is connected, use the connected user's profile, which means + // either the current profile if owned by the user, or their default profile. + if b.currentUser != nil { + profile := b.pm.CurrentProfile() + // TODO(nickkhyl): check if the current profile is allowed on the device, + // such as when [syspolicy.Tailnet] policy setting requires a specific Tailnet. + // See tailscale/corp#26249. + if uid := b.currentUser.UserID(); profile.LocalUserID() != uid { + profile = b.pm.DefaultUserProfile(uid) + } + return profile, false } - b.pm.SetCurrentUserID(uid) - if c, ok := b.currentUser.(ipnauth.ActorCloser); ok { - c.Close() + + // Otherwise, if on Windows, use the background profile if one is set. + // This includes staying on the current profile if Unattended Mode is enabled + // or if AlwaysOn mode is enabled and the current user is still signed in. + // If the returned background profileID is "", Tailscale will disconnect + // and remain idle until a GUI or CLI client connects. + if goos := envknob.GOOS(); goos == "windows" { + // If Unattended Mode is enabled for the current profile, keep using it. + if b.pm.CurrentPrefs().ForceDaemon() { + return b.pm.CurrentProfile(), true + } + // Otherwise, use the profile returned by the extension. + profile := b.extHost.DetermineBackgroundProfile(b.pm) + return profile, true } - b.currentUser = actor - b.resetForProfileChangeLockedOnEntry(unlock) - return uid, nil + + // On other platforms, however, Tailscale continues to run in the background + // using the current profile. + // + // TODO(nickkhyl): check if the current profile is allowed on the device, + // such as when [syspolicy.Tailnet] policy setting requires a specific Tailnet. + // See tailscale/corp#26249. + return b.pm.CurrentProfile(), false +} + +// CurrentUserForTest returns the current user and the associated WindowsUserID. +// It is used for testing only, and will be removed along with the rest of the +// "current user" functionality as we progress on the multi-user improvements (tailscale/corp#18342). +func (b *LocalBackend) CurrentUserForTest() (ipn.WindowsUserID, ipnauth.Actor) { + b.mu.Lock() + defer b.mu.Unlock() + return b.pm.CurrentUserID(), b.currentUser } func (b *LocalBackend) CheckPrefs(p *ipn.Prefs) error { @@ -3484,7 +3993,7 @@ func (b *LocalBackend) checkSSHPrefsLocked(p *ipn.Prefs) error { if !p.RunSSH { return nil } - if err := envknob.CanRunTailscaleSSH(); err != nil { + if err := featureknob.CanRunTailscaleSSH(); err != nil { return err } if runtime.GOOS == "linux" { @@ -3493,13 +4002,12 @@ func (b *LocalBackend) checkSSHPrefsLocked(p *ipn.Prefs) error { if envknob.SSHIgnoreTailnetPolicy() || envknob.SSHPolicyFile() != "" { return nil } - if b.netMap != nil { - if !b.netMap.HasCap(tailcfg.CapabilitySSH) { - if b.isDefaultServerLocked() { - return errors.New("Unable to enable local Tailscale SSH server; not enabled on Tailnet. See https://tailscale.com/s/ssh") - } - return errors.New("Unable to enable local Tailscale SSH server; not enabled on Tailnet.") + // Assume that we do have the SSH capability if don't have a netmap yet. + if !b.currentNode().SelfHasCapOr(tailcfg.CapabilitySSH, true) { + if b.isDefaultServerLocked() { + return errors.New("Unable to enable local Tailscale SSH server; not enabled on Tailnet. See https://tailscale.com/s/ssh") } + return errors.New("Unable to enable local Tailscale SSH server; not enabled on Tailnet.") } return nil } @@ -3511,7 +4019,7 @@ func (b *LocalBackend) sshOnButUnusableHealthCheckMessageLocked() (healthMessage if envknob.SSHIgnoreTailnetPolicy() || envknob.SSHPolicyFile() != "" { return "development SSH policy in use" } - nm := b.netMap + nm := b.currentNode().NetMap() if nm == nil { return "" } @@ -3565,7 +4073,16 @@ func updateExitNodeUsageWarning(p ipn.PrefsView, state *netmon.State, healthTrac } func (b *LocalBackend) checkExitNodePrefsLocked(p *ipn.Prefs) error { - if (p.ExitNodeIP.IsValid() || p.ExitNodeID != "") && p.AdvertisesExitNode() { + tryingToUseExitNode := p.ExitNodeIP.IsValid() || p.ExitNodeID != "" + if !tryingToUseExitNode { + return nil + } + + if err := featureknob.CanUseExitNode(); err != nil { + return err + } + + if p.AdvertisesExitNode() { return errors.New("Cannot advertise an exit node and use an exit node at the same time.") } return nil @@ -3638,7 +4155,15 @@ func (b *LocalBackend) MaybeClearAppConnector(mp *ipn.MaskedPrefs) error { return err } +// EditPrefs applies the changes in mp to the current prefs, +// acting as the tailscaled itself rather than a specific user. func (b *LocalBackend) EditPrefs(mp *ipn.MaskedPrefs) (ipn.PrefsView, error) { + return b.EditPrefsAs(mp, ipnauth.Self) +} + +// EditPrefsAs is like EditPrefs, but makes the change as the specified actor. +// It returns an error if the actor is not allowed to make the change. +func (b *LocalBackend) EditPrefsAs(mp *ipn.MaskedPrefs, actor ipnauth.Actor) (ipn.PrefsView, error) { if mp.SetsInternal() { return ipn.PrefsView{}, errors.New("can't set Internal fields") } @@ -3649,11 +4174,94 @@ func (b *LocalBackend) EditPrefs(mp *ipn.MaskedPrefs) (ipn.PrefsView, error) { mp.InternalExitNodePriorSet = true } + // Acquire the lock before checking the profile access to prevent + // TOCTOU issues caused by the current profile changing between the + // check and the actual edit. unlock := b.lockAndGetUnlock() defer unlock() + if mp.WantRunningSet && !mp.WantRunning && b.pm.CurrentPrefs().WantRunning() { + if err := actor.CheckProfileAccess(b.pm.CurrentProfile(), ipnauth.Disconnect, b.extHost.AuditLogger()); err != nil { + b.logf("check profile access failed: %v", err) + return ipn.PrefsView{}, err + } + + // If a user has enough rights to disconnect, such as when [syspolicy.AlwaysOn] + // is disabled, or [syspolicy.AlwaysOnOverrideWithReason] is also set and the user + // provides a reason for disconnecting, then we should not force the "always on" + // mode on them until the policy changes, they switch to a different profile, etc. + b.overrideAlwaysOn = true + + if reconnectAfter, _ := syspolicy.GetDuration(syspolicy.ReconnectAfter, 0); reconnectAfter > 0 { + b.startReconnectTimerLocked(reconnectAfter) + } + } + return b.editPrefsLockedOnEntry(mp, unlock) } +// startReconnectTimerLocked sets a timer to automatically set WantRunning to true +// after the specified duration. +func (b *LocalBackend) startReconnectTimerLocked(d time.Duration) { + if b.reconnectTimer != nil { + // Stop may return false if the timer has already fired, + // and the function has been called in its own goroutine, + // but lost the race to acquire b.mu. In this case, it'll + // end up as a no-op due to a reconnectTimer mismatch + // once it manages to acquire the lock. This is fine, and we + // don't need to check the return value. + b.reconnectTimer.Stop() + } + profileID := b.pm.CurrentProfile().ID() + var reconnectTimer tstime.TimerController + reconnectTimer = b.clock.AfterFunc(d, func() { + unlock := b.lockAndGetUnlock() + defer unlock() + + if b.reconnectTimer != reconnectTimer { + // We're either not the most recent timer, or we lost the race when + // the timer was stopped. No need to reconnect. + return + } + b.reconnectTimer = nil + + cp := b.pm.CurrentProfile() + if cp.ID() != profileID { + // The timer fired before the profile changed but we lost the race + // and acquired the lock shortly after. + // No need to reconnect. + return + } + + mp := &ipn.MaskedPrefs{WantRunningSet: true, Prefs: ipn.Prefs{WantRunning: true}} + if _, err := b.editPrefsLockedOnEntry(mp, unlock); err != nil { + b.logf("failed to automatically reconnect as %q after %v: %v", cp.Name(), d, err) + } else { + b.logf("automatically reconnected as %q after %v", cp.Name(), d) + } + }) + b.reconnectTimer = reconnectTimer + b.logf("reconnect for %q has been scheduled and will be performed in %v", b.pm.CurrentProfile().Name(), d) +} + +func (b *LocalBackend) resetAlwaysOnOverrideLocked() { + b.overrideAlwaysOn = false + b.stopReconnectTimerLocked() +} + +func (b *LocalBackend) stopReconnectTimerLocked() { + if b.reconnectTimer != nil { + // Stop may return false if the timer has already fired, + // and the function has been called in its own goroutine, + // but lost the race to acquire b.mu. + // In this case, it'll end up as a no-op due to a reconnectTimer + // mismatch (see [LocalBackend.startReconnectTimerLocked]) + // once it manages to acquire the lock. This is fine, and we + // don't need to check the return value. + b.reconnectTimer.Stop() + b.reconnectTimer = nil + } +} + // Warning: b.mu must be held on entry, but it unlocks it on the way out. // TODO(bradfitz): redo the locking on all these weird methods like this. func (b *LocalBackend) editPrefsLockedOnEntry(mp *ipn.MaskedPrefs, unlock unlockOnce) (ipn.PrefsView, error) { @@ -3662,11 +4270,13 @@ func (b *LocalBackend) editPrefsLockedOnEntry(mp *ipn.MaskedPrefs, unlock unlock if mp.EggSet { mp.EggSet = false b.egg = true - go b.doSetHostinfoFilterServices() + b.goTracker.Go(b.doSetHostinfoFilterServices) } + p0 := b.pm.CurrentPrefs() p1 := b.pm.CurrentPrefs().AsStruct() p1.ApplyEdits(mp) + if err := b.checkPrefsLocked(p1); err != nil { b.logf("EditPrefs check error: %v", err) return ipn.PrefsView{}, err @@ -3678,9 +4288,23 @@ func (b *LocalBackend) editPrefsLockedOnEntry(mp *ipn.MaskedPrefs, unlock unlock if p1.View().Equals(p0) { return stripKeysFromPrefs(p0), nil } + b.logf("EditPrefs: %v", mp.Pretty()) newPrefs := b.setPrefsLockedOnEntry(p1, unlock) + // This is recorded here in the EditPrefs path, not the setPrefs path on purpose. + // recordForEdit records metrics related to edits and changes, not the final state. + // If, in the future, we want to record gauge-metrics related to the state of prefs, + // that should be done in the setPrefs path. + e := prefsMetricsEditEvent{ + change: mp, + pNew: p1.View(), + pOld: p0, + node: b.currentNode(), + lastSuggestedExitNode: b.lastSuggestedExitNode, + } + e.record() + // Note: don't perform any actions for the new prefs here. Not // every prefs change goes through EditPrefs. Put your actions // in setPrefsLocksOnEntry instead. @@ -3699,7 +4323,7 @@ func (b *LocalBackend) checkProfileNameLocked(p *ipn.Prefs) error { // No profile with that name exists. That's fine. return nil } - if id != b.pm.CurrentProfile().ID { + if id != b.pm.CurrentProfile().ID() { // Name is already in use by another profile. return fmt.Errorf("profile name %q already in use", p.ProfileName) } @@ -3720,25 +4344,39 @@ func (b *LocalBackend) wantIngressLocked() bool { return b.serveConfig.Valid() && b.serveConfig.HasAllowFunnel() } +// hasIngressEnabledLocked reports whether the node has any funnel endpoint enabled. This bool is sent to control (in +// Hostinfo.IngressEnabled) to determine whether 'Funnel' badge should be displayed on this node in the admin panel. +func (b *LocalBackend) hasIngressEnabledLocked() bool { + return b.serveConfig.Valid() && b.serveConfig.IsFunnelOn() +} + +// shouldWireInactiveIngressLocked reports whether the node is in a state where funnel is not actively enabled, but it +// seems that it is intended to be used with funnel. +func (b *LocalBackend) shouldWireInactiveIngressLocked() bool { + return b.serveConfig.Valid() && !b.hasIngressEnabledLocked() && b.wantIngressLocked() +} + // setPrefsLockedOnEntry requires b.mu be held to call it, but it // unlocks b.mu when done. newp ownership passes to this function. -// It returns a readonly copy of the new prefs. +// It returns a read-only copy of the new prefs. func (b *LocalBackend) setPrefsLockedOnEntry(newp *ipn.Prefs, unlock unlockOnce) ipn.PrefsView { defer unlock() - netMap := b.netMap + cn := b.currentNode() + netMap := cn.NetMap() b.setAtomicValuesFromPrefsLocked(newp.View()) oldp := b.pm.CurrentPrefs() if oldp.Valid() { newp.Persist = oldp.Persist().AsStruct() // caller isn't allowed to override this } - // setExitNodeID returns whether it updated b.prefs, but - // everything in this function treats b.prefs as completely new - // anyway. No-op if no exit node resolution is needed. - setExitNodeID(newp, netMap, b.lastSuggestedExitNode) - // applySysPolicy does likewise so we can also ignore its return value. - applySysPolicy(newp) + // applySysPolicy returns whether it updated newp, + // but everything in this function treats b.prefs as completely new + // anyway, so its return value can be ignored here. + applySysPolicy(newp, b.lastSuggestedExitNode, b.overrideAlwaysOn) + // setExitNodeID does likewise. No-op if no exit node resolution is needed. + setExitNodeID(newp, netMap) + // We do this to avoid holding the lock while doing everything else. oldHi := b.hostinfo @@ -3751,16 +4389,16 @@ func (b *LocalBackend) setPrefsLockedOnEntry(newp *ipn.Prefs, unlock unlockOnce) hostInfoChanged := !oldHi.Equal(newHi) cc := b.cc - b.updateFilterLocked(netMap, newp.View()) + b.updateFilterLocked(newp.View()) if oldp.ShouldSSHBeRunning() && !newp.ShouldSSHBeRunning() { if b.sshServer != nil { - go b.sshServer.Shutdown() + b.goTracker.Go(b.sshServer.Shutdown) b.sshServer = nil } } if netMap != nil { - newProfile := netMap.UserProfiles[netMap.User()] + newProfile := profileFromView(netMap.UserProfiles[netMap.User()]) if newLoginName := newProfile.LoginName; newLoginName != "" { if !oldp.Persist().Valid() { b.logf("active login: %s", newLoginName) @@ -3775,11 +4413,14 @@ func (b *LocalBackend) setPrefsLockedOnEntry(newp *ipn.Prefs, unlock unlockOnce) } prefs := newp.View() - if err := b.pm.SetPrefs(prefs, ipn.NetworkProfile{ - MagicDNSName: b.netMap.MagicDNSSuffix(), - DomainName: b.netMap.DomainName(), - }); err != nil { + np := cmp.Or(cn.NetworkProfile(), b.pm.CurrentProfile().NetworkProfile()) + if err := b.pm.SetPrefs(prefs, np); err != nil { b.logf("failed to save new controlclient state: %v", err) + } else if prefs.WantRunning() { + // Reset the always-on override if WantRunning is true in the new prefs, + // such as when the user toggles the Connected switch in the GUI + // or runs `tailscale up`. + b.resetAlwaysOnOverrideLocked() } if newp.AutoUpdate.Apply.EqualBool(true) { @@ -3800,7 +4441,7 @@ func (b *LocalBackend) setPrefsLockedOnEntry(newp *ipn.Prefs, unlock unlockOnce) b.MagicConn().SetDERPMap(netMap.DERPMap) } - if !oldp.WantRunning() && newp.WantRunning { + if !oldp.WantRunning() && newp.WantRunning && cc != nil { b.logf("transitioning to running; doing Login...") cc.Login(controlclient.LoginDefault) } @@ -3881,6 +4522,11 @@ func (b *LocalBackend) TCPHandlerForDst(src, dst netip.AddrPort) (handler func(c } } + // TODO(tailscale/corp#26001): Get handler for VIP services and Local IPs using + // the same function. + if handler := b.tcpHandlerForVIPService(dst, src); handler != nil { + return handler, opts + } // Then handle external connections to the local IP. if !b.isLocalIP(dst.Addr()) { return nil, nil @@ -3932,7 +4578,7 @@ func (b *LocalBackend) peerAPIServicesLocked() (ret []tailcfg.Service) { }) } switch runtime.GOOS { - case "linux", "freebsd", "openbsd", "illumos", "darwin", "windows", "android", "ios": + case "linux", "freebsd", "openbsd", "illumos", "solaris", "darwin", "windows", "android", "ios": // These are the platforms currently supported by // net/dns/resolver/tsdns.go:Resolver.HandleExitNodeDNSQuery. ret = append(ret, tailcfg.Service{ @@ -3980,15 +4626,38 @@ func (b *LocalBackend) doSetHostinfoFilterServices() { c := len(hi.Services) hi.Services = append(hi.Services[:c:c], peerAPIServices...) hi.PushDeviceToken = b.pushDeviceToken.Load() + + // Compare the expected ports from peerAPIServices to the actual ports in hi.Services. + expectedPorts := extractPeerAPIPorts(peerAPIServices) + actualPorts := extractPeerAPIPorts(hi.Services) + if expectedPorts != actualPorts { + b.logf("Hostinfo peerAPI ports changed: expected %v, got %v", expectedPorts, actualPorts) + } + cc.SetHostinfo(&hi) } +type portPair struct { + v4, v6 uint16 +} + +func extractPeerAPIPorts(services []tailcfg.Service) portPair { + var p portPair + for _, s := range services { + switch s.Proto { + case "peerapi4": + p.v4 = s.Port + case "peerapi6": + p.v6 = s.Port + } + } + return p +} + // NetMap returns the latest cached network map received from // controlclient, or nil if no network map was received yet. func (b *LocalBackend) NetMap() *netmap.NetworkMap { - b.mu.Lock() - defer b.mu.Unlock() - return b.netMap + return b.currentNode().NetMap() } func (b *LocalBackend) isEngineBlocked() bool { @@ -4045,10 +4714,7 @@ func (b *LocalBackend) reconfigAppConnectorLocked(nm *netmap.NetworkMap, prefs i return } - // TODO(raggi): rework the view infrastructure so the large deep clone is no - // longer required - sn := nm.SelfNode.AsStruct() - attrs, err := tailcfg.UnmarshalNodeCapJSON[appctype.AppConnectorAttr](sn.CapMap, appConnectorCapName) + attrs, err := tailcfg.UnmarshalNodeCapViewJSON[appctype.AppConnectorAttr](nm.SelfNode.CapMap(), appConnectorCapName) if err != nil { b.logf("[unexpected] error parsing app connector mapcap: %v", err) return @@ -4078,6 +4744,41 @@ func (b *LocalBackend) reconfigAppConnectorLocked(nm *netmap.NetworkMap, prefs i b.appConnector.UpdateDomainsAndRoutes(domains, routes) } +func (b *LocalBackend) readvertiseAppConnectorRoutes() { + // Note: we should never call b.appConnector methods while holding b.mu. + // This can lead to a deadlock, like + // https://github.com/tailscale/corp/issues/25965. + // + // Grab a copy of the field, since b.mu only guards access to the + // b.appConnector field itself. + b.mu.Lock() + appConnector := b.appConnector + b.mu.Unlock() + + if appConnector == nil { + return + } + domainRoutes := appConnector.DomainRoutes() + if domainRoutes == nil { + return + } + + // Re-advertise the stored routes, in case stored state got out of + // sync with previously advertised routes in prefs. + var prefixes []netip.Prefix + for _, ips := range domainRoutes { + for _, ip := range ips { + prefixes = append(prefixes, netip.PrefixFrom(ip, ip.BitLen())) + } + } + // Note: AdvertiseRoute will trim routes that are already + // advertised, so if everything is already being advertised this is + // a noop. + if err := b.AdvertiseRoute(prefixes...); err != nil { + b.logf("error advertising stored app connector routes: %v", err) + } +} + // authReconfig pushes a new configuration into wgengine, if engine // updates are not currently blocked, based on the cached netmap and // user prefs. @@ -4085,16 +4786,22 @@ func (b *LocalBackend) authReconfig() { b.mu.Lock() blocked := b.blocked prefs := b.pm.CurrentPrefs() - nm := b.netMap + cn := b.currentNode() + nm := cn.NetMap() hasPAC := b.prevIfState.HasPAC() - disableSubnetsIfPAC := nm.HasCap(tailcfg.NodeAttrDisableSubnetsIfPAC) - userDialUseRoutes := nm.HasCap(tailcfg.NodeAttrUserDialUseRoutes) - dohURL, dohURLOK := exitNodeCanProxyDNS(nm, b.peers, prefs.ExitNodeID()) - dcfg := dnsConfigForNetmap(nm, b.peers, prefs, b.keyExpired, b.logf, version.OS()) + disableSubnetsIfPAC := cn.SelfHasCap(tailcfg.NodeAttrDisableSubnetsIfPAC) + dohURL, dohURLOK := cn.exitNodeCanProxyDNS(prefs.ExitNodeID()) + dcfg := cn.dnsConfigForNetmap(prefs, b.keyExpired, b.logf, version.OS()) // If the current node is an app connector, ensure the app connector machine is started b.reconfigAppConnectorLocked(nm, prefs) + closing := b.shutdownCalled b.mu.Unlock() + if closing { + b.logf("[v1] authReconfig: skipping because in shutdown") + return + } + if blocked { b.logf("[v1] authReconfig: blocked, skipping.") return @@ -4134,7 +4841,7 @@ func (b *LocalBackend) authReconfig() { return } - oneCGNATRoute := shouldUseOneCGNATRoute(b.logf, b.sys.ControlKnobs(), version.OS()) + oneCGNATRoute := shouldUseOneCGNATRoute(b.logf, b.sys.NetMon.Get(), b.sys.ControlKnobs(), version.OS()) rcfg := b.routerConfig(cfg, prefs, oneCGNATRoute) err = b.e.Reconfig(cfg, rcfg, dcfg) @@ -4143,13 +4850,8 @@ func (b *LocalBackend) authReconfig() { } b.logf("[v1] authReconfig: ra=%v dns=%v 0x%02x: %v", prefs.RouteAll(), prefs.CorpDNS(), flags, err) - if userDialUseRoutes { - b.dialer.SetRoutes(rcfg.Routes, rcfg.LocalRoutes) - } else { - b.dialer.SetRoutes(nil, nil) - } - b.initPeerAPIListener() + b.readvertiseAppConnectorRoutes() } // shouldUseOneCGNATRoute reports whether we should prefer to make one big @@ -4157,7 +4859,7 @@ func (b *LocalBackend) authReconfig() { // // The versionOS is a Tailscale-style version ("iOS", "macOS") and not // a runtime.GOOS. -func shouldUseOneCGNATRoute(logf logger.Logf, controlKnobs *controlknobs.Knobs, versionOS string) bool { +func shouldUseOneCGNATRoute(logf logger.Logf, mon *netmon.Monitor, controlKnobs *controlknobs.Knobs, versionOS string) bool { if controlKnobs != nil { // Explicit enabling or disabling always take precedence. if v, ok := controlKnobs.OneCGNAT.Load().Get(); ok { @@ -4166,13 +4868,18 @@ func shouldUseOneCGNATRoute(logf logger.Logf, controlKnobs *controlknobs.Knobs, } } + if versionOS == "plan9" { + // Just temporarily during plan9 bringup to have fewer routes to debug. + return true + } + // Also prefer to do this on the Mac, so that we don't need to constantly // update the network extension configuration (which is disruptive to // Chrome, see https://github.com/tailscale/tailscale/issues/3102). Only // use fine-grained routes if another interfaces is also using the CGNAT // IP range. if versionOS == "macOS" { - hasCGNATInterface, err := netmon.HasCGNATInterface() + hasCGNATInterface, err := mon.HasCGNATInterface() if err != nil { logf("shouldUseOneCGNATRoute: Could not determine if any interfaces use CGNAT: %v", err) return false @@ -4185,193 +4892,6 @@ func shouldUseOneCGNATRoute(logf logger.Logf, controlKnobs *controlknobs.Knobs, return false } -// dnsConfigForNetmap returns a *dns.Config for the given netmap, -// prefs, client OS version, and cloud hosting environment. -// -// The versionOS is a Tailscale-style version ("iOS", "macOS") and not -// a runtime.GOOS. -func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.NodeView, prefs ipn.PrefsView, selfExpired bool, logf logger.Logf, versionOS string) *dns.Config { - if nm == nil { - return nil - } - - // If the current node's key is expired, then we don't program any DNS - // configuration into the operating system. This ensures that if the - // DNS configuration specifies a DNS server that is only reachable over - // Tailscale, we don't break connectivity for the user. - // - // TODO(andrew-d): this also stops returning anything from quad-100; we - // could do the same thing as having "CorpDNS: false" and keep that but - // not program the OS? - if selfExpired { - return &dns.Config{} - } - - dcfg := &dns.Config{ - Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, - Hosts: map[dnsname.FQDN][]netip.Addr{}, - } - - // selfV6Only is whether we only have IPv6 addresses ourselves. - selfV6Only := nm.GetAddresses().ContainsFunc(tsaddr.PrefixIs6) && - !nm.GetAddresses().ContainsFunc(tsaddr.PrefixIs4) - dcfg.OnlyIPv6 = selfV6Only - - // Populate MagicDNS records. We do this unconditionally so that - // quad-100 can always respond to MagicDNS queries, even if the OS - // isn't configured to make MagicDNS resolution truly - // magic. Details in - // https://github.com/tailscale/tailscale/issues/1886. - set := func(name string, addrs views.Slice[netip.Prefix]) { - if addrs.Len() == 0 || name == "" { - return - } - fqdn, err := dnsname.ToFQDN(name) - if err != nil { - return // TODO: propagate error? - } - var have4 bool - for _, addr := range addrs.All() { - if addr.Addr().Is4() { - have4 = true - break - } - } - var ips []netip.Addr - for _, addr := range addrs.All() { - if selfV6Only { - if addr.Addr().Is6() { - ips = append(ips, addr.Addr()) - } - continue - } - // If this node has an IPv4 address, then - // remove peers' IPv6 addresses for now, as we - // don't guarantee that the peer node actually - // can speak IPv6 correctly. - // - // https://github.com/tailscale/tailscale/issues/1152 - // tracks adding the right capability reporting to - // enable AAAA in MagicDNS. - if addr.Addr().Is6() && have4 { - continue - } - ips = append(ips, addr.Addr()) - } - dcfg.Hosts[fqdn] = ips - } - set(nm.Name, nm.GetAddresses()) - for _, peer := range peers { - set(peer.Name(), peer.Addresses()) - } - for _, rec := range nm.DNS.ExtraRecords { - switch rec.Type { - case "", "A", "AAAA": - // Treat these all the same for now: infer from the value - default: - // TODO: more - continue - } - ip, err := netip.ParseAddr(rec.Value) - if err != nil { - // Ignore. - continue - } - fqdn, err := dnsname.ToFQDN(rec.Name) - if err != nil { - continue - } - dcfg.Hosts[fqdn] = append(dcfg.Hosts[fqdn], ip) - } - - if !prefs.CorpDNS() { - return dcfg - } - - for _, dom := range nm.DNS.Domains { - fqdn, err := dnsname.ToFQDN(dom) - if err != nil { - logf("[unexpected] non-FQDN search domain %q", dom) - } - dcfg.SearchDomains = append(dcfg.SearchDomains, fqdn) - } - if nm.DNS.Proxied { // actually means "enable MagicDNS" - for _, dom := range magicDNSRootDomains(nm) { - dcfg.Routes[dom] = nil // resolve internally with dcfg.Hosts - } - } - - addDefault := func(resolvers []*dnstype.Resolver) { - dcfg.DefaultResolvers = append(dcfg.DefaultResolvers, resolvers...) - } - - // If we're using an exit node and that exit node is new enough (1.19.x+) - // to run a DoH DNS proxy, then send all our DNS traffic through it. - if dohURL, ok := exitNodeCanProxyDNS(nm, peers, prefs.ExitNodeID()); ok { - addDefault([]*dnstype.Resolver{{Addr: dohURL}}) - return dcfg - } - - // If the user has set default resolvers ("override local DNS"), prefer to - // use those resolvers as the default, otherwise if there are WireGuard exit - // node resolvers, use those as the default. - if len(nm.DNS.Resolvers) > 0 { - addDefault(nm.DNS.Resolvers) - } else { - if resolvers, ok := wireguardExitNodeDNSResolvers(nm, peers, prefs.ExitNodeID()); ok { - addDefault(resolvers) - } - } - - for suffix, resolvers := range nm.DNS.Routes { - fqdn, err := dnsname.ToFQDN(suffix) - if err != nil { - logf("[unexpected] non-FQDN route suffix %q", suffix) - } - - // Create map entry even if len(resolvers) == 0; Issue 2706. - // This lets the control plane send ExtraRecords for which we - // can authoritatively answer "name not exists" for when the - // control plane also sends this explicit but empty route - // making it as something we handle. - // - // While we're already populating it, might as well size the - // slice appropriately. - // Per #9498 the exact requirements of nil vs empty slice remain - // unclear, this is a haunted graveyard to be resolved. - dcfg.Routes[fqdn] = make([]*dnstype.Resolver, 0, len(resolvers)) - dcfg.Routes[fqdn] = append(dcfg.Routes[fqdn], resolvers...) - } - - // Set FallbackResolvers as the default resolvers in the - // scenarios that can't handle a purely split-DNS config. See - // https://github.com/tailscale/tailscale/issues/1743 for - // details. - switch { - case len(dcfg.DefaultResolvers) != 0: - // Default resolvers already set. - case !prefs.ExitNodeID().IsZero(): - // When using an exit node, we send all DNS traffic to the exit node, so - // we don't need a fallback resolver. - // - // However, if the exit node is too old to run a DoH DNS proxy, then we - // need to use a fallback resolver as it's very likely the LAN resolvers - // will become unreachable. - // - // This is especially important on Apple OSes, where - // adding the default route to the tunnel interface makes - // it "primary", and we MUST provide VPN-sourced DNS - // settings or we break all DNS resolution. - // - // https://github.com/tailscale/tailscale/issues/1713 - addDefault(nm.DNS.FallbackResolvers) - case len(dcfg.Routes) == 0: - // No settings requiring split DNS, no problem. - } - - return dcfg -} - // SetTCPHandlerForFunnelFlow sets the TCP handler for Funnel flows. // It should only be called before the LocalBackend is used. func (b *LocalBackend) SetTCPHandlerForFunnelFlow(h func(src netip.AddrPort, dstPort uint16) (handler func(net.Conn))) { @@ -4425,26 +4945,6 @@ func (b *LocalBackend) TailscaleVarRoot() string { return "" } -func (b *LocalBackend) fileRootLocked(uid tailcfg.UserID) string { - if v := b.directFileRoot; v != "" { - return v - } - varRoot := b.TailscaleVarRoot() - if varRoot == "" { - b.logf("Taildrop disabled; no state directory") - return "" - } - baseDir := fmt.Sprintf("%s-uid-%d", - strings.ReplaceAll(b.activeLogin, "@", "-"), - uid) - dir := filepath.Join(varRoot, "files", baseDir) - if err := os.MkdirAll(dir, 0700); err != nil { - b.logf("Taildrop disabled; error making directory: %v", err) - return "" - } - return dir -} - // closePeerAPIListenersLocked closes any existing PeerAPI listeners // and clears out the PeerAPI server state. // @@ -4467,22 +4967,27 @@ func (b *LocalBackend) closePeerAPIListenersLocked() { const peerAPIListenAsync = runtime.GOOS == "windows" || runtime.GOOS == "android" func (b *LocalBackend) initPeerAPIListener() { + b.logf("[v1] initPeerAPIListener: entered") b.mu.Lock() defer b.mu.Unlock() if b.shutdownCalled { + b.logf("[v1] initPeerAPIListener: shutting down") return } - if b.netMap == nil { + cn := b.currentNode() + nm := cn.NetMap() + if nm == nil { // We're called from authReconfig which checks that // netMap is non-nil, but if a concurrent Logout, // ResetForClientDisconnect, or Start happens when its // mutex was released, the netMap could be // nil'ed out (Issue 1996). Bail out early here if so. + b.logf("[v1] initPeerAPIListener: no netmap") return } - addrs := b.netMap.GetAddresses() + addrs := nm.GetAddresses() if addrs.Len() == len(b.peerAPIListeners) { allSame := true for i, pln := range b.peerAPIListeners { @@ -4493,32 +4998,21 @@ func (b *LocalBackend) initPeerAPIListener() { } if allSame { // Nothing to do. + b.logf("[v1] initPeerAPIListener: %d netmap addresses match existing listeners", addrs.Len()) return } } b.closePeerAPIListenersLocked() - selfNode := b.netMap.SelfNode - if !selfNode.Valid() || b.netMap.GetAddresses().Len() == 0 { + selfNode := nm.SelfNode + if !selfNode.Valid() || nm.GetAddresses().Len() == 0 { + b.logf("[v1] initPeerAPIListener: no addresses in netmap") return } - fileRoot := b.fileRootLocked(selfNode.User()) - if fileRoot == "" { - b.logf("peerapi starting without Taildrop directory configured") - } - ps := &peerAPIServer{ b: b, - taildrop: taildrop.ManagerOptions{ - Logf: b.logf, - Clock: tstime.DefaultClock{Clock: b.clock}, - State: b.store, - Dir: fileRoot, - DirectFileMode: b.directFileRoot != "", - SendFileNotify: b.sendFileNotify, - }.New(), } if dm, ok := b.sys.DNSManager.GetOK(); ok { ps.resolver = dm.Resolver() @@ -4534,6 +5028,7 @@ func (b *LocalBackend) initPeerAPIListener() { ln, err = ps.listen(a.Addr(), b.prevIfState) if err != nil { if peerAPIListenAsync { + b.logf("[v1] possibly transient peerapi listen(%q) error, will try again on linkChange: %v", a.Addr(), err) // Expected. But we fix it later in linkChange // ("peerAPIListeners too low"). continue @@ -4559,7 +5054,7 @@ func (b *LocalBackend) initPeerAPIListener() { b.peerAPIListeners = append(b.peerAPIListeners, pln) } - go b.doSetHostinfoFilterServices() + b.goTracker.Go(b.doSetHostinfoFilterServices) } // magicDNSRootDomains returns the subset of nm.DNS.Domains that are the search domains for MagicDNS. @@ -4772,13 +5267,20 @@ func (b *LocalBackend) applyPrefsToHostinfoLocked(hi *tailcfg.Hostinfo, prefs ip } hi.SSH_HostKeys = sshHostKeys - // The Hostinfo.WantIngress field tells control whether this node wants to - // be wired up for ingress connections. If harmless if it's accidentally - // true; the actual policy is controlled in tailscaled by ServeConfig. But - // if this is accidentally false, then control may not configure DNS - // properly. This exists as an optimization to control to program fewer DNS - // records that have ingress enabled but are not actually being used. - hi.WireIngress = b.wantIngressLocked() + hi.ServicesHash = b.vipServiceHash(b.vipServicesFromPrefsLocked(prefs)) + + // The Hostinfo.IngressEnabled field is used to communicate to control whether + // the node has funnel enabled. + hi.IngressEnabled = b.hasIngressEnabledLocked() + // The Hostinfo.WantIngress field tells control whether the user intends + // to use funnel with this node even though it is not currently enabled. + // This is an optimization to control- Funnel requires creation of DNS + // records and because DNS propagation can take time, we want to ensure + // that the records exist for any node that intends to use funnel even + // if it's not enabled. If hi.IngressEnabled is true, control knows that + // DNS records are needed, so we can save bandwidth and not send + // WireIngress. + hi.WireIngress = b.shouldWireInactiveIngressLocked() hi.AppConnector.Set(prefs.AppConnector().Advertise) } @@ -4797,8 +5299,9 @@ func (b *LocalBackend) enterState(newState ipn.State) { // enterStateLockedOnEntry is like enterState but requires b.mu be held to call // it, but it unlocks b.mu when done (via unlock, a once func). func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlockOnce) { + cn := b.currentNode() oldState := b.state - b.state = newState + b.setStateLocked(newState) prefs := b.pm.CurrentPrefs() // Some temporary (2024-05-05) debugging code to help us catch @@ -4809,7 +5312,7 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock panic("[unexpected] use of main control server in integration test") } - netMap := b.netMap + netMap := cn.NetMap() activeLogin := b.activeLogin authURL := b.authURL if newState == ipn.Running { @@ -4820,7 +5323,7 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock // can be shut down if we transition away from Running. if b.captiveCancel == nil { b.captiveCtx, b.captiveCancel = context.WithCancel(b.ctx) - go b.checkCaptivePortalLoop(b.captiveCtx) + b.goTracker.Go(func() { b.checkCaptivePortalLoop(b.captiveCtx) }) } } else if oldState == ipn.Running { // Transitioning away from running. @@ -4864,13 +5367,15 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock } b.blockEngineUpdates(true) fallthrough - case ipn.Stopped: + case ipn.Stopped, ipn.NoState: + // Unconfigure the engine if it has stopped (WantRunning is set to false) + // or if we've switched to a different profile and the state is unknown. err := b.e.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}) if err != nil { b.logf("Reconfig(down): %v", err) } - if authURL == "" { + if newState == ipn.Stopped && authURL == "" { systemd.Status("Stopped; run 'tailscale up' to log in") } case ipn.Starting, ipn.NeedsMachineAuth: @@ -4880,12 +5385,10 @@ func (b *LocalBackend) enterStateLockedOnEntry(newState ipn.State, unlock unlock case ipn.Running: var addrStrs []string addrs := netMap.GetAddresses() - for i := range addrs.Len() { - addrStrs = append(addrStrs, addrs.At(i).Addr().String()) + for _, p := range addrs.All() { + addrStrs = append(addrStrs, p.Addr().String()) } systemd.Status("Connected; %s; %s", activeLogin, strings.Join(addrStrs, " ")) - case ipn.NoState: - // Do nothing. default: b.logf("[unexpected] unknown newState %#v", newState) } @@ -4914,7 +5417,8 @@ func (b *LocalBackend) NodeKey() key.NodePublic { func (b *LocalBackend) nextStateLocked() ipn.State { var ( cc = b.cc - netMap = b.netMap + cn = b.currentNode() + netMap = cn.NetMap() state = b.state blocked = b.blocked st = b.engineStatus @@ -5072,7 +5576,7 @@ func (b *LocalBackend) requestEngineStatusAndWait() { b.statusLock.Lock() defer b.statusLock.Unlock() - go b.e.RequestStatus() + b.goTracker.Go(b.e.RequestStatus) b.logf("requestEngineStatusAndWait: waiting...") b.statusChanged.Wait() // temporarily releases lock while waiting b.logf("requestEngineStatusAndWait: got status update.") @@ -5119,41 +5623,7 @@ func (b *LocalBackend) resetControlClientLocked() controlclient.Client { func (b *LocalBackend) resetAuthURLLocked() { b.authURL = "" b.authURLTime = time.Time{} - b.interact = false -} - -// ResetForClientDisconnect resets the backend for GUI clients running -// in interactive (non-headless) mode. This is currently used only by -// Windows. This causes all state to be cleared, lest an unrelated user -// connect to tailscaled next. But it does not trigger a logout; we -// don't want to the user to have to reauthenticate in the future -// when they restart the GUI. -func (b *LocalBackend) ResetForClientDisconnect() { - b.logf("LocalBackend.ResetForClientDisconnect") - - unlock := b.lockAndGetUnlock() - defer unlock() - - prevCC := b.resetControlClientLocked() - if prevCC != nil { - // Needs to happen without b.mu held. - defer prevCC.Shutdown() - } - - b.setNetMapLocked(nil) - b.pm.Reset() - if b.currentUser != nil { - if c, ok := b.currentUser.(ipnauth.ActorCloser); ok { - c.Close() - } - b.currentUser = nil - } - b.keyExpired = false - b.resetAuthURLLocked() - b.activeLogin = "" - b.resetDialPlan() - b.setAtomicValuesFromPrefsLocked(ipn.PrefsView{}) - b.enterStateLockedOnEntry(ipn.Stopped, unlock) + b.authActor = nil } func (b *LocalBackend) ShouldRunSSH() bool { return b.sshAtomicBool.Load() && envknob.CanSSHD() } @@ -5183,7 +5653,7 @@ func (b *LocalBackend) setWebClientAtomicBoolLocked(nm *netmap.NetworkMap) { shouldRun := !nm.HasCap(tailcfg.NodeAttrDisableWebClient) wasRunning := b.webClientAtomicBool.Swap(shouldRun) if wasRunning && !shouldRun { - go b.webClientShutdown() // stop web client + b.goTracker.Go(b.webClientShutdown) // stop web client } } @@ -5253,7 +5723,7 @@ func (b *LocalBackend) Logout(ctx context.Context) error { unlock = b.lockAndGetUnlock() defer unlock() - if err := b.pm.DeleteProfile(profile.ID); err != nil { + if err := b.pm.DeleteProfile(profile.ID()); err != nil { b.logf("error deleting profile: %v", err) return err } @@ -5304,45 +5774,51 @@ func (b *LocalBackend) setNetInfo(ni *tailcfg.NetInfo) { } } -func (b *LocalBackend) setAutoExitNodeIDLockedOnEntry(unlock unlockOnce) { +func (b *LocalBackend) setAutoExitNodeIDLockedOnEntry(unlock unlockOnce) (newPrefs ipn.PrefsView) { + var zero ipn.PrefsView defer unlock() prefs := b.pm.CurrentPrefs() if !prefs.Valid() { b.logf("[unexpected]: received tailnet exit node ID pref change callback but current prefs are nil") - return + return zero } prefsClone := prefs.AsStruct() newSuggestion, err := b.suggestExitNodeLocked(nil) if err != nil { b.logf("setAutoExitNodeID: %v", err) - return + return zero + } + if prefsClone.ExitNodeID == newSuggestion.ID { + return zero } prefsClone.ExitNodeID = newSuggestion.ID - _, err = b.editPrefsLockedOnEntry(&ipn.MaskedPrefs{ + newPrefs, err = b.editPrefsLockedOnEntry(&ipn.MaskedPrefs{ Prefs: *prefsClone, ExitNodeIDSet: true, }, unlock) if err != nil { b.logf("setAutoExitNodeID: failed to apply exit node ID preference: %v", err) - return + return zero } + return newPrefs } // setNetMapLocked updates the LocalBackend state to reflect the newly // received nm. If nm is nil, it resets all configuration as though // Tailscale is turned off. func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { + oldSelf := b.currentNode().NetMap().SelfNodeOrZero() + b.dialer.SetNetMap(nm) if ns, ok := b.sys.Netstack.GetOK(); ok { ns.UpdateNetstackIPs(nm) } var login string if nm != nil { - login = cmp.Or(nm.UserProfiles[nm.User()].LoginName, "") + login = cmp.Or(profileFromView(nm.UserProfiles[nm.User()]).LoginName, "") } - b.netMap = nm - b.updatePeersFromNetmapLocked(nm) + b.currentNode().SetNetMap(nm) if login != b.activeLogin { b.logf("active login: %v", login) b.activeLogin = login @@ -5355,13 +5831,6 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { b.health.SetControlHealth(nil) } - // Determine if file sharing is enabled - fs := nm.HasCap(tailcfg.CapabilityFileSharing) - if fs != b.capFileSharing { - osshare.SetFileSharingEnabled(fs, b.logf) - } - b.capFileSharing = fs - if nm.HasCap(tailcfg.NodeAttrLinuxMustUseIPTables) { b.capForcedNetfilter = "iptables" } else if nm.HasCap(tailcfg.NodeAttrLinuxMustUseNfTables) { @@ -5380,34 +5849,22 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { netns.SetDisableBindConnToInterface(nm.HasCap(tailcfg.CapabilityDebugDisableBindConnToInterface)) b.setTCPPortsInterceptedFromNetmapAndPrefsLocked(b.pm.CurrentPrefs()) - if nm == nil { - b.nodeByAddr = nil + b.ipVIPServiceMap = nm.GetIPVIPServiceMap() + + if !oldSelf.Equal(nm.SelfNodeOrZero()) { + for _, f := range b.extHost.Hooks().OnSelfChange { + f(nm.SelfNode) + } + } + if nm == nil { // If there is no netmap, the client is going into a "turned off" // state so reset the metrics. b.metrics.approvedRoutes.Set(0) - b.metrics.primaryRoutes.Set(0) return } - // Update the nodeByAddr index. - if b.nodeByAddr == nil { - b.nodeByAddr = map[netip.Addr]tailcfg.NodeID{} - } - // First pass, mark everything unwanted. - for k := range b.nodeByAddr { - b.nodeByAddr[k] = 0 - } - addNode := func(n tailcfg.NodeView) { - for _, ipp := range n.Addresses().All() { - if ipp.IsSingleIP() { - b.nodeByAddr[ipp.Addr()] = n.ID() - } - } - } - if nm.SelfNode.Valid() { - addNode(nm.SelfNode) - + if nm.SelfNode.Valid() { var approved float64 for _, route := range nm.SelfNode.AllowedIPs().All() { if !views.SliceContains(nm.SelfNode.Addresses(), route) && !tsaddr.IsExitRoute(route) { @@ -5415,46 +5872,12 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { } } b.metrics.approvedRoutes.Set(approved) - b.metrics.primaryRoutes.Set(float64(tsaddr.WithoutExitRoute(nm.SelfNode.PrimaryRoutes()).Len())) - } - for _, p := range nm.Peers { - addNode(p) - } - // Third pass, actually delete the unwanted items. - for k, v := range b.nodeByAddr { - if v == 0 { - delete(b.nodeByAddr, k) - } } b.updateDrivePeersLocked(nm) b.driveNotifyCurrentSharesLocked() } -func (b *LocalBackend) updatePeersFromNetmapLocked(nm *netmap.NetworkMap) { - if nm == nil { - b.peers = nil - return - } - - // First pass, mark everything unwanted. - for k := range b.peers { - b.peers[k] = tailcfg.NodeView{} - } - - // Second pass, add everything wanted. - for _, p := range nm.Peers { - mak.Set(&b.peers, p.ID(), p) - } - - // Third pass, remove deleted things. - for k, v := range b.peers { - if !v.Valid() { - delete(b.peers, k) - } - } -} - // responseBodyWrapper wraps an io.ReadCloser and stores // the number of bytesRead. type responseBodyWrapper struct { @@ -5550,7 +5973,7 @@ func (dt *driveTransport) RoundTrip(req *http.Request) (resp *http.Response, err } dt.b.mu.Lock() - selfNodeKey := dt.b.netMap.SelfNode.Key().ShortString() + selfNodeKey := dt.b.currentNode().Self().Key().ShortString() dt.b.mu.Unlock() n, _, ok := dt.b.WhoIs("tcp", netip.MustParseAddrPort(req.URL.Host)) shareNodeKey := "unknown" @@ -5625,7 +6048,7 @@ func (b *LocalBackend) setDebugLogsByCapabilityLocked(nm *netmap.NetworkMap) { // the method to only run the reset-logic and not reload the store from memory to ensure // foreground sessions are not removed if they are not saved on disk. func (b *LocalBackend) reloadServeConfigLocked(prefs ipn.PrefsView) { - if b.netMap == nil || !b.netMap.SelfNode.Valid() || !prefs.Valid() || b.pm.CurrentProfile().ID == "" { + if !b.currentNode().Self().Valid() || !prefs.Valid() || b.pm.CurrentProfile().ID() == "" { // We're not logged in, so we don't have a profile. // Don't try to load the serve config. b.lastServeConfJSON = mem.B(nil) @@ -5633,7 +6056,7 @@ func (b *LocalBackend) reloadServeConfigLocked(prefs ipn.PrefsView) { return } - confKey := ipn.ServeConfigKey(b.pm.CurrentProfile().ID) + confKey := ipn.ServeConfigKey(b.pm.CurrentProfile().ID()) // TODO(maisem,bradfitz): prevent reading the config from disk // if the profile has not changed. confj, err := b.store.ReadState(confKey) @@ -5668,6 +6091,7 @@ func (b *LocalBackend) reloadServeConfigLocked(prefs ipn.PrefsView) { // b.mu must be held. func (b *LocalBackend) setTCPPortsInterceptedFromNetmapAndPrefsLocked(prefs ipn.PrefsView) { handlePorts := make([]uint16, 0, 4) + var vipServicesPorts map[tailcfg.ServiceName][]uint16 if prefs.Valid() && prefs.RunSSH() && envknob.CanSSHD() { handlePorts = append(handlePorts, 22) @@ -5684,14 +6108,27 @@ func (b *LocalBackend) setTCPPortsInterceptedFromNetmapAndPrefsLocked(prefs ipn. b.reloadServeConfigLocked(prefs) if b.serveConfig.Valid() { servePorts := make([]uint16, 0, 3) - b.serveConfig.RangeOverTCPs(func(port uint16, _ ipn.TCPPortHandlerView) bool { + for port := range b.serveConfig.TCPs() { if port > 0 { servePorts = append(servePorts, uint16(port)) } - return true - }) + } handlePorts = append(handlePorts, servePorts...) + for svc, cfg := range b.serveConfig.Services().All() { + servicePorts := make([]uint16, 0, 3) + for port := range cfg.TCP().All() { + if port > 0 { + servicePorts = append(servicePorts, uint16(port)) + } + } + if _, ok := vipServicesPorts[svc]; !ok { + mak.Set(&vipServicesPorts, svc, servicePorts) + } else { + mak.Set(&vipServicesPorts, svc, append(vipServicesPorts[svc], servicePorts...)) + } + } + b.setServeProxyHandlersLocked() // don't listen on netmap addresses if we're in userspace mode @@ -5699,14 +6136,36 @@ func (b *LocalBackend) setTCPPortsInterceptedFromNetmapAndPrefsLocked(prefs ipn. b.updateServeTCPPortNetMapAddrListenersLocked(servePorts) } } - // Kick off a Hostinfo update to control if WireIngress changed. - if wire := b.wantIngressLocked(); b.hostinfo != nil && b.hostinfo.WireIngress != wire { + + // Update funnel info in hostinfo and kick off control update if needed. + b.updateIngressLocked() + b.setTCPPortsIntercepted(handlePorts) + b.setVIPServicesTCPPortsInterceptedLocked(vipServicesPorts) +} + +// updateIngressLocked updates the hostinfo.WireIngress and hostinfo.IngressEnabled fields and kicks off a Hostinfo +// update if the values have changed. +// +// b.mu must be held. +func (b *LocalBackend) updateIngressLocked() { + if b.hostinfo == nil { + return + } + hostInfoChanged := false + if ie := b.hasIngressEnabledLocked(); b.hostinfo.IngressEnabled != ie { + b.logf("Hostinfo.IngressEnabled changed to %v", ie) + b.hostinfo.IngressEnabled = ie + hostInfoChanged = true + } + if wire := b.shouldWireInactiveIngressLocked(); b.hostinfo.WireIngress != wire { b.logf("Hostinfo.WireIngress changed to %v", wire) b.hostinfo.WireIngress = wire - go b.doSetHostinfoFilterServices() + hostInfoChanged = true + } + // Kick off a Hostinfo update to control if ingress status has changed. + if hostInfoChanged { + b.goTracker.Go(b.doSetHostinfoFilterServices) } - - b.setTCPPortsIntercepted(handlePorts) } // setServeProxyHandlersLocked ensures there is an http proxy handler for each @@ -5717,16 +6176,16 @@ func (b *LocalBackend) setServeProxyHandlersLocked() { return } var backends map[string]bool - b.serveConfig.RangeOverWebs(func(_ ipn.HostPort, conf ipn.WebServerConfigView) (cont bool) { - conf.Handlers().Range(func(_ string, h ipn.HTTPHandlerView) (cont bool) { + for _, conf := range b.serveConfig.Webs() { + for _, h := range conf.Handlers().All() { backend := h.Proxy() if backend == "" { // Only create proxy handlers for servers with a proxy backend. - return true + continue } mak.Set(&backends, backend, true) if _, ok := b.serveProxyHandlers.Load(backend); ok { - return true + continue } b.logf("serve: creating a new proxy handler for %s", backend) @@ -5735,13 +6194,11 @@ func (b *LocalBackend) setServeProxyHandlersLocked() { // The backend endpoint (h.Proxy) should have been validated by expandProxyTarget // in the CLI, so just log the error here. b.logf("[unexpected] could not create proxy for %v: %s", backend, err) - return true + continue } b.serveProxyHandlers.Store(backend, p) - return true - }) - return true - }) + } + } // Clean up handlers for proxy backends that are no longer present // in configuration. @@ -5801,141 +6258,6 @@ func (b *LocalBackend) TestOnlyPublicKeys() (machineKey key.MachinePublic, nodeK return mk, nk } -func (b *LocalBackend) removeFileWaiter(handle set.Handle) { - b.mu.Lock() - defer b.mu.Unlock() - delete(b.fileWaiters, handle) -} - -func (b *LocalBackend) addFileWaiter(wakeWaiter context.CancelFunc) set.Handle { - b.mu.Lock() - defer b.mu.Unlock() - return b.fileWaiters.Add(wakeWaiter) -} - -func (b *LocalBackend) WaitingFiles() ([]apitype.WaitingFile, error) { - b.mu.Lock() - apiSrv := b.peerAPIServer - b.mu.Unlock() - return mayDeref(apiSrv).taildrop.WaitingFiles() -} - -// AwaitWaitingFiles is like WaitingFiles but blocks while ctx is not done, -// waiting for any files to be available. -// -// On return, exactly one of the results will be non-empty or non-nil, -// respectively. -func (b *LocalBackend) AwaitWaitingFiles(ctx context.Context) ([]apitype.WaitingFile, error) { - if ff, err := b.WaitingFiles(); err != nil || len(ff) > 0 { - return ff, err - } - - for { - gotFile, gotFileCancel := context.WithCancel(context.Background()) - defer gotFileCancel() - - handle := b.addFileWaiter(gotFileCancel) - defer b.removeFileWaiter(handle) - - // Now that we've registered ourselves, check again, in case - // of race. Otherwise there's a small window where we could - // miss a file arrival and wait forever. - if ff, err := b.WaitingFiles(); err != nil || len(ff) > 0 { - return ff, err - } - - select { - case <-gotFile.Done(): - if ff, err := b.WaitingFiles(); err != nil || len(ff) > 0 { - return ff, err - } - case <-ctx.Done(): - return nil, ctx.Err() - } - } -} - -func (b *LocalBackend) DeleteFile(name string) error { - b.mu.Lock() - apiSrv := b.peerAPIServer - b.mu.Unlock() - return mayDeref(apiSrv).taildrop.DeleteFile(name) -} - -func (b *LocalBackend) OpenFile(name string) (rc io.ReadCloser, size int64, err error) { - b.mu.Lock() - apiSrv := b.peerAPIServer - b.mu.Unlock() - return mayDeref(apiSrv).taildrop.OpenFile(name) -} - -// hasCapFileSharing reports whether the current node has the file -// sharing capability enabled. -func (b *LocalBackend) hasCapFileSharing() bool { - b.mu.Lock() - defer b.mu.Unlock() - return b.capFileSharing -} - -// FileTargets lists nodes that the current node can send files to. -func (b *LocalBackend) FileTargets() ([]*apitype.FileTarget, error) { - var ret []*apitype.FileTarget - - b.mu.Lock() - defer b.mu.Unlock() - nm := b.netMap - if b.state != ipn.Running || nm == nil { - return nil, errors.New("not connected to the tailnet") - } - if !b.capFileSharing { - return nil, errors.New("file sharing not enabled by Tailscale admin") - } - for _, p := range b.peers { - if !b.peerIsTaildropTargetLocked(p) { - continue - } - if p.Hostinfo().OS() == "tvOS" { - continue - } - peerAPI := peerAPIBase(b.netMap, p) - if peerAPI == "" { - continue - } - ret = append(ret, &apitype.FileTarget{ - Node: p.AsStruct(), - PeerAPIURL: peerAPI, - }) - } - slices.SortFunc(ret, func(a, b *apitype.FileTarget) int { - return cmp.Compare(a.Node.Name, b.Node.Name) - }) - return ret, nil -} - -// peerIsTaildropTargetLocked reports whether p is a valid Taildrop file -// recipient from this node according to its ownership and the capabilities in -// the netmap. -// -// b.mu must be locked. -func (b *LocalBackend) peerIsTaildropTargetLocked(p tailcfg.NodeView) bool { - if b.netMap == nil || !p.Valid() { - return false - } - if b.netMap.User() == p.User() { - return true - } - if p.Addresses().Len() > 0 && - b.peerHasCapLocked(p.Addresses().At(0).Addr(), tailcfg.PeerCapabilityFileSharingTarget) { - // Explicitly noted in the netmap ACL caps as a target. - return true - } - return false -} - -func (b *LocalBackend) peerHasCapLocked(addr netip.Addr, wantCap tailcfg.PeerCapability) bool { - return b.peerCapsLocked(addr).HasCapability(wantCap) -} - // SetDNS adds a DNS record for the given domain name & TXT record // value. // @@ -5974,8 +6296,7 @@ func (b *LocalBackend) SetDNS(ctx context.Context, name, value string) error { func peerAPIPorts(peer tailcfg.NodeView) (p4, p6 uint16) { svcs := peer.Hostinfo().Services() - for i := range svcs.Len() { - s := svcs.At(i) + for _, s := range svcs.All() { switch s.Proto { case tailcfg.PeerAPI4: p4 = s.Port @@ -5986,59 +6307,6 @@ func peerAPIPorts(peer tailcfg.NodeView) (p4, p6 uint16) { return } -// peerAPIURL returns an HTTP URL for the peer's peerapi service, -// without a trailing slash. -// -// If ip or port is the zero value then it returns the empty string. -func peerAPIURL(ip netip.Addr, port uint16) string { - if port == 0 || !ip.IsValid() { - return "" - } - return fmt.Sprintf("http://%v", netip.AddrPortFrom(ip, port)) -} - -// peerAPIBase returns the "http://ip:port" URL base to reach peer's peerAPI. -// It returns the empty string if the peer doesn't support the peerapi -// or there's no matching address family based on the netmap's own addresses. -func peerAPIBase(nm *netmap.NetworkMap, peer tailcfg.NodeView) string { - if nm == nil || !peer.Valid() || !peer.Hostinfo().Valid() { - return "" - } - - var have4, have6 bool - addrs := nm.GetAddresses() - for i := range addrs.Len() { - a := addrs.At(i) - if !a.IsSingleIP() { - continue - } - switch { - case a.Addr().Is4(): - have4 = true - case a.Addr().Is6(): - have6 = true - } - } - p4, p6 := peerAPIPorts(peer) - switch { - case have4 && p4 != 0: - return peerAPIURL(nodeIP(peer, netip.Addr.Is4), p4) - case have6 && p6 != 0: - return peerAPIURL(nodeIP(peer, netip.Addr.Is6), p6) - } - return "" -} - -func nodeIP(n tailcfg.NodeView, pred func(netip.Addr) bool) netip.Addr { - for i := range n.Addresses().Len() { - a := n.Addresses().At(i) - if a.IsSingleIP() && pred(a.Addr()) { - return a.Addr() - } - } - return netip.Addr{} -} - func (b *LocalBackend) CheckIPForwarding() error { if b.sys.IsNetstackRouter() { return nil @@ -6125,12 +6393,7 @@ func (b *LocalBackend) SetUDPGROForwarding() error { // DERPMap returns the current DERPMap in use, or nil if not connected. func (b *LocalBackend) DERPMap() *tailcfg.DERPMap { - b.mu.Lock() - defer b.mu.Unlock() - if b.netMap == nil { - return nil - } - return b.netMap.DERPMap + return b.currentNode().DERPMap() } // OfferingExitNode reports whether b is currently offering exit node @@ -6170,7 +6433,7 @@ func (b *LocalBackend) OfferingAppConnector() bool { func (b *LocalBackend) allowExitNodeDNSProxyToServeName(name string) bool { b.mu.Lock() defer b.mu.Unlock() - nm := b.netMap + nm := b.NetMap() if nm == nil { return false } @@ -6211,6 +6474,20 @@ func (b *LocalBackend) SetExpirySooner(ctx context.Context, expiry time.Time) er return cc.SetExpirySooner(ctx, expiry) } +// SetDeviceAttrs does a synchronous call to the control plane to update +// the node's attributes. +// +// See docs on [tailcfg.SetDeviceAttributesRequest] for background. +func (b *LocalBackend) SetDeviceAttrs(ctx context.Context, attrs tailcfg.AttrUpdate) error { + b.mu.Lock() + cc := b.ccAuto + b.mu.Unlock() + if cc == nil { + return errors.New("not running") + } + return cc.SetDeviceAttrs(ctx, attrs) +} + // exitNodeCanProxyDNS reports the DoH base URL ("http://foo/dns-query") without query parameters // to exitNodeID's DoH service, if available. // @@ -6263,8 +6540,8 @@ func peerCanProxyDNS(p tailcfg.NodeView) bool { // If p.Cap is not populated (e.g. older control server), then do the old // thing of searching through services. services := p.Hostinfo().Services() - for i := range services.Len() { - if s := services.At(i); s.Proto == tailcfg.PeerAPIDNS && s.Port >= 1 { + for _, s := range services.All() { + if s.Proto == tailcfg.PeerAPIDNS && s.Port >= 1 { return true } } @@ -6302,7 +6579,7 @@ func (n keyProvingNoiseRoundTripper) RoundTrip(req *http.Request) (*http.Respons b.mu.Lock() cc := b.ccAuto - if nm := b.netMap; nm != nil { + if nm := b.NetMap(); nm != nil { priv = nm.PrivateKey } b.mu.Unlock() @@ -6434,11 +6711,12 @@ func (b *LocalBackend) handleQuad100Port80Conn(w http.ResponseWriter, r *http.Re defer b.mu.Unlock() io.WriteString(w, "

Tailscale

\n") - if b.netMap == nil { + nm := b.currentNode().NetMap() + if nm == nil { io.WriteString(w, "No netmap.\n") return } - addrs := b.netMap.GetAddresses() + addrs := nm.GetAddresses() if addrs.Len() == 0 { io.WriteString(w, "No local addresses.\n") return @@ -6469,7 +6747,7 @@ func (b *LocalBackend) Doctor(ctx context.Context, logf logger.Logf) { // controlplane. checks = append(checks, doctor.CheckFunc("dns-resolvers", func(_ context.Context, logf logger.Logf) error { b.mu.Lock() - nm := b.netMap + nm := b.NetMap() b.mu.Unlock() if nm == nil { return nil @@ -6529,25 +6807,26 @@ func (b *LocalBackend) ShouldInterceptTCPPort(port uint16) bool { return b.shouldInterceptTCPPortAtomic.Load()(port) } +// ShouldInterceptVIPServiceTCPPort reports whether the given TCP port number +// to a VIP service should be intercepted by Tailscaled and handled in-process. +func (b *LocalBackend) ShouldInterceptVIPServiceTCPPort(ap netip.AddrPort) bool { + return b.shouldInterceptVIPServicesTCPPortAtomic.Load()(ap) +} + // SwitchProfile switches to the profile with the given id. // It will restart the backend on success. // If the profile is not known, it returns an errProfileNotFound. func (b *LocalBackend) SwitchProfile(profile ipn.ProfileID) error { - if b.CurrentProfile().ID == profile { - return nil - } unlock := b.lockAndGetUnlock() defer unlock() oldControlURL := b.pm.CurrentPrefs().ControlURLOrDefault() - if err := b.pm.SwitchProfile(profile); err != nil { - return err + if _, changed, err := b.pm.SwitchToProfileByID(profile); !changed || err != nil { + return err // nil if we're already on the target profile } - // As an optimization, only reset the dialPlan if the control URL - // changed; we treat an empty URL as "unknown" and always reset. - newControlURL := b.pm.CurrentPrefs().ControlURLOrDefault() - if oldControlURL != newControlURL || oldControlURL == "" || newControlURL == "" { + // As an optimization, only reset the dialPlan if the control URL changed. + if newControlURL := b.pm.CurrentPrefs().ControlURLOrDefault(); oldControlURL != newControlURL { b.resetDialPlan() } @@ -6556,12 +6835,12 @@ func (b *LocalBackend) SwitchProfile(profile ipn.ProfileID) error { func (b *LocalBackend) initTKALocked() error { cp := b.pm.CurrentProfile() - if cp.ID == "" { + if cp.ID() == "" { b.tka = nil return nil } if b.tka != nil { - if b.tka.profile == cp.ID { + if b.tka.profile == cp.ID() { // Already initialized. return nil } @@ -6591,7 +6870,7 @@ func (b *LocalBackend) initTKALocked() error { } b.tka = &tkaState{ - profile: cp.ID, + profile: cp.ID(), authority: authority, storage: storage, } @@ -6612,6 +6891,25 @@ func (b *LocalBackend) resetDialPlan() { } } +// getHardwareAddrs returns the hardware addresses for the machine. If the list +// of hardware addresses is empty, it will return the previously known hardware +// addresses. Both the current, and previously known hardware addresses might be +// empty. +func (b *LocalBackend) getHardwareAddrs() ([]string, error) { + addrs, err := posture.GetHardwareAddrs() + if err != nil { + return nil, err + } + + if len(addrs) == 0 { + b.logf("getHardwareAddrs: got empty list of hwaddrs, returning previous list") + return b.lastKnownHardwareAddrs.Load(), nil + } + + b.lastKnownHardwareAddrs.Store(addrs) + return addrs, nil +} + // resetForProfileChangeLockedOnEntry resets the backend for a profile change. // // b.mu must held on entry. It is released on exit. @@ -6624,17 +6922,30 @@ func (b *LocalBackend) resetForProfileChangeLockedOnEntry(unlock unlockOnce) err // down, so no need to do any work. return nil } + b.currentNodeAtomic.Store(newNodeBackend()) b.setNetMapLocked(nil) // Reset netmap. + b.updateFilterLocked(ipn.PrefsView{}) // Reset the NetworkMap in the engine b.e.SetNetworkMap(new(netmap.NetworkMap)) - if err := b.initTKALocked(); err != nil { - return err + if prevCC := b.resetControlClientLocked(); prevCC != nil { + // Needs to happen without b.mu held. + defer prevCC.Shutdown() } + // TKA errors should not prevent resetting the backend state. + // However, we should still return the error to the caller. + tkaErr := b.initTKALocked() b.lastServeConfJSON = mem.B(nil) b.serveConfig = ipn.ServeConfigView{} b.lastSuggestedExitNode = "" + b.keyExpired = false + b.resetAlwaysOnOverrideLocked() + b.extHost.NotifyProfileChange(b.pm.CurrentProfile(), b.pm.CurrentPrefs(), false) + b.setAtomicValuesFromPrefsLocked(b.pm.CurrentPrefs()) b.enterStateLockedOnEntry(ipn.NoState, unlock) // Reset state; releases b.mu b.health.SetLocalLogConfigHealth(nil) + if tkaErr != nil { + return tkaErr + } return b.Start(ipn.Options{}) } @@ -6644,7 +6955,7 @@ func (b *LocalBackend) DeleteProfile(p ipn.ProfileID) error { unlock := b.lockAndGetUnlock() defer unlock() - needToRestart := b.pm.CurrentProfile().ID == p + needToRestart := b.pm.CurrentProfile().ID() == p if err := b.pm.DeleteProfile(p); err != nil { if err == errProfileNotFound { return nil @@ -6659,7 +6970,7 @@ func (b *LocalBackend) DeleteProfile(p ipn.ProfileID) error { // CurrentProfile returns the current LoginProfile. // The value may be zero if the profile is not persisted. -func (b *LocalBackend) CurrentProfile() ipn.LoginProfile { +func (b *LocalBackend) CurrentProfile() ipn.LoginProfileView { b.mu.Lock() defer b.mu.Unlock() return b.pm.CurrentProfile() @@ -6670,7 +6981,7 @@ func (b *LocalBackend) NewProfile() error { unlock := b.lockAndGetUnlock() defer unlock() - b.pm.NewProfile() + b.pm.SwitchToNewProfile() // The new profile doesn't yet have a ControlURL because it hasn't been // set. Conservatively reset the dialPlan. @@ -6680,7 +6991,7 @@ func (b *LocalBackend) NewProfile() error { } // ListProfiles returns a list of all LoginProfiles. -func (b *LocalBackend) ListProfiles() []ipn.LoginProfile { +func (b *LocalBackend) ListProfiles() []ipn.LoginProfileView { b.mu.Lock() defer b.mu.Unlock() return b.pm.Profiles() @@ -6708,48 +7019,6 @@ func (b *LocalBackend) ResetAuth() error { return b.resetForProfileChangeLockedOnEntry(unlock) } -// StreamDebugCapture writes a pcap stream of packets traversing -// tailscaled to the provided response writer. -func (b *LocalBackend) StreamDebugCapture(ctx context.Context, w io.Writer) error { - var s *capture.Sink - - b.mu.Lock() - if b.debugSink == nil { - s = capture.New() - b.debugSink = s - b.e.InstallCaptureHook(s.LogPacket) - } else { - s = b.debugSink - } - b.mu.Unlock() - - unregister := s.RegisterOutput(w) - - select { - case <-ctx.Done(): - case <-s.WaitCh(): - } - unregister() - - // Shut down & uninstall the sink if there are no longer - // any outputs on it. - b.mu.Lock() - defer b.mu.Unlock() - - select { - case <-b.ctx.Done(): - return nil - default: - } - if b.debugSink != nil && b.debugSink.NumOutputs() == 0 { - s := b.debugSink - b.e.InstallCaptureHook(nil) - b.debugSink = nil - return s.Close() - } - return nil -} - func (b *LocalBackend) GetPeerEndpointChanges(ctx context.Context, ip netip.Addr) ([]magicsock.EndpointChange, error) { pip, ok := b.e.PeerForIP(ip) if !ok { @@ -6830,17 +7099,17 @@ func (b *LocalBackend) DoSelfUpdate() { // ObserveDNSResponse passes a DNS response from the PeerAPI DNS server to the // App Connector to enable route discovery. -func (b *LocalBackend) ObserveDNSResponse(res []byte) { +func (b *LocalBackend) ObserveDNSResponse(res []byte) error { var appConnector *appc.AppConnector b.mu.Lock() if b.appConnector == nil { b.mu.Unlock() - return + return nil } appConnector = b.appConnector b.mu.Unlock() - appConnector.ObserveDNSResponse(res) + return appConnector.ObserveDNSResponse(res) } // ErrDisallowedAutoRoute is returned by AdvertiseRoute when a route that is not allowed is requested. @@ -6851,7 +7120,7 @@ var ErrDisallowedAutoRoute = errors.New("route is not allowed") // If the route is disallowed, ErrDisallowedAutoRoute is returned. func (b *LocalBackend) AdvertiseRoute(ipps ...netip.Prefix) error { finalRoutes := b.Prefs().AdvertiseRoutes().AsSlice() - newRoutes := false + var newRoutes []netip.Prefix for _, ipp := range ipps { if !allowedAutoRoute(ipp) { @@ -6867,13 +7136,14 @@ func (b *LocalBackend) AdvertiseRoute(ipps ...netip.Prefix) error { } finalRoutes = append(finalRoutes, ipp) - newRoutes = true + newRoutes = append(newRoutes, ipp) } - if !newRoutes { + if len(newRoutes) == 0 { return nil } + b.logf("advertising new app connector routes: %v", newRoutes) _, err := b.EditPrefs(&ipn.MaskedPrefs{ Prefs: ipn.Prefs{ AdvertiseRoutes: finalRoutes, @@ -6927,7 +7197,7 @@ func (b *LocalBackend) UnadvertiseRoute(toRemove ...netip.Prefix) error { // namespace a key with the profile manager's current profile key, if any func namespaceKeyForCurrentProfile(pm *profileManager, key ipn.StateKey) ipn.StateKey { - return pm.CurrentProfile().Key + "||" + key + return pm.CurrentProfile().Key() + "||" + key } const routeInfoStateStoreKey ipn.StateKey = "_routeInfo" @@ -6935,7 +7205,7 @@ const routeInfoStateStoreKey ipn.StateKey = "_routeInfo" func (b *LocalBackend) storeRouteInfo(ri *appc.RouteInfo) error { b.mu.Lock() defer b.mu.Unlock() - if b.pm.CurrentProfile().ID == "" { + if b.pm.CurrentProfile().ID() == "" { return nil } key := namespaceKeyForCurrentProfile(b.pm, routeInfoStateStoreKey) @@ -6947,7 +7217,7 @@ func (b *LocalBackend) storeRouteInfo(ri *appc.RouteInfo) error { } func (b *LocalBackend) readRouteInfoLocked() (*appc.RouteInfo, error) { - if b.pm.CurrentProfile().ID == "" { + if b.pm.CurrentProfile().ID() == "" { return &appc.RouteInfo{}, nil } key := namespaceKeyForCurrentProfile(b.pm, routeInfoStateStoreKey) @@ -7000,14 +7270,6 @@ func allowedAutoRoute(ipp netip.Prefix) bool { return true } -// mayDeref dereferences p if non-nil, otherwise it returns the zero value. -func mayDeref[T any](p *T) (v T) { - if p == nil { - return v - } - return *p -} - var ErrNoPreferredDERP = errors.New("no preferred DERP, try again later") // suggestExitNodeLocked computes a suggestion based on the current netmap and last netcheck report. If @@ -7025,12 +7287,12 @@ var ErrNoPreferredDERP = errors.New("no preferred DERP, try again later") func (b *LocalBackend) suggestExitNodeLocked(netMap *netmap.NetworkMap) (response apitype.ExitNodeSuggestionResponse, err error) { // netMap is an optional netmap to use that overrides b.netMap (needed for SetControlClientStatus before b.netMap is updated). If netMap is nil, then b.netMap is used. if netMap == nil { - netMap = b.netMap + netMap = b.NetMap() } lastReport := b.MagicConn().GetLastNetcheckReport(b.ctx) prevSuggestion := b.lastSuggestedExitNode - res, err := suggestExitNode(lastReport, netMap, prevSuggestion, randomRegion, randomNode, getAllowedSuggestions()) + res, err := suggestExitNode(lastReport, netMap, prevSuggestion, randomRegion, randomNode, b.getAllowedSuggestions()) if err != nil { return res, err } @@ -7044,6 +7306,22 @@ func (b *LocalBackend) SuggestExitNode() (response apitype.ExitNodeSuggestionRes return b.suggestExitNodeLocked(nil) } +// getAllowedSuggestions returns a set of exit nodes permitted by the most recent +// [syspolicy.AllowedSuggestedExitNodes] value. Callers must not mutate the returned set. +func (b *LocalBackend) getAllowedSuggestions() set.Set[tailcfg.StableNodeID] { + b.allowedSuggestedExitNodesMu.Lock() + defer b.allowedSuggestedExitNodesMu.Unlock() + return b.allowedSuggestedExitNodes +} + +// refreshAllowedSuggestions rebuilds the set of permitted exit nodes +// from the current [syspolicy.AllowedSuggestedExitNodes] value. +func (b *LocalBackend) refreshAllowedSuggestions() { + b.allowedSuggestedExitNodesMu.Lock() + defer b.allowedSuggestedExitNodesMu.Unlock() + b.allowedSuggestedExitNodes = fillAllowedSuggestions() +} + // selectRegionFunc returns a DERP region from the slice of candidate regions. // The value is returned, not the slice index. type selectRegionFunc func(views.Slice[int]) int @@ -7053,8 +7331,6 @@ type selectRegionFunc func(views.Slice[int]) int // choice. type selectNodeFunc func(nodes views.Slice[tailcfg.NodeView], last tailcfg.StableNodeID) tailcfg.NodeView -var getAllowedSuggestions = lazy.SyncFunc(fillAllowedSuggestions) - func fillAllowedSuggestions() set.Set[tailcfg.StableNodeID] { nodes, err := syspolicy.GetStringArray(syspolicy.AllowedSuggestedExitNodes, nil) if err != nil { @@ -7093,8 +7369,8 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, prevSug if len(candidates) == 1 { peer := candidates[0] if hi := peer.Hostinfo(); hi.Valid() { - if loc := hi.Location(); loc != nil { - res.Location = loc.View() + if loc := hi.Location(); loc.Valid() { + res.Location = loc } } res.ID = peer.StableID() @@ -7114,15 +7390,7 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, prevSug } distances := make([]nodeDistance, 0, len(candidates)) for _, c := range candidates { - if c.DERP() != "" { - ipp, err := netip.ParseAddrPort(c.DERP()) - if err != nil { - continue - } - if ipp.Addr() != tailcfg.DerpMagicIPAddr { - continue - } - regionID := int(ipp.Port()) + if regionID := c.HomeDERP(); regionID != 0 { candidatesByRegion[regionID] = append(candidatesByRegion[regionID], c) continue } @@ -7138,10 +7406,10 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, prevSug continue } loc := hi.Location() - if loc == nil { + if !loc.Valid() { continue } - distance := longLatDistance(preferredDERP.Latitude, preferredDERP.Longitude, loc.Latitude, loc.Longitude) + distance := longLatDistance(preferredDERP.Latitude, preferredDERP.Longitude, loc.Latitude(), loc.Longitude()) if distance < minDistance { minDistance = distance } @@ -7150,9 +7418,9 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, prevSug // First, try to select an exit node that has the closest DERP home, based on lastReport's DERP latency. // If there are no latency values, it returns an arbitrary region if len(candidatesByRegion) > 0 { - minRegion := minLatencyDERPRegion(xmaps.Keys(candidatesByRegion), report) + minRegion := minLatencyDERPRegion(slicesx.MapKeys(candidatesByRegion), report) if minRegion == 0 { - minRegion = selectRegion(views.SliceOf(xmaps.Keys(candidatesByRegion))) + minRegion = selectRegion(views.SliceOf(slicesx.MapKeys(candidatesByRegion))) } regionCandidates, ok := candidatesByRegion[minRegion] if !ok { @@ -7162,8 +7430,8 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, prevSug res.ID = chosen.StableID() res.Name = chosen.Name() if hi := chosen.Hostinfo(); hi.Valid() { - if loc := hi.Location(); loc != nil { - res.Location = loc.View() + if loc := hi.Location(); loc.Valid() { + res.Location = loc } } return res, nil @@ -7192,8 +7460,8 @@ func suggestExitNode(report *netcheck.Report, netMap *netmap.NetworkMap, prevSug res.ID = chosen.StableID() res.Name = chosen.Name() if hi := chosen.Hostinfo(); hi.Valid() { - if loc := hi.Location(); loc != nil { - res.Location = loc.View() + if loc := hi.Location(); loc.Valid() { + res.Location = loc } } return res, nil @@ -7209,13 +7477,13 @@ func pickWeighted(candidates []tailcfg.NodeView) []tailcfg.NodeView { continue } loc := hi.Location() - if loc == nil || loc.Priority < maxWeight { + if !loc.Valid() || loc.Priority() < maxWeight { continue } - if maxWeight != loc.Priority { + if maxWeight != loc.Priority() { best = best[:0] } - maxWeight = loc.Priority + maxWeight = loc.Priority() best = append(best, c) } return best @@ -7344,23 +7612,89 @@ func (b *LocalBackend) startAutoUpdate(logPrefix string) (retErr error) { // rules that require a source IP to have a certain node capability. // // TODO(bradfitz): optimize this later if/when it matters. +// TODO(nickkhyl): move this into [nodeBackend] along with [LocalBackend.updateFilterLocked]. func (b *LocalBackend) srcIPHasCapForFilter(srcIP netip.Addr, cap tailcfg.NodeCapability) bool { if cap == "" { // Shouldn't happen, but just in case. // But the empty cap also shouldn't be found in Node.CapMap. return false } - - b.mu.Lock() - defer b.mu.Unlock() - - nodeID, ok := b.nodeByAddr[srcIP] + cn := b.currentNode() + nodeID, ok := cn.NodeByAddr(srcIP) if !ok { return false } - n, ok := b.peers[nodeID] + n, ok := cn.PeerByID(nodeID) if !ok { return false } return n.HasCap(cap) } + +// maybeUsernameOf returns the actor's username if the actor +// is non-nil and its username can be resolved. +func maybeUsernameOf(actor ipnauth.Actor) string { + var username string + if actor != nil { + username, _ = actor.Username() + } + return username +} + +// VIPServices returns the list of tailnet services that this node +// is serving as a destination for. +// The returned memory is owned by the caller. +func (b *LocalBackend) VIPServices() []*tailcfg.VIPService { + b.mu.Lock() + defer b.mu.Unlock() + return b.vipServicesFromPrefsLocked(b.pm.CurrentPrefs()) +} + +func (b *LocalBackend) vipServiceHash(services []*tailcfg.VIPService) string { + if len(services) == 0 { + return "" + } + buf, err := json.Marshal(services) + if err != nil { + b.logf("vipServiceHashLocked: %v", err) + return "" + } + hash := sha256.Sum256(buf) + return hex.EncodeToString(hash[:]) +} + +func (b *LocalBackend) vipServicesFromPrefsLocked(prefs ipn.PrefsView) []*tailcfg.VIPService { + // keyed by service name + var services map[tailcfg.ServiceName]*tailcfg.VIPService + if b.serveConfig.Valid() { + for svc, config := range b.serveConfig.Services().All() { + mak.Set(&services, svc, &tailcfg.VIPService{ + Name: svc, + Ports: config.ServicePortRange(), + }) + } + } + + for _, s := range prefs.AdvertiseServices().All() { + sn := tailcfg.ServiceName(s) + if services == nil || services[sn] == nil { + mak.Set(&services, sn, &tailcfg.VIPService{ + Name: sn, + }) + } + services[sn].Active = true + } + + servicesList := slicesx.MapValues(services) + // [slicesx.MapValues] provides the values in an indeterminate order, but since we'll + // be hashing a representation of this list later we want it to be in a consistent + // order. + slices.SortFunc(servicesList, func(a, b *tailcfg.VIPService) int { + return strings.Compare(a.Name.String(), b.Name.String()) + }) + return servicesList +} + +var ( + metricCurrentWatchIPNBus = clientmetric.NewGauge("localbackend_current_watch_ipn_bus") +) diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index b0e12d5005431..19cfd91953e4c 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -13,8 +13,10 @@ import ( "net/http" "net/netip" "os" + "path/filepath" "reflect" "slices" + "strings" "sync" "testing" "time" @@ -31,6 +33,8 @@ import ( "tailscale.com/health" "tailscale.com/hostinfo" "tailscale.com/ipn" + "tailscale.com/ipn/conffile" + "tailscale.com/ipn/ipnauth" "tailscale.com/ipn/store/mem" "tailscale.com/net/netcheck" "tailscale.com/net/netmon" @@ -40,6 +44,7 @@ import ( "tailscale.com/tsd" "tailscale.com/tstest" "tailscale.com/types/dnstype" + "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" @@ -52,8 +57,11 @@ import ( "tailscale.com/util/must" "tailscale.com/util/set" "tailscale.com/util/syspolicy" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" "tailscale.com/wgengine" "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/filter/filtertype" "tailscale.com/wgengine/wgcfg" ) @@ -428,20 +436,30 @@ func (panicOnUseTransport) RoundTrip(*http.Request) (*http.Response, error) { } func newTestLocalBackend(t testing.TB) *LocalBackend { + return newTestLocalBackendWithSys(t, tsd.NewSystem()) +} + +// newTestLocalBackendWithSys creates a new LocalBackend with the given tsd.System. +// If the state store or engine are not set in sys, they will be set to a new +// in-memory store and fake userspace engine, respectively. +func newTestLocalBackendWithSys(t testing.TB, sys *tsd.System) *LocalBackend { var logf logger.Logf = logger.Discard - sys := new(tsd.System) - store := new(mem.Store) - sys.Set(store) - eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) - if err != nil { - t.Fatalf("NewFakeUserspaceEngine: %v", err) + if _, ok := sys.StateStore.GetOK(); !ok { + sys.Set(new(mem.Store)) + } + if _, ok := sys.Engine.GetOK(); !ok { + eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry(), sys.Bus.Get()) + if err != nil { + t.Fatalf("NewFakeUserspaceEngine: %v", err) + } + t.Cleanup(eng.Close) + sys.Set(eng) } - t.Cleanup(eng.Close) - sys.Set(eng) lb, err := NewLocalBackend(logf, logid.PublicID{}, sys, 0) if err != nil { t.Fatalf("NewLocalBackend: %v", err) } + t.Cleanup(lb.Shutdown) return lb } @@ -557,54 +575,6 @@ func TestSetUseExitNodeEnabled(t *testing.T) { } } -func TestFileTargets(t *testing.T) { - b := new(LocalBackend) - _, err := b.FileTargets() - if got, want := fmt.Sprint(err), "not connected to the tailnet"; got != want { - t.Errorf("before connect: got %q; want %q", got, want) - } - - b.netMap = new(netmap.NetworkMap) - _, err = b.FileTargets() - if got, want := fmt.Sprint(err), "not connected to the tailnet"; got != want { - t.Errorf("non-running netmap: got %q; want %q", got, want) - } - - b.state = ipn.Running - _, err = b.FileTargets() - if got, want := fmt.Sprint(err), "file sharing not enabled by Tailscale admin"; got != want { - t.Errorf("without cap: got %q; want %q", got, want) - } - - b.capFileSharing = true - got, err := b.FileTargets() - if err != nil { - t.Fatal(err) - } - if len(got) != 0 { - t.Fatalf("unexpected %d peers", len(got)) - } - - var peerMap map[tailcfg.NodeID]tailcfg.NodeView - mak.NonNil(&peerMap) - var nodeID tailcfg.NodeID - nodeID = 1234 - peer := &tailcfg.Node{ - ID: 1234, - Hostinfo: (&tailcfg.Hostinfo{OS: "tvOS"}).View(), - } - peerMap[nodeID] = peer.View() - b.peers = peerMap - got, err = b.FileTargets() - if err != nil { - t.Fatal(err) - } - if len(got) != 0 { - t.Fatalf("unexpected %d peers", len(got)) - } - // (other cases handled by TestPeerAPIBase above) -} - func TestInternalAndExternalInterfaces(t *testing.T) { type interfacePrefix struct { i netmon.Interface @@ -950,15 +920,15 @@ func TestWatchNotificationsCallbacks(t *testing.T) { // tests LocalBackend.updateNetmapDeltaLocked func TestUpdateNetmapDelta(t *testing.T) { b := newTestLocalBackend(t) - if b.updateNetmapDeltaLocked(nil) { + if b.currentNode().UpdateNetmapDelta(nil) { t.Errorf("updateNetmapDeltaLocked() = true, want false with nil netmap") } - b.netMap = &netmap.NetworkMap{} + nm := &netmap.NetworkMap{} for i := range 5 { - b.netMap.Peers = append(b.netMap.Peers, (&tailcfg.Node{ID: (tailcfg.NodeID(i) + 1)}).View()) + nm.Peers = append(nm.Peers, (&tailcfg.Node{ID: (tailcfg.NodeID(i) + 1)}).View()) } - b.updatePeersFromNetmapLocked(b.netMap) + b.currentNode().SetNetMap(nm) someTime := time.Unix(123, 0) muts, ok := netmap.MutationsFromMapResponse(&tailcfg.MapResponse{ @@ -985,14 +955,14 @@ func TestUpdateNetmapDelta(t *testing.T) { t.Fatal("netmap.MutationsFromMapResponse failed") } - if !b.updateNetmapDeltaLocked(muts) { + if !b.currentNode().UpdateNetmapDelta(muts) { t.Fatalf("updateNetmapDeltaLocked() = false, want true with new netmap") } wants := []*tailcfg.Node{ { - ID: 1, - DERP: "127.3.3.40:1", + ID: 1, + HomeDERP: 1, }, { ID: 2, @@ -1008,9 +978,9 @@ func TestUpdateNetmapDelta(t *testing.T) { }, } for _, want := range wants { - gotv, ok := b.peers[want.ID] + gotv, ok := b.currentNode().PeerByID(want.ID) if !ok { - t.Errorf("netmap.Peer %v missing from b.peers", want.ID) + t.Errorf("netmap.Peer %v missing from b.profile.Peers", want.ID) continue } got := gotv.AsStruct() @@ -1036,13 +1006,13 @@ func TestWhoIs(t *testing.T) { Addresses: []netip.Prefix{netip.MustParsePrefix("100.200.200.200/32")}, }).View(), }, - UserProfiles: map[tailcfg.UserID]tailcfg.UserProfile{ - 10: { + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + 10: (&tailcfg.UserProfile{ DisplayName: "Myself", - }, - 20: { + }).View(), + 20: (&tailcfg.UserProfile{ DisplayName: "Peer", - }, + }).View(), }, }) tests := []struct { @@ -1356,7 +1326,9 @@ func TestObserveDNSResponse(t *testing.T) { b := newTestBackend(t) // ensure no error when no app connector is configured - b.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")) + if err := b.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } rc := &appctest.RouteCollector{} if shouldStore { @@ -1367,7 +1339,9 @@ func TestObserveDNSResponse(t *testing.T) { b.appConnector.UpdateDomains([]string{"example.com"}) b.appConnector.Wait(context.Background()) - b.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")) + if err := b.ObserveDNSResponse(dnsResponse("example.com.", "192.0.0.8")); err != nil { + t.Errorf("ObserveDNSResponse: %v", err) + } b.appConnector.Wait(context.Background()) wantRoutes := []netip.Prefix{netip.MustParsePrefix("192.0.0.8/32")} if !slices.Equal(rc.Routes(), wantRoutes) { @@ -1424,7 +1398,7 @@ func TestCoveredRouteRangeNoDefault(t *testing.T) { func TestReconfigureAppConnector(t *testing.T) { b := newTestBackend(t) - b.reconfigAppConnectorLocked(b.netMap, b.pm.prefs) + b.reconfigAppConnectorLocked(b.NetMap(), b.pm.prefs) if b.appConnector != nil { t.Fatal("unexpected app connector") } @@ -1437,7 +1411,7 @@ func TestReconfigureAppConnector(t *testing.T) { }, AppConnectorSet: true, }) - b.reconfigAppConnectorLocked(b.netMap, b.pm.prefs) + b.reconfigAppConnectorLocked(b.NetMap(), b.pm.prefs) if b.appConnector == nil { t.Fatal("expected app connector") } @@ -1448,15 +1422,19 @@ func TestReconfigureAppConnector(t *testing.T) { "connectors": ["tag:example"] }` - b.netMap.SelfNode = (&tailcfg.Node{ - Name: "example.ts.net", - Tags: []string{"tag:example"}, - CapMap: (tailcfg.NodeCapMap)(map[tailcfg.NodeCapability][]tailcfg.RawMessage{ - "tailscale.com/app-connectors": {tailcfg.RawMessage(appCfg)}, - }), - }).View() + nm := &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Name: "example.ts.net", + Tags: []string{"tag:example"}, + CapMap: (tailcfg.NodeCapMap)(map[tailcfg.NodeCapability][]tailcfg.RawMessage{ + "tailscale.com/app-connectors": {tailcfg.RawMessage(appCfg)}, + }), + }).View(), + } + + b.currentNode().SetNetMap(nm) - b.reconfigAppConnectorLocked(b.netMap, b.pm.prefs) + b.reconfigAppConnectorLocked(b.NetMap(), b.pm.prefs) b.appConnector.Wait(context.Background()) want := []string{"example.com"} @@ -1476,7 +1454,7 @@ func TestReconfigureAppConnector(t *testing.T) { }, AppConnectorSet: true, }) - b.reconfigAppConnectorLocked(b.netMap, b.pm.prefs) + b.reconfigAppConnectorLocked(b.NetMap(), b.pm.prefs) if b.appConnector != nil { t.Fatal("expected no app connector") } @@ -1485,6 +1463,62 @@ func TestReconfigureAppConnector(t *testing.T) { } } +func TestBackfillAppConnectorRoutes(t *testing.T) { + // Create backend with an empty app connector. + b := newTestBackend(t) + // newTestBackend creates a backend with a non-nil netmap, + // but this test requires a nil netmap. + // Otherwise, instead of backfilling, [LocalBackend.reconfigAppConnectorLocked] + // uses the domains and routes from netmap's [appctype.AppConnectorAttr]. + // Additionally, a non-nil netmap makes reconfigAppConnectorLocked + // asynchronous, resulting in a flaky test. + // Therefore, we set the netmap to nil to simulate a fresh backend start + // or a profile switch where the netmap is not yet available. + b.setNetMapLocked(nil) + if err := b.Start(ipn.Options{}); err != nil { + t.Fatal(err) + } + if _, err := b.EditPrefs(&ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + AppConnector: ipn.AppConnectorPrefs{Advertise: true}, + }, + AppConnectorSet: true, + }); err != nil { + t.Fatal(err) + } + b.reconfigAppConnectorLocked(b.NetMap(), b.pm.prefs) + + // Smoke check that AdvertiseRoutes doesn't have the test IP. + ip := netip.MustParseAddr("1.2.3.4") + routes := b.Prefs().AdvertiseRoutes().AsSlice() + if slices.Contains(routes, netip.PrefixFrom(ip, ip.BitLen())) { + t.Fatalf("AdvertiseRoutes %v on a fresh backend already contains advertised route for %v", routes, ip) + } + + // Store the test IP in profile data, but not in Prefs.AdvertiseRoutes. + b.ControlKnobs().AppCStoreRoutes.Store(true) + if err := b.storeRouteInfo(&appc.RouteInfo{ + Domains: map[string][]netip.Addr{ + "example.com": {ip}, + }, + }); err != nil { + t.Fatal(err) + } + + // Mimic b.authReconfigure for the app connector bits. + b.mu.Lock() + b.reconfigAppConnectorLocked(b.NetMap(), b.pm.prefs) + b.mu.Unlock() + b.readvertiseAppConnectorRoutes() + + // Check that Prefs.AdvertiseRoutes got backfilled with routes stored in + // profile data. + routes = b.Prefs().AdvertiseRoutes().AsSlice() + if !slices.Contains(routes, netip.PrefixFrom(ip, ip.BitLen())) { + t.Fatalf("AdvertiseRoutes %v was not backfilled from stored app connector routes with %v", routes, ip) + } +} + func resolversEqual(t *testing.T, a, b []*dnstype.Resolver) bool { if a == nil && b == nil { return true @@ -1557,94 +1591,6 @@ func dnsResponse(domain, address string) []byte { return must.Get(b.Finish()) } -type errorSyspolicyHandler struct { - t *testing.T - err error - key syspolicy.Key - allowKeys map[syspolicy.Key]*string -} - -func (h *errorSyspolicyHandler) ReadString(key string) (string, error) { - sk := syspolicy.Key(key) - if _, ok := h.allowKeys[sk]; !ok { - h.t.Errorf("ReadString: %q is not in list of permitted keys", h.key) - } - if sk == h.key { - return "", h.err - } - return "", syspolicy.ErrNoSuchKey -} - -func (h *errorSyspolicyHandler) ReadUInt64(key string) (uint64, error) { - h.t.Errorf("ReadUInt64(%q) unexpectedly called", key) - return 0, syspolicy.ErrNoSuchKey -} - -func (h *errorSyspolicyHandler) ReadBoolean(key string) (bool, error) { - h.t.Errorf("ReadBoolean(%q) unexpectedly called", key) - return false, syspolicy.ErrNoSuchKey -} - -func (h *errorSyspolicyHandler) ReadStringArray(key string) ([]string, error) { - h.t.Errorf("ReadStringArray(%q) unexpectedly called", key) - return nil, syspolicy.ErrNoSuchKey -} - -type mockSyspolicyHandler struct { - t *testing.T - // stringPolicies is the collection of policies that we expect to see - // queried by the current test. If the policy is expected but unset, then - // use nil, otherwise use a string equal to the policy's desired value. - stringPolicies map[syspolicy.Key]*string - // stringArrayPolicies is the collection of policies that we expected to see - // queries by the current test, that return policy string arrays. - stringArrayPolicies map[syspolicy.Key][]string - // failUnknownPolicies is set if policies other than those in stringPolicies - // (uint64 or bool policies are not supported by mockSyspolicyHandler yet) - // should be considered a test failure if they are queried. - failUnknownPolicies bool -} - -func (h *mockSyspolicyHandler) ReadString(key string) (string, error) { - if s, ok := h.stringPolicies[syspolicy.Key(key)]; ok { - if s == nil { - return "", syspolicy.ErrNoSuchKey - } - return *s, nil - } - if h.failUnknownPolicies { - h.t.Errorf("ReadString(%q) unexpectedly called", key) - } - return "", syspolicy.ErrNoSuchKey -} - -func (h *mockSyspolicyHandler) ReadUInt64(key string) (uint64, error) { - if h.failUnknownPolicies { - h.t.Errorf("ReadUInt64(%q) unexpectedly called", key) - } - return 0, syspolicy.ErrNoSuchKey -} - -func (h *mockSyspolicyHandler) ReadBoolean(key string) (bool, error) { - if h.failUnknownPolicies { - h.t.Errorf("ReadBoolean(%q) unexpectedly called", key) - } - return false, syspolicy.ErrNoSuchKey -} - -func (h *mockSyspolicyHandler) ReadStringArray(key string) ([]string, error) { - if h.failUnknownPolicies { - h.t.Errorf("ReadStringArray(%q) unexpectedly called", key) - } - if s, ok := h.stringArrayPolicies[syspolicy.Key(key)]; ok { - if s == nil { - return []string{}, syspolicy.ErrNoSuchKey - } - return s, nil - } - return nil, syspolicy.ErrNoSuchKey -} - func TestSetExitNodeIDPolicy(t *testing.T) { pfx := netip.MustParsePrefix tests := []struct { @@ -1854,23 +1800,21 @@ func TestSetExitNodeIDPolicy(t *testing.T) { }, } + syspolicy.RegisterWellKnownSettingsForTest(t) + for _, test := range tests { t.Run(test.name, func(t *testing.T) { b := newTestBackend(t) - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: map[syspolicy.Key]*string{ - syspolicy.ExitNodeID: nil, - syspolicy.ExitNodeIP: nil, - }, - } + + policyStore := source.NewTestStore(t) if test.exitNodeIDKey { - msh.stringPolicies[syspolicy.ExitNodeID] = &test.exitNodeID + policyStore.SetStrings(source.TestSettingOf(syspolicy.ExitNodeID, test.exitNodeID)) } if test.exitNodeIPKey { - msh.stringPolicies[syspolicy.ExitNodeIP] = &test.exitNodeIP + policyStore.SetStrings(source.TestSettingOf(syspolicy.ExitNodeIP, test.exitNodeIP)) } - syspolicy.SetHandlerForTest(t, msh) + syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) + if test.nm == nil { test.nm = new(netmap.NetworkMap) } @@ -1879,10 +1823,19 @@ func TestSetExitNodeIDPolicy(t *testing.T) { } pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker))) pm.prefs = test.prefs.View() - b.netMap = test.nm + b.currentNode().SetNetMap(test.nm) b.pm = pm b.lastSuggestedExitNode = test.lastSuggestedExitNode - changed := setExitNodeID(b.pm.prefs.AsStruct(), test.nm, tailcfg.StableNodeID(test.lastSuggestedExitNode)) + + prefs := b.pm.prefs.AsStruct() + if changed := applySysPolicy(prefs, test.lastSuggestedExitNode, false) || setExitNodeID(prefs, test.nm); changed != test.prefsChanged { + t.Errorf("wanted prefs changed %v, got prefs changed %v", test.prefsChanged, changed) + } + + // Both [LocalBackend.SetPrefsForTest] and [LocalBackend.EditPrefs] + // apply syspolicy settings to the current profile's preferences. Therefore, + // we pass the current, unmodified preferences and expect the effective + // preferences to change. b.SetPrefsForTest(pm.CurrentPrefs().AsStruct()) if got := b.pm.prefs.ExitNodeID(); got != tailcfg.StableNodeID(test.exitNodeIDWant) { @@ -1895,10 +1848,6 @@ func TestSetExitNodeIDPolicy(t *testing.T) { } else if got.String() != test.exitNodeIPWant { t.Errorf("got %v want %v", got, test.exitNodeIPWant) } - - if changed != test.prefsChanged { - t.Errorf("wanted prefs changed %v, got prefs changed %v", test.prefsChanged, changed) - } }) } } @@ -1935,16 +1884,16 @@ func TestUpdateNetmapDeltaAutoExitNode(t *testing.T) { PreferredDERP: 2, } tests := []struct { - name string - lastSuggestedExitNode tailcfg.StableNodeID - netmap *netmap.NetworkMap - muts []*tailcfg.PeerChange - exitNodeIDWant tailcfg.StableNodeID - updateNetmapDeltaResponse bool - report *netcheck.Report + name string + lastSuggestedExitNode tailcfg.StableNodeID + netmap *netmap.NetworkMap + muts []*tailcfg.PeerChange + exitNodeIDWant tailcfg.StableNodeID + report *netcheck.Report }{ { - name: "selected auto exit node goes offline", + // selected auto exit node goes offline + name: "exit-node-goes-offline", lastSuggestedExitNode: peer1.StableID(), netmap: &netmap.NetworkMap{ Peers: []tailcfg.NodeView{ @@ -1963,12 +1912,12 @@ func TestUpdateNetmapDeltaAutoExitNode(t *testing.T) { Online: ptr.To(true), }, }, - exitNodeIDWant: peer2.StableID(), - updateNetmapDeltaResponse: false, - report: report, + exitNodeIDWant: peer2.StableID(), + report: report, }, { - name: "other exit node goes offline doesn't change selected auto exit node that's still online", + // other exit node goes offline doesn't change selected auto exit node that's still online + name: "other-node-goes-offline", lastSuggestedExitNode: peer2.StableID(), netmap: &netmap.NetworkMap{ Peers: []tailcfg.NodeView{ @@ -1987,26 +1936,38 @@ func TestUpdateNetmapDeltaAutoExitNode(t *testing.T) { Online: ptr.To(true), }, }, - exitNodeIDWant: peer2.StableID(), - updateNetmapDeltaResponse: true, - report: report, - }, - } - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: map[syspolicy.Key]*string{ - syspolicy.ExitNodeID: ptr.To("auto:any"), + exitNodeIDWant: peer2.StableID(), + report: report, }, } - syspolicy.SetHandlerForTest(t, msh) + + syspolicy.RegisterWellKnownSettingsForTest(t) + policyStore := source.NewTestStoreOf(t, source.TestSettingOf( + syspolicy.ExitNodeID, "auto:any", + )) + syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { b := newTestLocalBackend(t) - b.netMap = tt.netmap - b.updatePeersFromNetmapLocked(b.netMap) + b.currentNode().SetNetMap(tt.netmap) b.lastSuggestedExitNode = tt.lastSuggestedExitNode b.sys.MagicSock.Get().SetLastNetcheckReportForTest(b.ctx, tt.report) b.SetPrefsForTest(b.pm.CurrentPrefs().AsStruct()) + + allDone := make(chan bool, 1) + defer b.goTracker.AddDoneCallback(func() { + b.mu.Lock() + defer b.mu.Unlock() + if b.goTracker.RunningGoroutines() > 0 { + return + } + select { + case allDone <- true: + default: + } + })() + someTime := time.Unix(123, 0) muts, ok := netmap.MutationsFromMapResponse(&tailcfg.MapResponse{ PeersChangedPatch: tt.muts, @@ -2014,16 +1975,34 @@ func TestUpdateNetmapDeltaAutoExitNode(t *testing.T) { if !ok { t.Fatal("netmap.MutationsFromMapResponse failed") } + if b.pm.prefs.ExitNodeID() != tt.lastSuggestedExitNode { t.Fatalf("did not set exit node ID to last suggested exit node despite auto policy") } + was := b.goTracker.StartedGoroutines() got := b.UpdateNetmapDelta(muts) - if got != tt.updateNetmapDeltaResponse { - t.Fatalf("got %v expected %v from UpdateNetmapDelta", got, tt.updateNetmapDeltaResponse) + if !got { + t.Error("got false from UpdateNetmapDelta") + } + startedGoroutine := b.goTracker.StartedGoroutines() != was + + wantChange := tt.exitNodeIDWant != tt.lastSuggestedExitNode + if startedGoroutine != wantChange { + t.Errorf("got startedGoroutine %v, want %v", startedGoroutine, wantChange) + } + if startedGoroutine { + select { + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for goroutine to finish") + case <-allDone: + } } - if b.pm.prefs.ExitNodeID() != tt.exitNodeIDWant { - t.Fatalf("did not get expected exit node id after UpdateNetmapDelta") + b.mu.Lock() + gotExitNode := b.pm.prefs.ExitNodeID() + b.mu.Unlock() + if gotExitNode != tt.exitNodeIDWant { + t.Fatalf("exit node ID after UpdateNetmapDelta = %v; want %v", gotExitNode, tt.exitNodeIDWant) } }) } @@ -2047,13 +2026,11 @@ func TestAutoExitNodeSetNetInfoCallback(t *testing.T) { } cc = newClient(t, opts) b.cc = cc - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: map[syspolicy.Key]*string{ - syspolicy.ExitNodeID: ptr.To("auto:any"), - }, - } - syspolicy.SetHandlerForTest(t, msh) + syspolicy.RegisterWellKnownSettingsForTest(t) + policyStore := source.NewTestStoreOf(t, source.TestSettingOf( + syspolicy.ExitNodeID, "auto:any", + )) + syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) peer1 := makePeer(1, withCap(26), withDERP(3), withSuggest(), withExitRoutes()) peer2 := makePeer(2, withCap(26), withDERP(2), withSuggest(), withExitRoutes()) selfNode := tailcfg.Node{ @@ -2061,7 +2038,7 @@ func TestAutoExitNodeSetNetInfoCallback(t *testing.T) { netip.MustParsePrefix("100.64.1.1/32"), netip.MustParsePrefix("fe70::1/128"), }, - DERP: "127.3.3.40:2", + HomeDERP: 2, } defaultDERPMap := &tailcfg.DERPMap{ Regions: map[int]*tailcfg.DERPRegion{ @@ -2091,14 +2068,14 @@ func TestAutoExitNodeSetNetInfoCallback(t *testing.T) { }, }, } - b.netMap = &netmap.NetworkMap{ + b.currentNode().SetNetMap(&netmap.NetworkMap{ SelfNode: selfNode.View(), Peers: []tailcfg.NodeView{ peer1, peer2, }, DERPMap: defaultDERPMap, - } + }) b.lastSuggestedExitNode = peer1.StableID() b.SetPrefsForTest(b.pm.CurrentPrefs().AsStruct()) if eid := b.Prefs().ExitNodeID(); eid != peer1.StableID() { @@ -2158,14 +2135,12 @@ func TestSetControlClientStatusAutoExitNode(t *testing.T) { DERPMap: derpMap, } b := newTestLocalBackend(t) - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: map[syspolicy.Key]*string{ - syspolicy.ExitNodeID: ptr.To("auto:any"), - }, - } - syspolicy.SetHandlerForTest(t, msh) - b.netMap = nm + syspolicy.RegisterWellKnownSettingsForTest(t) + policyStore := source.NewTestStoreOf(t, source.TestSettingOf( + syspolicy.ExitNodeID, "auto:any", + )) + syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) + b.currentNode().SetNetMap(nm) b.lastSuggestedExitNode = peer1.StableID() b.sys.MagicSock.Get().SetLastNetcheckReportForTest(b.ctx, report) b.SetPrefsForTest(b.pm.CurrentPrefs().AsStruct()) @@ -2398,22 +2373,21 @@ func TestApplySysPolicy(t *testing.T) { }, } + syspolicy.RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: make(map[syspolicy.Key]*string, len(tt.stringPolicies)), - } + settings := make([]source.TestSetting[string], 0, len(tt.stringPolicies)) for p, v := range tt.stringPolicies { - v := v // construct a unique pointer for each policy value - msh.stringPolicies[p] = &v + settings = append(settings, source.TestSettingOf(p, v)) } - syspolicy.SetHandlerForTest(t, msh) + policyStore := source.NewTestStoreOf(t, settings...) + syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) t.Run("unit", func(t *testing.T) { prefs := tt.prefs.Clone() - gotAnyChange := applySysPolicy(prefs) + gotAnyChange := applySysPolicy(prefs, "", false) if gotAnyChange && prefs.Equals(&tt.prefs) { t.Errorf("anyChange but prefs is unchanged: %v", prefs.Pretty()) @@ -2544,40 +2518,24 @@ func TestPreferencePolicyInfo(t *testing.T) { }, } + syspolicy.RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { for _, pp := range preferencePolicies { t.Run(string(pp.key), func(t *testing.T) { - var h syspolicy.Handler - - allPolicies := make(map[syspolicy.Key]*string, len(preferencePolicies)+1) - allPolicies[syspolicy.ControlURL] = nil - for _, pp := range preferencePolicies { - allPolicies[pp.key] = nil - } - - if tt.policyError != nil { - h = &errorSyspolicyHandler{ - t: t, - err: tt.policyError, - key: pp.key, - allowKeys: allPolicies, - } - } else { - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: allPolicies, - failUnknownPolicies: true, - } - msh.stringPolicies[pp.key] = &tt.policyValue - h = msh + s := source.TestSetting[string]{ + Key: pp.key, + Error: tt.policyError, + Value: tt.policyValue, } - syspolicy.SetHandlerForTest(t, h) + policyStore := source.NewTestStoreOf(t, s) + syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) prefs := defaultPrefs.AsStruct() pp.set(prefs, tt.initialValue) - gotAnyChange := applySysPolicy(prefs) + gotAnyChange := applySysPolicy(prefs, "", false) if gotAnyChange != tt.wantChange { t.Errorf("anyChange=%v, want %v", gotAnyChange, tt.wantChange) @@ -2662,7 +2620,7 @@ func TestOnTailnetDefaultAutoUpdate(t *testing.T) { // On platforms that don't support auto-update we can never // transition to auto-updates being enabled. The value should // remain unchanged after onTailnetDefaultAutoUpdate. - if !clientupdate.CanAutoUpdate() && want.EqualBool(true) { + if !clientupdate.CanAutoUpdate() { want = tt.before } if got := b.pm.CurrentPrefs().AutoUpdate().Apply; got != want { @@ -2674,7 +2632,6 @@ func TestOnTailnetDefaultAutoUpdate(t *testing.T) { func TestTCPHandlerForDst(t *testing.T) { b := newTestBackend(t) - tests := []struct { desc string dst string @@ -2697,11 +2654,6 @@ func TestTCPHandlerForDst(t *testing.T) { }, { desc: "intercept port 8080 (Taildrive) on quad100 IPv4", - dst: "100.100.100.100:8080", - intercept: true, - }, - { - desc: "intercept port 8080 (Taildrive) on quad100 IPv6", dst: "[fd7a:115c:a1e0::53]:8080", intercept: true, }, @@ -2726,7 +2678,6 @@ func TestTCPHandlerForDst(t *testing.T) { intercept: false, }, } - for _, tt := range tests { t.Run(tt.dst, func(t *testing.T) { t.Log(tt.desc) @@ -2741,83 +2692,305 @@ func TestTCPHandlerForDst(t *testing.T) { } } -func TestDriveManageShares(t *testing.T) { +func TestTCPHandlerForDstWithVIPService(t *testing.T) { + b := newTestBackend(t) + svcIPMap := tailcfg.ServiceIPMappings{ + "svc:foo": []netip.Addr{ + netip.MustParseAddr("100.101.101.101"), + netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:cd96:6565:6565"), + }, + "svc:bar": []netip.Addr{ + netip.MustParseAddr("100.99.99.99"), + netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:cd96:626b:628b"), + }, + "svc:baz": []netip.Addr{ + netip.MustParseAddr("100.133.133.133"), + netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:cd96:8585:8585"), + }, + } + svcIPMapJSON, err := json.Marshal(svcIPMap) + if err != nil { + t.Fatal(err) + } + b.setNetMapLocked( + &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Name: "example.ts.net", + CapMap: tailcfg.NodeCapMap{ + tailcfg.NodeAttrServiceHost: []tailcfg.RawMessage{tailcfg.RawMessage(svcIPMapJSON)}, + }, + }).View(), + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + tailcfg.UserID(1): (&tailcfg.UserProfile{ + LoginName: "someone@example.com", + DisplayName: "Some One", + ProfilePicURL: "https://example.com/photo.jpg", + }).View(), + }, + }, + ) + + err = b.setServeConfigLocked( + &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 882: {HTTP: true}, + 883: {HTTPS: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "foo.example.ts.net:882": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Proxy: "http://127.0.0.1:3000"}, + }, + }, + "foo.example.ts.net:883": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Text: "test"}, + }, + }, + }, + }, + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 990: {TCPForward: "127.0.0.1:8443"}, + 991: {TCPForward: "127.0.0.1:5432", TerminateTLS: "bar.test.ts.net"}, + }, + }, + "svc:qux": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 600: {HTTPS: true}, + }, + Web: map[ipn.HostPort]*ipn.WebServerConfig{ + "qux.example.ts.net:600": { + Handlers: map[string]*ipn.HTTPHandler{ + "/": {Text: "qux"}, + }, + }, + }, + }, + }, + }, + "", + ) + if err != nil { + t.Fatal(err) + } + tests := []struct { - name string - disabled bool - existing []*drive.Share - add *drive.Share - remove string - rename [2]string - expect any + desc string + dst string + intercept bool }{ { - name: "append", - existing: []*drive.Share{ - {Name: "b"}, - {Name: "d"}, - }, - add: &drive.Share{Name: " E "}, - expect: []*drive.Share{ - {Name: "b"}, - {Name: "d"}, - {Name: "e"}, - }, + desc: "intercept port 80 (Web UI) on quad100 IPv4", + dst: "100.100.100.100:80", + intercept: true, }, { - name: "prepend", - existing: []*drive.Share{ - {Name: "b"}, - {Name: "d"}, - }, - add: &drive.Share{Name: " A "}, - expect: []*drive.Share{ - {Name: "a"}, - {Name: "b"}, - {Name: "d"}, - }, + desc: "intercept port 80 (Web UI) on quad100 IPv6", + dst: "[fd7a:115c:a1e0::53]:80", + intercept: true, }, { - name: "insert", - existing: []*drive.Share{ - {Name: "b"}, - {Name: "d"}, - }, - add: &drive.Share{Name: " C "}, - expect: []*drive.Share{ - {Name: "b"}, - {Name: "c"}, - {Name: "d"}, - }, + desc: "don't intercept port 80 on local ip", + dst: "100.100.103.100:80", + intercept: false, }, { - name: "replace", - existing: []*drive.Share{ - {Name: "b", Path: "i"}, - {Name: "d"}, - }, - add: &drive.Share{Name: " B ", Path: "ii"}, - expect: []*drive.Share{ - {Name: "b", Path: "ii"}, - {Name: "d"}, - }, + desc: "intercept port 8080 (Taildrive) on quad100 IPv4", + dst: "100.100.100.100:8080", + intercept: true, }, { - name: "add_bad_name", - add: &drive.Share{Name: "$"}, - expect: drive.ErrInvalidShareName, + desc: "intercept port 8080 (Taildrive) on quad100 IPv6", + dst: "[fd7a:115c:a1e0::53]:8080", + intercept: true, }, { - name: "add_disabled", - disabled: true, - add: &drive.Share{Name: "a"}, - expect: drive.ErrDriveNotEnabled, + desc: "don't intercept port 8080 on local ip", + dst: "100.100.103.100:8080", + intercept: false, }, { - name: "remove", - existing: []*drive.Share{ - {Name: "a"}, - {Name: "b"}, + desc: "don't intercept port 9080 on quad100 IPv4", + dst: "100.100.100.100:9080", + intercept: false, + }, + { + desc: "don't intercept port 9080 on quad100 IPv6", + dst: "[fd7a:115c:a1e0::53]:9080", + intercept: false, + }, + { + desc: "don't intercept port 9080 on local ip", + dst: "100.100.103.100:9080", + intercept: false, + }, + // VIP service destinations + { + desc: "intercept port 882 (HTTP) on service foo IPv4", + dst: "100.101.101.101:882", + intercept: true, + }, + { + desc: "intercept port 882 (HTTP) on service foo IPv6", + dst: "[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:882", + intercept: true, + }, + { + desc: "intercept port 883 (HTTPS) on service foo IPv4", + dst: "100.101.101.101:883", + intercept: true, + }, + { + desc: "intercept port 883 (HTTPS) on service foo IPv6", + dst: "[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:883", + intercept: true, + }, + { + desc: "intercept port 990 (TCPForward) on service bar IPv4", + dst: "100.99.99.99:990", + intercept: true, + }, + { + desc: "intercept port 990 (TCPForward) on service bar IPv6", + dst: "[fd7a:115c:a1e0:ab12:4843:cd96:626b:628b]:990", + intercept: true, + }, + { + desc: "intercept port 991 (TCPForward with TerminateTLS) on service bar IPv4", + dst: "100.99.99.99:990", + intercept: true, + }, + { + desc: "intercept port 991 (TCPForward with TerminateTLS) on service bar IPv6", + dst: "[fd7a:115c:a1e0:ab12:4843:cd96:626b:628b]:990", + intercept: true, + }, + { + desc: "don't intercept port 4444 on service foo IPv4", + dst: "100.101.101.101:4444", + intercept: false, + }, + { + desc: "don't intercept port 4444 on service foo IPv6", + dst: "[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:4444", + intercept: false, + }, + { + desc: "don't intercept port 600 on unknown service IPv4", + dst: "100.22.22.22:883", + intercept: false, + }, + { + desc: "don't intercept port 600 on unknown service IPv6", + dst: "[fd7a:115c:a1e0:ab12:4843:cd96:626b:628b]:883", + intercept: false, + }, + { + desc: "don't intercept port 600 (HTTPS) on service baz IPv4", + dst: "100.133.133.133:600", + intercept: false, + }, + { + desc: "don't intercept port 600 (HTTPS) on service baz IPv6", + dst: "[fd7a:115c:a1e0:ab12:4843:cd96:8585:8585]:600", + intercept: false, + }, + } + + for _, tt := range tests { + t.Run(tt.dst, func(t *testing.T) { + t.Log(tt.desc) + src := netip.MustParseAddrPort("100.100.102.100:51234") + h, _ := b.TCPHandlerForDst(src, netip.MustParseAddrPort(tt.dst)) + if !tt.intercept && h != nil { + t.Error("intercepted traffic we shouldn't have") + } else if tt.intercept && h == nil { + t.Error("failed to intercept traffic we should have") + } + }) + } +} + +func TestDriveManageShares(t *testing.T) { + tests := []struct { + name string + disabled bool + existing []*drive.Share + add *drive.Share + remove string + rename [2]string + expect any + }{ + { + name: "append", + existing: []*drive.Share{ + {Name: "b"}, + {Name: "d"}, + }, + add: &drive.Share{Name: " E "}, + expect: []*drive.Share{ + {Name: "b"}, + {Name: "d"}, + {Name: "e"}, + }, + }, + { + name: "prepend", + existing: []*drive.Share{ + {Name: "b"}, + {Name: "d"}, + }, + add: &drive.Share{Name: " A "}, + expect: []*drive.Share{ + {Name: "a"}, + {Name: "b"}, + {Name: "d"}, + }, + }, + { + name: "insert", + existing: []*drive.Share{ + {Name: "b"}, + {Name: "d"}, + }, + add: &drive.Share{Name: " C "}, + expect: []*drive.Share{ + {Name: "b"}, + {Name: "c"}, + {Name: "d"}, + }, + }, + { + name: "replace", + existing: []*drive.Share{ + {Name: "b", Path: "i"}, + {Name: "d"}, + }, + add: &drive.Share{Name: " B ", Path: "ii"}, + expect: []*drive.Share{ + {Name: "b", Path: "ii"}, + {Name: "d"}, + }, + }, + { + name: "add_bad_name", + add: &drive.Share{Name: "$"}, + expect: drive.ErrInvalidShareName, + }, + { + name: "add_disabled", + disabled: true, + add: &drive.Share{Name: "a"}, + expect: drive.ErrDriveNotEnabled, + }, + { + name: "remove", + existing: []*drive.Share{ + {Name: "a"}, + {Name: "b"}, {Name: "c"}, }, remove: "b", @@ -2898,9 +3071,11 @@ func TestDriveManageShares(t *testing.T) { b.driveSetSharesLocked(tt.existing) } if !tt.disabled { - self := b.netMap.SelfNode.AsStruct() + nm := ptr.To(*b.currentNode().NetMap()) + self := nm.SelfNode.AsStruct() self.CapMap = tailcfg.NodeCapMap{tailcfg.NodeAttrsTaildriveShare: nil} - b.netMap.SelfNode = self.View() + nm.SelfNode = self.View() + b.currentNode().SetNetMap(nm) b.sys.Set(driveimpl.NewFileSystemForRemote(b.logf)) } b.mu.Unlock() @@ -3044,7 +3219,7 @@ func makePeer(id tailcfg.NodeID, opts ...peerOptFunc) tailcfg.NodeView { ID: id, StableID: tailcfg.StableNodeID(fmt.Sprintf("stable%d", id)), Name: fmt.Sprintf("peer%d", id), - DERP: fmt.Sprintf("127.3.3.40:%d", id), + HomeDERP: int(id), } for _, opt := range opts { opt(node) @@ -3060,13 +3235,13 @@ func withName(name string) peerOptFunc { func withDERP(region int) peerOptFunc { return func(n *tailcfg.Node) { - n.DERP = fmt.Sprintf("127.3.3.40:%d", region) + n.HomeDERP = region } } func withoutDERP() peerOptFunc { return func(n *tailcfg.Node) { - n.DERP = "" + n.HomeDERP = 0 } } @@ -3140,12 +3315,10 @@ func deterministicNodeForTest(t testing.TB, want views.Slice[tailcfg.StableNodeI var ret tailcfg.NodeView gotIDs := make([]tailcfg.StableNodeID, got.Len()) - for i := range got.Len() { - nv := got.At(i) + for i, nv := range got.All() { if !nv.Valid() { t.Fatalf("invalid node at index %v", i) } - gotIDs[i] = nv.StableID() if nv.StableID() == use { ret = nv @@ -3823,15 +3996,16 @@ func TestShouldAutoExitNode(t *testing.T) { expectedBool: false, }, } + + syspolicy.RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - msh := &mockSyspolicyHandler{ - t: t, - stringPolicies: map[syspolicy.Key]*string{ - syspolicy.ExitNodeID: ptr.To(tt.exitNodeIDPolicyValue), - }, - } - syspolicy.SetHandlerForTest(t, msh) + policyStore := source.NewTestStoreOf(t, source.TestSettingOf( + syspolicy.ExitNodeID, tt.exitNodeIDPolicyValue, + )) + syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) + got := shouldAutoExitNode() if got != tt.expectedBool { t.Fatalf("expected %v got %v for %v policy value", tt.expectedBool, got, tt.exitNodeIDPolicyValue) @@ -3881,9 +4055,9 @@ func TestReadWriteRouteInfo(t *testing.T) { b := newTestBackend(t) prof1 := ipn.LoginProfile{ID: "id1", Key: "key1"} prof2 := ipn.LoginProfile{ID: "id2", Key: "key2"} - b.pm.knownProfiles["id1"] = &prof1 - b.pm.knownProfiles["id2"] = &prof2 - b.pm.currentProfile = &prof1 + b.pm.knownProfiles["id1"] = prof1.View() + b.pm.knownProfiles["id2"] = prof2.View() + b.pm.currentProfile = prof1.View() // set up routeInfo ri1 := &appc.RouteInfo{} @@ -3907,7 +4081,7 @@ func TestReadWriteRouteInfo(t *testing.T) { } // write the other routeInfo as the other profile - if err := b.pm.SwitchProfile("id2"); err != nil { + if _, _, err := b.pm.SwitchToProfileByID("id2"); err != nil { t.Fatal(err) } if err := b.storeRouteInfo(ri2); err != nil { @@ -3915,7 +4089,7 @@ func TestReadWriteRouteInfo(t *testing.T) { } // read the routeInfo of the first profile - if err := b.pm.SwitchProfile("id1"); err != nil { + if _, _, err := b.pm.SwitchToProfileByID("id1"); err != nil { t.Fatal(err) } readRi, err = b.readRouteInfoLocked() @@ -3927,7 +4101,7 @@ func TestReadWriteRouteInfo(t *testing.T) { } // read the routeInfo of the second profile - if err := b.pm.SwitchProfile("id2"); err != nil { + if _, _, err := b.pm.SwitchToProfileByID("id2"); err != nil { t.Fatal(err) } readRi, err = b.readRouteInfoLocked() @@ -3969,17 +4143,13 @@ func TestFillAllowedSuggestions(t *testing.T) { want: []tailcfg.StableNodeID{"ABC", "def", "gHiJ"}, }, } + syspolicy.RegisterWellKnownSettingsForTest(t) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mh := mockSyspolicyHandler{ - t: t, - } - if tt.allowPolicy != nil { - mh.stringArrayPolicies = map[syspolicy.Key][]string{ - syspolicy.AllowedSuggestedExitNodes: tt.allowPolicy, - } - } - syspolicy.SetHandlerForTest(t, &mh) + policyStore := source.NewTestStoreOf(t, source.TestSettingOf( + syspolicy.AllowedSuggestedExitNodes, tt.allowPolicy, + )) + syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore) got := fillAllowedSuggestions() if got == nil { @@ -3998,3 +4168,1174 @@ func TestFillAllowedSuggestions(t *testing.T) { }) } } + +func TestNotificationTargetMatch(t *testing.T) { + tests := []struct { + name string + target notificationTarget + actor ipnauth.Actor + wantMatch bool + }{ + { + name: "AllClients/Nil", + target: allClients, + actor: nil, + wantMatch: true, + }, + { + name: "AllClients/NoUID/NoCID", + target: allClients, + actor: &ipnauth.TestActor{}, + wantMatch: true, + }, + { + name: "AllClients/WithUID/NoCID", + target: allClients, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.NoClientID}, + wantMatch: true, + }, + { + name: "AllClients/NoUID/WithCID", + target: allClients, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, + }, + { + name: "AllClients/WithUID/WithCID", + target: allClients, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, + }, + { + name: "FilterByUID/Nil", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: nil, + wantMatch: false, + }, + { + name: "FilterByUID/NoUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{}, + wantMatch: false, + }, + { + name: "FilterByUID/NoUID/WithCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("A")}, + wantMatch: false, + }, + { + name: "FilterByUID/SameUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4"}, + wantMatch: true, + }, + { + name: "FilterByUID/DifferentUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8"}, + wantMatch: false, + }, + { + name: "FilterByUID/SameUID/WithCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, + }, + { + name: "FilterByUID/DifferentUID/WithCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: false, + }, + { + name: "FilterByCID/Nil", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: nil, + wantMatch: false, + }, + { + name: "FilterByCID/NoUID/NoCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{}, + wantMatch: false, + }, + { + name: "FilterByCID/NoUID/SameCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, + }, + { + name: "FilterByCID/NoUID/DifferentCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("B")}, + wantMatch: false, + }, + { + name: "FilterByCID/WithUID/NoCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4"}, + wantMatch: false, + }, + { + name: "FilterByCID/WithUID/SameCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, + }, + { + name: "FilterByCID/WithUID/DifferentCID", + target: notificationTarget{clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("B")}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/Nil", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4"}, + actor: nil, + wantMatch: false, + }, + { + name: "FilterByUID+CID/NoUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/NoUID/SameCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("A")}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/NoUID/DifferentCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{CID: ipnauth.ClientIDFrom("B")}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/SameUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4"}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/SameUID/SameCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: true, + }, + { + name: "FilterByUID+CID/SameUID/DifferentCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-1-2-3-4", CID: ipnauth.ClientIDFrom("B")}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/DifferentUID/NoCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8"}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/DifferentUID/SameCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8", CID: ipnauth.ClientIDFrom("A")}, + wantMatch: false, + }, + { + name: "FilterByUID+CID/DifferentUID/DifferentCID", + target: notificationTarget{userID: "S-1-5-21-1-2-3-4", clientID: ipnauth.ClientIDFrom("A")}, + actor: &ipnauth.TestActor{UID: "S-1-5-21-5-6-7-8", CID: ipnauth.ClientIDFrom("B")}, + wantMatch: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotMatch := tt.target.match(tt.actor) + if gotMatch != tt.wantMatch { + t.Errorf("match: got %v; want %v", gotMatch, tt.wantMatch) + } + }) + } +} + +type newTestControlFn func(tb testing.TB, opts controlclient.Options) controlclient.Client + +func newLocalBackendWithTestControl(t *testing.T, enableLogging bool, newControl newTestControlFn) *LocalBackend { + return newLocalBackendWithSysAndTestControl(t, enableLogging, tsd.NewSystem(), newControl) +} + +func newLocalBackendWithSysAndTestControl(t *testing.T, enableLogging bool, sys *tsd.System, newControl newTestControlFn) *LocalBackend { + logf := logger.Discard + if enableLogging { + logf = tstest.WhileTestRunningLogger(t) + } + + if _, hasStore := sys.StateStore.GetOK(); !hasStore { + store := new(mem.Store) + sys.Set(store) + } + if _, hasEngine := sys.Engine.GetOK(); !hasEngine { + e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry(), sys.Bus.Get()) + if err != nil { + t.Fatalf("NewFakeUserspaceEngine: %v", err) + } + t.Cleanup(e.Close) + sys.Set(e) + } + + b, err := NewLocalBackend(logf, logid.PublicID{}, sys, 0) + if err != nil { + t.Fatalf("NewLocalBackend: %v", err) + } + t.Cleanup(b.Shutdown) + b.DisablePortMapperForTest() + + b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) { + return newControl(t, opts), nil + }) + return b +} + +// notificationHandler is any function that can process (e.g., check) a notification. +// It returns whether the notification has been handled or should be passed to the next handler. +// The handler may be called from any goroutine, so it must avoid calling functions +// that are restricted to the goroutine running the test or benchmark function, +// such as [testing.common.FailNow] and [testing.common.Fatalf]. +type notificationHandler func(testing.TB, ipnauth.Actor, *ipn.Notify) bool + +// wantedNotification names a [notificationHandler] that processes a notification +// the test expects and wants to receive. The name is used to report notifications +// that haven't been received within the expected timeout. +type wantedNotification struct { + name string + cond notificationHandler +} + +// notificationWatcher observes [LocalBackend] notifications as the specified actor, +// reporting missing but expected notifications using [testing.common.Error], +// and delegating the handling of unexpected notifications to the [notificationHandler]s. +type notificationWatcher struct { + tb testing.TB + lb *LocalBackend + actor ipnauth.Actor + + mu sync.Mutex + mask ipn.NotifyWatchOpt + want []wantedNotification // notifications we want to receive + unexpected []notificationHandler // funcs that are called to check any other notifications + ctxCancel context.CancelFunc // cancels the outstanding [LocalBackend.WatchNotificationsAs] call + got []*ipn.Notify // all notifications, both wanted and unexpected, we've received so far + gotWanted []*ipn.Notify // only the expected notifications; holds nil for any notification that hasn't been received + gotWantedCh chan struct{} // closed when we have received the last wanted notification + doneCh chan struct{} // closed when [LocalBackend.WatchNotificationsAs] returns +} + +func newNotificationWatcher(tb testing.TB, lb *LocalBackend, actor ipnauth.Actor) *notificationWatcher { + return ¬ificationWatcher{tb: tb, lb: lb, actor: actor} +} + +func (w *notificationWatcher) watch(mask ipn.NotifyWatchOpt, wanted []wantedNotification, unexpected ...notificationHandler) { + w.tb.Helper() + + // Cancel any outstanding [LocalBackend.WatchNotificationsAs] calls. + w.mu.Lock() + ctxCancel := w.ctxCancel + doneCh := w.doneCh + w.mu.Unlock() + if doneCh != nil { + ctxCancel() + <-doneCh + } + + doneCh = make(chan struct{}) + gotWantedCh := make(chan struct{}) + ctx, ctxCancel := context.WithCancel(context.Background()) + w.tb.Cleanup(func() { + ctxCancel() + <-doneCh + }) + + w.mu.Lock() + w.mask = mask + w.want = wanted + w.unexpected = unexpected + w.ctxCancel = ctxCancel + w.got = nil + w.gotWanted = make([]*ipn.Notify, len(wanted)) + w.gotWantedCh = gotWantedCh + w.doneCh = doneCh + w.mu.Unlock() + + watchAddedCh := make(chan struct{}) + go func() { + defer close(doneCh) + if len(wanted) == 0 { + close(gotWantedCh) + if len(unexpected) == 0 { + close(watchAddedCh) + return + } + } + + var nextWantIdx int + w.lb.WatchNotificationsAs(ctx, w.actor, w.mask, func() { close(watchAddedCh) }, func(notify *ipn.Notify) (keepGoing bool) { + w.tb.Helper() + + w.mu.Lock() + defer w.mu.Unlock() + w.got = append(w.got, notify) + + wanted := false + for i := nextWantIdx; i < len(w.want); i++ { + if wanted = w.want[i].cond(w.tb, w.actor, notify); wanted { + w.gotWanted[i] = notify + nextWantIdx = i + 1 + break + } + } + + if wanted && nextWantIdx == len(w.want) { + close(w.gotWantedCh) + if len(w.unexpected) == 0 { + // If we have received the last wanted notification, + // and we don't have any handlers for the unexpected notifications, + // we can stop the watcher right away. + return false + } + + } + + if !wanted { + // If we've received a notification we didn't expect, + // it could either be an unwanted notification caused by a bug + // or just a miscellaneous one that's irrelevant for the current test. + // Call unexpected notification handlers, if any, to + // check and fail the test if necessary. + for _, h := range w.unexpected { + if h(w.tb, w.actor, notify) { + break + } + } + } + + return true + }) + + }() + <-watchAddedCh +} + +func (w *notificationWatcher) check() []*ipn.Notify { + w.tb.Helper() + + w.mu.Lock() + cancel := w.ctxCancel + gotWantedCh := w.gotWantedCh + checkUnexpected := len(w.unexpected) != 0 + doneCh := w.doneCh + w.mu.Unlock() + + // Wait for up to 10 seconds to receive expected notifications. + timeout := 10 * time.Second + for { + select { + case <-gotWantedCh: + if checkUnexpected { + gotWantedCh = nil + // But do not wait longer than 500ms for unexpected notifications after + // the expected notifications have been received. + timeout = 500 * time.Millisecond + continue + } + case <-doneCh: + // [LocalBackend.WatchNotificationsAs] has already returned, so no further + // notifications will be received. There's no reason to wait any longer. + case <-time.After(timeout): + } + cancel() + <-doneCh + break + } + + // Report missing notifications, if any, and log all received notifications, + // including both expected and unexpected ones. + w.mu.Lock() + defer w.mu.Unlock() + if hasMissing := slices.Contains(w.gotWanted, nil); hasMissing { + want := make([]string, len(w.want)) + got := make([]string, 0, len(w.want)) + for i, wn := range w.want { + want[i] = wn.name + if w.gotWanted[i] != nil { + got = append(got, wn.name) + } + } + w.tb.Errorf("Notifications(%s): got %q; want %q", actorDescriptionForTest(w.actor), strings.Join(got, ", "), strings.Join(want, ", ")) + for i, n := range w.got { + w.tb.Logf("%d. %v", i, n) + } + return nil + } + + return w.gotWanted +} + +func actorDescriptionForTest(actor ipnauth.Actor) string { + var parts []string + if actor != nil { + if name, _ := actor.Username(); name != "" { + parts = append(parts, name) + } + if uid := actor.UserID(); uid != "" { + parts = append(parts, string(uid)) + } + if clientID, _ := actor.ClientID(); clientID != ipnauth.NoClientID { + parts = append(parts, clientID.String()) + } + } + return fmt.Sprintf("Actor{%s}", strings.Join(parts, ", ")) +} + +func TestLoginNotifications(t *testing.T) { + const ( + enableLogging = true + controlURL = "https://localhost:1/" + loginURL = "https://localhost:1/1" + ) + + wantBrowseToURL := wantedNotification{ + name: "BrowseToURL", + cond: func(t testing.TB, actor ipnauth.Actor, n *ipn.Notify) bool { + if n.BrowseToURL != nil && *n.BrowseToURL != loginURL { + t.Errorf("BrowseToURL (%s): got %q; want %q", actorDescriptionForTest(actor), *n.BrowseToURL, loginURL) + return false + } + return n.BrowseToURL != nil + }, + } + unexpectedBrowseToURL := func(t testing.TB, actor ipnauth.Actor, n *ipn.Notify) bool { + if n.BrowseToURL != nil { + t.Errorf("Unexpected BrowseToURL(%s): %v", actorDescriptionForTest(actor), n) + return true + } + return false + } + + tests := []struct { + name string + logInAs ipnauth.Actor + urlExpectedBy []ipnauth.Actor + urlUnexpectedBy []ipnauth.Actor + }{ + { + name: "NoObservers", + logInAs: &ipnauth.TestActor{UID: "A"}, + urlExpectedBy: []ipnauth.Actor{}, // ensure that it does not panic if no one is watching + }, + { + name: "SingleUser", + logInAs: &ipnauth.TestActor{UID: "A"}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A"}}, + }, + { + name: "SameUser/TwoSessions/NoCID", + logInAs: &ipnauth.TestActor{UID: "A"}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A"}, &ipnauth.TestActor{UID: "A"}}, + }, + { + name: "SameUser/TwoSessions/OneWithCID", + logInAs: &ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("123")}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("123")}}, + urlUnexpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A"}}, + }, + { + name: "SameUser/TwoSessions/BothWithCID", + logInAs: &ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("123")}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("123")}}, + urlUnexpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("456")}}, + }, + { + name: "DifferentUsers/NoCID", + logInAs: &ipnauth.TestActor{UID: "A"}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A"}}, + urlUnexpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "B"}}, + }, + { + name: "DifferentUsers/SameCID", + logInAs: &ipnauth.TestActor{UID: "A"}, + urlExpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "A", CID: ipnauth.ClientIDFrom("123")}}, + urlUnexpectedBy: []ipnauth.Actor{&ipnauth.TestActor{UID: "B", CID: ipnauth.ClientIDFrom("123")}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + lb := newLocalBackendWithTestControl(t, enableLogging, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + return newClient(tb, opts) + }) + if _, err := lb.EditPrefs(&ipn.MaskedPrefs{ControlURLSet: true, Prefs: ipn.Prefs{ControlURL: controlURL}}); err != nil { + t.Fatalf("(*EditPrefs).Start(): %v", err) + } + if err := lb.Start(ipn.Options{}); err != nil { + t.Fatalf("(*LocalBackend).Start(): %v", err) + } + + sessions := make([]*notificationWatcher, 0, len(tt.urlExpectedBy)+len(tt.urlUnexpectedBy)) + for _, actor := range tt.urlExpectedBy { + session := newNotificationWatcher(t, lb, actor) + session.watch(0, []wantedNotification{wantBrowseToURL}) + sessions = append(sessions, session) + } + for _, actor := range tt.urlUnexpectedBy { + session := newNotificationWatcher(t, lb, actor) + session.watch(0, nil, unexpectedBrowseToURL) + sessions = append(sessions, session) + } + + if err := lb.StartLoginInteractiveAs(context.Background(), tt.logInAs); err != nil { + t.Fatal(err) + } + + lb.cc.(*mockControl).send(nil, loginURL, false, nil) + + var wg sync.WaitGroup + wg.Add(len(sessions)) + for _, sess := range sessions { + go func() { // check all sessions in parallel + sess.check() + wg.Done() + }() + } + wg.Wait() + }) + } +} + +// TestConfigFileReload tests that the LocalBackend reloads its configuration +// when the configuration file changes. +func TestConfigFileReload(t *testing.T) { + type testCase struct { + name string + initial *conffile.Config + updated *conffile.Config + checkFn func(*testing.T, *LocalBackend) + } + + tests := []testCase{ + { + name: "hostname_change", + initial: &conffile.Config{ + Parsed: ipn.ConfigVAlpha{ + Version: "alpha0", + Hostname: ptr.To("initial-host"), + }, + }, + updated: &conffile.Config{ + Parsed: ipn.ConfigVAlpha{ + Version: "alpha0", + Hostname: ptr.To("updated-host"), + }, + }, + checkFn: func(t *testing.T, b *LocalBackend) { + if got := b.Prefs().Hostname(); got != "updated-host" { + t.Errorf("hostname = %q; want updated-host", got) + } + }, + }, + { + name: "start_advertising_services", + initial: &conffile.Config{ + Parsed: ipn.ConfigVAlpha{ + Version: "alpha0", + }, + }, + updated: &conffile.Config{ + Parsed: ipn.ConfigVAlpha{ + Version: "alpha0", + AdvertiseServices: []string{"svc:abc", "svc:def"}, + }, + }, + checkFn: func(t *testing.T, b *LocalBackend) { + if got := b.Prefs().AdvertiseServices().AsSlice(); !reflect.DeepEqual(got, []string{"svc:abc", "svc:def"}) { + t.Errorf("AdvertiseServices = %v; want [svc:abc, svc:def]", got) + } + }, + }, + { + name: "change_advertised_services", + initial: &conffile.Config{ + Parsed: ipn.ConfigVAlpha{ + Version: "alpha0", + AdvertiseServices: []string{"svc:abc", "svc:def"}, + }, + }, + updated: &conffile.Config{ + Parsed: ipn.ConfigVAlpha{ + Version: "alpha0", + AdvertiseServices: []string{"svc:abc", "svc:ghi"}, + }, + }, + checkFn: func(t *testing.T, b *LocalBackend) { + if got := b.Prefs().AdvertiseServices().AsSlice(); !reflect.DeepEqual(got, []string{"svc:abc", "svc:ghi"}) { + t.Errorf("AdvertiseServices = %v; want [svc:abc, svc:ghi]", got) + } + }, + }, + { + name: "unset_advertised_services", + initial: &conffile.Config{ + Parsed: ipn.ConfigVAlpha{ + Version: "alpha0", + AdvertiseServices: []string{"svc:abc"}, + }, + }, + updated: &conffile.Config{ + Parsed: ipn.ConfigVAlpha{ + Version: "alpha0", + }, + }, + checkFn: func(t *testing.T, b *LocalBackend) { + if b.Prefs().AdvertiseServices().Len() != 0 { + t.Errorf("got %d AdvertiseServices wants none", b.Prefs().AdvertiseServices().Len()) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "tailscale.conf") + + // Write initial config + initialJSON, err := json.Marshal(tc.initial.Parsed) + if err != nil { + t.Fatal(err) + } + if err := os.WriteFile(path, initialJSON, 0644); err != nil { + t.Fatal(err) + } + + // Create backend with initial config + tc.initial.Path = path + tc.initial.Raw = initialJSON + sys := tsd.NewSystem() + sys.InitialConfig = tc.initial + b := newTestLocalBackendWithSys(t, sys) + + // Update config file + updatedJSON, err := json.Marshal(tc.updated.Parsed) + if err != nil { + t.Fatal(err) + } + if err := os.WriteFile(path, updatedJSON, 0644); err != nil { + t.Fatal(err) + } + + // Trigger reload + if ok, err := b.ReloadConfig(); !ok || err != nil { + t.Fatalf("ReloadConfig() = %v, %v; want true, nil", ok, err) + } + + // Check outcome + tc.checkFn(t, b) + }) + } +} + +func TestGetVIPServices(t *testing.T) { + tests := []struct { + name string + advertised []string + serveConfig *ipn.ServeConfig + want []*tailcfg.VIPService + }{ + { + "advertised-only", + []string{"svc:abc", "svc:def"}, + &ipn.ServeConfig{}, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Active: true, + }, + { + Name: "svc:def", + Active: true, + }, + }, + }, + { + "served-only", + []string{}, + &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:abc": {Tun: true}, + }, + }, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Ports: []tailcfg.ProtoPortRange{{Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + "served-and-advertised", + []string{"svc:abc"}, + &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:abc": {Tun: true}, + }, + }, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Active: true, + Ports: []tailcfg.ProtoPortRange{{Ports: tailcfg.PortRangeAny}}, + }, + }, + }, + { + "served-and-advertised-different-service", + []string{"svc:def"}, + &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:abc": {Tun: true}, + }, + }, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Ports: []tailcfg.ProtoPortRange{{Ports: tailcfg.PortRangeAny}}, + }, + { + Name: "svc:def", + Active: true, + }, + }, + }, + { + "served-with-port-ranges-one-range-single", + []string{}, + &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:abc": {TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTPS: true}, + }}, + }, + }, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Ports: []tailcfg.ProtoPortRange{{Proto: 6, Ports: tailcfg.PortRange{First: 80, Last: 80}}}, + }, + }, + }, + { + "served-with-port-ranges-one-range-multiple", + []string{}, + &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:abc": {TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTPS: true}, + 81: {HTTPS: true}, + 82: {HTTPS: true}, + }}, + }, + }, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Ports: []tailcfg.ProtoPortRange{{Proto: 6, Ports: tailcfg.PortRange{First: 80, Last: 82}}}, + }, + }, + }, + { + "served-with-port-ranges-multiple-ranges", + []string{}, + &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:abc": {TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTPS: true}, + 81: {HTTPS: true}, + 82: {HTTPS: true}, + 1212: {HTTPS: true}, + 1213: {HTTPS: true}, + 1214: {HTTPS: true}, + }}, + }, + }, + []*tailcfg.VIPService{ + { + Name: "svc:abc", + Ports: []tailcfg.ProtoPortRange{ + {Proto: 6, Ports: tailcfg.PortRange{First: 80, Last: 82}}, + {Proto: 6, Ports: tailcfg.PortRange{First: 1212, Last: 1214}}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lb := newLocalBackendWithTestControl(t, false, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + return newClient(tb, opts) + }) + lb.serveConfig = tt.serveConfig.View() + prefs := &ipn.Prefs{ + AdvertiseServices: tt.advertised, + } + got := lb.vipServicesFromPrefsLocked(prefs.View()) + slices.SortFunc(got, func(a, b *tailcfg.VIPService) int { + return strings.Compare(a.Name.String(), b.Name.String()) + }) + if !reflect.DeepEqual(tt.want, got) { + t.Logf("want:") + for _, s := range tt.want { + t.Logf("%+v", s) + } + t.Logf("got:") + for _, s := range got { + t.Logf("%+v", s) + } + t.Fail() + return + } + }) + } +} + +func TestUpdatePrefsOnSysPolicyChange(t *testing.T) { + const enableLogging = false + + type fieldChange struct { + name string + want any + } + + wantPrefsChanges := func(want ...fieldChange) *wantedNotification { + return &wantedNotification{ + name: "Prefs", + cond: func(t testing.TB, actor ipnauth.Actor, n *ipn.Notify) bool { + if n.Prefs != nil { + prefs := reflect.Indirect(reflect.ValueOf(n.Prefs.AsStruct())) + for _, f := range want { + got := prefs.FieldByName(f.name).Interface() + if !reflect.DeepEqual(got, f.want) { + t.Errorf("%v: got %v; want %v", f.name, got, f.want) + } + } + } + return n.Prefs != nil + }, + } + } + + unexpectedPrefsChange := func(t testing.TB, _ ipnauth.Actor, n *ipn.Notify) bool { + if n.Prefs != nil { + t.Errorf("Unexpected Prefs: %v", n.Prefs.Pretty()) + return true + } + return false + } + + tests := []struct { + name string + initialPrefs *ipn.Prefs + stringSettings []source.TestSetting[string] + want *wantedNotification + }{ + { + name: "ShieldsUp/True", + stringSettings: []source.TestSetting[string]{source.TestSettingOf(syspolicy.EnableIncomingConnections, "never")}, + want: wantPrefsChanges(fieldChange{"ShieldsUp", true}), + }, + { + name: "ShieldsUp/False", + initialPrefs: &ipn.Prefs{ShieldsUp: true}, + stringSettings: []source.TestSetting[string]{source.TestSettingOf(syspolicy.EnableIncomingConnections, "always")}, + want: wantPrefsChanges(fieldChange{"ShieldsUp", false}), + }, + { + name: "ExitNodeID", + stringSettings: []source.TestSetting[string]{source.TestSettingOf(syspolicy.ExitNodeID, "foo")}, + want: wantPrefsChanges(fieldChange{"ExitNodeID", tailcfg.StableNodeID("foo")}), + }, + { + name: "EnableRunExitNode", + stringSettings: []source.TestSetting[string]{source.TestSettingOf(syspolicy.EnableRunExitNode, "always")}, + want: wantPrefsChanges(fieldChange{"AdvertiseRoutes", []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}}), + }, + { + name: "Multiple", + initialPrefs: &ipn.Prefs{ + ExitNodeAllowLANAccess: true, + }, + stringSettings: []source.TestSetting[string]{ + source.TestSettingOf(syspolicy.EnableServerMode, "always"), + source.TestSettingOf(syspolicy.ExitNodeAllowLANAccess, "never"), + source.TestSettingOf(syspolicy.ExitNodeIP, "127.0.0.1"), + }, + want: wantPrefsChanges( + fieldChange{"ForceDaemon", true}, + fieldChange{"ExitNodeAllowLANAccess", false}, + fieldChange{"ExitNodeIP", netip.MustParseAddr("127.0.0.1")}, + ), + }, + { + name: "NoChange", + initialPrefs: &ipn.Prefs{ + CorpDNS: true, + ExitNodeID: "foo", + AdvertiseRoutes: []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + stringSettings: []source.TestSetting[string]{ + source.TestSettingOf(syspolicy.EnableTailscaleDNS, "always"), + source.TestSettingOf(syspolicy.ExitNodeID, "foo"), + source.TestSettingOf(syspolicy.EnableRunExitNode, "always"), + }, + want: nil, // syspolicy settings match the preferences; no change notification is expected. + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + syspolicy.RegisterWellKnownSettingsForTest(t) + store := source.NewTestStoreOf[string](t) + syspolicy.MustRegisterStoreForTest(t, "TestSource", setting.DeviceScope, store) + + lb := newLocalBackendWithTestControl(t, enableLogging, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + return newClient(tb, opts) + }) + if tt.initialPrefs != nil { + lb.SetPrefsForTest(tt.initialPrefs) + } + if err := lb.Start(ipn.Options{}); err != nil { + t.Fatalf("(*LocalBackend).Start(): %v", err) + } + + nw := newNotificationWatcher(t, lb, &ipnauth.TestActor{}) + if tt.want != nil { + nw.watch(0, []wantedNotification{*tt.want}) + } else { + nw.watch(0, nil, unexpectedPrefsChange) + } + + store.SetStrings(tt.stringSettings...) + + nw.check() + }) + } +} + +func TestUpdateIngressLocked(t *testing.T) { + tests := []struct { + name string + hi *tailcfg.Hostinfo + sc *ipn.ServeConfig + wantIngress bool + wantWireIngress bool + wantControlUpdate bool + }{ + { + name: "no_hostinfo_no_serve_config", + hi: nil, + }, + { + name: "empty_hostinfo_no_serve_config", + hi: &tailcfg.Hostinfo{}, + }, + { + name: "empty_hostinfo_funnel_enabled", + hi: &tailcfg.Hostinfo{}, + sc: &ipn.ServeConfig{ + AllowFunnel: map[ipn.HostPort]bool{ + "tailnet.xyz:443": true, + }, + }, + wantIngress: true, + wantWireIngress: false, // implied by wantIngress + wantControlUpdate: true, + }, + { + name: "empty_hostinfo_funnel_disabled", + hi: &tailcfg.Hostinfo{}, + sc: &ipn.ServeConfig{ + AllowFunnel: map[ipn.HostPort]bool{ + "tailnet.xyz:443": false, + }, + }, + wantWireIngress: true, // true if there is any AllowFunnel block + wantControlUpdate: true, + }, + { + name: "empty_hostinfo_no_funnel", + hi: &tailcfg.Hostinfo{}, + sc: &ipn.ServeConfig{ + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTPS: true}, + }, + }, + }, + { + name: "funnel_enabled_no_change", + hi: &tailcfg.Hostinfo{ + IngressEnabled: true, + }, + sc: &ipn.ServeConfig{ + AllowFunnel: map[ipn.HostPort]bool{ + "tailnet.xyz:443": true, + }, + }, + wantIngress: true, + wantWireIngress: false, // implied by wantIngress + }, + { + name: "funnel_disabled_no_change", + hi: &tailcfg.Hostinfo{ + WireIngress: true, + }, + sc: &ipn.ServeConfig{ + AllowFunnel: map[ipn.HostPort]bool{ + "tailnet.xyz:443": false, + }, + }, + wantWireIngress: true, // true if there is any AllowFunnel block + }, + { + name: "funnel_changes_to_disabled", + hi: &tailcfg.Hostinfo{ + IngressEnabled: true, + }, + sc: &ipn.ServeConfig{ + AllowFunnel: map[ipn.HostPort]bool{ + "tailnet.xyz:443": false, + }, + }, + wantWireIngress: true, // true if there is any AllowFunnel block + wantControlUpdate: true, + }, + { + name: "funnel_changes_to_enabled", + hi: &tailcfg.Hostinfo{ + WireIngress: true, + }, + sc: &ipn.ServeConfig{ + AllowFunnel: map[ipn.HostPort]bool{ + "tailnet.xyz:443": true, + }, + }, + wantIngress: true, + wantWireIngress: false, // implied by wantIngress + wantControlUpdate: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := newTestLocalBackend(t) + b.hostinfo = tt.hi + b.serveConfig = tt.sc.View() + allDone := make(chan bool, 1) + defer b.goTracker.AddDoneCallback(func() { + b.mu.Lock() + defer b.mu.Unlock() + if b.goTracker.RunningGoroutines() > 0 { + return + } + select { + case allDone <- true: + default: + } + })() + + was := b.goTracker.StartedGoroutines() + b.updateIngressLocked() + + if tt.hi != nil { + if tt.hi.IngressEnabled != tt.wantIngress { + t.Errorf("IngressEnabled = %v, want %v", tt.hi.IngressEnabled, tt.wantIngress) + } + if tt.hi.WireIngress != tt.wantWireIngress { + t.Errorf("WireIngress = %v, want %v", tt.hi.WireIngress, tt.wantWireIngress) + } + } + + startedGoroutine := b.goTracker.StartedGoroutines() != was + if startedGoroutine != tt.wantControlUpdate { + t.Errorf("control update triggered = %v, want %v", startedGoroutine, tt.wantControlUpdate) + } + + if startedGoroutine { + select { + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for goroutine to finish") + case <-allDone: + } + } + }) + } +} + +// TestSrcCapPacketFilter tests that LocalBackend handles packet filters with +// SrcCaps instead of Srcs (IPs) +func TestSrcCapPacketFilter(t *testing.T) { + lb := newLocalBackendWithTestControl(t, false, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + return newClient(tb, opts) + }) + if err := lb.Start(ipn.Options{}); err != nil { + t.Fatalf("(*LocalBackend).Start(): %v", err) + } + + var k key.NodePublic + must.Do(k.UnmarshalText([]byte("nodekey:5c8f86d5fc70d924e55f02446165a5dae8f822994ad26bcf4b08fd841f9bf261"))) + + controlClient := lb.cc.(*mockControl) + controlClient.send(nil, "", false, &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")}, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, + ID: 2, + Key: k, + CapMap: tailcfg.NodeCapMap{"cap-X": nil}, // node 2 has cap + }).View(), + (&tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("3.3.3.3/32")}, + ID: 3, + Key: k, + CapMap: tailcfg.NodeCapMap{}, // node 3 does not have the cap + }).View(), + }, + PacketFilter: []filtertype.Match{{ + IPProto: views.SliceOf([]ipproto.Proto{ipproto.TCP}), + SrcCaps: []tailcfg.NodeCapability{"cap-X"}, // cap in packet filter rule + Dsts: []filtertype.NetPortRange{{ + Net: netip.MustParsePrefix("1.1.1.1/32"), + Ports: filtertype.PortRange{ + First: 22, + Last: 22, + }, + }}, + }}, + }) + + f := lb.GetFilterForTest() + res := f.Check(netip.MustParseAddr("2.2.2.2"), netip.MustParseAddr("1.1.1.1"), 22, ipproto.TCP) + if res != filter.Accept { + t.Errorf("Check(2.2.2.2, ...) = %s, want %s", res, filter.Accept) + } + + res = f.Check(netip.MustParseAddr("3.3.3.3"), netip.MustParseAddr("1.1.1.1"), 22, ipproto.TCP) + if !res.IsDrop() { + t.Error("IsDrop() for node without cap = false, want true") + } +} diff --git a/ipn/ipnlocal/loglines_test.go b/ipn/ipnlocal/loglines_test.go index f70987c0e8ad3..5bea6cabca4c4 100644 --- a/ipn/ipnlocal/loglines_test.go +++ b/ipn/ipnlocal/loglines_test.go @@ -47,10 +47,10 @@ func TestLocalLogLines(t *testing.T) { idA := logid(0xaa) // set up a LocalBackend, super bare bones. No functional data. - sys := new(tsd.System) + sys := tsd.NewSystem() store := new(mem.Store) sys.Set(store) - e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) + e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry(), sys.Bus.Get()) if err != nil { t.Fatal(err) } diff --git a/ipn/ipnlocal/network-lock.go b/ipn/ipnlocal/network-lock.go index d20bf94eb971a..36d39a465a654 100644 --- a/ipn/ipnlocal/network-lock.go +++ b/ipn/ipnlocal/network-lock.go @@ -407,7 +407,7 @@ func (b *LocalBackend) tkaApplyDisablementLocked(secret []byte) error { // // b.mu must be held. func (b *LocalBackend) chonkPathLocked() string { - return filepath.Join(b.TailscaleVarRoot(), "tka-profiles", string(b.pm.CurrentProfile().ID)) + return filepath.Join(b.TailscaleVarRoot(), "tka-profiles", string(b.pm.CurrentProfile().ID())) } // tkaBootstrapFromGenesisLocked initializes the local (on-disk) state of the @@ -430,8 +430,7 @@ func (b *LocalBackend) tkaBootstrapFromGenesisLocked(g tkatype.MarshaledAUM, per } bootstrapStateID := fmt.Sprintf("%d:%d", genesis.State.StateID1, genesis.State.StateID2) - for i := range persist.DisallowedTKAStateIDs().Len() { - stateID := persist.DisallowedTKAStateIDs().At(i) + for _, stateID := range persist.DisallowedTKAStateIDs().All() { if stateID == bootstrapStateID { return fmt.Errorf("TKA with stateID of %q is disallowed on this node", stateID) } @@ -456,7 +455,7 @@ func (b *LocalBackend) tkaBootstrapFromGenesisLocked(g tkatype.MarshaledAUM, per } b.tka = &tkaState{ - profile: b.pm.CurrentProfile().ID, + profile: b.pm.CurrentProfile().ID(), authority: authority, storage: chonk, } @@ -517,9 +516,10 @@ func (b *LocalBackend) NetworkLockStatus() *ipnstate.NetworkLockStatus { var selfAuthorized bool nodeKeySignature := &tka.NodeKeySignature{} - if b.netMap != nil { - selfAuthorized = b.tka.authority.NodeKeyAuthorized(b.netMap.SelfNode.Key(), b.netMap.SelfNode.KeySignature().AsSlice()) == nil - if err := nodeKeySignature.Unserialize(b.netMap.SelfNode.KeySignature().AsSlice()); err != nil { + nm := b.currentNode().NetMap() + if nm != nil { + selfAuthorized = b.tka.authority.NodeKeyAuthorized(nm.SelfNode.Key(), nm.SelfNode.KeySignature().AsSlice()) == nil + if err := nodeKeySignature.Unserialize(nm.SelfNode.KeySignature().AsSlice()); err != nil { b.logf("failed to decode self node key signature: %v", err) } } @@ -540,9 +540,9 @@ func (b *LocalBackend) NetworkLockStatus() *ipnstate.NetworkLockStatus { } var visible []*ipnstate.TKAPeer - if b.netMap != nil { - visible = make([]*ipnstate.TKAPeer, len(b.netMap.Peers)) - for i, p := range b.netMap.Peers { + if nm != nil { + visible = make([]*ipnstate.TKAPeer, len(nm.Peers)) + for i, p := range nm.Peers { s := tkaStateFromPeer(p) visible[i] = &s } @@ -572,8 +572,7 @@ func tkaStateFromPeer(p tailcfg.NodeView) ipnstate.TKAPeer { TailscaleIPs: make([]netip.Addr, 0, p.Addresses().Len()), NodeKey: p.Key(), } - for i := range p.Addresses().Len() { - addr := p.Addresses().At(i) + for _, addr := range p.Addresses().All() { if addr.IsSingleIP() && tsaddr.IsTailscaleIP(addr.Addr()) { fp.TailscaleIPs = append(fp.TailscaleIPs, addr.Addr()) } @@ -704,12 +703,10 @@ func (b *LocalBackend) NetworkLockForceLocalDisable() error { id1, id2 := b.tka.authority.StateIDs() stateID := fmt.Sprintf("%d:%d", id1, id2) + cn := b.currentNode() newPrefs := b.pm.CurrentPrefs().AsStruct().Clone() // .Persist should always be initialized here. newPrefs.Persist.DisallowedTKAStateIDs = append(newPrefs.Persist.DisallowedTKAStateIDs, stateID) - if err := b.pm.SetPrefs(newPrefs.View(), ipn.NetworkProfile{ - MagicDNSName: b.netMap.MagicDNSSuffix(), - DomainName: b.netMap.DomainName(), - }); err != nil { + if err := b.pm.SetPrefs(newPrefs.View(), cn.NetworkProfile()); err != nil { return fmt.Errorf("saving prefs: %w", err) } diff --git a/ipn/ipnlocal/network-lock_test.go b/ipn/ipnlocal/network-lock_test.go index 4b79136c81ea9..838f16cb9001f 100644 --- a/ipn/ipnlocal/network-lock_test.go +++ b/ipn/ipnlocal/network-lock_test.go @@ -202,7 +202,7 @@ func TestTKADisablementFlow(t *testing.T) { }).View(), ipn.NetworkProfile{})) temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID)) + tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) os.Mkdir(tkaPath, 0755) chonk, err := tka.ChonkDir(tkaPath) if err != nil { @@ -410,7 +410,7 @@ func TestTKASync(t *testing.T) { } temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID)) + tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) os.Mkdir(tkaPath, 0755) // Setup the TKA authority on the node. nodeStorage, err := tka.ChonkDir(tkaPath) @@ -710,7 +710,7 @@ func TestTKADisable(t *testing.T) { }).View(), ipn.NetworkProfile{})) temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID)) + tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) os.Mkdir(tkaPath, 0755) key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} chonk, err := tka.ChonkDir(tkaPath) @@ -770,7 +770,7 @@ func TestTKADisable(t *testing.T) { ccAuto: cc, logf: t.Logf, tka: &tkaState{ - profile: pm.CurrentProfile().ID, + profile: pm.CurrentProfile().ID(), authority: authority, storage: chonk, }, @@ -805,7 +805,7 @@ func TestTKASign(t *testing.T) { key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID)) + tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) os.Mkdir(tkaPath, 0755) chonk, err := tka.ChonkDir(tkaPath) if err != nil { @@ -890,7 +890,7 @@ func TestTKAForceDisable(t *testing.T) { }).View(), ipn.NetworkProfile{})) temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID)) + tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) os.Mkdir(tkaPath, 0755) chonk, err := tka.ChonkDir(tkaPath) if err != nil { @@ -989,7 +989,7 @@ func TestTKAAffectedSigs(t *testing.T) { tkaKey := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID)) + tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) os.Mkdir(tkaPath, 0755) chonk, err := tka.ChonkDir(tkaPath) if err != nil { @@ -1124,7 +1124,7 @@ func TestTKARecoverCompromisedKeyFlow(t *testing.T) { compromisedKey := tka.Key{Kind: tka.Key25519, Public: compromisedPriv.Public().Verifier(), Votes: 1} temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID)) + tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) os.Mkdir(tkaPath, 0755) chonk, err := tka.ChonkDir(tkaPath) if err != nil { diff --git a/ipn/ipnlocal/node_backend.go b/ipn/ipnlocal/node_backend.go new file mode 100644 index 0000000000000..fb77f38ebc87a --- /dev/null +++ b/ipn/ipnlocal/node_backend.go @@ -0,0 +1,665 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "cmp" + "net/netip" + "slices" + "sync" + "sync/atomic" + + "go4.org/netipx" + "tailscale.com/ipn" + "tailscale.com/net/dns" + "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" + "tailscale.com/types/dnstype" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/types/netmap" + "tailscale.com/types/ptr" + "tailscale.com/types/views" + "tailscale.com/util/dnsname" + "tailscale.com/util/mak" + "tailscale.com/util/slicesx" + "tailscale.com/wgengine/filter" +) + +// nodeBackend is node-specific [LocalBackend] state. It is usually the current node. +// +// Its exported methods are safe for concurrent use, but the struct is not a snapshot of state at a given moment; +// its state can change between calls. For example, asking for the same value (e.g., netmap or prefs) twice +// may return different results. Returned values are immutable and safe for concurrent use. +// +// If both the [LocalBackend]'s internal mutex and the [nodeBackend] mutex must be held at the same time, +// the [LocalBackend] mutex must be acquired first. See the comment on the [LocalBackend] field for more details. +// +// Two pointers to different [nodeBackend] instances represent different local nodes. +// However, there's currently a bug where a new [nodeBackend] might not be created +// during an implicit node switch (see tailscale/corp#28014). + +// In the future, we might want to include at least the following in this struct (in addition to the current fields). +// However, not everything should be exported or otherwise made available to the outside world (e.g. [ipnext] extensions, +// peer API handlers, etc.). +// - [ipn.State]: when the LocalBackend switches to a different [nodeBackend], it can update the state of the old one. +// - [ipn.LoginProfileView] and [ipn.Prefs]: we should update them when the [profileManager] reports changes to them. +// In the future, [profileManager] (and the corresponding methods of the [LocalBackend]) can be made optional, +// and something else could be used to set them once or update them as needed. +// - [tailcfg.HostinfoView]: it includes certain fields that are tied to the current profile/node/prefs. We should also +// update to build it once instead of mutating it in twelvety different places. +// - [filter.Filter] (normal and jailed, along with the filterHash): the nodeBackend could have a method to (re-)build +// the filter for the current netmap/prefs (see [LocalBackend.updateFilterLocked]), and it needs to track the current +// filters and their hash. +// - Fields related to a requested or required (re-)auth: authURL, authURLTime, authActor, keyExpired, etc. +// - [controlclient.Client]/[*controlclient.Auto]: the current control client. It is ties to a node identity. +// - [tkaState]: it is tied to the current profile / node. +// - Fields related to scheduled node expiration: nmExpiryTimer, numClientStatusCalls, [expiryManager]. +// +// It should not include any fields used by specific features that don't belong in [LocalBackend]. +// Even if they're tied to the local node, instead of moving them here, we should extract the entire feature +// into a separate package and have it install proper hooks. +type nodeBackend struct { + // filterAtomic is a stateful packet filter. Immutable once created, but can be + // replaced with a new one. + filterAtomic atomic.Pointer[filter.Filter] + + // TODO(nickkhyl): maybe use sync.RWMutex? + mu sync.Mutex // protects the following fields + + // NetMap is the most recently set full netmap from the controlclient. + // It can't be mutated in place once set. Because it can't be mutated in place, + // delta updates from the control server don't apply to it. Instead, use + // the peers map to get up-to-date information on the state of peers. + // In general, avoid using the netMap.Peers slice. We'd like it to go away + // as of 2023-09-17. + // TODO(nickkhyl): make it an atomic pointer to avoid the need for a mutex? + netMap *netmap.NetworkMap + + // peers is the set of current peers and their current values after applying + // delta node mutations as they come in (with mu held). The map values can be + // given out to callers, but the map itself can be mutated in place (with mu held) + // and must not escape the [nodeBackend]. + peers map[tailcfg.NodeID]tailcfg.NodeView + + // nodeByAddr maps nodes' own addresses (excluding subnet routes) to node IDs. + // It is mutated in place (with mu held) and must not escape the [nodeBackend]. + nodeByAddr map[netip.Addr]tailcfg.NodeID +} + +func newNodeBackend() *nodeBackend { + cn := &nodeBackend{} + // Default filter blocks everything and logs nothing. + noneFilter := filter.NewAllowNone(logger.Discard, &netipx.IPSet{}) + cn.filterAtomic.Store(noneFilter) + return cn +} + +func (nb *nodeBackend) Self() tailcfg.NodeView { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { + return tailcfg.NodeView{} + } + return nb.netMap.SelfNode +} + +func (nb *nodeBackend) SelfUserID() tailcfg.UserID { + self := nb.Self() + if !self.Valid() { + return 0 + } + return self.User() +} + +// SelfHasCap reports whether the specified capability was granted to the self node in the most recent netmap. +func (nb *nodeBackend) SelfHasCap(wantCap tailcfg.NodeCapability) bool { + return nb.SelfHasCapOr(wantCap, false) +} + +// SelfHasCapOr is like [nodeBackend.SelfHasCap], but returns the specified default value +// if the netmap is not available yet. +func (nb *nodeBackend) SelfHasCapOr(wantCap tailcfg.NodeCapability, def bool) bool { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { + return def + } + return nb.netMap.AllCaps.Contains(wantCap) +} + +func (nb *nodeBackend) NetworkProfile() ipn.NetworkProfile { + nb.mu.Lock() + defer nb.mu.Unlock() + return ipn.NetworkProfile{ + // These are ok to call with nil netMap. + MagicDNSName: nb.netMap.MagicDNSSuffix(), + DomainName: nb.netMap.DomainName(), + } +} + +// TODO(nickkhyl): update it to return a [tailcfg.DERPMapView]? +func (nb *nodeBackend) DERPMap() *tailcfg.DERPMap { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { + return nil + } + return nb.netMap.DERPMap +} + +func (nb *nodeBackend) NodeByAddr(ip netip.Addr) (_ tailcfg.NodeID, ok bool) { + nb.mu.Lock() + defer nb.mu.Unlock() + nid, ok := nb.nodeByAddr[ip] + return nid, ok +} + +func (nb *nodeBackend) NodeByKey(k key.NodePublic) (_ tailcfg.NodeID, ok bool) { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { + return 0, false + } + if self := nb.netMap.SelfNode; self.Valid() && self.Key() == k { + return self.ID(), true + } + // TODO(bradfitz,nickkhyl): add nodeByKey like nodeByAddr instead of walking peers. + for _, n := range nb.peers { + if n.Key() == k { + return n.ID(), true + } + } + return 0, false +} + +func (nb *nodeBackend) PeerByID(id tailcfg.NodeID) (_ tailcfg.NodeView, ok bool) { + nb.mu.Lock() + defer nb.mu.Unlock() + n, ok := nb.peers[id] + return n, ok +} + +func (nb *nodeBackend) PeerByStableID(id tailcfg.StableNodeID) (_ tailcfg.NodeView, ok bool) { + nb.mu.Lock() + defer nb.mu.Unlock() + for _, n := range nb.peers { + if n.StableID() == id { + return n, true + } + } + return tailcfg.NodeView{}, false +} + +func (nb *nodeBackend) UserByID(id tailcfg.UserID) (_ tailcfg.UserProfileView, ok bool) { + nb.mu.Lock() + nm := nb.netMap + nb.mu.Unlock() + if nm == nil { + return tailcfg.UserProfileView{}, false + } + u, ok := nm.UserProfiles[id] + return u, ok +} + +// Peers returns all the current peers in an undefined order. +func (nb *nodeBackend) Peers() []tailcfg.NodeView { + nb.mu.Lock() + defer nb.mu.Unlock() + return slicesx.MapValues(nb.peers) +} + +func (nb *nodeBackend) PeersForTest() []tailcfg.NodeView { + nb.mu.Lock() + defer nb.mu.Unlock() + ret := slicesx.MapValues(nb.peers) + slices.SortFunc(ret, func(a, b tailcfg.NodeView) int { + return cmp.Compare(a.ID(), b.ID()) + }) + return ret +} + +// AppendMatchingPeers returns base with all peers that match pred appended. +// +// It acquires b.mu to read the netmap but releases it before calling pred. +func (nb *nodeBackend) AppendMatchingPeers(base []tailcfg.NodeView, pred func(tailcfg.NodeView) bool) []tailcfg.NodeView { + var peers []tailcfg.NodeView + + nb.mu.Lock() + if nb.netMap != nil { + // All fields on b.netMap are immutable, so this is + // safe to copy and use outside the lock. + peers = nb.netMap.Peers + } + nb.mu.Unlock() + + ret := base + for _, peer := range peers { + // The peers in b.netMap don't contain updates made via + // UpdateNetmapDelta. So only use PeerView in b.netMap for its NodeID, + // and then look up the latest copy in b.peers which is updated in + // response to UpdateNetmapDelta edits. + nb.mu.Lock() + peer, ok := nb.peers[peer.ID()] + nb.mu.Unlock() + if ok && pred(peer) { + ret = append(ret, peer) + } + } + return ret +} + +// PeerCaps returns the capabilities that remote src IP has to +// ths current node. +func (nb *nodeBackend) PeerCaps(src netip.Addr) tailcfg.PeerCapMap { + nb.mu.Lock() + defer nb.mu.Unlock() + return nb.peerCapsLocked(src) +} + +func (nb *nodeBackend) peerCapsLocked(src netip.Addr) tailcfg.PeerCapMap { + if nb.netMap == nil { + return nil + } + filt := nb.filterAtomic.Load() + if filt == nil { + return nil + } + addrs := nb.netMap.GetAddresses() + for i := range addrs.Len() { + a := addrs.At(i) + if !a.IsSingleIP() { + continue + } + dst := a.Addr() + if dst.BitLen() == src.BitLen() { // match on family + return filt.CapsWithValues(src, dst) + } + } + return nil +} + +// PeerHasCap reports whether the peer contains the given capability string, +// with any value(s). +func (nb *nodeBackend) PeerHasCap(peer tailcfg.NodeView, wantCap tailcfg.PeerCapability) bool { + if !peer.Valid() { + return false + } + + nb.mu.Lock() + defer nb.mu.Unlock() + for _, ap := range peer.Addresses().All() { + if nb.peerHasCapLocked(ap.Addr(), wantCap) { + return true + } + } + return false +} + +func (nb *nodeBackend) peerHasCapLocked(addr netip.Addr, wantCap tailcfg.PeerCapability) bool { + return nb.peerCapsLocked(addr).HasCapability(wantCap) +} + +func (nb *nodeBackend) PeerHasPeerAPI(p tailcfg.NodeView) bool { + return nb.PeerAPIBase(p) != "" +} + +// PeerAPIBase returns the "http://ip:port" URL base to reach peer's PeerAPI, +// or the empty string if the peer is invalid or doesn't support PeerAPI. +func (nb *nodeBackend) PeerAPIBase(p tailcfg.NodeView) string { + nb.mu.Lock() + nm := nb.netMap + nb.mu.Unlock() + return peerAPIBase(nm, p) +} + +func nodeIP(n tailcfg.NodeView, pred func(netip.Addr) bool) netip.Addr { + for _, pfx := range n.Addresses().All() { + if pfx.IsSingleIP() && pred(pfx.Addr()) { + return pfx.Addr() + } + } + return netip.Addr{} +} + +func (nb *nodeBackend) NetMap() *netmap.NetworkMap { + nb.mu.Lock() + defer nb.mu.Unlock() + return nb.netMap +} + +func (nb *nodeBackend) netMapWithPeers() *netmap.NetworkMap { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { + return nil + } + nm := ptr.To(*nb.netMap) // shallow clone + nm.Peers = slicesx.MapValues(nb.peers) + slices.SortFunc(nm.Peers, func(a, b tailcfg.NodeView) int { + return cmp.Compare(a.ID(), b.ID()) + }) + return nm +} + +func (nb *nodeBackend) SetNetMap(nm *netmap.NetworkMap) { + nb.mu.Lock() + defer nb.mu.Unlock() + nb.netMap = nm + nb.updateNodeByAddrLocked() + nb.updatePeersLocked() +} + +func (nb *nodeBackend) updateNodeByAddrLocked() { + nm := nb.netMap + if nm == nil { + nb.nodeByAddr = nil + return + } + + // Update the nodeByAddr index. + if nb.nodeByAddr == nil { + nb.nodeByAddr = map[netip.Addr]tailcfg.NodeID{} + } + // First pass, mark everything unwanted. + for k := range nb.nodeByAddr { + nb.nodeByAddr[k] = 0 + } + addNode := func(n tailcfg.NodeView) { + for _, ipp := range n.Addresses().All() { + if ipp.IsSingleIP() { + nb.nodeByAddr[ipp.Addr()] = n.ID() + } + } + } + if nm.SelfNode.Valid() { + addNode(nm.SelfNode) + } + for _, p := range nm.Peers { + addNode(p) + } + // Third pass, actually delete the unwanted items. + for k, v := range nb.nodeByAddr { + if v == 0 { + delete(nb.nodeByAddr, k) + } + } +} + +func (nb *nodeBackend) updatePeersLocked() { + nm := nb.netMap + if nm == nil { + nb.peers = nil + return + } + + // First pass, mark everything unwanted. + for k := range nb.peers { + nb.peers[k] = tailcfg.NodeView{} + } + + // Second pass, add everything wanted. + for _, p := range nm.Peers { + mak.Set(&nb.peers, p.ID(), p) + } + + // Third pass, remove deleted things. + for k, v := range nb.peers { + if !v.Valid() { + delete(nb.peers, k) + } + } +} + +func (nb *nodeBackend) UpdateNetmapDelta(muts []netmap.NodeMutation) (handled bool) { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil || len(nb.peers) == 0 { + return false + } + + // Locally cloned mutable nodes, to avoid calling AsStruct (clone) + // multiple times on a node if it's mutated multiple times in this + // call (e.g. its endpoints + online status both change) + var mutableNodes map[tailcfg.NodeID]*tailcfg.Node + + for _, m := range muts { + n, ok := mutableNodes[m.NodeIDBeingMutated()] + if !ok { + nv, ok := nb.peers[m.NodeIDBeingMutated()] + if !ok { + // TODO(bradfitz): unexpected metric? + return false + } + n = nv.AsStruct() + mak.Set(&mutableNodes, nv.ID(), n) + } + m.Apply(n) + } + for nid, n := range mutableNodes { + nb.peers[nid] = n.View() + } + return true +} + +// unlockedNodesPermitted reports whether any peer with theUnsignedPeerAPIOnly bool set true has any of its allowed IPs +// in the specified packet filter. +// +// TODO(nickkhyl): It is here temporarily until we can move the whole [LocalBackend.updateFilterLocked] here, +// but change it so it builds and returns a filter for the current netmap/prefs instead of re-configuring the engine filter. +// Something like (*nodeBackend).RebuildFilters() (filter, jailedFilter *filter.Filter, changed bool) perhaps? +func (nb *nodeBackend) unlockedNodesPermitted(packetFilter []filter.Match) bool { + nb.mu.Lock() + defer nb.mu.Unlock() + return packetFilterPermitsUnlockedNodes(nb.peers, packetFilter) +} + +func (nb *nodeBackend) filter() *filter.Filter { + return nb.filterAtomic.Load() +} + +func (nb *nodeBackend) setFilter(f *filter.Filter) { + nb.filterAtomic.Store(f) +} + +func (nb *nodeBackend) dnsConfigForNetmap(prefs ipn.PrefsView, selfExpired bool, logf logger.Logf, versionOS string) *dns.Config { + nb.mu.Lock() + defer nb.mu.Unlock() + return dnsConfigForNetmap(nb.netMap, nb.peers, prefs, selfExpired, logf, versionOS) +} + +func (nb *nodeBackend) exitNodeCanProxyDNS(exitNodeID tailcfg.StableNodeID) (dohURL string, ok bool) { + nb.mu.Lock() + defer nb.mu.Unlock() + return exitNodeCanProxyDNS(nb.netMap, nb.peers, exitNodeID) +} + +// dnsConfigForNetmap returns a *dns.Config for the given netmap, +// prefs, client OS version, and cloud hosting environment. +// +// The versionOS is a Tailscale-style version ("iOS", "macOS") and not +// a runtime.GOOS. +func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg.NodeView, prefs ipn.PrefsView, selfExpired bool, logf logger.Logf, versionOS string) *dns.Config { + if nm == nil { + return nil + } + + // If the current node's key is expired, then we don't program any DNS + // configuration into the operating system. This ensures that if the + // DNS configuration specifies a DNS server that is only reachable over + // Tailscale, we don't break connectivity for the user. + // + // TODO(andrew-d): this also stops returning anything from quad-100; we + // could do the same thing as having "CorpDNS: false" and keep that but + // not program the OS? + if selfExpired { + return &dns.Config{} + } + + dcfg := &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: map[dnsname.FQDN][]netip.Addr{}, + } + + // selfV6Only is whether we only have IPv6 addresses ourselves. + selfV6Only := nm.GetAddresses().ContainsFunc(tsaddr.PrefixIs6) && + !nm.GetAddresses().ContainsFunc(tsaddr.PrefixIs4) + dcfg.OnlyIPv6 = selfV6Only + + wantAAAA := nm.AllCaps.Contains(tailcfg.NodeAttrMagicDNSPeerAAAA) + + // Populate MagicDNS records. We do this unconditionally so that + // quad-100 can always respond to MagicDNS queries, even if the OS + // isn't configured to make MagicDNS resolution truly + // magic. Details in + // https://github.com/tailscale/tailscale/issues/1886. + set := func(name string, addrs views.Slice[netip.Prefix]) { + if addrs.Len() == 0 || name == "" { + return + } + fqdn, err := dnsname.ToFQDN(name) + if err != nil { + return // TODO: propagate error? + } + var have4 bool + for _, addr := range addrs.All() { + if addr.Addr().Is4() { + have4 = true + break + } + } + var ips []netip.Addr + for _, addr := range addrs.All() { + if selfV6Only { + if addr.Addr().Is6() { + ips = append(ips, addr.Addr()) + } + continue + } + // If this node has an IPv4 address, then + // remove peers' IPv6 addresses for now, as we + // don't guarantee that the peer node actually + // can speak IPv6 correctly. + // + // https://github.com/tailscale/tailscale/issues/1152 + // tracks adding the right capability reporting to + // enable AAAA in MagicDNS. + if addr.Addr().Is6() && have4 && !wantAAAA { + continue + } + ips = append(ips, addr.Addr()) + } + dcfg.Hosts[fqdn] = ips + } + set(nm.Name, nm.GetAddresses()) + for _, peer := range peers { + set(peer.Name(), peer.Addresses()) + } + for _, rec := range nm.DNS.ExtraRecords { + switch rec.Type { + case "", "A", "AAAA": + // Treat these all the same for now: infer from the value + default: + // TODO: more + continue + } + ip, err := netip.ParseAddr(rec.Value) + if err != nil { + // Ignore. + continue + } + fqdn, err := dnsname.ToFQDN(rec.Name) + if err != nil { + continue + } + dcfg.Hosts[fqdn] = append(dcfg.Hosts[fqdn], ip) + } + + if !prefs.CorpDNS() { + return dcfg + } + + for _, dom := range nm.DNS.Domains { + fqdn, err := dnsname.ToFQDN(dom) + if err != nil { + logf("[unexpected] non-FQDN search domain %q", dom) + } + dcfg.SearchDomains = append(dcfg.SearchDomains, fqdn) + } + if nm.DNS.Proxied { // actually means "enable MagicDNS" + for _, dom := range magicDNSRootDomains(nm) { + dcfg.Routes[dom] = nil // resolve internally with dcfg.Hosts + } + } + + addDefault := func(resolvers []*dnstype.Resolver) { + dcfg.DefaultResolvers = append(dcfg.DefaultResolvers, resolvers...) + } + + // If we're using an exit node and that exit node is new enough (1.19.x+) + // to run a DoH DNS proxy, then send all our DNS traffic through it. + if dohURL, ok := exitNodeCanProxyDNS(nm, peers, prefs.ExitNodeID()); ok { + addDefault([]*dnstype.Resolver{{Addr: dohURL}}) + return dcfg + } + + // If the user has set default resolvers ("override local DNS"), prefer to + // use those resolvers as the default, otherwise if there are WireGuard exit + // node resolvers, use those as the default. + if len(nm.DNS.Resolvers) > 0 { + addDefault(nm.DNS.Resolvers) + } else { + if resolvers, ok := wireguardExitNodeDNSResolvers(nm, peers, prefs.ExitNodeID()); ok { + addDefault(resolvers) + } + } + + for suffix, resolvers := range nm.DNS.Routes { + fqdn, err := dnsname.ToFQDN(suffix) + if err != nil { + logf("[unexpected] non-FQDN route suffix %q", suffix) + } + + // Create map entry even if len(resolvers) == 0; Issue 2706. + // This lets the control plane send ExtraRecords for which we + // can authoritatively answer "name not exists" for when the + // control plane also sends this explicit but empty route + // making it as something we handle. + // + // While we're already populating it, might as well size the + // slice appropriately. + // Per #9498 the exact requirements of nil vs empty slice remain + // unclear, this is a haunted graveyard to be resolved. + dcfg.Routes[fqdn] = make([]*dnstype.Resolver, 0, len(resolvers)) + dcfg.Routes[fqdn] = append(dcfg.Routes[fqdn], resolvers...) + } + + // Set FallbackResolvers as the default resolvers in the + // scenarios that can't handle a purely split-DNS config. See + // https://github.com/tailscale/tailscale/issues/1743 for + // details. + switch { + case len(dcfg.DefaultResolvers) != 0: + // Default resolvers already set. + case !prefs.ExitNodeID().IsZero(): + // When using an exit node, we send all DNS traffic to the exit node, so + // we don't need a fallback resolver. + // + // However, if the exit node is too old to run a DoH DNS proxy, then we + // need to use a fallback resolver as it's very likely the LAN resolvers + // will become unreachable. + // + // This is especially important on Apple OSes, where + // adding the default route to the tunnel interface makes + // it "primary", and we MUST provide VPN-sourced DNS + // settings or we break all DNS resolution. + // + // https://github.com/tailscale/tailscale/issues/1713 + addDefault(nm.DNS.FallbackResolvers) + case len(dcfg.Routes) == 0: + // No settings requiring split DNS, no problem. + } + + return dcfg +} diff --git a/ipn/ipnlocal/peerapi.go b/ipn/ipnlocal/peerapi.go index aa18c35886648..84aaecf7efb78 100644 --- a/ipn/ipnlocal/peerapi.go +++ b/ipn/ipnlocal/peerapi.go @@ -15,18 +15,15 @@ import ( "net" "net/http" "net/netip" - "net/url" "os" "path/filepath" "runtime" "slices" - "sort" "strconv" "strings" "sync" "time" - "github.com/kortschak/wol" "golang.org/x/net/dns/dnsmessage" "golang.org/x/net/http/httpguts" "tailscale.com/drive" @@ -39,10 +36,9 @@ import ( "tailscale.com/net/netutil" "tailscale.com/net/sockstats" "tailscale.com/tailcfg" - "tailscale.com/taildrop" + "tailscale.com/types/netmap" "tailscale.com/types/views" "tailscale.com/util/clientmetric" - "tailscale.com/util/httphdr" "tailscale.com/util/httpm" "tailscale.com/wgengine/filter" ) @@ -65,8 +61,6 @@ type peerDNSQueryHandler interface { type peerAPIServer struct { b *LocalBackend resolver peerDNSQueryHandler - - taildrop *taildrop.Manager } func (s *peerAPIServer) listen(ip netip.Addr, ifState *netmon.State) (ln net.Listener, err error) { @@ -226,6 +220,29 @@ type peerAPIHandler struct { peerUser tailcfg.UserProfile // profile of peerNode } +// PeerAPIHandler is the interface implemented by [peerAPIHandler] and needed by +// module features registered via tailscale.com/feature/*. +type PeerAPIHandler interface { + Peer() tailcfg.NodeView + PeerCaps() tailcfg.PeerCapMap + Self() tailcfg.NodeView + LocalBackend() *LocalBackend + IsSelfUntagged() bool // whether the peer is untagged and the same as this user + RemoteAddr() netip.AddrPort + Logf(format string, a ...any) +} + +func (h *peerAPIHandler) IsSelfUntagged() bool { + return !h.selfNode.IsTagged() && !h.peerNode.IsTagged() && h.isSelf +} +func (h *peerAPIHandler) Peer() tailcfg.NodeView { return h.peerNode } +func (h *peerAPIHandler) Self() tailcfg.NodeView { return h.selfNode } +func (h *peerAPIHandler) RemoteAddr() netip.AddrPort { return h.remoteAddr } +func (h *peerAPIHandler) LocalBackend() *LocalBackend { return h.ps.b } +func (h *peerAPIHandler) Logf(format string, a ...any) { + h.logf(format, a...) +} + func (h *peerAPIHandler) logf(format string, a ...any) { h.ps.b.logf("peerapi: "+format, a...) } @@ -233,11 +250,13 @@ func (h *peerAPIHandler) logf(format string, a ...any) { // isAddressValid reports whether addr is a valid destination address for this // node originating from the peer. func (h *peerAPIHandler) isAddressValid(addr netip.Addr) bool { - if v := h.peerNode.SelfNodeV4MasqAddrForThisPeer(); v != nil { - return *v == addr + if !addr.IsValid() { + return false } - if v := h.peerNode.SelfNodeV6MasqAddrForThisPeer(); v != nil { - return *v == addr + v4MasqAddr, hasMasqV4 := h.peerNode.SelfNodeV4MasqAddrForThisPeer().GetOk() + v6MasqAddr, hasMasqV6 := h.peerNode.SelfNodeV6MasqAddrForThisPeer().GetOk() + if hasMasqV4 || hasMasqV6 { + return addr == v4MasqAddr || addr == v6MasqAddr } pfx := netip.PrefixFrom(addr, addr.BitLen()) return views.SliceContains(h.selfNode.Addresses(), pfx) @@ -300,6 +319,29 @@ func peerAPIRequestShouldGetSecurityHeaders(r *http.Request) bool { return false } +// RegisterPeerAPIHandler registers a PeerAPI handler. +// +// The path should be of the form "/v0/foo". +// +// It panics if the path is already registered. +func RegisterPeerAPIHandler(path string, f func(PeerAPIHandler, http.ResponseWriter, *http.Request)) { + if _, ok := peerAPIHandlers[path]; ok { + panic(fmt.Sprintf("duplicate PeerAPI handler %q", path)) + } + peerAPIHandlers[path] = f + if strings.HasSuffix(path, "/") { + peerAPIHandlerPrefixes[path] = f + } +} + +var ( + peerAPIHandlers = map[string]func(PeerAPIHandler, http.ResponseWriter, *http.Request){} // by URL.Path + + // peerAPIHandlerPrefixes are the subset of peerAPIHandlers where + // the map key ends with a slash, indicating a prefix match. + peerAPIHandlerPrefixes = map[string]func(PeerAPIHandler, http.ResponseWriter, *http.Request){} +) + func (h *peerAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err := h.validatePeerAPIRequest(r); err != nil { metricInvalidRequests.Add(1) @@ -312,12 +354,11 @@ func (h *peerAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-Content-Type-Options", "nosniff") } - if strings.HasPrefix(r.URL.Path, "/v0/put/") { - if r.Method == "PUT" { - metricPutCalls.Add(1) + for pfx, ph := range peerAPIHandlerPrefixes { + if strings.HasPrefix(r.URL.Path, pfx) { + ph(h, w, r) + return } - h.handlePeerPut(w, r) - return } if strings.HasPrefix(r.URL.Path, "/dns-query") { metricDNSCalls.Add(1) @@ -344,10 +385,6 @@ func (h *peerAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "/v0/dnsfwd": h.handleServeDNSFwd(w, r) return - case "/v0/wol": - metricWakeOnLANCalls.Add(1) - h.handleWakeOnLAN(w, r) - return case "/v0/interfaces": h.handleServeInterfaces(w, r) return @@ -362,6 +399,14 @@ func (h *peerAPIHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.handleServeIngress(w, r) return } + if ph, ok := peerAPIHandlers[r.URL.Path]; ok { + ph(h, w, r) + return + } + if r.URL.Path != "/" { + http.Error(w, "unsupported peerapi path", http.StatusNotFound) + return + } who := h.peerUser.DisplayName fmt.Fprintf(w, ` @@ -386,7 +431,7 @@ func (h *peerAPIHandler) handleServeIngress(w http.ResponseWriter, r *http.Reque } logAndError := func(code int, publicMsg string) { h.logf("ingress: bad request from %v: %s", h.remoteAddr, publicMsg) - http.Error(w, publicMsg, http.StatusMethodNotAllowed) + http.Error(w, publicMsg, code) } bad := func(publicMsg string) { logAndError(http.StatusBadRequest, publicMsg) @@ -450,7 +495,7 @@ func (h *peerAPIHandler) handleServeInterfaces(w http.ResponseWriter, r *http.Re fmt.Fprintf(w, "

Could not get the default route: %s

\n", html.EscapeString(err.Error())) } - if hasCGNATInterface, err := netmon.HasCGNATInterface(); hasCGNATInterface { + if hasCGNATInterface, err := h.ps.b.sys.NetMon.Get().HasCGNATInterface(); hasCGNATInterface { fmt.Fprintln(w, "

There is another interface using the CGNAT range.

") } else if err != nil { fmt.Fprintf(w, "

Could not check for CGNAT interfaces: %s

\n", html.EscapeString(err.Error())) @@ -599,15 +644,6 @@ func (h *peerAPIHandler) handleServeSockStats(w http.ResponseWriter, r *http.Req fmt.Fprintln(w, "") } -// canPutFile reports whether h can put a file ("Taildrop") to this node. -func (h *peerAPIHandler) canPutFile() bool { - if h.peerNode.UnsignedPeerAPIOnly() { - // Unsigned peers can't send files. - return false - } - return h.isSelf || h.peerHasCap(tailcfg.PeerCapabilityFileSharingSend) -} - // canDebug reports whether h can debug this node (goroutines, metrics, // magicsock internal state, etc). func (h *peerAPIHandler) canDebug() bool { @@ -622,14 +658,6 @@ func (h *peerAPIHandler) canDebug() bool { return h.isSelf || h.peerHasCap(tailcfg.PeerCapabilityDebugPeer) } -// canWakeOnLAN reports whether h can send a Wake-on-LAN packet from this node. -func (h *peerAPIHandler) canWakeOnLAN() bool { - if h.peerNode.UnsignedPeerAPIOnly() { - return false - } - return h.isSelf || h.peerHasCap(tailcfg.PeerCapabilityWakeOnLAN) -} - var allowSelfIngress = envknob.RegisterBool("TS_ALLOW_SELF_INGRESS") // canIngress reports whether h can send ingress requests to this node. @@ -638,117 +666,13 @@ func (h *peerAPIHandler) canIngress() bool { } func (h *peerAPIHandler) peerHasCap(wantCap tailcfg.PeerCapability) bool { - return h.peerCaps().HasCapability(wantCap) + return h.PeerCaps().HasCapability(wantCap) } -func (h *peerAPIHandler) peerCaps() tailcfg.PeerCapMap { +func (h *peerAPIHandler) PeerCaps() tailcfg.PeerCapMap { return h.ps.b.PeerCaps(h.remoteAddr.Addr()) } -func (h *peerAPIHandler) handlePeerPut(w http.ResponseWriter, r *http.Request) { - if !h.canPutFile() { - http.Error(w, taildrop.ErrNoTaildrop.Error(), http.StatusForbidden) - return - } - if !h.ps.b.hasCapFileSharing() { - http.Error(w, taildrop.ErrNoTaildrop.Error(), http.StatusForbidden) - return - } - rawPath := r.URL.EscapedPath() - prefix, ok := strings.CutPrefix(rawPath, "/v0/put/") - if !ok { - http.Error(w, "misconfigured internals", http.StatusForbidden) - return - } - baseName, err := url.PathUnescape(prefix) - if err != nil { - http.Error(w, taildrop.ErrInvalidFileName.Error(), http.StatusBadRequest) - return - } - enc := json.NewEncoder(w) - switch r.Method { - case "GET": - id := taildrop.ClientID(h.peerNode.StableID()) - if prefix == "" { - // List all the partial files. - files, err := h.ps.taildrop.PartialFiles(id) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - if err := enc.Encode(files); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - h.logf("json.Encoder.Encode error: %v", err) - return - } - } else { - // Stream all the block hashes for the specified file. - next, close, err := h.ps.taildrop.HashPartialFile(id, baseName) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - defer close() - for { - switch cs, err := next(); { - case err == io.EOF: - return - case err != nil: - http.Error(w, err.Error(), http.StatusInternalServerError) - h.logf("HashPartialFile.next error: %v", err) - return - default: - if err := enc.Encode(cs); err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - h.logf("json.Encoder.Encode error: %v", err) - return - } - } - } - } - case "PUT": - t0 := h.ps.b.clock.Now() - id := taildrop.ClientID(h.peerNode.StableID()) - - var offset int64 - if rangeHdr := r.Header.Get("Range"); rangeHdr != "" { - ranges, ok := httphdr.ParseRange(rangeHdr) - if !ok || len(ranges) != 1 || ranges[0].Length != 0 { - http.Error(w, "invalid Range header", http.StatusBadRequest) - return - } - offset = ranges[0].Start - } - n, err := h.ps.taildrop.PutFile(taildrop.ClientID(fmt.Sprint(id)), baseName, r.Body, offset, r.ContentLength) - switch err { - case nil: - d := h.ps.b.clock.Since(t0).Round(time.Second / 10) - h.logf("got put of %s in %v from %v/%v", approxSize(n), d, h.remoteAddr.Addr(), h.peerNode.ComputedName) - io.WriteString(w, "{}\n") - case taildrop.ErrNoTaildrop: - http.Error(w, err.Error(), http.StatusForbidden) - case taildrop.ErrInvalidFileName: - http.Error(w, err.Error(), http.StatusBadRequest) - case taildrop.ErrFileExists: - http.Error(w, err.Error(), http.StatusConflict) - default: - http.Error(w, err.Error(), http.StatusInternalServerError) - } - default: - http.Error(w, "expected method GET or PUT", http.StatusMethodNotAllowed) - } -} - -func approxSize(n int64) string { - if n <= 1<<10 { - return "<=1KB" - } - if n <= 1<<20 { - return "<=1MB" - } - return fmt.Sprintf("~%dMB", n>>20) -} - func (h *peerAPIHandler) handleServeGoroutines(w http.ResponseWriter, r *http.Request) { if !h.canDebug() { http.Error(w, "denied; no debug access", http.StatusForbidden) @@ -815,61 +739,6 @@ func (h *peerAPIHandler) handleServeDNSFwd(w http.ResponseWriter, r *http.Reques dh.ServeHTTP(w, r) } -func (h *peerAPIHandler) handleWakeOnLAN(w http.ResponseWriter, r *http.Request) { - if !h.canWakeOnLAN() { - http.Error(w, "no WoL access", http.StatusForbidden) - return - } - if r.Method != "POST" { - http.Error(w, "bad method", http.StatusMethodNotAllowed) - return - } - macStr := r.FormValue("mac") - if macStr == "" { - http.Error(w, "missing 'mac' param", http.StatusBadRequest) - return - } - mac, err := net.ParseMAC(macStr) - if err != nil { - http.Error(w, "bad 'mac' param", http.StatusBadRequest) - return - } - var password []byte // TODO(bradfitz): support? does anything use WoL passwords? - st := h.ps.b.sys.NetMon.Get().InterfaceState() - if st == nil { - http.Error(w, "failed to get interfaces state", http.StatusInternalServerError) - return - } - var res struct { - SentTo []string - Errors []string - } - for ifName, ips := range st.InterfaceIPs { - for _, ip := range ips { - if ip.Addr().IsLoopback() || ip.Addr().Is6() { - continue - } - local := &net.UDPAddr{ - IP: ip.Addr().AsSlice(), - Port: 0, - } - remote := &net.UDPAddr{ - IP: net.IPv4bcast, - Port: 0, - } - if err := wol.Wake(mac, password, local, remote); err != nil { - res.Errors = append(res.Errors, err.Error()) - } else { - res.SentTo = append(res.SentTo, ifName) - } - break // one per interface is enough - } - } - sort.Strings(res.SentTo) - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(res) -} - func (h *peerAPIHandler) replyToDNSQueries() bool { if h.isSelf { // If the peer is owned by the same user, just allow it @@ -900,7 +769,7 @@ func (h *peerAPIHandler) replyToDNSQueries() bool { // but an app connector explicitly adds 0.0.0.0/32 (and the // IPv6 equivalent) to make this work (see updateFilterLocked // in LocalBackend). - f := b.filterAtomic.Load() + f := b.currentNode().filter() if f == nil { return false } @@ -964,7 +833,11 @@ func (h *peerAPIHandler) handleDNSQuery(w http.ResponseWriter, r *http.Request) // instead to avoid re-parsing the DNS response for improved performance in // the future. if h.ps.b.OfferingAppConnector() { - h.ps.b.ObserveDNSResponse(res) + if err := h.ps.b.ObserveDNSResponse(res); err != nil { + h.logf("ObserveDNSResponse error: %v", err) + // This is not fatal, we probably just failed to parse the upstream + // response. Return it to the caller anyway. + } } if pretty { @@ -1148,7 +1021,7 @@ func (h *peerAPIHandler) handleServeDrive(w http.ResponseWriter, r *http.Request return } - capsMap := h.peerCaps() + capsMap := h.PeerCaps() driveCaps, ok := capsMap[tailcfg.PeerCapabilityTaildrive] if !ok { h.logf("taildrive: not permitted") @@ -1163,7 +1036,7 @@ func (h *peerAPIHandler) handleServeDrive(w http.ResponseWriter, r *http.Request p, err := drive.ParsePermissions(rawPerms) if err != nil { - h.logf("taildrive: error parsing permissions: %w", err.Error()) + h.logf("taildrive: error parsing permissions: %v", err) http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -1222,6 +1095,48 @@ func parseDriveFileExtensionForLog(path string) string { return fileExt } +// peerAPIURL returns an HTTP URL for the peer's peerapi service, +// without a trailing slash. +// +// If ip or port is the zero value then it returns the empty string. +func peerAPIURL(ip netip.Addr, port uint16) string { + if port == 0 || !ip.IsValid() { + return "" + } + return fmt.Sprintf("http://%v", netip.AddrPortFrom(ip, port)) +} + +// peerAPIBase returns the "http://ip:port" URL base to reach peer's peerAPI. +// It returns the empty string if the peer doesn't support the peerapi +// or there's no matching address family based on the netmap's own addresses. +func peerAPIBase(nm *netmap.NetworkMap, peer tailcfg.NodeView) string { + if nm == nil || !peer.Valid() || !peer.Hostinfo().Valid() { + return "" + } + + var have4, have6 bool + addrs := nm.GetAddresses() + for _, a := range addrs.All() { + if !a.IsSingleIP() { + continue + } + switch { + case a.Addr().Is4(): + have4 = true + case a.Addr().Is6(): + have6 = true + } + } + p4, p6 := peerAPIPorts(peer) + switch { + case have4 && p4 != 0: + return peerAPIURL(nodeIP(peer, netip.Addr.Is4), p4) + case have6 && p6 != 0: + return peerAPIURL(nodeIP(peer, netip.Addr.Is6), p6) + } + return "" +} + // newFakePeerAPIListener creates a new net.Listener that acts like // it's listening on the provided IP address and on TCP port 1. // @@ -1272,8 +1187,6 @@ var ( metricInvalidRequests = clientmetric.NewCounter("peerapi_invalid_requests") // Non-debug PeerAPI endpoints. - metricPutCalls = clientmetric.NewCounter("peerapi_put") - metricDNSCalls = clientmetric.NewCounter("peerapi_dns") - metricWakeOnLANCalls = clientmetric.NewCounter("peerapi_wol") - metricIngressCalls = clientmetric.NewCounter("peerapi_ingress") + metricDNSCalls = clientmetric.NewCounter("peerapi_dns") + metricIngressCalls = clientmetric.NewCounter("peerapi_ingress") ) diff --git a/ipn/ipnlocal/peerapi_test.go b/ipn/ipnlocal/peerapi_test.go index ff9b627693a8a..d8655afa08aa8 100644 --- a/ipn/ipnlocal/peerapi_test.go +++ b/ipn/ipnlocal/peerapi_test.go @@ -4,36 +4,27 @@ package ipnlocal import ( - "bytes" "context" "encoding/json" - "fmt" - "io" - "io/fs" - "math/rand" "net/http" "net/http/httptest" "net/netip" - "os" - "path/filepath" "slices" "strings" "testing" - "github.com/google/go-cmp/cmp" "go4.org/netipx" "golang.org/x/net/dns/dnsmessage" "tailscale.com/appc" "tailscale.com/appc/appctest" - "tailscale.com/client/tailscale/apitype" "tailscale.com/health" "tailscale.com/ipn" "tailscale.com/ipn/store/mem" "tailscale.com/tailcfg" - "tailscale.com/taildrop" "tailscale.com/tstest" "tailscale.com/types/logger" "tailscale.com/types/netmap" + "tailscale.com/util/eventbus" "tailscale.com/util/must" "tailscale.com/util/usermetric" "tailscale.com/wgengine" @@ -74,64 +65,18 @@ func bodyNotContains(sub string) check { } } -func fileHasSize(name string, size int) check { - return func(t *testing.T, e *peerAPITestEnv) { - root := e.ph.ps.taildrop.Dir() - if root == "" { - t.Errorf("no rootdir; can't check whether %q has size %v", name, size) - return - } - path := filepath.Join(root, name) - if fi, err := os.Stat(path); err != nil { - t.Errorf("fileHasSize(%q, %v): %v", name, size, err) - } else if fi.Size() != int64(size) { - t.Errorf("file %q has size %v; want %v", name, fi.Size(), size) - } - } -} - -func fileHasContents(name string, want string) check { - return func(t *testing.T, e *peerAPITestEnv) { - root := e.ph.ps.taildrop.Dir() - if root == "" { - t.Errorf("no rootdir; can't check contents of %q", name) - return - } - path := filepath.Join(root, name) - got, err := os.ReadFile(path) - if err != nil { - t.Errorf("fileHasContents: %v", err) - return - } - if string(got) != want { - t.Errorf("file contents = %q; want %q", got, want) - } - } -} - -func hexAll(v string) string { - var sb strings.Builder - for i := range len(v) { - fmt.Fprintf(&sb, "%%%02x", v[i]) - } - return sb.String() -} - func TestHandlePeerAPI(t *testing.T) { tests := []struct { - name string - isSelf bool // the peer sending the request is owned by us - capSharing bool // self node has file sharing capability - debugCap bool // self node has debug capability - omitRoot bool // don't configure - reqs []*http.Request - checks []check + name string + isSelf bool // the peer sending the request is owned by us + debugCap bool // self node has debug capability + reqs []*http.Request + checks []check }{ { - name: "not_peer_api", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("GET", "/", nil)}, + name: "not_peer_api", + isSelf: true, + reqs: []*http.Request{httptest.NewRequest("GET", "/", nil)}, checks: checks( httpStatus(200), bodyContains("This is my Tailscale device."), @@ -139,10 +84,9 @@ func TestHandlePeerAPI(t *testing.T) { ), }, { - name: "not_peer_api_not_owner", - isSelf: false, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("GET", "/", nil)}, + name: "not_peer_api_not_owner", + isSelf: false, + reqs: []*http.Request{httptest.NewRequest("GET", "/", nil)}, checks: checks( httpStatus(200), bodyContains("This is my Tailscale device."), @@ -173,255 +117,6 @@ func TestHandlePeerAPI(t *testing.T) { bodyContains("ServeHTTP"), ), }, - { - name: "reject_non_owner_put", - isSelf: false, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, - checks: checks( - httpStatus(http.StatusForbidden), - bodyContains("Taildrop disabled"), - ), - }, - { - name: "owner_without_cap", - isSelf: true, - capSharing: false, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, - checks: checks( - httpStatus(http.StatusForbidden), - bodyContains("Taildrop disabled"), - ), - }, - { - name: "owner_with_cap_no_rootdir", - omitRoot: true, - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, - checks: checks( - httpStatus(http.StatusForbidden), - bodyContains("Taildrop disabled; no storage directory"), - ), - }, - { - name: "bad_method", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("POST", "/v0/put/foo", nil)}, - checks: checks( - httpStatus(405), - bodyContains("expected method GET or PUT"), - ), - }, - { - name: "put_zero_length", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", nil)}, - checks: checks( - httpStatus(200), - bodyContains("{}"), - fileHasSize("foo", 0), - fileHasContents("foo", ""), - ), - }, - { - name: "put_non_zero_length_content_length", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("contents"))}, - checks: checks( - httpStatus(200), - bodyContains("{}"), - fileHasSize("foo", len("contents")), - fileHasContents("foo", "contents"), - ), - }, - { - name: "put_non_zero_length_chunked", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo", struct{ io.Reader }{strings.NewReader("contents")})}, - checks: checks( - httpStatus(200), - bodyContains("{}"), - fileHasSize("foo", len("contents")), - fileHasContents("foo", "contents"), - ), - }, - { - name: "bad_filename_partial", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo.partial", nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_deleted", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo.deleted", nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_dot", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/.", nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_empty", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/", nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_slash", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/foo/bar", nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_encoded_dot", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("."), nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_encoded_slash", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("/"), nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_encoded_backslash", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("\\"), nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_encoded_dotdot", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll(".."), nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "bad_filename_encoded_dotdot_out", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("foo/../../../../../etc/passwd"), nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "put_spaces_and_caps", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("Foo Bar.dat"), strings.NewReader("baz"))}, - checks: checks( - httpStatus(200), - bodyContains("{}"), - fileHasContents("Foo Bar.dat", "baz"), - ), - }, - { - name: "put_unicode", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("ĐĸĐžĐŧĐ°Ņ и ĐĩĐŗĐž Đ´Ņ€ŅƒĐˇŅŒŅ.mp3"), strings.NewReader("ĐŗĐģавĐŊŅ‹Đš ĐžĐˇĐžŅ€ĐŊиĐē"))}, - checks: checks( - httpStatus(200), - bodyContains("{}"), - fileHasContents("ĐĸĐžĐŧĐ°Ņ и ĐĩĐŗĐž Đ´Ņ€ŅƒĐˇŅŒŅ.mp3", "ĐŗĐģавĐŊŅ‹Đš ĐžĐˇĐžŅ€ĐŊиĐē"), - ), - }, - { - name: "put_invalid_utf8", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+(hexAll("😜")[:3]), nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "put_invalid_null", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/%00", nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "put_invalid_non_printable", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/%01", nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "put_invalid_colon", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll("nul:"), nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, - { - name: "put_invalid_surrounding_whitespace", - isSelf: true, - capSharing: true, - reqs: []*http.Request{httptest.NewRequest("PUT", "/v0/put/"+hexAll(" foo "), nil)}, - checks: checks( - httpStatus(400), - bodyContains("invalid filename"), - ), - }, { name: "host-val/bad-ip", isSelf: true, @@ -449,72 +144,6 @@ func TestHandlePeerAPI(t *testing.T) { httpStatus(200), ), }, - { - name: "duplicate_zero_length", - isSelf: true, - capSharing: true, - reqs: []*http.Request{ - httptest.NewRequest("PUT", "/v0/put/foo", nil), - httptest.NewRequest("PUT", "/v0/put/foo", nil), - }, - checks: checks( - httpStatus(200), - func(t *testing.T, env *peerAPITestEnv) { - got, err := env.ph.ps.taildrop.WaitingFiles() - if err != nil { - t.Fatalf("WaitingFiles error: %v", err) - } - want := []apitype.WaitingFile{{Name: "foo", Size: 0}} - if diff := cmp.Diff(got, want); diff != "" { - t.Fatalf("WaitingFile mismatch (-got +want):\n%s", diff) - } - }, - ), - }, - { - name: "duplicate_non_zero_length_content_length", - isSelf: true, - capSharing: true, - reqs: []*http.Request{ - httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("contents")), - httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("contents")), - }, - checks: checks( - httpStatus(200), - func(t *testing.T, env *peerAPITestEnv) { - got, err := env.ph.ps.taildrop.WaitingFiles() - if err != nil { - t.Fatalf("WaitingFiles error: %v", err) - } - want := []apitype.WaitingFile{{Name: "foo", Size: 8}} - if diff := cmp.Diff(got, want); diff != "" { - t.Fatalf("WaitingFile mismatch (-got +want):\n%s", diff) - } - }, - ), - }, - { - name: "duplicate_different_files", - isSelf: true, - capSharing: true, - reqs: []*http.Request{ - httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("fizz")), - httptest.NewRequest("PUT", "/v0/put/foo", strings.NewReader("buzz")), - }, - checks: checks( - httpStatus(200), - func(t *testing.T, env *peerAPITestEnv) { - got, err := env.ph.ps.taildrop.WaitingFiles() - if err != nil { - t.Fatalf("WaitingFiles error: %v", err) - } - want := []apitype.WaitingFile{{Name: "foo", Size: 4}, {Name: "foo (1)", Size: 4}} - if diff := cmp.Diff(got, want); diff != "" { - t.Fatalf("WaitingFile mismatch (-got +want):\n%s", diff) - } - }, - ), - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -528,11 +157,10 @@ func TestHandlePeerAPI(t *testing.T) { } var e peerAPITestEnv lb := &LocalBackend{ - logf: e.logBuf.Logf, - capFileSharing: tt.capSharing, - netMap: &netmap.NetworkMap{SelfNode: selfNode.View()}, - clock: &tstest.Clock{}, + logf: e.logBuf.Logf, + clock: &tstest.Clock{}, } + lb.currentNode().SetNetMap(&netmap.NetworkMap{SelfNode: selfNode.View()}) e.ph = &peerAPIHandler{ isSelf: tt.isSelf, selfNode: selfNode.View(), @@ -543,16 +171,6 @@ func TestHandlePeerAPI(t *testing.T) { b: lb, }, } - var rootDir string - if !tt.omitRoot { - rootDir = t.TempDir() - if e.ph.ps.taildrop == nil { - e.ph.ps.taildrop = taildrop.ManagerOptions{ - Logf: e.logBuf.Logf, - Dir: rootDir, - }.New() - } - } for _, req := range tt.reqs { e.rr = httptest.NewRecorder() if req.Host == "example.com" { @@ -563,76 +181,10 @@ func TestHandlePeerAPI(t *testing.T) { for _, f := range tt.checks { f(t, &e) } - if t.Failed() && rootDir != "" { - t.Logf("Contents of %s:", rootDir) - des, _ := fs.ReadDir(os.DirFS(rootDir), ".") - for _, de := range des { - fi, err := de.Info() - if err != nil { - t.Log(err) - } else { - t.Logf(" %v %5d %s", fi.Mode(), fi.Size(), de.Name()) - } - } - } }) } } -// Windows likes to hold on to file descriptors for some indeterminate -// amount of time after you close them and not let you delete them for -// a bit. So test that we work around that sufficiently. -func TestFileDeleteRace(t *testing.T) { - dir := t.TempDir() - ps := &peerAPIServer{ - b: &LocalBackend{ - logf: t.Logf, - capFileSharing: true, - clock: &tstest.Clock{}, - }, - taildrop: taildrop.ManagerOptions{ - Logf: t.Logf, - Dir: dir, - }.New(), - } - ph := &peerAPIHandler{ - isSelf: true, - peerNode: (&tailcfg.Node{ - ComputedName: "some-peer-name", - }).View(), - selfNode: (&tailcfg.Node{ - Addresses: []netip.Prefix{netip.MustParsePrefix("100.100.100.101/32")}, - }).View(), - ps: ps, - } - buf := make([]byte, 2<<20) - for range 30 { - rr := httptest.NewRecorder() - ph.ServeHTTP(rr, httptest.NewRequest("PUT", "http://100.100.100.101:123/v0/put/foo.txt", bytes.NewReader(buf[:rand.Intn(len(buf))]))) - if res := rr.Result(); res.StatusCode != 200 { - t.Fatal(res.Status) - } - wfs, err := ps.taildrop.WaitingFiles() - if err != nil { - t.Fatal(err) - } - if len(wfs) != 1 { - t.Fatalf("waiting files = %d; want 1", len(wfs)) - } - - if err := ps.taildrop.DeleteFile("foo.txt"); err != nil { - t.Fatal(err) - } - wfs, err = ps.taildrop.WaitingFiles() - if err != nil { - t.Fatal(err) - } - if len(wfs) != 0 { - t.Fatalf("waiting files = %d; want 0", len(wfs)) - } - } -} - func TestPeerAPIReplyToDNSQueries(t *testing.T) { var h peerAPIHandler @@ -643,9 +195,12 @@ func TestPeerAPIReplyToDNSQueries(t *testing.T) { h.isSelf = false h.remoteAddr = netip.MustParseAddrPort("100.150.151.152:12345") + bus := eventbus.New() + defer bus.Close() + ht := new(health.Tracker) reg := new(usermetric.Registry) - eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0, ht, reg) + eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0, ht, reg, bus) pm := must.Get(newProfileManager(new(mem.Store), t.Logf, ht)) h.ps = &peerAPIServer{ b: &LocalBackend{ @@ -695,9 +250,12 @@ func TestPeerAPIPrettyReplyCNAME(t *testing.T) { var h peerAPIHandler h.remoteAddr = netip.MustParseAddrPort("100.150.151.152:12345") + bus := eventbus.New() + defer bus.Close() + ht := new(health.Tracker) reg := new(usermetric.Registry) - eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0, ht, reg) + eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0, ht, reg, bus) pm := must.Get(newProfileManager(new(mem.Store), t.Logf, ht)) var a *appc.AppConnector if shouldStore { @@ -768,10 +326,12 @@ func TestPeerAPIReplyToDNSQueriesAreObserved(t *testing.T) { var h peerAPIHandler h.remoteAddr = netip.MustParseAddrPort("100.150.151.152:12345") + bus := eventbus.New() + defer bus.Close() rc := &appctest.RouteCollector{} ht := new(health.Tracker) reg := new(usermetric.Registry) - eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0, ht, reg) + eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0, ht, reg, bus) pm := must.Get(newProfileManager(new(mem.Store), t.Logf, ht)) var a *appc.AppConnector if shouldStore { @@ -833,10 +393,12 @@ func TestPeerAPIReplyToDNSQueriesAreObservedWithCNAMEFlattening(t *testing.T) { var h peerAPIHandler h.remoteAddr = netip.MustParseAddrPort("100.150.151.152:12345") + bus := eventbus.New() + defer bus.Close() ht := new(health.Tracker) reg := new(usermetric.Registry) rc := &appctest.RouteCollector{} - eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0, ht, reg) + eng, _ := wgengine.NewFakeUserspaceEngine(logger.Discard, 0, ht, reg, bus) pm := must.Get(newProfileManager(new(mem.Store), t.Logf, ht)) var a *appc.AppConnector if shouldStore { diff --git a/ipn/ipnlocal/prefs_metrics.go b/ipn/ipnlocal/prefs_metrics.go new file mode 100644 index 0000000000000..fa768ba3ce238 --- /dev/null +++ b/ipn/ipnlocal/prefs_metrics.go @@ -0,0 +1,99 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "errors" + + "tailscale.com/ipn" + "tailscale.com/tailcfg" + "tailscale.com/util/clientmetric" +) + +// Counter metrics for edit/change events +var ( + // metricExitNodeEnabled is incremented when the user enables an exit node independent of the node's characteristics. + metricExitNodeEnabled = clientmetric.NewCounter("prefs_exit_node_enabled") + // metricExitNodeEnabledSuggested is incremented when the user enables the suggested exit node. + metricExitNodeEnabledSuggested = clientmetric.NewCounter("prefs_exit_node_enabled_suggested") + // metricExitNodeEnabledMullvad is incremented when the user enables a Mullvad exit node. + metricExitNodeEnabledMullvad = clientmetric.NewCounter("prefs_exit_node_enabled_mullvad") + // metricWantRunningEnabled is incremented when WantRunning transitions from false to true. + metricWantRunningEnabled = clientmetric.NewCounter("prefs_want_running_enabled") + // metricWantRunningDisabled is incremented when WantRunning transitions from true to false. + metricWantRunningDisabled = clientmetric.NewCounter("prefs_want_running_disabled") +) + +type exitNodeProperty string + +const ( + exitNodeTypePreferred exitNodeProperty = "suggested" // The exit node is the last suggested exit node + exitNodeTypeMullvad exitNodeProperty = "mullvad" // The exit node is a Mullvad exit node +) + +// prefsMetricsEditEvent encapsulates information needed to record metrics related +// to any changes to preferences. +type prefsMetricsEditEvent struct { + change *ipn.MaskedPrefs // the preference mask used to update the preferences + pNew ipn.PrefsView // new preferences (after ApplyUpdates) + pOld ipn.PrefsView // old preferences (before ApplyUpdates) + node *nodeBackend // the node the event is associated with + lastSuggestedExitNode tailcfg.StableNodeID // the last suggested exit node +} + +// record records changes to preferences as clientmetrics. +func (e *prefsMetricsEditEvent) record() error { + if e.change == nil || e.node == nil { + return errors.New("prefsMetricsEditEvent: missing required fields") + } + + // Record up/down events. + if e.change.WantRunningSet && (e.pNew.WantRunning() != e.pOld.WantRunning()) { + if e.pNew.WantRunning() { + metricWantRunningEnabled.Add(1) + } else { + metricWantRunningDisabled.Add(1) + } + } + + // Record any changes to exit node settings. + if e.change.ExitNodeIDSet || e.change.ExitNodeIPSet { + if exitNodeTypes, ok := e.exitNodeType(e.pNew.ExitNodeID()); ok { + // We have switched to a valid exit node if ok is true. + metricExitNodeEnabled.Add(1) + + // We may have some additional characteristics we should also record. + for _, t := range exitNodeTypes { + switch t { + case exitNodeTypePreferred: + metricExitNodeEnabledSuggested.Add(1) + case exitNodeTypeMullvad: + metricExitNodeEnabledMullvad.Add(1) + } + } + } + } + return nil +} + +// exitNodeTypesLocked returns type of exit node for the given stable ID. +// An exit node may have multiple type (can be both mullvad and preferred +// simultaneously for example). +// +// This will return ok as true if the supplied stable ID resolves to a known peer, +// false otherwise. The caller is responsible for ensuring that the id belongs to +// an exit node. +func (e *prefsMetricsEditEvent) exitNodeType(id tailcfg.StableNodeID) (props []exitNodeProperty, isNode bool) { + var peer tailcfg.NodeView + + if peer, isNode = e.node.PeerByStableID(id); isNode { + if tailcfg.StableNodeID(id) == e.lastSuggestedExitNode { + props = append(props, exitNodeTypePreferred) + } + if peer.IsWireGuardOnly() { + props = append(props, exitNodeTypeMullvad) + } + } + return props, isNode +} diff --git a/ipn/ipnlocal/profiles.go b/ipn/ipnlocal/profiles.go index b13f921d66095..1d312cfa606b3 100644 --- a/ipn/ipnlocal/profiles.go +++ b/ipn/ipnlocal/profiles.go @@ -17,6 +17,7 @@ import ( "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/ipn" + "tailscale.com/ipn/ipnext" "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/util/clientmetric" @@ -24,6 +25,9 @@ import ( var debug = envknob.RegisterBool("TS_DEBUG_PROFILES") +// [profileManager] implements [ipnext.ProfileStore]. +var _ ipnext.ProfileStore = (*profileManager)(nil) + // profileManager is a wrapper around an [ipn.StateStore] that manages // multiple profiles and the current profile. // @@ -35,9 +39,31 @@ type profileManager struct { health *health.Tracker currentUserID ipn.WindowsUserID - knownProfiles map[ipn.ProfileID]*ipn.LoginProfile // always non-nil - currentProfile *ipn.LoginProfile // always non-nil - prefs ipn.PrefsView // always Valid. + knownProfiles map[ipn.ProfileID]ipn.LoginProfileView // always non-nil + currentProfile ipn.LoginProfileView // always Valid (once [newProfileManager] returns). + prefs ipn.PrefsView // always Valid (once [newProfileManager] returns). + + // StateChangeHook is an optional hook that is called when the current profile or prefs change, + // such as due to a profile switch or a change in the profile's preferences. + // It is typically set by the [LocalBackend] to invert the dependency between + // the [profileManager] and the [LocalBackend], so that instead of [LocalBackend] + // asking [profileManager] for the state, we can have [profileManager] call + // [LocalBackend] when the state changes. See also: + // https://github.com/tailscale/tailscale/pull/15791#discussion_r2060838160 + StateChangeHook ipnext.ProfileStateChangeCallback + + // extHost is the bridge between [profileManager] and the registered [ipnext.Extension]s. + // It may be nil in tests. A nil pointer is a valid, no-op host. + extHost *ExtensionHost +} + +// SetExtensionHost sets the [ExtensionHost] for the [profileManager]. +// The specified host will be notified about profile and prefs changes +// and will immediately be notified about the current profile and prefs. +// A nil host is a valid, no-op host. +func (pm *profileManager) SetExtensionHost(host *ExtensionHost) { + pm.extHost = host + host.NotifyProfileChange(pm.currentProfile, pm.prefs, false) } func (pm *profileManager) dlogf(format string, args ...any) { @@ -64,8 +90,7 @@ func (pm *profileManager) SetCurrentUserID(uid ipn.WindowsUserID) { if pm.currentUserID == uid { return } - pm.currentUserID = uid - if err := pm.SwitchToDefaultProfile(); err != nil { + if _, _, err := pm.SwitchToDefaultProfileForUser(uid); err != nil { // SetCurrentUserID should never fail and must always switch to the // user's default profile or create a new profile for the current user. // Until we implement multi-user support and the new permission model, @@ -73,65 +98,158 @@ func (pm *profileManager) SetCurrentUserID(uid ipn.WindowsUserID) { // that when SetCurrentUserID exits, the profile in pm.currentProfile // is either an existing profile owned by the user, or a new, empty profile. pm.logf("%q's default profile cannot be used; creating a new one: %v", uid, err) - pm.NewProfileForUser(uid) + pm.SwitchToNewProfileForUser(uid) + } +} + +// SwitchToProfile switches to the specified profile and (temporarily, +// while the "current user" is still a thing on Windows; see tailscale/corp#18342) +// sets its owner as the current user. The profile must be a valid profile +// returned by the [profileManager], such as by [profileManager.Profiles], +// [profileManager.ProfileByID], or [profileManager.NewProfileForUser]. +// +// It is a shorthand for [profileManager.SetCurrentUserID] followed by +// [profileManager.SwitchProfileByID], but it is more efficient as it switches +// directly to the specified profile rather than switching to the user's +// default profile first. It is a no-op if the specified profile is already +// the current profile. +// +// As a special case, if the specified profile view is not valid, it resets +// both the current user and the profile to a new, empty profile not owned +// by any user. +// +// It returns the current profile and whether the call resulted in a profile change, +// or an error if the specified profile does not exist or its prefs could not be loaded. +// +// It may be called during [profileManager] initialization before [newProfileManager] returns +// and must check whether pm.currentProfile is Valid before using it. +func (pm *profileManager) SwitchToProfile(profile ipn.LoginProfileView) (cp ipn.LoginProfileView, changed bool, err error) { + prefs := defaultPrefs + switch { + case !profile.Valid(): + // Create a new profile that is not associated with any user. + profile = pm.NewProfileForUser("") + case profile == pm.currentProfile, + profile.ID() != "" && pm.currentProfile.Valid() && profile.ID() == pm.currentProfile.ID(), + profile.ID() == "" && profile.Equals(pm.currentProfile) && prefs.Equals(pm.prefs): + // The profile is already the current profile; no need to switch. + // + // It includes three cases: + // 1. The target profile and the current profile are aliases referencing the [ipn.LoginProfile]. + // The profile may be either a new (non-persisted) profile or an existing well-known profile. + // 2. The target profile is a well-known, persisted profile with the same ID as the current profile. + // 3. The target and the current profiles are both new (non-persisted) profiles and they are equal. + // At minimum, equality means that the profiles are owned by the same user on platforms that support it + // and the prefs are the same as well. + return pm.currentProfile, false, nil + case profile.ID() == "": + // Copy the specified profile to prevent accidental mutation. + profile = profile.AsStruct().View() + default: + // Find an existing profile by ID and load its prefs. + kp, ok := pm.knownProfiles[profile.ID()] + if !ok { + // The profile ID is not valid; it may have been deleted or never existed. + // As the target profile should have been returned by the [profileManager], + // this is unexpected and might indicate a bug in the code. + return pm.currentProfile, false, fmt.Errorf("[unexpected] %w: %s (%s)", errProfileNotFound, profile.Name(), profile.ID()) + } + profile = kp + if prefs, err = pm.loadSavedPrefs(profile.Key()); err != nil { + return pm.currentProfile, false, fmt.Errorf("failed to load profile prefs for %s (%s): %w", profile.Name(), profile.ID(), err) + } } + + if profile.ID() == "" { // new profile that has never been persisted + metricNewProfile.Add(1) + } else { + metricSwitchProfile.Add(1) + } + + pm.prefs = prefs + pm.updateHealth() + pm.currentProfile = profile + pm.currentUserID = profile.LocalUserID() + if err := pm.setProfileAsUserDefault(profile); err != nil { + // This is not a fatal error; we've already switched to the profile. + // But if updating the default profile fails, we should log it. + pm.logf("failed to set %s (%s) as the default profile: %v", profile.Name(), profile.ID(), err) + } + + if f := pm.StateChangeHook; f != nil { + f(pm.currentProfile, pm.prefs, false) + } + // Do not call pm.extHost.NotifyProfileChange here; it is invoked in + // [LocalBackend.resetForProfileChangeLockedOnEntry] after the netmap reset. + // TODO(nickkhyl): Consider moving it here (or into the stateChangeCb handler + // in [LocalBackend]) once the profile/node state, including the netmap, + // is actually tied to the current profile. + + return profile, true, nil } -// DefaultUserProfileID returns [ipn.ProfileID] of the default (last used) profile for the specified user, -// or an empty string if the specified user does not have a default profile. -func (pm *profileManager) DefaultUserProfileID(uid ipn.WindowsUserID) ipn.ProfileID { +// DefaultUserProfile returns a read-only view of the default (last used) profile for the specified user. +// It returns a read-only view of a new, non-persisted profile if the specified user does not have a default profile. +func (pm *profileManager) DefaultUserProfile(uid ipn.WindowsUserID) ipn.LoginProfileView { // Read the CurrentProfileKey from the store which stores // the selected profile for the specified user. b, err := pm.store.ReadState(ipn.CurrentProfileKey(string(uid))) - pm.dlogf("DefaultUserProfileID: ReadState(%q) = %v, %v", string(uid), len(b), err) + pm.dlogf("DefaultUserProfile: ReadState(%q) = %v, %v", string(uid), len(b), err) if err == ipn.ErrStateNotExist || len(b) == 0 { if runtime.GOOS == "windows" { - pm.dlogf("DefaultUserProfileID: windows: migrating from legacy preferences") - profile, err := pm.migrateFromLegacyPrefs(uid, false) + pm.dlogf("DefaultUserProfile: windows: migrating from legacy preferences") + profile, err := pm.migrateFromLegacyPrefs(uid) if err == nil { - return profile.ID + return profile } pm.logf("failed to migrate from legacy preferences: %v", err) } - return "" + return pm.NewProfileForUser(uid) } pk := ipn.StateKey(string(b)) - prof := pm.findProfileByKey(pk) - if prof == nil { - pm.dlogf("DefaultUserProfileID: no profile found for key: %q", pk) - return "" + prof := pm.findProfileByKey(uid, pk) + if !prof.Valid() { + pm.dlogf("DefaultUserProfile: no profile found for key: %q", pk) + return pm.NewProfileForUser(uid) } - return prof.ID + return prof } // checkProfileAccess returns an [errProfileAccessDenied] if the current user // does not have access to the specified profile. -func (pm *profileManager) checkProfileAccess(profile *ipn.LoginProfile) error { - if pm.currentUserID != "" && profile.LocalUserID != pm.currentUserID { +func (pm *profileManager) checkProfileAccess(profile ipn.LoginProfileView) error { + return pm.checkProfileAccessAs(pm.currentUserID, profile) +} + +// checkProfileAccessAs returns an [errProfileAccessDenied] if the specified user +// does not have access to the specified profile. +func (pm *profileManager) checkProfileAccessAs(uid ipn.WindowsUserID, profile ipn.LoginProfileView) error { + if uid != "" && profile.LocalUserID() != uid { return errProfileAccessDenied } return nil } -// allProfiles returns all profiles accessible to the current user. +// allProfilesFor returns all profiles accessible to the specified user. // The returned profiles are sorted by Name. -func (pm *profileManager) allProfiles() (out []*ipn.LoginProfile) { +func (pm *profileManager) allProfilesFor(uid ipn.WindowsUserID) []ipn.LoginProfileView { + out := make([]ipn.LoginProfileView, 0, len(pm.knownProfiles)) for _, p := range pm.knownProfiles { - if pm.checkProfileAccess(p) == nil { + if pm.checkProfileAccessAs(uid, p) == nil { out = append(out, p) } } - slices.SortFunc(out, func(a, b *ipn.LoginProfile) int { - return cmp.Compare(a.Name, b.Name) + slices.SortFunc(out, func(a, b ipn.LoginProfileView) int { + return cmp.Compare(a.Name(), b.Name()) }) return out } -// matchingProfiles is like [profileManager.allProfiles], but returns only profiles +// matchingProfiles is like [profileManager.allProfilesFor], but returns only profiles // matching the given predicate. -func (pm *profileManager) matchingProfiles(f func(*ipn.LoginProfile) bool) (out []*ipn.LoginProfile) { - all := pm.allProfiles() +func (pm *profileManager) matchingProfiles(uid ipn.WindowsUserID, f func(ipn.LoginProfileView) bool) (out []ipn.LoginProfileView) { + all := pm.allProfilesFor(uid) out = all[:0] for _, p := range all { if f(p) { @@ -144,11 +262,11 @@ func (pm *profileManager) matchingProfiles(f func(*ipn.LoginProfile) bool) (out // findMatchingProfiles returns all profiles accessible to the current user // that represent the same node/user as prefs. // The returned profiles are sorted by Name. -func (pm *profileManager) findMatchingProfiles(prefs ipn.PrefsView) []*ipn.LoginProfile { - return pm.matchingProfiles(func(p *ipn.LoginProfile) bool { - return p.ControlURL == prefs.ControlURL() && - (p.UserProfile.ID == prefs.Persist().UserProfile().ID || - p.NodeID == prefs.Persist().NodeID()) +func (pm *profileManager) findMatchingProfiles(uid ipn.WindowsUserID, prefs ipn.PrefsView) []ipn.LoginProfileView { + return pm.matchingProfiles(uid, func(p ipn.LoginProfileView) bool { + return p.ControlURL() == prefs.ControlURL() && + (p.UserProfile().ID == prefs.Persist().UserProfile().ID || + p.NodeID() == prefs.Persist().NodeID()) }) } @@ -156,19 +274,19 @@ func (pm *profileManager) findMatchingProfiles(prefs ipn.PrefsView) []*ipn.Login // given name. It returns "" if no such profile exists among profiles // accessible to the current user. func (pm *profileManager) ProfileIDForName(name string) ipn.ProfileID { - p := pm.findProfileByName(name) - if p == nil { + p := pm.findProfileByName(pm.currentUserID, name) + if !p.Valid() { return "" } - return p.ID + return p.ID() } -func (pm *profileManager) findProfileByName(name string) *ipn.LoginProfile { - out := pm.matchingProfiles(func(p *ipn.LoginProfile) bool { - return p.Name == name +func (pm *profileManager) findProfileByName(uid ipn.WindowsUserID, name string) ipn.LoginProfileView { + out := pm.matchingProfiles(uid, func(p ipn.LoginProfileView) bool { + return p.Name() == name && pm.checkProfileAccessAs(uid, p) == nil }) if len(out) == 0 { - return nil + return ipn.LoginProfileView{} } if len(out) > 1 { pm.logf("[unexpected] multiple profiles with the same name") @@ -176,12 +294,12 @@ func (pm *profileManager) findProfileByName(name string) *ipn.LoginProfile { return out[0] } -func (pm *profileManager) findProfileByKey(key ipn.StateKey) *ipn.LoginProfile { - out := pm.matchingProfiles(func(p *ipn.LoginProfile) bool { - return p.Key == key +func (pm *profileManager) findProfileByKey(uid ipn.WindowsUserID, key ipn.StateKey) ipn.LoginProfileView { + out := pm.matchingProfiles(uid, func(p ipn.LoginProfileView) bool { + return p.Key() == key && pm.checkProfileAccessAs(uid, p) == nil }) if len(out) == 0 { - return nil + return ipn.LoginProfileView{} } if len(out) > 1 { pm.logf("[unexpected] multiple profiles with the same key") @@ -194,19 +312,13 @@ func (pm *profileManager) setUnattendedModeAsConfigured() error { return nil } - if pm.currentProfile.Key != "" && pm.prefs.ForceDaemon() { - return pm.WriteState(ipn.ServerModeStartKey, []byte(pm.currentProfile.Key)) + if pm.currentProfile.Key() != "" && pm.prefs.ForceDaemon() { + return pm.WriteState(ipn.ServerModeStartKey, []byte(pm.currentProfile.Key())) } else { return pm.WriteState(ipn.ServerModeStartKey, nil) } } -// Reset unloads the current profile, if any. -func (pm *profileManager) Reset() { - pm.currentUserID = "" - pm.NewProfile() -} - // SetPrefs sets the current profile's prefs to the provided value. // It also saves the prefs to the [ipn.StateStore]. It stores a copy of the // provided prefs, which may be accessed via [profileManager.CurrentPrefs]. @@ -222,36 +334,67 @@ func (pm *profileManager) SetPrefs(prefsIn ipn.PrefsView, np ipn.NetworkProfile) } // Check if we already have an existing profile that matches the user/node. - if existing := pm.findMatchingProfiles(prefsIn); len(existing) > 0 { + if existing := pm.findMatchingProfiles(pm.currentUserID, prefsIn); len(existing) > 0 { // We already have a profile for this user/node we should reuse it. Also // cleanup any other duplicate profiles. cp = existing[0] existing = existing[1:] for _, p := range existing { // Clear the state. - if err := pm.store.WriteState(p.Key, nil); err != nil { + if err := pm.store.WriteState(p.Key(), nil); err != nil { // We couldn't delete the state, so keep the profile around. continue } // Remove the profile, knownProfiles will be persisted // in [profileManager.setProfilePrefs] below. - delete(pm.knownProfiles, p.ID) + delete(pm.knownProfiles, p.ID()) } } - pm.currentProfile = cp - if err := pm.SetProfilePrefs(cp, prefsIn, np); err != nil { + // TODO(nickkhyl): Revisit how we handle implicit switching to a different profile, + // which occurs when prefsIn represents a node/user different from that of the + // currentProfile. It happens when a login (either reauth or user-initiated login) + // is completed with a different node/user identity than the one currently in use. + // + // Currently, we overwrite the existing profile prefs with the ones from prefsIn, + // where prefsIn is the previous profile's prefs with an updated Persist, LoggedOut, + // WantRunning and possibly other fields. This may not be the desired behavior. + // + // Additionally, LocalBackend doesn't treat it as a proper profile switch, meaning that + // [LocalBackend.resetForProfileChangeLockedOnEntry] is not called and certain + // node/profile-specific state may not be reset as expected. + // + // However, [profileManager] notifies [ipnext.Extension]s about the profile change, + // so features migrated from LocalBackend to external packages should not be affected. + // + // See tailscale/corp#28014. + if !cp.Equals(pm.currentProfile) { + const sameNode = false // implicit profile switch + pm.currentProfile = cp + pm.prefs = prefsIn.AsStruct().View() + if f := pm.StateChangeHook; f != nil { + f(cp, prefsIn, sameNode) + } + pm.extHost.NotifyProfileChange(cp, prefsIn, sameNode) + } + cp, err := pm.setProfilePrefs(nil, prefsIn, np) + if err != nil { return err } return pm.setProfileAsUserDefault(cp) - } -// SetProfilePrefs is like [profileManager.SetPrefs], but sets prefs for the specified [ipn.LoginProfile] -// which is not necessarily the [profileManager.CurrentProfile]. It returns an [errProfileAccessDenied] -// if the specified profile is not accessible by the current user. -func (pm *profileManager) SetProfilePrefs(lp *ipn.LoginProfile, prefsIn ipn.PrefsView, np ipn.NetworkProfile) error { - if err := pm.checkProfileAccess(lp); err != nil { - return err +// setProfilePrefs is like [profileManager.SetPrefs], but sets prefs for the specified [ipn.LoginProfile], +// returning a read-only view of the updated profile on success. If the specified profile is nil, +// it defaults to the current profile. If the profile is not accessible by the current user, +// the method returns an [errProfileAccessDenied]. +func (pm *profileManager) setProfilePrefs(lp *ipn.LoginProfile, prefsIn ipn.PrefsView, np ipn.NetworkProfile) (ipn.LoginProfileView, error) { + isCurrentProfile := lp == nil || (lp.ID != "" && lp.ID == pm.currentProfile.ID()) + if isCurrentProfile { + lp = pm.CurrentProfile().AsStruct() + } + + if err := pm.checkProfileAccess(lp.View()); err != nil { + return ipn.LoginProfileView{}, err } // An empty profile.ID indicates that the profile is new, the node info wasn't available, @@ -291,23 +434,42 @@ func (pm *profileManager) SetProfilePrefs(lp *ipn.LoginProfile, prefsIn ipn.Pref lp.UserProfile = up lp.NetworkProfile = np + // Update the current profile view to reflect the changes + // if the specified profile is the current profile. + if isCurrentProfile { + // Always set pm.currentProfile to the new profile view for pointer equality. + // We check it further down the call stack. + lp := lp.View() + sameProfileInfo := lp.Equals(pm.currentProfile) + pm.currentProfile = lp + if !sameProfileInfo { + // But only invoke the callbacks if the profile info has actually changed. + const sameNode = true // just an info update; still the same node + pm.prefs = prefsIn.AsStruct().View() // suppress further callbacks for this change + if f := pm.StateChangeHook; f != nil { + f(lp, prefsIn, sameNode) + } + pm.extHost.NotifyProfileChange(lp, prefsIn, sameNode) + } + } + // An empty profile.ID indicates that the node info is not available yet, // and the profile doesn't need to be saved on disk. if lp.ID != "" { - pm.knownProfiles[lp.ID] = lp + pm.knownProfiles[lp.ID] = lp.View() if err := pm.writeKnownProfiles(); err != nil { - return err + return ipn.LoginProfileView{}, err } // Clone prefsIn and create a read-only view as a safety measure to // prevent accidental preference mutations, both externally and internally. - if err := pm.setProfilePrefsNoPermCheck(lp, prefsIn.AsStruct().View()); err != nil { - return err + if err := pm.setProfilePrefsNoPermCheck(lp.View(), prefsIn.AsStruct().View()); err != nil { + return ipn.LoginProfileView{}, err } } - return nil + return lp.View(), nil } -func newUnusedID(knownProfiles map[ipn.ProfileID]*ipn.LoginProfile) (ipn.ProfileID, ipn.StateKey) { +func newUnusedID(knownProfiles map[ipn.ProfileID]ipn.LoginProfileView) (ipn.ProfileID, ipn.StateKey) { var idb [2]byte for { rand.Read(idb[:]) @@ -326,14 +488,40 @@ func newUnusedID(knownProfiles map[ipn.ProfileID]*ipn.LoginProfile) (ipn.Profile // The method does not perform any additional checks on the specified // profile, such as verifying the caller's access rights or checking // if another profile for the same node already exists. -func (pm *profileManager) setProfilePrefsNoPermCheck(profile *ipn.LoginProfile, clonedPrefs ipn.PrefsView) error { +func (pm *profileManager) setProfilePrefsNoPermCheck(profile ipn.LoginProfileView, clonedPrefs ipn.PrefsView) error { isCurrentProfile := pm.currentProfile == profile if isCurrentProfile { + oldPrefs := pm.prefs pm.prefs = clonedPrefs + + // Sadly, profile prefs can be changed in multiple ways. + // It's pretty chaotic, and in many cases callers use + // unexported methods of the profile manager instead of + // going through [LocalBackend.setPrefsLockedOnEntry] + // or at least using [profileManager.SetPrefs]. + // + // While we should definitely clean this up to improve + // the overall structure of how prefs are set, which would + // also address current and future conflicts, such as + // competing features changing the same prefs, this method + // is currently the central place where we can detect all + // changes to the current profile's prefs. + // + // That said, regardless of the cleanup, we might want + // to keep the profileManager responsible for invoking + // profile- and prefs-related callbacks. + + if !clonedPrefs.Equals(oldPrefs) { + if f := pm.StateChangeHook; f != nil { + f(pm.currentProfile, clonedPrefs, true) + } + pm.extHost.NotifyProfilePrefsChanged(pm.currentProfile, oldPrefs, clonedPrefs) + } + pm.updateHealth() } - if profile.Key != "" { - if err := pm.writePrefsToStore(profile.Key, clonedPrefs); err != nil { + if profile.Key() != "" { + if err := pm.writePrefsToStore(profile.Key(), clonedPrefs); err != nil { return err } } else if !isCurrentProfile { @@ -362,38 +550,33 @@ func (pm *profileManager) writePrefsToStore(key ipn.StateKey, prefs ipn.PrefsVie } // Profiles returns the list of known profiles accessible to the current user. -func (pm *profileManager) Profiles() []ipn.LoginProfile { - allProfiles := pm.allProfiles() - out := make([]ipn.LoginProfile, len(allProfiles)) - for i, p := range allProfiles { - out[i] = *p - } - return out +func (pm *profileManager) Profiles() []ipn.LoginProfileView { + return pm.allProfilesFor(pm.currentUserID) } // ProfileByID returns a profile with the given id, if it is accessible to the current user. // If the profile exists but is not accessible to the current user, it returns an [errProfileAccessDenied]. // If the profile does not exist, it returns an [errProfileNotFound]. -func (pm *profileManager) ProfileByID(id ipn.ProfileID) (ipn.LoginProfile, error) { +func (pm *profileManager) ProfileByID(id ipn.ProfileID) (ipn.LoginProfileView, error) { kp, err := pm.profileByIDNoPermCheck(id) if err != nil { - return ipn.LoginProfile{}, err + return ipn.LoginProfileView{}, err } if err := pm.checkProfileAccess(kp); err != nil { - return ipn.LoginProfile{}, err + return ipn.LoginProfileView{}, err } - return *kp, nil + return kp, nil } // profileByIDNoPermCheck is like [profileManager.ProfileByID], but it doesn't // check user's access rights to the profile. -func (pm *profileManager) profileByIDNoPermCheck(id ipn.ProfileID) (*ipn.LoginProfile, error) { - if id == pm.currentProfile.ID { +func (pm *profileManager) profileByIDNoPermCheck(id ipn.ProfileID) (ipn.LoginProfileView, error) { + if id == pm.currentProfile.ID() { return pm.currentProfile, nil } kp, ok := pm.knownProfiles[id] if !ok { - return nil, errProfileNotFound + return ipn.LoginProfileView{}, errProfileNotFound } return kp, nil } @@ -412,55 +595,45 @@ func (pm *profileManager) ProfilePrefs(id ipn.ProfileID) (ipn.PrefsView, error) return pm.profilePrefs(kp) } -func (pm *profileManager) profilePrefs(p *ipn.LoginProfile) (ipn.PrefsView, error) { - if p.ID == pm.currentProfile.ID { +func (pm *profileManager) profilePrefs(p ipn.LoginProfileView) (ipn.PrefsView, error) { + if p.ID() == pm.currentProfile.ID() { return pm.prefs, nil } - return pm.loadSavedPrefs(p.Key) + return pm.loadSavedPrefs(p.Key()) } -// SwitchProfile switches to the profile with the given id. +// SwitchToProfileByID switches to the profile with the given id. +// It returns the current profile and whether the call resulted in a profile change. // If the profile exists but is not accessible to the current user, it returns an [errProfileAccessDenied]. // If the profile does not exist, it returns an [errProfileNotFound]. -func (pm *profileManager) SwitchProfile(id ipn.ProfileID) error { - metricSwitchProfile.Add(1) - - kp, ok := pm.knownProfiles[id] - if !ok { - return errProfileNotFound - } - if pm.currentProfile != nil && kp.ID == pm.currentProfile.ID && pm.prefs.Valid() { - return nil +func (pm *profileManager) SwitchToProfileByID(id ipn.ProfileID) (_ ipn.LoginProfileView, changed bool, err error) { + if id == pm.currentProfile.ID() { + return pm.currentProfile, false, nil } - - if err := pm.checkProfileAccess(kp); err != nil { - return fmt.Errorf("%w: profile %q is not accessible to the current user", err, id) - } - prefs, err := pm.loadSavedPrefs(kp.Key) + profile, err := pm.ProfileByID(id) if err != nil { - return err + return pm.currentProfile, false, err } - pm.prefs = prefs - pm.updateHealth() - pm.currentProfile = kp - return pm.setProfileAsUserDefault(kp) + return pm.SwitchToProfile(profile) } -// SwitchToDefaultProfile switches to the default (last used) profile for the current user. -// It creates a new one and switches to it if the current user does not have a default profile, +// SwitchToDefaultProfileForUser switches to the default (last used) profile for the specified user. +// It creates a new one and switches to it if the specified user does not have a default profile, // or returns an error if the default profile is inaccessible or could not be loaded. -func (pm *profileManager) SwitchToDefaultProfile() error { - if id := pm.DefaultUserProfileID(pm.currentUserID); id != "" { - return pm.SwitchProfile(id) - } - pm.NewProfileForUser(pm.currentUserID) - return nil +func (pm *profileManager) SwitchToDefaultProfileForUser(uid ipn.WindowsUserID) (_ ipn.LoginProfileView, changed bool, err error) { + return pm.SwitchToProfile(pm.DefaultUserProfile(uid)) +} + +// SwitchToDefaultProfile is like [profileManager.SwitchToDefaultProfileForUser], but switches +// to the default profile for the current user. +func (pm *profileManager) SwitchToDefaultProfile() (_ ipn.LoginProfileView, changed bool, err error) { + return pm.SwitchToDefaultProfileForUser(pm.currentUserID) } // setProfileAsUserDefault sets the specified profile as the default for the current user. // It returns an [errProfileAccessDenied] if the specified profile is not accessible to the current user. -func (pm *profileManager) setProfileAsUserDefault(profile *ipn.LoginProfile) error { - if profile.Key == "" { +func (pm *profileManager) setProfileAsUserDefault(profile ipn.LoginProfileView) error { + if profile.Key() == "" { // The profile has not been persisted yet; ignore it for now. return nil } @@ -468,7 +641,7 @@ func (pm *profileManager) setProfileAsUserDefault(profile *ipn.LoginProfile) err return errProfileAccessDenied } k := ipn.CurrentProfileKey(string(pm.currentUserID)) - return pm.WriteState(k, []byte(profile.Key)) + return pm.WriteState(k, []byte(profile.Key())) } func (pm *profileManager) loadSavedPrefs(key ipn.StateKey) (ipn.PrefsView, error) { @@ -507,10 +680,10 @@ func (pm *profileManager) loadSavedPrefs(key ipn.StateKey) (ipn.PrefsView, error return savedPrefs.View(), nil } -// CurrentProfile returns the current LoginProfile. +// CurrentProfile returns a read-only [ipn.LoginProfileView] of the current profile. // The value may be zero if the profile is not persisted. -func (pm *profileManager) CurrentProfile() ipn.LoginProfile { - return *pm.currentProfile +func (pm *profileManager) CurrentProfile() ipn.LoginProfileView { + return pm.currentProfile } // errProfileNotFound is returned by methods that accept a ProfileID @@ -532,8 +705,7 @@ var errProfileAccessDenied = errors.New("profile access denied") // This is useful for deleting the last profile. In other cases, it is // recommended to call [profileManager.SwitchProfile] first. func (pm *profileManager) DeleteProfile(id ipn.ProfileID) error { - metricDeleteProfile.Add(1) - if id == pm.currentProfile.ID { + if id == pm.currentProfile.ID() { return pm.deleteCurrentProfile() } kp, ok := pm.knownProfiles[id] @@ -550,9 +722,9 @@ func (pm *profileManager) deleteCurrentProfile() error { if err := pm.checkProfileAccess(pm.currentProfile); err != nil { return err } - if pm.currentProfile.ID == "" { + if pm.currentProfile.ID() == "" { // Deleting the in-memory only new profile, just create a new one. - pm.NewProfile() + pm.SwitchToNewProfile() return nil } return pm.deleteProfileNoPermCheck(pm.currentProfile) @@ -560,14 +732,15 @@ func (pm *profileManager) deleteCurrentProfile() error { // deleteProfileNoPermCheck is like [profileManager.DeleteProfile], // but it doesn't check user's access rights to the profile. -func (pm *profileManager) deleteProfileNoPermCheck(profile *ipn.LoginProfile) error { - if profile.ID == pm.currentProfile.ID { - pm.NewProfile() +func (pm *profileManager) deleteProfileNoPermCheck(profile ipn.LoginProfileView) error { + if profile.ID() == pm.currentProfile.ID() { + pm.SwitchToNewProfile() } - if err := pm.WriteState(profile.Key, nil); err != nil { + if err := pm.WriteState(profile.Key(), nil); err != nil { return err } - delete(pm.knownProfiles, profile.ID) + delete(pm.knownProfiles, profile.ID()) + metricDeleteProfile.Add(1) return pm.writeKnownProfiles() } @@ -578,8 +751,8 @@ func (pm *profileManager) DeleteAllProfilesForUser() error { currentProfileDeleted := false writeKnownProfiles := func() error { - if currentProfileDeleted || pm.currentProfile.ID == "" { - pm.NewProfile() + if currentProfileDeleted || pm.currentProfile.ID() == "" { + pm.SwitchToNewProfile() } return pm.writeKnownProfiles() } @@ -589,14 +762,14 @@ func (pm *profileManager) DeleteAllProfilesForUser() error { // Skip profiles we don't have access to. continue } - if err := pm.WriteState(kp.Key, nil); err != nil { + if err := pm.WriteState(kp.Key(), nil); err != nil { // Write to remove references to profiles we've already deleted, but // return the original error. writeKnownProfiles() return err } - delete(pm.knownProfiles, kp.ID) - if kp.ID == pm.currentProfile.ID { + delete(pm.knownProfiles, kp.ID()) + if kp.ID() == pm.currentProfile.ID() { currentProfileDeleted = true } } @@ -608,6 +781,7 @@ func (pm *profileManager) writeKnownProfiles() error { if err != nil { return err } + metricProfileCount.Set(int64(len(pm.knownProfiles))) return pm.WriteState(ipn.KnownProfilesStateKey, b) } @@ -618,44 +792,25 @@ func (pm *profileManager) updateHealth() { pm.health.SetAutoUpdatePrefs(pm.prefs.AutoUpdate().Check, pm.prefs.AutoUpdate().Apply) } -// NewProfile creates and switches to a new unnamed profile. The new profile is +// SwitchToNewProfile creates and switches to a new unnamed profile. The new profile is // not persisted until [profileManager.SetPrefs] is called with a logged-in user. -func (pm *profileManager) NewProfile() { - pm.NewProfileForUser(pm.currentUserID) +func (pm *profileManager) SwitchToNewProfile() { + pm.SwitchToNewProfileForUser(pm.currentUserID) } -// NewProfileForUser is like [profileManager.NewProfile], but it switches to the +// SwitchToNewProfileForUser is like [profileManager.SwitchToNewProfile], but it switches to the // specified user and sets that user as the profile owner for the new profile. -func (pm *profileManager) NewProfileForUser(uid ipn.WindowsUserID) { - pm.currentUserID = uid - - metricNewProfile.Add(1) - - pm.prefs = defaultPrefs - pm.updateHealth() - pm.currentProfile = &ipn.LoginProfile{LocalUserID: uid} +func (pm *profileManager) SwitchToNewProfileForUser(uid ipn.WindowsUserID) { + pm.SwitchToProfile(pm.NewProfileForUser(uid)) } -// newProfileWithPrefs creates a new profile with the specified prefs and assigns -// the specified uid as the profile owner. If switchNow is true, it switches to the -// newly created profile immediately. It returns the newly created profile on success, -// or an error on failure. -func (pm *profileManager) newProfileWithPrefs(uid ipn.WindowsUserID, prefs ipn.PrefsView, switchNow bool) (*ipn.LoginProfile, error) { - metricNewProfile.Add(1) +// zeroProfile is a read-only view of a new, empty profile that is not persisted to the store. +var zeroProfile = (&ipn.LoginProfile{}).View() - profile := &ipn.LoginProfile{LocalUserID: uid} - if err := pm.SetProfilePrefs(profile, prefs, ipn.NetworkProfile{}); err != nil { - return nil, err - } - if switchNow { - pm.currentProfile = profile - pm.prefs = prefs.AsStruct().View() - pm.updateHealth() - if err := pm.setProfileAsUserDefault(profile); err != nil { - return nil, err - } - } - return profile, nil +// NewProfileForUser creates a new profile for the specified user and returns a read-only view of it. +// It neither switches to the new profile nor persists it to the store. +func (pm *profileManager) NewProfileForUser(uid ipn.WindowsUserID) ipn.LoginProfileView { + return (&ipn.LoginProfile{LocalUserID: uid}).View() } // defaultPrefs is the default prefs for a new profile. This initializes before @@ -711,8 +866,8 @@ func readAutoStartKey(store ipn.StateStore, goos string) (ipn.StateKey, error) { return ipn.StateKey(autoStartKey), nil } -func readKnownProfiles(store ipn.StateStore) (map[ipn.ProfileID]*ipn.LoginProfile, error) { - var knownProfiles map[ipn.ProfileID]*ipn.LoginProfile +func readKnownProfiles(store ipn.StateStore) (map[ipn.ProfileID]ipn.LoginProfileView, error) { + var knownProfiles map[ipn.ProfileID]ipn.LoginProfileView prfB, err := store.ReadState(ipn.KnownProfilesStateKey) switch err { case nil: @@ -720,7 +875,7 @@ func readKnownProfiles(store ipn.StateStore) (map[ipn.ProfileID]*ipn.LoginProfil return nil, fmt.Errorf("unmarshaling known profiles: %w", err) } case ipn.ErrStateNotExist: - knownProfiles = make(map[ipn.ProfileID]*ipn.LoginProfile) + knownProfiles = make(map[ipn.ProfileID]ipn.LoginProfileView) default: return nil, fmt.Errorf("calling ReadState on state store: %w", err) } @@ -739,6 +894,8 @@ func newProfileManagerWithGOOS(store ipn.StateStore, logf logger.Logf, ht *healt return nil, err } + metricProfileCount.Set(int64(len(knownProfiles))) + pm := &profileManager{ goos: goos, store: store, @@ -747,27 +904,9 @@ func newProfileManagerWithGOOS(store ipn.StateStore, logf logger.Logf, ht *healt health: ht, } + var initialProfile ipn.LoginProfileView if stateKey != "" { - for _, v := range knownProfiles { - if v.Key == stateKey { - pm.currentProfile = v - } - } - if pm.currentProfile == nil { - if suf, ok := strings.CutPrefix(string(stateKey), "user-"); ok { - pm.currentUserID = ipn.WindowsUserID(suf) - } - pm.NewProfile() - } else { - pm.currentUserID = pm.currentProfile.LocalUserID - } - prefs, err := pm.loadSavedPrefs(stateKey) - if err != nil { - return nil, err - } - if err := pm.setProfilePrefsNoPermCheck(pm.currentProfile, prefs); err != nil { - return nil, err - } + initialProfile = pm.findProfileByKey("", stateKey) // Most platform behavior is controlled by the goos parameter, however // some behavior is implied by build tag and fails when run on Windows, // so we explicitly avoid that behavior when running on Windows. @@ -778,28 +917,35 @@ func newProfileManagerWithGOOS(store ipn.StateStore, logf logger.Logf, ht *healt } else if len(knownProfiles) == 0 && goos != "windows" && runtime.GOOS != "windows" { // No known profiles, try a migration. pm.dlogf("no known profiles; trying to migrate from legacy prefs") - if _, err := pm.migrateFromLegacyPrefs(pm.currentUserID, true); err != nil { - return nil, err + if initialProfile, err = pm.migrateFromLegacyPrefs(pm.currentUserID); err != nil { + } - } else { - pm.NewProfile() } - + if !initialProfile.Valid() { + var initialUserID ipn.WindowsUserID + if suf, ok := strings.CutPrefix(string(stateKey), "user-"); ok { + initialUserID = ipn.WindowsUserID(suf) + } + initialProfile = pm.NewProfileForUser(initialUserID) + } + if _, _, err := pm.SwitchToProfile(initialProfile); err != nil { + return nil, err + } return pm, nil } -func (pm *profileManager) migrateFromLegacyPrefs(uid ipn.WindowsUserID, switchNow bool) (*ipn.LoginProfile, error) { +func (pm *profileManager) migrateFromLegacyPrefs(uid ipn.WindowsUserID) (ipn.LoginProfileView, error) { metricMigration.Add(1) sentinel, prefs, err := pm.loadLegacyPrefs(uid) if err != nil { metricMigrationError.Add(1) - return nil, fmt.Errorf("load legacy prefs: %w", err) + return ipn.LoginProfileView{}, fmt.Errorf("load legacy prefs: %w", err) } pm.dlogf("loaded legacy preferences; sentinel=%q", sentinel) - profile, err := pm.newProfileWithPrefs(uid, prefs, switchNow) + profile, err := pm.setProfilePrefs(&ipn.LoginProfile{LocalUserID: uid}, prefs, ipn.NetworkProfile{}) if err != nil { metricMigrationError.Add(1) - return nil, fmt.Errorf("migrating _daemon profile: %w", err) + return ipn.LoginProfileView{}, fmt.Errorf("migrating _daemon profile: %w", err) } pm.completeMigration(sentinel) pm.dlogf("completed legacy preferences migration with sentinel=%q", sentinel) @@ -809,8 +955,8 @@ func (pm *profileManager) migrateFromLegacyPrefs(uid ipn.WindowsUserID, switchNo func (pm *profileManager) requiresBackfill() bool { return pm != nil && - pm.currentProfile != nil && - pm.currentProfile.NetworkProfile.RequiresBackfill() + pm.currentProfile.Valid() && + pm.currentProfile.NetworkProfile().RequiresBackfill() } var ( @@ -818,6 +964,7 @@ var ( metricSwitchProfile = clientmetric.NewCounter("profiles_switch") metricDeleteProfile = clientmetric.NewCounter("profiles_delete") metricDeleteAllProfile = clientmetric.NewCounter("profiles_delete_all") + metricProfileCount = clientmetric.NewGauge("profiles_count") metricMigration = clientmetric.NewCounter("profiles_migration") metricMigrationError = clientmetric.NewCounter("profiles_migration_error") diff --git a/ipn/ipnlocal/profiles_test.go b/ipn/ipnlocal/profiles_test.go index 73e4f6535387e..52b095be1a5fe 100644 --- a/ipn/ipnlocal/profiles_test.go +++ b/ipn/ipnlocal/profiles_test.go @@ -7,6 +7,7 @@ import ( "fmt" "os/user" "strconv" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -33,7 +34,7 @@ func TestProfileCurrentUserSwitch(t *testing.T) { newProfile := func(t *testing.T, loginName string) ipn.PrefsView { id++ t.Helper() - pm.NewProfile() + pm.SwitchToNewProfile() p := pm.CurrentPrefs().AsStruct() p.Persist = &persist.Persist{ NodeID: tailcfg.StableNodeID(fmt.Sprint(id)), @@ -52,11 +53,11 @@ func TestProfileCurrentUserSwitch(t *testing.T) { pm.SetCurrentUserID("user1") newProfile(t, "user1") cp := pm.currentProfile - pm.DeleteProfile(cp.ID) - if pm.currentProfile == nil { + pm.DeleteProfile(cp.ID()) + if !pm.currentProfile.Valid() { t.Fatal("currentProfile is nil") - } else if pm.currentProfile.ID != "" { - t.Fatalf("currentProfile.ID = %q, want empty", pm.currentProfile.ID) + } else if pm.currentProfile.ID() != "" { + t.Fatalf("currentProfile.ID = %q, want empty", pm.currentProfile.ID()) } if !pm.CurrentPrefs().Equals(defaultPrefs) { t.Fatalf("CurrentPrefs() = %v, want emptyPrefs", pm.CurrentPrefs().Pretty()) @@ -67,10 +68,10 @@ func TestProfileCurrentUserSwitch(t *testing.T) { t.Fatal(err) } pm.SetCurrentUserID("user1") - if pm.currentProfile == nil { + if !pm.currentProfile.Valid() { t.Fatal("currentProfile is nil") - } else if pm.currentProfile.ID != "" { - t.Fatalf("currentProfile.ID = %q, want empty", pm.currentProfile.ID) + } else if pm.currentProfile.ID() != "" { + t.Fatalf("currentProfile.ID = %q, want empty", pm.currentProfile.ID()) } if !pm.CurrentPrefs().Equals(defaultPrefs) { t.Fatalf("CurrentPrefs() = %v, want emptyPrefs", pm.CurrentPrefs().Pretty()) @@ -88,7 +89,7 @@ func TestProfileList(t *testing.T) { newProfile := func(t *testing.T, loginName string) ipn.PrefsView { id++ t.Helper() - pm.NewProfile() + pm.SwitchToNewProfile() p := pm.CurrentPrefs().AsStruct() p.Persist = &persist.Persist{ NodeID: tailcfg.StableNodeID(fmt.Sprint(id)), @@ -110,8 +111,8 @@ func TestProfileList(t *testing.T) { t.Fatalf("got %d profiles, want %d", len(got), len(want)) } for i, w := range want { - if got[i].Name != w { - t.Errorf("got profile %d name %q, want %q", i, got[i].Name, w) + if got[i].Name() != w { + t.Errorf("got profile %d name %q, want %q", i, got[i].Name(), w) } } } @@ -129,10 +130,10 @@ func TestProfileList(t *testing.T) { pm.SetCurrentUserID("user1") checkProfiles(t, "alice", "bob") - if lp := pm.findProfileByKey(carol.Key); lp != nil { + if lp := pm.findProfileByKey("user1", carol.Key()); lp.Valid() { t.Fatalf("found profile for user2 in user1's profile list") } - if lp := pm.findProfileByName(carol.Name); lp != nil { + if lp := pm.findProfileByName("user1", carol.Name()); lp.Valid() { t.Fatalf("found profile for user2 in user1's profile list") } @@ -162,7 +163,7 @@ func TestProfileDupe(t *testing.T) { must.Do(pm.SetPrefs(prefs.View(), ipn.NetworkProfile{})) } login := func(pm *profileManager, p *persist.Persist) { - pm.NewProfile() + pm.SwitchToNewProfile() reauth(pm, p) } @@ -294,7 +295,7 @@ func TestProfileDupe(t *testing.T) { profs := pm.Profiles() var got []*persist.Persist for _, p := range profs { - prefs, err := pm.loadSavedPrefs(p.Key) + prefs, err := pm.loadSavedPrefs(p.Key()) if err != nil { t.Fatal(err) } @@ -328,9 +329,9 @@ func TestProfileManagement(t *testing.T) { checkProfiles := func(t *testing.T) { t.Helper() prof := pm.CurrentProfile() - t.Logf("\tCurrentProfile = %q", prof) - if prof.Name != wantCurProfile { - t.Fatalf("CurrentProfile = %q; want %q", prof, wantCurProfile) + t.Logf("\tCurrentProfile = %q", prof.Name()) + if prof.Name() != wantCurProfile { + t.Fatalf("CurrentProfile = %q; want %q", prof.Name(), wantCurProfile) } profiles := pm.Profiles() wantLen := len(wantProfiles) @@ -349,13 +350,13 @@ func TestProfileManagement(t *testing.T) { t.Fatalf("CurrentPrefs = %v; want %v", p.Pretty(), wantProfiles[wantCurProfile].Pretty()) } for _, p := range profiles { - got, err := pm.loadSavedPrefs(p.Key) + got, err := pm.loadSavedPrefs(p.Key()) if err != nil { t.Fatal(err) } // Use Hostname as a proxy for all prefs. - if !got.Equals(wantProfiles[p.Name]) { - t.Fatalf("Prefs for profile %q =\n got=%+v\nwant=%v", p, got.Pretty(), wantProfiles[p.Name].Pretty()) + if !got.Equals(wantProfiles[p.Name()]) { + t.Fatalf("Prefs for profile %q =\n got=%+v\nwant=%v", p.Name(), got.Pretty(), wantProfiles[p.Name()].Pretty()) } } } @@ -399,7 +400,7 @@ func TestProfileManagement(t *testing.T) { checkProfiles(t) t.Logf("Create new profile") - pm.NewProfile() + pm.SwitchToNewProfile() wantCurProfile = "" wantProfiles[""] = defaultPrefs checkProfiles(t) @@ -422,7 +423,7 @@ func TestProfileManagement(t *testing.T) { checkProfiles(t) t.Logf("Delete default profile") - if err := pm.DeleteProfile(pm.findProfileByName("user@1.example.com").ID); err != nil { + if err := pm.DeleteProfile(pm.ProfileIDForName("user@1.example.com")); err != nil { t.Fatal(err) } delete(wantProfiles, "user@1.example.com") @@ -438,7 +439,7 @@ func TestProfileManagement(t *testing.T) { checkProfiles(t) t.Logf("Create new profile - 2") - pm.NewProfile() + pm.SwitchToNewProfile() wantCurProfile = "" wantProfiles[""] = defaultPrefs checkProfiles(t) @@ -506,9 +507,9 @@ func TestProfileManagementWindows(t *testing.T) { checkProfiles := func(t *testing.T) { t.Helper() prof := pm.CurrentProfile() - t.Logf("\tCurrentProfile = %q", prof) - if prof.Name != wantCurProfile { - t.Fatalf("CurrentProfile = %q; want %q", prof, wantCurProfile) + t.Logf("\tCurrentProfile = %q", prof.Name()) + if prof.Name() != wantCurProfile { + t.Fatalf("CurrentProfile = %q; want %q", prof.Name(), wantCurProfile) } if p := pm.CurrentPrefs(); !p.Equals(wantProfiles[wantCurProfile]) { t.Fatalf("CurrentPrefs = %+v; want %+v", p.Pretty(), wantProfiles[wantCurProfile].Pretty()) @@ -550,7 +551,7 @@ func TestProfileManagementWindows(t *testing.T) { { t.Logf("Create new profile") - pm.NewProfile() + pm.SwitchToNewProfile() wantCurProfile = "" wantProfiles[""] = defaultPrefs checkProfiles(t) @@ -609,3 +610,535 @@ func TestDefaultPrefs(t *testing.T) { t.Errorf("defaultPrefs is %s, want %s; defaultPrefs should only modify WantRunning and LoggedOut, all other defaults should be in ipn.NewPrefs.", p2.Pretty(), p1.Pretty()) } } + +// mutPrefsFn is a function that mutates the prefs. +// Deserialization pre‑populates prefs with default (non‑zero) values. +// After saving prefs and reading them back, we may not get exactly what we set. +// For this reason, tests apply changes through a helper that mutates +// [ipn.NewPrefs] instead of hard‑coding expected values in each case. +type mutPrefsFn func(*ipn.Prefs) + +type profileState struct { + *ipn.LoginProfile + mutPrefs mutPrefsFn +} + +func (s *profileState) prefs() ipn.PrefsView { + prefs := ipn.NewPrefs() // apply changes to the default prefs + s.mutPrefs(prefs) + return prefs.View() +} + +type profileStateChange struct { + *ipn.LoginProfile + mutPrefs mutPrefsFn + sameNode bool +} + +func wantProfileChange(state profileState) profileStateChange { + return profileStateChange{ + LoginProfile: state.LoginProfile, + mutPrefs: state.mutPrefs, + sameNode: false, + } +} + +func wantPrefsChange(state profileState) profileStateChange { + return profileStateChange{ + LoginProfile: state.LoginProfile, + mutPrefs: state.mutPrefs, + sameNode: true, + } +} + +func makeDefaultPrefs(p *ipn.Prefs) { *p = *defaultPrefs.AsStruct() } + +func makeKnownProfileState(id int, nameSuffix string, uid ipn.WindowsUserID, mutPrefs mutPrefsFn) profileState { + lowerNameSuffix := strings.ToLower(nameSuffix) + nid := "node-" + tailcfg.StableNodeID(lowerNameSuffix) + up := tailcfg.UserProfile{ + ID: tailcfg.UserID(id), + LoginName: fmt.Sprintf("user-%s@example.com", lowerNameSuffix), + DisplayName: "User " + nameSuffix, + } + return profileState{ + LoginProfile: &ipn.LoginProfile{ + LocalUserID: uid, + Name: up.LoginName, + ID: ipn.ProfileID(fmt.Sprintf("%04X", id)), + Key: "profile-" + ipn.StateKey(nameSuffix), + NodeID: nid, + UserProfile: up, + }, + mutPrefs: func(p *ipn.Prefs) { + p.Hostname = "Hostname-" + nameSuffix + if mutPrefs != nil { + mutPrefs(p) // apply any additional changes + } + p.Persist = &persist.Persist{NodeID: nid, UserProfile: up} + }, + } +} + +func TestProfileStateChangeCallback(t *testing.T) { + t.Parallel() + + // A few well-known profiles to use in tests. + emptyProfile := profileState{ + LoginProfile: &ipn.LoginProfile{}, + mutPrefs: makeDefaultPrefs, + } + profile0000 := profileState{ + LoginProfile: &ipn.LoginProfile{ID: "0000", Key: "profile-0000"}, + mutPrefs: makeDefaultPrefs, + } + profileA := makeKnownProfileState(0xA, "A", "", nil) + profileB := makeKnownProfileState(0xB, "B", "", nil) + profileC := makeKnownProfileState(0xC, "C", "", nil) + + aliceUserID := ipn.WindowsUserID("S-1-5-21-1-2-3-4") + aliceEmptyProfile := profileState{ + LoginProfile: &ipn.LoginProfile{LocalUserID: aliceUserID}, + mutPrefs: makeDefaultPrefs, + } + bobUserID := ipn.WindowsUserID("S-1-5-21-3-4-5-6") + bobEmptyProfile := profileState{ + LoginProfile: &ipn.LoginProfile{LocalUserID: bobUserID}, + mutPrefs: makeDefaultPrefs, + } + bobKnownProfile := makeKnownProfileState(0xB0B, "Bob", bobUserID, nil) + + tests := []struct { + name string + initial *profileState // if non-nil, this is the initial profile and prefs to start wit + knownProfiles []profileState // known profiles we can switch to + action func(*profileManager) // action to take on the profile manager + wantChanges []profileStateChange // expected state changes + }{ + { + name: "no-changes", + action: func(*profileManager) { + // do nothing + }, + wantChanges: nil, + }, + { + name: "no-initial/new-profile", + action: func(pm *profileManager) { + // The profile manager is new and started with a new empty profile. + // This should not trigger a state change callback. + pm.SwitchToNewProfile() + }, + wantChanges: nil, + }, + { + name: "no-initial/new-profile-for-user", + action: func(pm *profileManager) { + // But switching to a new profile for a specific user should trigger + // a state change callback. + pm.SwitchToNewProfileForUser(aliceUserID) + }, + wantChanges: []profileStateChange{ + // We want a new empty profile (owned by the specified user) + // and the default prefs. + wantProfileChange(aliceEmptyProfile), + }, + }, + { + name: "with-initial/new-profile", + initial: &profile0000, + action: func(pm *profileManager) { + // And so does switching to a new profile when the initial profile + // is non-empty. + pm.SwitchToNewProfile() + }, + wantChanges: []profileStateChange{ + // We want a new empty profile and the default prefs. + wantProfileChange(emptyProfile), + }, + }, + { + name: "with-initial/new-profile/twice", + initial: &profile0000, + action: func(pm *profileManager) { + // If we switch to a new profile twice, we should only get one state change. + pm.SwitchToNewProfile() + pm.SwitchToNewProfile() + }, + wantChanges: []profileStateChange{ + // We want a new empty profile and the default prefs. + wantProfileChange(emptyProfile), + }, + }, + { + name: "with-initial/new-profile-for-user/twice", + initial: &profile0000, + action: func(pm *profileManager) { + // Unless we switch to a new profile for a specific user, + // in which case we should get a state change twice. + pm.SwitchToNewProfileForUser(aliceUserID) + pm.SwitchToNewProfileForUser(aliceUserID) // no change here + pm.SwitchToNewProfileForUser(bobUserID) + }, + wantChanges: []profileStateChange{ + // Both profiles are empty, but they are owned by different users. + wantProfileChange(aliceEmptyProfile), + wantProfileChange(bobEmptyProfile), + }, + }, + { + name: "with-initial/new-profile/twice/with-prefs-change", + initial: &profile0000, + action: func(pm *profileManager) { + // Or unless we switch to a new profile, change the prefs, + // then switch to a new profile again. Since the current + // profile is not empty after the prefs change, we should + // get state changes for all three actions. + pm.SwitchToNewProfile() + p := pm.CurrentPrefs().AsStruct() + p.WantRunning = true + pm.SetPrefs(p.View(), ipn.NetworkProfile{}) + pm.SwitchToNewProfile() + }, + wantChanges: []profileStateChange{ + wantProfileChange(emptyProfile), // new empty profile + wantPrefsChange(profileState{ // prefs change, same profile + LoginProfile: &ipn.LoginProfile{}, + mutPrefs: func(p *ipn.Prefs) { + *p = *defaultPrefs.AsStruct() + p.WantRunning = true + }, + }), + wantProfileChange(emptyProfile), // new empty profile again + }, + }, + { + name: "switch-to-profile/by-id", + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + // Switching to a known profile by ID should trigger a state change callback. + pm.SwitchToProfileByID(profileB.ID) + }, + wantChanges: []profileStateChange{ + wantProfileChange(profileB), + }, + }, + { + name: "switch-to-profile/by-id/non-existent", + knownProfiles: []profileState{profileA, profileC}, // no profileB + action: func(pm *profileManager) { + // Switching to a non-existent profile should fail and not trigger a state change callback. + pm.SwitchToProfileByID(profileB.ID) + }, + wantChanges: []profileStateChange{}, + }, + { + name: "switch-to-profile/by-id/twice-same", + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + // But only for the first switch. + // The second switch to the same profile should not trigger a state change callback. + pm.SwitchToProfileByID(profileB.ID) + pm.SwitchToProfileByID(profileB.ID) + }, + wantChanges: []profileStateChange{ + wantProfileChange(profileB), + }, + }, + { + name: "switch-to-profile/by-id/many", + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + // Same idea, but with multiple switches. + pm.SwitchToProfileByID(profileB.ID) // switch to Profile-B + pm.SwitchToProfileByID(profileB.ID) // then to Profile-B again (no change) + pm.SwitchToProfileByID(profileC.ID) // then to Profile-C (change) + pm.SwitchToProfileByID(profileA.ID) // then to Profile-A (change) + pm.SwitchToProfileByID(profileB.ID) // then to Profile-B (change) + }, + wantChanges: []profileStateChange{ + wantProfileChange(profileB), + wantProfileChange(profileC), + wantProfileChange(profileA), + wantProfileChange(profileB), + }, + }, + { + name: "switch-to-profile/by-view", + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + // Switching to a known profile by an [ipn.LoginProfileView] + // should also trigger a state change callback. + pm.SwitchToProfile(profileB.View()) + }, + wantChanges: []profileStateChange{ + wantProfileChange(profileB), + }, + }, + { + name: "switch-to-profile/by-view/empty", + initial: &profile0000, + action: func(pm *profileManager) { + // SwitchToProfile supports switching to an empty profile. + emptyProfile := &ipn.LoginProfile{} + pm.SwitchToProfile(emptyProfile.View()) + }, + wantChanges: []profileStateChange{ + wantProfileChange(emptyProfile), + }, + }, + { + name: "switch-to-profile/by-view/non-existent", + knownProfiles: []profileState{profileA, profileC}, + action: func(pm *profileManager) { + // Switching to a an unknown profile by an [ipn.LoginProfileView] + // should fail and not trigger a state change callback. + pm.SwitchToProfile(profileB.View()) + }, + wantChanges: []profileStateChange{}, + }, + { + name: "switch-to-profile/by-view/empty-for-user", + initial: &profile0000, + action: func(pm *profileManager) { + // And switching to an empty profile for a specific user also works. + pm.SwitchToProfile(bobEmptyProfile.View()) + }, + wantChanges: []profileStateChange{ + wantProfileChange(bobEmptyProfile), + }, + }, + { + name: "switch-to-profile/by-view/invalid", + initial: &profile0000, + action: func(pm *profileManager) { + // Switching to an invalid profile should create and switch + // to a new empty profile. + pm.SwitchToProfile(ipn.LoginProfileView{}) + }, + wantChanges: []profileStateChange{ + wantProfileChange(emptyProfile), + }, + }, + { + name: "delete-profile/current", + initial: &profileA, // profileA is the current profile + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + // Deleting the current profile should switch to a new empty profile. + pm.DeleteProfile(profileA.ID) + }, + wantChanges: []profileStateChange{ + wantProfileChange(emptyProfile), + }, + }, + { + name: "delete-profile/current-with-user", + initial: &bobKnownProfile, + knownProfiles: []profileState{profileA, profileB, profileC, bobKnownProfile}, + action: func(pm *profileManager) { + // Similarly, deleting the current profile for a specific user should switch + // to a new empty profile for that user (at least while the "current user" + // is still a thing on Windows). + pm.DeleteProfile(bobKnownProfile.ID) + }, + wantChanges: []profileStateChange{ + wantProfileChange(bobEmptyProfile), + }, + }, + { + name: "delete-profile/non-current", + initial: &profileA, // profileA is the current profile + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + // But deleting a non-current profile should not trigger a state change callback. + pm.DeleteProfile(profileB.ID) + }, + wantChanges: []profileStateChange{}, + }, + { + name: "set-prefs/new-profile", + initial: &emptyProfile, // the current profile is empty + action: func(pm *profileManager) { + // The current profile is new and empty, but we can still set p. + // This should trigger a state change callback. + p := pm.CurrentPrefs().AsStruct() + p.WantRunning = true + p.Hostname = "New-Hostname" + pm.SetPrefs(p.View(), ipn.NetworkProfile{}) + }, + wantChanges: []profileStateChange{ + // Still an empty profile, but with new prefs. + wantPrefsChange(profileState{ + LoginProfile: emptyProfile.LoginProfile, + mutPrefs: func(p *ipn.Prefs) { + *p = *emptyProfile.prefs().AsStruct() + p.WantRunning = true + p.Hostname = "New-Hostname" + }, + }), + }, + }, + { + name: "set-prefs/current-profile", + initial: &profileA, // profileA is the current profile + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + p := pm.CurrentPrefs().AsStruct() + p.WantRunning = true + p.Hostname = "New-Hostname" + pm.SetPrefs(p.View(), ipn.NetworkProfile{}) + }, + wantChanges: []profileStateChange{ + wantPrefsChange(profileState{ + LoginProfile: profileA.LoginProfile, // same profile + mutPrefs: func(p *ipn.Prefs) { // but with new prefs + *p = *profileA.prefs().AsStruct() + p.WantRunning = true + p.Hostname = "New-Hostname" + }, + }), + }, + }, + { + name: "set-prefs/current-profile/profile-name", + initial: &profileA, // profileA is the current profile + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + p := pm.CurrentPrefs().AsStruct() + p.ProfileName = "This is User A" + pm.SetPrefs(p.View(), ipn.NetworkProfile{}) + }, + wantChanges: []profileStateChange{ + // Still the same profile, but with a new profile name + // populated from the prefs. The prefs are also updated. + wantPrefsChange(profileState{ + LoginProfile: func() *ipn.LoginProfile { + p := profileA.Clone() + p.Name = "This is User A" + return p + }(), + mutPrefs: func(p *ipn.Prefs) { + *p = *profileA.prefs().AsStruct() + p.ProfileName = "This is User A" + }, + }), + }, + }, + { + name: "set-prefs/implicit-switch/from-new", + initial: &emptyProfile, // a new, empty profile + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + // The user attempted to add a new profile but actually logged in as the same + // node/user as profileB. When [LocalBackend.SetControlClientStatus] calls + // [profileManager.SetPrefs] with the [persist.Persist] for profileB, we + // implicitly switch to that profile instead of creating a duplicate for the + // same node/user. + // + // TODO(nickkhyl): currently, [LocalBackend.SetControlClientStatus] uses the p + // of the current profile, not those of the profile we switch to. This is all wrong + // and should be fixed. But for now, we just test that the state change callback + // is called with the new profile and p. + p := pm.CurrentPrefs().AsStruct() + p.Persist = profileB.prefs().Persist().AsStruct() + p.WantRunning = true + p.LoggedOut = false + pm.SetPrefs(p.View(), ipn.NetworkProfile{}) + }, + wantChanges: []profileStateChange{ + // Calling [profileManager.SetPrefs] like this is effectively a profile switch + // rather than a prefs change. + wantProfileChange(profileState{ + LoginProfile: profileB.LoginProfile, + mutPrefs: func(p *ipn.Prefs) { + *p = *emptyProfile.prefs().AsStruct() + p.Persist = profileB.prefs().Persist().AsStruct() + p.WantRunning = true + p.LoggedOut = false + }, + }), + }, + }, + { + name: "set-prefs/implicit-switch/from-other", + initial: &profileA, // profileA is the current profile + knownProfiles: []profileState{profileA, profileB, profileC}, + action: func(pm *profileManager) { + // Same idea, but the current profile is profileA rather than a new empty profile. + // Note: this is all wrong. See the comment above and [profileManager.SetPrefs]. + p := pm.CurrentPrefs().AsStruct() + p.Persist = profileB.prefs().Persist().AsStruct() + p.WantRunning = true + p.LoggedOut = false + pm.SetPrefs(p.View(), ipn.NetworkProfile{}) + }, + wantChanges: []profileStateChange{ + wantProfileChange(profileState{ + LoginProfile: profileB.LoginProfile, + mutPrefs: func(p *ipn.Prefs) { + *p = *profileA.prefs().AsStruct() + p.Persist = profileB.prefs().Persist().AsStruct() + p.WantRunning = true + p.LoggedOut = false + }, + }), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + store := new(mem.Store) + pm, err := newProfileManagerWithGOOS(store, logger.Discard, new(health.Tracker), "linux") + if err != nil { + t.Fatalf("newProfileManagerWithGOOS: %v", err) + } + for _, p := range tt.knownProfiles { + pm.writePrefsToStore(p.Key, p.prefs()) + pm.knownProfiles[p.ID] = p.View() + } + if err := pm.writeKnownProfiles(); err != nil { + t.Fatalf("writeKnownProfiles: %v", err) + } + + if tt.initial != nil { + pm.currentUserID = tt.initial.LocalUserID + pm.currentProfile = tt.initial.View() + pm.prefs = tt.initial.prefs() + } + + type stateChange struct { + Profile *ipn.LoginProfile + Prefs *ipn.Prefs + SameNode bool + } + wantChanges := make([]stateChange, 0, len(tt.wantChanges)) + for _, w := range tt.wantChanges { + wantPrefs := ipn.NewPrefs() + w.mutPrefs(wantPrefs) // apply changes to the default prefs + wantChanges = append(wantChanges, stateChange{ + Profile: w.LoginProfile, + Prefs: wantPrefs, + SameNode: w.sameNode, + }) + } + + gotChanges := make([]stateChange, 0, len(tt.wantChanges)) + pm.StateChangeHook = func(profile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + gotChanges = append(gotChanges, stateChange{ + Profile: profile.AsStruct(), + Prefs: prefs.AsStruct(), + SameNode: sameNode, + }) + } + + tt.action(pm) + + if diff := cmp.Diff(wantChanges, gotChanges, defaultCmpOpts...); diff != "" { + t.Errorf("StateChange callbacks: (-want +got): %v", diff) + } + }) + } +} diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go index 67d521f0968eb..44d63fe54a902 100644 --- a/ipn/ipnlocal/serve.go +++ b/ipn/ipnlocal/serve.go @@ -54,8 +54,9 @@ var ErrETagMismatch = errors.New("etag mismatch") var serveHTTPContextKey ctxkey.Key[*serveHTTPContext] type serveHTTPContext struct { - SrcAddr netip.AddrPort - DestPort uint16 + SrcAddr netip.AddrPort + ForVIPService tailcfg.ServiceName // "" means local + DestPort uint16 // provides funnel-specific context, nil if not funneled Funnel *funnelFlow @@ -231,7 +232,7 @@ func (b *LocalBackend) updateServeTCPPortNetMapAddrListenersLocked(ports []uint1 } } - nm := b.netMap + nm := b.NetMap() if nm == nil { b.logf("netMap is nil") return @@ -242,8 +243,7 @@ func (b *LocalBackend) updateServeTCPPortNetMapAddrListenersLocked(ports []uint1 } addrs := nm.GetAddresses() - for i := range addrs.Len() { - a := addrs.At(i) + for _, a := range addrs.All() { for _, p := range ports { addrPort := netip.AddrPortFrom(a.Addr(), p) if _, ok := b.serveListeners[addrPort]; ok { @@ -276,7 +276,13 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string return errors.New("can't reconfigure tailscaled when using a config file; config file is locked") } - nm := b.netMap + if config != nil { + if err := config.CheckValidServicesConfig(); err != nil { + return err + } + } + + nm := b.NetMap() if nm == nil { return errors.New("netMap is nil") } @@ -312,7 +318,7 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string bs = j } - profileID := b.pm.CurrentProfile().ID + profileID := b.pm.CurrentProfile().ID() confKey := ipn.ServeConfigKey(profileID) if err := b.store.WriteState(confKey, bs); err != nil { return fmt.Errorf("writing ServeConfig to StateStore: %w", err) @@ -327,7 +333,7 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string if b.serveConfig.Valid() { has = b.serveConfig.Foreground().Contains } - prevConfig.Foreground().Range(func(k string, v ipn.ServeConfigView) (cont bool) { + for k := range prevConfig.Foreground().All() { if !has(k) { for _, sess := range b.notifyWatchers { if sess.sessionID == k { @@ -335,8 +341,7 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string } } } - return true - }) + } } return nil @@ -434,6 +439,105 @@ func (b *LocalBackend) HandleIngressTCPConn(ingressPeer tailcfg.NodeView, target handler(c) } +// tcpHandlerForVIPService returns a handler for a TCP connection to a VIP service +// that is being served via the ipn.ServeConfig. It returns nil if the destination +// address is not a VIP service or if the VIP service does not have a TCP handler set. +func (b *LocalBackend) tcpHandlerForVIPService(dstAddr, srcAddr netip.AddrPort) (handler func(net.Conn) error) { + b.mu.Lock() + sc := b.serveConfig + ipVIPServiceMap := b.ipVIPServiceMap + b.mu.Unlock() + + if !sc.Valid() { + return nil + } + + dport := dstAddr.Port() + + dstSvc, ok := ipVIPServiceMap[dstAddr.Addr()] + if !ok { + return nil + } + + tcph, ok := sc.FindServiceTCP(dstSvc, dstAddr.Port()) + if !ok { + b.logf("The destination service doesn't have a TCP handler set.") + return nil + } + + if tcph.HTTPS() || tcph.HTTP() { + hs := &http.Server{ + Handler: http.HandlerFunc(b.serveWebHandler), + BaseContext: func(_ net.Listener) context.Context { + return serveHTTPContextKey.WithValue(context.Background(), &serveHTTPContext{ + SrcAddr: srcAddr, + ForVIPService: dstSvc, + DestPort: dport, + }) + }, + } + if tcph.HTTPS() { + // TODO(kevinliang10): just leaving this TLS cert creation as if we don't have other + // hostnames, but for services this getTLSServeCetForPort will need a version that also take + // in the hostname. How to store the TLS cert is still being discussed. + hs.TLSConfig = &tls.Config{ + GetCertificate: b.getTLSServeCertForPort(dport, dstSvc), + } + return func(c net.Conn) error { + return hs.ServeTLS(netutil.NewOneConnListener(c, nil), "", "") + } + } + + return func(c net.Conn) error { + return hs.Serve(netutil.NewOneConnListener(c, nil)) + } + } + + if backDst := tcph.TCPForward(); backDst != "" { + return func(conn net.Conn) error { + defer conn.Close() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + backConn, err := b.dialer.SystemDial(ctx, "tcp", backDst) + cancel() + if err != nil { + b.logf("localbackend: failed to TCP proxy port %v (from %v) to %s: %v", dport, srcAddr, backDst, err) + return nil + } + defer backConn.Close() + if sni := tcph.TerminateTLS(); sni != "" { + conn = tls.Server(conn, &tls.Config{ + GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + pair, err := b.GetCertPEM(ctx, sni) + if err != nil { + return nil, err + } + cert, err := tls.X509KeyPair(pair.CertPEM, pair.KeyPEM) + if err != nil { + return nil, err + } + return &cert, nil + }, + }) + } + + errc := make(chan error, 1) + go func() { + _, err := io.Copy(backConn, conn) + errc <- err + }() + go func() { + _, err := io.Copy(conn, backConn) + errc <- err + }() + return <-errc + } + } + + return nil +} + // tcpHandlerForServe returns a handler for a TCP connection to be served via // the ipn.ServeConfig. The funnelFlow can be nil if this is not a funneled // connection. @@ -464,7 +568,7 @@ func (b *LocalBackend) tcpHandlerForServe(dport uint16, srcAddr netip.AddrPort, } if tcph.HTTPS() { hs.TLSConfig = &tls.Config{ - GetCertificate: b.getTLSServeCertForPort(dport), + GetCertificate: b.getTLSServeCertForPort(dport, ""), } return func(c net.Conn) error { return hs.ServeTLS(netutil.NewOneConnListener(c, nil), "", "") @@ -528,7 +632,7 @@ func (b *LocalBackend) getServeHandler(r *http.Request) (_ ipn.HTTPHandlerView, hostname := r.Host if r.TLS == nil { - tcd := "." + b.Status().CurrentTailnet.MagicDNSSuffix + tcd := "." + b.CurrentProfile().NetworkProfile().MagicDNSName if host, _, err := net.SplitHostPort(hostname); err == nil { hostname = host } @@ -544,7 +648,7 @@ func (b *LocalBackend) getServeHandler(r *http.Request) (_ ipn.HTTPHandlerView, b.logf("[unexpected] localbackend: no serveHTTPContext in request") return z, "", false } - wsc, ok := b.webServerConfig(hostname, sctx.DestPort) + wsc, ok := b.webServerConfig(hostname, sctx.ForVIPService, sctx.DestPort) if !ok { return z, "", false } @@ -902,7 +1006,7 @@ func allNumeric(s string) bool { return s != "" } -func (b *LocalBackend) webServerConfig(hostname string, port uint16) (c ipn.WebServerConfigView, ok bool) { +func (b *LocalBackend) webServerConfig(hostname string, forVIPService tailcfg.ServiceName, port uint16) (c ipn.WebServerConfigView, ok bool) { key := ipn.HostPort(fmt.Sprintf("%s:%v", hostname, port)) b.mu.Lock() @@ -911,15 +1015,18 @@ func (b *LocalBackend) webServerConfig(hostname string, port uint16) (c ipn.WebS if !b.serveConfig.Valid() { return c, false } + if forVIPService != "" { + return b.serveConfig.FindServiceWeb(forVIPService, key) + } return b.serveConfig.FindWeb(key) } -func (b *LocalBackend) getTLSServeCertForPort(port uint16) func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { +func (b *LocalBackend) getTLSServeCertForPort(port uint16, forVIPService tailcfg.ServiceName) func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { return func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { if hi == nil || hi.ServerName == "" { return nil, errors.New("no SNI ServerName") } - _, ok := b.webServerConfig(hi.ServerName, port) + _, ok := b.webServerConfig(hi.ServerName, forVIPService, port) if !ok { return nil, errors.New("no webserver configured for name/port") } diff --git a/ipn/ipnlocal/serve_test.go b/ipn/ipnlocal/serve_test.go index 73e66c2b9db16..b9370f8778e6b 100644 --- a/ipn/ipnlocal/serve_test.go +++ b/ipn/ipnlocal/serve_test.go @@ -296,6 +296,203 @@ func TestServeConfigForeground(t *testing.T) { } } +// TestServeConfigServices tests the side effects of setting the +// Services field in a ServeConfig. The Services field is a map +// of all services the current service host is serving. Unlike what we +// serve for node itself, there is no foreground and no local handlers +// for the services. So the only things we need to test are if the +// services configured are valid and if they correctly set intercept +// functions for netStack. +func TestServeConfigServices(t *testing.T) { + b := newTestBackend(t) + svcIPMap := tailcfg.ServiceIPMappings{ + "svc:foo": []netip.Addr{ + netip.MustParseAddr("100.101.101.101"), + netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:cd96:6565:6565"), + }, + "svc:bar": []netip.Addr{ + netip.MustParseAddr("100.99.99.99"), + netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:cd96:626b:628b"), + }, + } + svcIPMapJSON, err := json.Marshal(svcIPMap) + if err != nil { + t.Fatal(err) + } + + b.currentNode().SetNetMap(&netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Name: "example.ts.net", + CapMap: tailcfg.NodeCapMap{ + tailcfg.NodeAttrServiceHost: []tailcfg.RawMessage{tailcfg.RawMessage(svcIPMapJSON)}, + }, + }).View(), + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + tailcfg.UserID(1): (&tailcfg.UserProfile{ + LoginName: "someone@example.com", + DisplayName: "Some One", + ProfilePicURL: "https://example.com/photo.jpg", + }).View(), + }, + }) + + tests := []struct { + name string + conf *ipn.ServeConfig + expectedErr error + packetDstAddrPort []netip.AddrPort + intercepted bool + }{ + { + name: "no-services", + conf: &ipn.ServeConfig{}, + packetDstAddrPort: []netip.AddrPort{ + netip.MustParseAddrPort("100.101.101.101:443"), + }, + intercepted: false, + }, + { + name: "one-incorrectly-configured-service", + conf: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + }, + Tun: true, + }, + }, + }, + expectedErr: ipn.ErrServiceConfigHasBothTCPAndTun, + }, + { + // one correctly configured service with packet should be intercepted + name: "one-service-intercept-packet", + conf: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + 81: {HTTPS: true}, + }, + }, + }, + }, + packetDstAddrPort: []netip.AddrPort{ + netip.MustParseAddrPort("100.101.101.101:80"), + netip.MustParseAddrPort("[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:80"), + }, + intercepted: true, + }, + { + // one correctly configured service with packet should not be intercepted + name: "one-service-not-intercept-packet", + conf: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + 81: {HTTPS: true}, + }, + }, + }, + }, + packetDstAddrPort: []netip.AddrPort{ + netip.MustParseAddrPort("100.99.99.99:80"), + netip.MustParseAddrPort("[fd7a:115c:a1e0:ab12:4843:cd96:626b:628b]:80"), + netip.MustParseAddrPort("100.101.101.101:82"), + netip.MustParseAddrPort("[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:82"), + }, + intercepted: false, + }, + { + // multiple correctly configured service with packet should be intercepted + name: "multiple-service-intercept-packet", + conf: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + 81: {HTTPS: true}, + }, + }, + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + 81: {HTTPS: true}, + 82: {HTTPS: true}, + }, + }, + }, + }, + packetDstAddrPort: []netip.AddrPort{ + netip.MustParseAddrPort("100.99.99.99:80"), + netip.MustParseAddrPort("[fd7a:115c:a1e0:ab12:4843:cd96:626b:628b]:80"), + netip.MustParseAddrPort("100.101.101.101:81"), + netip.MustParseAddrPort("[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:81"), + }, + intercepted: true, + }, + { + // multiple correctly configured service with packet should not be intercepted + name: "multiple-service-not-intercept-packet", + conf: &ipn.ServeConfig{ + Services: map[tailcfg.ServiceName]*ipn.ServiceConfig{ + "svc:foo": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + 81: {HTTPS: true}, + }, + }, + "svc:bar": { + TCP: map[uint16]*ipn.TCPPortHandler{ + 80: {HTTP: true}, + 81: {HTTPS: true}, + 82: {HTTPS: true}, + }, + }, + }, + }, + packetDstAddrPort: []netip.AddrPort{ + // ips in capmap but port is not hosting service + netip.MustParseAddrPort("100.99.99.99:77"), + netip.MustParseAddrPort("[fd7a:115c:a1e0:ab12:4843:cd96:626b:628b]:77"), + netip.MustParseAddrPort("100.101.101.101:85"), + netip.MustParseAddrPort("[fd7a:115c:a1e0:ab12:4843:cd96:6565:6565]:85"), + // ips not in capmap + netip.MustParseAddrPort("100.102.102.102:80"), + netip.MustParseAddrPort("[fd7a:115c:a1e0:ab12:4843:cd96:6666:6666]:80"), + }, + intercepted: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := b.SetServeConfig(tt.conf, "") + if err != nil && tt.expectedErr != nil { + if !errors.Is(err, tt.expectedErr) { + t.Fatalf("expected error %v,\n got %v", tt.expectedErr, err) + } + return + } + if err != nil { + t.Fatal(err) + } + for _, addrPort := range tt.packetDstAddrPort { + if tt.intercepted != b.ShouldInterceptVIPServiceTCPPort(addrPort) { + if tt.intercepted { + t.Fatalf("expected packet to be intercepted") + } else { + t.Fatalf("expected packet not to be intercepted") + } + } + } + }) + } + +} + func TestServeConfigETag(t *testing.T) { b := newTestBackend(t) @@ -680,11 +877,12 @@ func newTestBackend(t *testing.T) *LocalBackend { logf = logger.WithPrefix(tstest.WhileTestRunningLogger(t), "... ") } - sys := &tsd.System{} + sys := tsd.NewSystem() e, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ SetSubsystem: sys.Set, HealthTracker: sys.HealthTracker(), Metrics: sys.UserMetricsRegistry(), + EventBus: sys.Bus.Get(), }) if err != nil { t.Fatal(err) @@ -701,38 +899,40 @@ func newTestBackend(t *testing.T) *LocalBackend { b.SetVarRoot(dir) pm := must.Get(newProfileManager(new(mem.Store), logf, new(health.Tracker))) - pm.currentProfile = &ipn.LoginProfile{ID: "id0"} + pm.currentProfile = (&ipn.LoginProfile{ID: "id0"}).View() b.pm = pm - b.netMap = &netmap.NetworkMap{ + b.currentNode().SetNetMap(&netmap.NetworkMap{ SelfNode: (&tailcfg.Node{ Name: "example.ts.net", }).View(), - UserProfiles: map[tailcfg.UserID]tailcfg.UserProfile{ - tailcfg.UserID(1): { + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + tailcfg.UserID(1): (&tailcfg.UserProfile{ LoginName: "someone@example.com", DisplayName: "Some One", ProfilePicURL: "https://example.com/photo.jpg", - }, + }).View(), }, - } - b.peers = map[tailcfg.NodeID]tailcfg.NodeView{ - 152: (&tailcfg.Node{ - ID: 152, - ComputedName: "some-peer", - User: tailcfg.UserID(1), - }).View(), - 153: (&tailcfg.Node{ - ID: 153, - ComputedName: "some-tagged-peer", - Tags: []string{"tag:server", "tag:test"}, - User: tailcfg.UserID(1), - }).View(), - } - b.nodeByAddr = map[netip.Addr]tailcfg.NodeID{ - netip.MustParseAddr("100.150.151.152"): 152, - netip.MustParseAddr("100.150.151.153"): 153, - } + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: 152, + ComputedName: "some-peer", + User: tailcfg.UserID(1), + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.150.151.152/32"), + }, + }).View(), + (&tailcfg.Node{ + ID: 153, + ComputedName: "some-tagged-peer", + Tags: []string{"tag:server", "tag:test"}, + User: tailcfg.UserID(1), + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.150.151.153/32"), + }, + }).View(), + }, + }) return b } diff --git a/ipn/ipnlocal/ssh.go b/ipn/ipnlocal/ssh.go index fbeb19bd15bd1..e48b1f2f1286e 100644 --- a/ipn/ipnlocal/ssh.go +++ b/ipn/ipnlocal/ssh.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || (darwin && !ios) || freebsd || openbsd +//go:build (linux && !android) || (darwin && !ios) || freebsd || openbsd || plan9 package ipnlocal @@ -24,10 +24,10 @@ import ( "strings" "sync" - "github.com/tailscale/golang-x-crypto/ssh" "go4.org/mem" + "golang.org/x/crypto/ssh" "tailscale.com/tailcfg" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/util/mak" ) @@ -80,30 +80,32 @@ func (b *LocalBackend) getSSHUsernames(req *tailcfg.C2NSSHUsernamesRequest) (*ta if err != nil { return nil, err } - lineread.Reader(bytes.NewReader(out), func(line []byte) error { + for line := range lineiter.Bytes(out) { line = bytes.TrimSpace(line) if len(line) == 0 || line[0] == '_' { - return nil + continue } add(string(line)) - return nil - }) + } default: - lineread.File("/etc/passwd", func(line []byte) error { + for lr := range lineiter.File("/etc/passwd") { + line, err := lr.Value() + if err != nil { + break + } line = bytes.TrimSpace(line) if len(line) == 0 || line[0] == '#' || line[0] == '_' { - return nil + continue } if mem.HasSuffix(mem.B(line), mem.S("/nologin")) || mem.HasSuffix(mem.B(line), mem.S("/false")) { - return nil + continue } colon := bytes.IndexByte(line, ':') if colon != -1 { add(string(line[:colon])) } - return nil - }) + } } return res, nil } diff --git a/ipn/ipnlocal/ssh_stub.go b/ipn/ipnlocal/ssh_stub.go index 7875ae3111f58..d129084e4c10c 100644 --- a/ipn/ipnlocal/ssh_stub.go +++ b/ipn/ipnlocal/ssh_stub.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build ios || (!linux && !darwin && !freebsd && !openbsd) +//go:build ios || android || (!linux && !darwin && !freebsd && !openbsd && !plan9) package ipnlocal diff --git a/ipn/ipnlocal/state_test.go b/ipn/ipnlocal/state_test.go index bebd0152b5a36..5d9e8b169f0a5 100644 --- a/ipn/ipnlocal/state_test.go +++ b/ipn/ipnlocal/state_test.go @@ -5,26 +5,46 @@ package ipnlocal import ( "context" + "errors" + "net/netip" + "strings" "sync" "sync/atomic" "testing" "time" qt "github.com/frankban/quicktest" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "tailscale.com/control/controlclient" "tailscale.com/envknob" "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" "tailscale.com/ipn/store/mem" + "tailscale.com/net/dns" + "tailscale.com/net/netmon" + "tailscale.com/net/packet" + "tailscale.com/net/tsdial" "tailscale.com/tailcfg" "tailscale.com/tsd" "tailscale.com/tstest" + "tailscale.com/types/dnstype" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" "tailscale.com/types/persist" + "tailscale.com/types/preftype" + "tailscale.com/util/dnsname" + "tailscale.com/util/mak" + "tailscale.com/util/must" "tailscale.com/wgengine" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/magicsock" + "tailscale.com/wgengine/router" + "tailscale.com/wgengine/wgcfg" + "tailscale.com/wgengine/wgint" ) // notifyThrottler receives notifications from an ipn.Backend, blocking @@ -170,6 +190,14 @@ func (cc *mockControl) send(err error, url string, loginFinished bool, nm *netma } } +func (cc *mockControl) authenticated(nm *netmap.NetworkMap) { + if selfUser, ok := nm.UserProfiles[nm.SelfNode.User()]; ok { + cc.persist.UserProfile = *selfUser.AsStruct() + } + cc.persist.NodeID = nm.SelfNode.StableID() + cc.send(nil, "", true, nm) +} + // called records that a particular function name was called. func (cc *mockControl) called(s string) { cc.mu.Lock() @@ -295,10 +323,10 @@ func TestStateMachine(t *testing.T) { c := qt.New(t) logf := tstest.WhileTestRunningLogger(t) - sys := new(tsd.System) + sys := tsd.NewSystem() store := new(testStateStorage) sys.Set(store) - e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) + e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry(), sys.Bus.Get()) if err != nil { t.Fatalf("NewFakeUserspaceEngine: %v", err) } @@ -309,6 +337,7 @@ func TestStateMachine(t *testing.T) { if err != nil { t.Fatalf("NewLocalBackend: %v", err) } + t.Cleanup(b.Shutdown) b.DisablePortMapperForTest() var cc, previousCC *mockControl @@ -734,12 +763,10 @@ func TestStateMachine(t *testing.T) { // b.Shutdown() explicitly ourselves. previousCC.assertShutdown(false) - // Note: unpause happens because ipn needs to get at least one netmap - // on startup, otherwise UIs can't show the node list, login - // name, etc when in state ipn.Stopped. - // Arguably they shouldn't try. But they currently do. nn := notifies.drain(2) - cc.assertCalls("New", "Login") + // We already have a netmap for this node, + // and WantRunning is false, so cc should be paused. + cc.assertCalls("New", "Login", "pause") c.Assert(nn[0].Prefs, qt.IsNotNil) c.Assert(nn[1].State, qt.IsNotNil) c.Assert(nn[0].Prefs.WantRunning(), qt.IsFalse) @@ -750,7 +777,11 @@ func TestStateMachine(t *testing.T) { // When logged in but !WantRunning, ipn leaves us unpaused to retrieve // the first netmap. Simulate that netmap being received, after which // it should pause us, to avoid wasting CPU retrieving unnecessarily - // additional netmap updates. + // additional netmap updates. Since our LocalBackend instance already + // has a netmap, we will reset it to nil to simulate the first netmap + // retrieval. + b.setNetMapLocked(nil) + cc.assertCalls("unpause") // // TODO: really the various GUIs and prefs should be refactored to // not require the netmap structure at all when starting while @@ -852,7 +883,7 @@ func TestStateMachine(t *testing.T) { // The last test case is the most common one: restarting when both // logged in and WantRunning. t.Logf("\n\nStart5") - notifies.expect(1) + notifies.expect(2) c.Assert(b.Start(ipn.Options{}), qt.IsNil) { // NOTE: cc.Shutdown() is correct here, since we didn't call @@ -860,30 +891,32 @@ func TestStateMachine(t *testing.T) { previousCC.assertShutdown(false) cc.assertCalls("New", "Login") - nn := notifies.drain(1) + nn := notifies.drain(2) cc.assertCalls() c.Assert(nn[0].Prefs, qt.IsNotNil) c.Assert(nn[0].Prefs.LoggedOut(), qt.IsFalse) c.Assert(nn[0].Prefs.WantRunning(), qt.IsTrue) - c.Assert(b.State(), qt.Equals, ipn.NoState) + // We're logged in and have a valid netmap, so we should + // be in the Starting state. + c.Assert(nn[1].State, qt.IsNotNil) + c.Assert(*nn[1].State, qt.Equals, ipn.Starting) + c.Assert(b.State(), qt.Equals, ipn.Starting) } // Control server accepts our valid key from before. t.Logf("\n\nLoginFinished5") - notifies.expect(1) + notifies.expect(0) cc.send(nil, "", true, &netmap.NetworkMap{ SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), }) { - nn := notifies.drain(1) + notifies.drain(0) cc.assertCalls() // NOTE: No LoginFinished message since no interactive // login was needed. - c.Assert(nn[0].State, qt.IsNotNil) - c.Assert(ipn.Starting, qt.Equals, *nn[0].State) // NOTE: No prefs change this time. WantRunning stays true. // We were in Starting in the first place, so that doesn't - // change either. + // change either, so we don't expect any notifications. c.Assert(ipn.Starting, qt.Equals, b.State()) } t.Logf("\n\nExpireKey") @@ -929,9 +962,9 @@ func TestStateMachine(t *testing.T) { func TestEditPrefsHasNoKeys(t *testing.T) { logf := tstest.WhileTestRunningLogger(t) - sys := new(tsd.System) + sys := tsd.NewSystem() sys.Set(new(mem.Store)) - e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) + e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry(), sys.Bus.Get()) if err != nil { t.Fatalf("NewFakeUserspaceEngine: %v", err) } @@ -942,13 +975,12 @@ func TestEditPrefsHasNoKeys(t *testing.T) { if err != nil { t.Fatalf("NewLocalBackend: %v", err) } + t.Cleanup(b.Shutdown) b.hostinfo = &tailcfg.Hostinfo{OS: "testos"} b.pm.SetPrefs((&ipn.Prefs{ Persist: &persist.Persist{ PrivateNodeKey: key.NewNode(), OldPrivateNodeKey: key.NewNode(), - - LegacyFrontendPrivateMachineKey: key.NewMachine(), }, }).View(), ipn.NetworkProfile{}) if p := b.pm.CurrentPrefs().Persist(); !p.Valid() || p.PrivateNodeKey().IsZero() { @@ -975,10 +1007,6 @@ func TestEditPrefsHasNoKeys(t *testing.T) { t.Errorf("OldPrivateNodeKey = %v; want zero", p.Persist().OldPrivateNodeKey()) } - if !p.Persist().LegacyFrontendPrivateMachineKey().IsZero() { - t.Errorf("LegacyFrontendPrivateMachineKey = %v; want zero", p.Persist().LegacyFrontendPrivateMachineKey()) - } - if !p.Persist().NetworkLockKey().IsZero() { t.Errorf("NetworkLockKey= %v; want zero", p.Persist().NetworkLockKey()) } @@ -1014,15 +1042,16 @@ func TestWGEngineStatusRace(t *testing.T) { t.Skip("test fails") c := qt.New(t) logf := tstest.WhileTestRunningLogger(t) - sys := new(tsd.System) + sys := tsd.NewSystem() sys.Set(new(mem.Store)) - eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set) + eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.Bus.Get()) c.Assert(err, qt.IsNil) t.Cleanup(eng.Close) sys.Set(eng) b, err := NewLocalBackend(logf, logid.PublicID{}, sys, 0) c.Assert(err, qt.IsNil) + t.Cleanup(b.Shutdown) var cc *mockControl b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) { @@ -1075,3 +1104,449 @@ func TestWGEngineStatusRace(t *testing.T) { wg.Wait() wantState(ipn.Running) } + +// TestEngineReconfigOnStateChange verifies that wgengine is properly reconfigured +// when the LocalBackend's state changes, such as when the user logs in, switches +// profiles, or disconnects from Tailscale. +func TestEngineReconfigOnStateChange(t *testing.T) { + enableLogging := false + connect := &ipn.MaskedPrefs{Prefs: ipn.Prefs{WantRunning: true}, WantRunningSet: true} + disconnect := &ipn.MaskedPrefs{Prefs: ipn.Prefs{WantRunning: false}, WantRunningSet: true} + node1 := testNetmapForNode(1, "node-1", []netip.Prefix{netip.MustParsePrefix("100.64.1.1/32")}) + node2 := testNetmapForNode(2, "node-2", []netip.Prefix{netip.MustParsePrefix("100.64.1.2/32")}) + routesWithQuad100 := func(extra ...netip.Prefix) []netip.Prefix { + return append(extra, netip.MustParsePrefix("100.100.100.100/32")) + } + hostsFor := func(nm *netmap.NetworkMap) map[dnsname.FQDN][]netip.Addr { + var hosts map[dnsname.FQDN][]netip.Addr + appendNode := func(n tailcfg.NodeView) { + addrs := make([]netip.Addr, 0, n.Addresses().Len()) + for _, addr := range n.Addresses().All() { + addrs = append(addrs, addr.Addr()) + } + mak.Set(&hosts, must.Get(dnsname.ToFQDN(n.Name())), addrs) + } + if nm != nil && nm.SelfNode.Valid() { + appendNode(nm.SelfNode) + } + for _, n := range nm.Peers { + appendNode(n) + } + return hosts + } + + tests := []struct { + name string + steps func(*testing.T, *LocalBackend, func() *mockControl) + wantState ipn.State + wantCfg *wgcfg.Config + wantRouterCfg *router.Config + wantDNSCfg *dns.Config + }{ + { + name: "Initial", + // The configs are nil until the the LocalBackend is started. + wantState: ipn.NoState, + wantCfg: nil, + wantRouterCfg: nil, + wantDNSCfg: nil, + }, + { + name: "Start", + steps: func(t *testing.T, lb *LocalBackend, _ func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + }, + // Once started, all configs must be reset and have their zero values. + wantState: ipn.NeedsLogin, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect", + steps: func(t *testing.T, lb *LocalBackend, _ func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + }, + // Same if WantRunning is true, but the auth is not completed yet. + wantState: ipn.NeedsLogin, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + }, + // After the auth is completed, the configs must be updated to reflect the node's netmap. + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Name: "tailscale", + NodeID: node1.SelfNode.StableID(), + Peers: []wgcfg.Peer{}, + Addresses: node1.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node1.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node1), + }, + }, + { + name: "Start/Connect/Login/Disconnect", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + mustDo2(t)(lb.EditPrefs(disconnect)) + }, + // After disconnecting, all configs must be reset and have their zero values. + wantState: ipn.Stopped, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login/NewProfile", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + mustDo(t)(lb.NewProfile()) + }, + // After switching to a new, empty profile, all configs should be reset + // and have their zero values until the auth is completed. + wantState: ipn.NeedsLogin, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login/NewProfile/Login", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + mustDo(t)(lb.NewProfile()) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node2) + }, + // Once the auth is completed, the configs must be updated to reflect the node's netmap. + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Name: "tailscale", + NodeID: node2.SelfNode.StableID(), + Peers: []wgcfg.Peer{}, + Addresses: node2.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node2.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node2), + }, + }, + { + name: "Start/Connect/Login/SwitchProfile", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + profileID := lb.CurrentProfile().ID() + mustDo(t)(lb.NewProfile()) + cc().authenticated(node2) + mustDo(t)(lb.SwitchProfile(profileID)) + }, + // After switching to an existing profile, all configs must be reset + // and have their zero values until the (non-interactive) login is completed. + wantState: ipn.NoState, + wantCfg: &wgcfg.Config{}, + wantRouterCfg: &router.Config{}, + wantDNSCfg: &dns.Config{}, + }, + { + name: "Start/Connect/Login/SwitchProfile/NonInteractiveLogin", + steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { + mustDo(t)(lb.Start(ipn.Options{})) + mustDo2(t)(lb.EditPrefs(connect)) + cc().authenticated(node1) + profileID := lb.CurrentProfile().ID() + mustDo(t)(lb.NewProfile()) + cc().authenticated(node2) + mustDo(t)(lb.SwitchProfile(profileID)) + cc().authenticated(node1) // complete the login + }, + // After switching profiles and completing the auth, the configs + // must be updated to reflect the node's netmap. + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Name: "tailscale", + NodeID: node1.SelfNode.StableID(), + Peers: []wgcfg.Peer{}, + Addresses: node1.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node1.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node1), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lb, engine, cc := newLocalBackendWithMockEngineAndControl(t, enableLogging) + + if tt.steps != nil { + tt.steps(t, lb, cc) + } + + if gotState := lb.State(); gotState != tt.wantState { + t.Errorf("State: got %v; want %v", gotState, tt.wantState) + } + + opts := []cmp.Option{ + cmpopts.EquateComparable(key.NodePublic{}, netip.Addr{}, netip.Prefix{}), + } + if diff := cmp.Diff(tt.wantCfg, engine.Config(), opts...); diff != "" { + t.Errorf("wgcfg.Config(+got -want): %v", diff) + } + if diff := cmp.Diff(tt.wantRouterCfg, engine.RouterConfig(), opts...); diff != "" { + t.Errorf("router.Config(+got -want): %v", diff) + } + if diff := cmp.Diff(tt.wantDNSCfg, engine.DNSConfig(), opts...); diff != "" { + t.Errorf("dns.Config(+got -want): %v", diff) + } + }) + } +} + +func testNetmapForNode(userID tailcfg.UserID, name string, addresses []netip.Prefix) *netmap.NetworkMap { + const ( + domain = "example.com" + magicDNSSuffix = ".test.ts.net" + ) + user := &tailcfg.UserProfile{ + ID: userID, + DisplayName: name, + LoginName: strings.Join([]string{name, domain}, "@"), + } + self := &tailcfg.Node{ + ID: tailcfg.NodeID(1000 + userID), + StableID: tailcfg.StableNodeID("stable-" + name), + User: user.ID, + Name: name + magicDNSSuffix, + Addresses: addresses, + MachineAuthorized: true, + } + return &netmap.NetworkMap{ + SelfNode: self.View(), + Name: self.Name, + Domain: domain, + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + user.ID: user.View(), + }, + } +} + +func mustDo(t *testing.T) func(error) { + t.Helper() + return func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } +} + +func mustDo2(t *testing.T) func(any, error) { + t.Helper() + return func(_ any, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } +} + +func newLocalBackendWithMockEngineAndControl(t *testing.T, enableLogging bool) (*LocalBackend, *mockEngine, func() *mockControl) { + t.Helper() + + logf := logger.Discard + if enableLogging { + logf = tstest.WhileTestRunningLogger(t) + } + + dialer := &tsdial.Dialer{Logf: logf} + dialer.SetNetMon(netmon.NewStatic()) + + sys := tsd.NewSystem() + sys.Set(dialer) + sys.Set(dialer.NetMon()) + + magicConn, err := magicsock.NewConn(magicsock.Options{ + Logf: logf, + NetMon: dialer.NetMon(), + Metrics: sys.UserMetricsRegistry(), + HealthTracker: sys.HealthTracker(), + DisablePortMapper: true, + }) + if err != nil { + t.Fatalf("NewConn failed: %v", err) + } + magicConn.SetNetworkUp(dialer.NetMon().InterfaceState().AnyInterfaceUp()) + sys.Set(magicConn) + + engine := newMockEngine() + sys.Set(engine) + t.Cleanup(func() { + engine.Close() + <-engine.Done() + }) + + lb := newLocalBackendWithSysAndTestControl(t, enableLogging, sys, func(tb testing.TB, opts controlclient.Options) controlclient.Client { + return newClient(tb, opts) + }) + return lb, engine, func() *mockControl { return lb.cc.(*mockControl) } +} + +var _ wgengine.Engine = (*mockEngine)(nil) + +// mockEngine implements [wgengine.Engine]. +type mockEngine struct { + done chan struct{} // closed when Close is called + + mu sync.Mutex // protects all following fields + closed bool + cfg *wgcfg.Config + routerCfg *router.Config + dnsCfg *dns.Config + + filter, jailedFilter *filter.Filter + + statusCb wgengine.StatusCallback +} + +func newMockEngine() *mockEngine { + return &mockEngine{ + done: make(chan struct{}), + } +} + +func (e *mockEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config) error { + e.mu.Lock() + defer e.mu.Unlock() + if e.closed { + return errors.New("engine closed") + } + e.cfg = cfg + e.routerCfg = routerCfg + e.dnsCfg = dnsCfg + return nil +} + +func (e *mockEngine) Config() *wgcfg.Config { + e.mu.Lock() + defer e.mu.Unlock() + return e.cfg +} + +func (e *mockEngine) RouterConfig() *router.Config { + e.mu.Lock() + defer e.mu.Unlock() + return e.routerCfg +} + +func (e *mockEngine) DNSConfig() *dns.Config { + e.mu.Lock() + defer e.mu.Unlock() + return e.dnsCfg +} + +func (e *mockEngine) PeerForIP(netip.Addr) (_ wgengine.PeerForIP, ok bool) { + return wgengine.PeerForIP{}, false +} + +func (e *mockEngine) GetFilter() *filter.Filter { + e.mu.Lock() + defer e.mu.Unlock() + return e.filter +} + +func (e *mockEngine) SetFilter(f *filter.Filter) { + e.mu.Lock() + e.filter = f + e.mu.Unlock() +} + +func (e *mockEngine) GetJailedFilter() *filter.Filter { + e.mu.Lock() + defer e.mu.Unlock() + return e.jailedFilter +} + +func (e *mockEngine) SetJailedFilter(f *filter.Filter) { + e.mu.Lock() + e.jailedFilter = f + e.mu.Unlock() +} + +func (e *mockEngine) SetStatusCallback(cb wgengine.StatusCallback) { + e.mu.Lock() + e.statusCb = cb + e.mu.Unlock() +} + +func (e *mockEngine) RequestStatus() { + e.mu.Lock() + cb := e.statusCb + e.mu.Unlock() + if cb != nil { + cb(&wgengine.Status{AsOf: time.Now()}, nil) + } +} + +func (e *mockEngine) PeerByKey(key.NodePublic) (_ wgint.Peer, ok bool) { + return wgint.Peer{}, false +} + +func (e *mockEngine) SetNetworkMap(*netmap.NetworkMap) {} + +func (e *mockEngine) UpdateStatus(*ipnstate.StatusBuilder) {} + +func (e *mockEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size int, cb func(*ipnstate.PingResult)) { + cb(&ipnstate.PingResult{IP: ip.String(), Err: "not implemented"}) +} + +func (e *mockEngine) InstallCaptureHook(packet.CaptureCallback) {} + +func (e *mockEngine) Close() { + e.mu.Lock() + defer e.mu.Unlock() + if e.closed { + return + } + e.closed = true + close(e.done) +} + +func (e *mockEngine) Done() <-chan struct{} { + return e.done +} diff --git a/ipn/ipnlocal/taildrop.go b/ipn/ipnlocal/taildrop.go deleted file mode 100644 index db7d8e12ab46e..0000000000000 --- a/ipn/ipnlocal/taildrop.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package ipnlocal - -import ( - "maps" - "slices" - "strings" - - "tailscale.com/ipn" -) - -// UpdateOutgoingFiles updates b.outgoingFiles to reflect the given updates and -// sends an ipn.Notify with the full list of outgoingFiles. -func (b *LocalBackend) UpdateOutgoingFiles(updates map[string]*ipn.OutgoingFile) { - b.mu.Lock() - if b.outgoingFiles == nil { - b.outgoingFiles = make(map[string]*ipn.OutgoingFile, len(updates)) - } - maps.Copy(b.outgoingFiles, updates) - outgoingFiles := make([]*ipn.OutgoingFile, 0, len(b.outgoingFiles)) - for _, file := range b.outgoingFiles { - outgoingFiles = append(outgoingFiles, file) - } - b.mu.Unlock() - slices.SortFunc(outgoingFiles, func(a, b *ipn.OutgoingFile) int { - t := a.Started.Compare(b.Started) - if t != 0 { - return t - } - return strings.Compare(a.Name, b.Name) - }) - b.send(ipn.Notify{OutgoingFiles: outgoingFiles}) -} diff --git a/ipn/ipnlocal/web_client.go b/ipn/ipnlocal/web_client.go index ccde9f01dced0..18145d1bb7e46 100644 --- a/ipn/ipnlocal/web_client.go +++ b/ipn/ipnlocal/web_client.go @@ -17,7 +17,7 @@ import ( "sync" "time" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/client/web" "tailscale.com/logtail/backoff" "tailscale.com/net/netutil" @@ -36,16 +36,16 @@ type webClient struct { server *web.Server // or nil, initialized lazily - // lc optionally specifies a LocalClient to use to connect + // lc optionally specifies a local.Client to use to connect // to the localapi for this tailscaled instance. // If nil, a default is used. - lc *tailscale.LocalClient + lc *local.Client } // ConfigureWebClient configures b.web prior to use. -// Specifially, it sets b.web.lc to the provided LocalClient. +// Specifially, it sets b.web.lc to the provided local.Client. // If provided as nil, b.web.lc is cleared out. -func (b *LocalBackend) ConfigureWebClient(lc *tailscale.LocalClient) { +func (b *LocalBackend) ConfigureWebClient(lc *local.Client) { b.webClient.mu.Lock() defer b.webClient.mu.Unlock() b.webClient.lc = lc @@ -116,13 +116,14 @@ func (b *LocalBackend) handleWebClientConn(c net.Conn) error { // for each of the local device's Tailscale IP addresses. This is needed to properly // route local traffic when using kernel networking mode. func (b *LocalBackend) updateWebClientListenersLocked() { - if b.netMap == nil { + nm := b.currentNode().NetMap() + if nm == nil { return } - addrs := b.netMap.GetAddresses() - for i := range addrs.Len() { - addrPort := netip.AddrPortFrom(addrs.At(i).Addr(), webClientPort) + addrs := nm.GetAddresses() + for _, pfx := range addrs.All() { + addrPort := netip.AddrPortFrom(pfx.Addr(), webClientPort) if _, ok := b.webClientListeners[addrPort]; ok { continue // already listening } diff --git a/ipn/ipnlocal/web_client_stub.go b/ipn/ipnlocal/web_client_stub.go index 1dfc8c27c3a09..31735de250d54 100644 --- a/ipn/ipnlocal/web_client_stub.go +++ b/ipn/ipnlocal/web_client_stub.go @@ -9,14 +9,14 @@ import ( "errors" "net" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" ) const webClientPort = 5252 type webClient struct{} -func (b *LocalBackend) ConfigureWebClient(lc *tailscale.LocalClient) {} +func (b *LocalBackend) ConfigureWebClient(lc *local.Client) {} func (b *LocalBackend) webClientGetOrInit() error { return errors.New("not implemented") diff --git a/ipn/ipnserver/actor.go b/ipn/ipnserver/actor.go index 761c9816cab27..dd40924bbf542 100644 --- a/ipn/ipnserver/actor.go +++ b/ipn/ipnserver/actor.go @@ -31,7 +31,13 @@ type actor struct { logf logger.Logf ci *ipnauth.ConnIdentity - isLocalSystem bool // whether the actor is the Windows' Local System identity. + clientID ipnauth.ClientID + userID ipn.WindowsUserID // cached Windows user ID of the connected client process. + // accessOverrideReason specifies the reason for overriding certain access restrictions, + // such as permitting a user to disconnect when the always-on mode is enabled, + // provided that such justification is allowed by the policy. + accessOverrideReason string + isLocalSystem bool // whether the actor is the Windows' Local System identity. } func newActor(logf logger.Logf, c net.Conn) (*actor, error) { @@ -39,7 +45,62 @@ func newActor(logf logger.Logf, c net.Conn) (*actor, error) { if err != nil { return nil, err } - return &actor{logf: logf, ci: ci, isLocalSystem: connIsLocalSystem(ci)}, nil + var clientID ipnauth.ClientID + if pid := ci.Pid(); pid != 0 { + // Derive [ipnauth.ClientID] from the PID of the connected client process. + // TODO(nickkhyl): This is transient and will be re-worked as we + // progress on tailscale/corp#18342. At minimum, we should use a 2-tuple + // (PID + StartTime) or a 3-tuple (PID + StartTime + UID) to identify + // the client process. This helps prevent security issues where a + // terminated client process's PID could be reused by a different + // process. This is not currently an issue as we allow only one user to + // connect anyway. + // Additionally, we should consider caching authentication results since + // operations like retrieving a username by SID might require network + // connectivity on domain-joined devices and/or be slow. + clientID = ipnauth.ClientIDFrom(pid) + } + return &actor{ + logf: logf, + ci: ci, + clientID: clientID, + userID: ci.WindowsUserID(), + isLocalSystem: connIsLocalSystem(ci), + }, + nil +} + +// actorWithAccessOverride returns a new actor that carries the specified +// reason for overriding certain access restrictions, if permitted by the +// policy. If the reason is "", it returns the base actor. +func actorWithAccessOverride(baseActor *actor, reason string) *actor { + if reason == "" { + return baseActor + } + return &actor{ + logf: baseActor.logf, + ci: baseActor.ci, + clientID: baseActor.clientID, + userID: baseActor.userID, + accessOverrideReason: reason, + isLocalSystem: baseActor.isLocalSystem, + } +} + +// CheckProfileAccess implements [ipnauth.Actor]. +func (a *actor) CheckProfileAccess(profile ipn.LoginProfileView, requestedAccess ipnauth.ProfileAccess, auditLogger ipnauth.AuditLogFunc) error { + // TODO(nickkhyl): return errors of more specific types and have them + // translated to the appropriate HTTP status codes in the API handler. + if profile.LocalUserID() != a.UserID() { + return errors.New("the target profile does not belong to the user") + } + switch requestedAccess { + case ipnauth.Disconnect: + // Disconnect is allowed if a user owns the profile and the policy permits it. + return ipnauth.CheckDisconnectPolicy(a, profile, a.accessOverrideReason, auditLogger) + default: + return errors.New("the requested operation is not allowed") + } } // IsLocalSystem implements [ipnauth.Actor]. @@ -54,13 +115,21 @@ func (a *actor) IsLocalAdmin(operatorUID string) bool { // UserID implements [ipnauth.Actor]. func (a *actor) UserID() ipn.WindowsUserID { - return a.ci.WindowsUserID() + return a.userID } func (a *actor) pid() int { return a.ci.Pid() } +// ClientID implements [ipnauth.Actor]. +func (a *actor) ClientID() (_ ipnauth.ClientID, ok bool) { + return a.clientID, a.clientID != ipnauth.NoClientID +} + +// Context implements [ipnauth.Actor]. +func (a *actor) Context() context.Context { return context.Background() } + // Username implements [ipnauth.Actor]. func (a *actor) Username() (string, error) { if a.ci == nil { @@ -75,7 +144,7 @@ func (a *actor) Username() (string, error) { } defer tok.Close() return tok.Username() - case "darwin", "linux": + case "darwin", "linux", "illumos", "solaris": uid, ok := a.ci.Creds().UserID() if !ok { return "", errors.New("missing user ID") @@ -91,11 +160,11 @@ func (a *actor) Username() (string, error) { } type actorOrError struct { - actor *actor + actor ipnauth.Actor err error } -func (a actorOrError) unwrap() (*actor, error) { +func (a actorOrError) unwrap() (ipnauth.Actor, error) { return a.actor, a.err } @@ -110,9 +179,15 @@ func contextWithActor(ctx context.Context, logf logger.Logf, c net.Conn) context return actorKey.WithValue(ctx, actorOrError{actor: actor, err: err}) } -// actorFromContext returns an [actor] associated with ctx, +// NewContextWithActorForTest returns a new context that carries the identity +// of the specified actor. It is used in tests only. +func NewContextWithActorForTest(ctx context.Context, actor ipnauth.Actor) context.Context { + return actorKey.WithValue(ctx, actorOrError{actor: actor}) +} + +// actorFromContext returns an [ipnauth.Actor] associated with ctx, // or an error if the context does not carry an actor's identity. -func actorFromContext(ctx context.Context) (*actor, error) { +func actorFromContext(ctx context.Context) (ipnauth.Actor, error) { return actorKey.Value(ctx).unwrap() } diff --git a/ipn/ipnserver/proxyconnect.go b/ipn/ipnserver/proxyconnect.go index 1094a79f9daf9..030c4efe4a6b0 100644 --- a/ipn/ipnserver/proxyconnect.go +++ b/ipn/ipnserver/proxyconnect.go @@ -14,7 +14,7 @@ import ( ) // handleProxyConnectConn handles a CONNECT request to -// log.tailscale.io (or whatever the configured log server is). This +// log.tailscale.com (or whatever the configured log server is). This // is intended for use by the Windows GUI client to log via when an // exit node is in use, so the logs don't go out via the exit node and // instead go directly, like tailscaled's. The dialer tried to do that diff --git a/ipn/ipnserver/server.go b/ipn/ipnserver/server.go index 73b5e82abee76..a7ded9c0088ec 100644 --- a/ipn/ipnserver/server.go +++ b/ipn/ipnserver/server.go @@ -7,6 +7,7 @@ package ipnserver import ( "context" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -20,8 +21,9 @@ import ( "sync/atomic" "unicode" + "tailscale.com/client/tailscale/apitype" "tailscale.com/envknob" - "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/localapi" "tailscale.com/net/netmon" @@ -30,6 +32,7 @@ import ( "tailscale.com/util/mak" "tailscale.com/util/set" "tailscale.com/util/systemd" + "tailscale.com/util/testenv" ) // Server is an IPN backend and its set of 0 or more active localhost @@ -39,18 +42,11 @@ type Server struct { logf logger.Logf netMon *netmon.Monitor // must be non-nil backendLogID logid.PublicID - // resetOnZero is whether to call bs.Reset on transition from - // 1->0 active HTTP requests. That is, this is whether the backend is - // being run in "client mode" that requires an active GUI - // connection (such as on Windows by default). Even if this - // is true, the ForceDaemon pref can override this. - resetOnZero bool // mu guards the fields that follow. // lock order: mu, then LocalBackend.mu mu sync.Mutex - lastUserID ipn.WindowsUserID // tracks last userid; on change, Reset state for paranoia - activeReqs map[*http.Request]*actor + activeReqs map[*http.Request]ipnauth.Actor backendWaiter waiterSet // of LocalBackend waiters zeroReqWaiter waiterSet // of blockUntilZeroConnections waiters } @@ -194,10 +190,22 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) { defer onDone() if strings.HasPrefix(r.URL.Path, "/localapi/") { - lah := localapi.NewHandler(lb, s.logf, s.backendLogID) - lah.PermitRead, lah.PermitWrite = ci.Permissions(lb.OperatorUserID()) - lah.PermitCert = ci.CanFetchCerts() - lah.Actor = ci + if actor, ok := ci.(*actor); ok { + reason, err := base64.StdEncoding.DecodeString(r.Header.Get(apitype.RequestReasonHeader)) + if err != nil { + http.Error(w, "invalid reason header", http.StatusBadRequest) + return + } + ci = actorWithAccessOverride(actor, string(reason)) + } + + lah := localapi.NewHandler(ci, lb, s.logf, s.backendLogID) + if actor, ok := ci.(*actor); ok { + lah.PermitRead, lah.PermitWrite = actor.Permissions(lb.OperatorUserID()) + lah.PermitCert = actor.CanFetchCerts() + } else if testenv.InTest() { + lah.PermitRead, lah.PermitWrite = true, true + } lah.ServeHTTP(w, r) return } @@ -230,11 +238,11 @@ func (e inUseOtherUserError) Unwrap() error { return e.error } // The returned error, when non-nil, will be of type inUseOtherUserError. // // s.mu must be held. -func (s *Server) checkConnIdentityLocked(ci *actor) error { +func (s *Server) checkConnIdentityLocked(ci ipnauth.Actor) error { // If clients are already connected, verify they're the same user. // This mostly matters on Windows at the moment. if len(s.activeReqs) > 0 { - var active *actor + var active ipnauth.Actor for _, active = range s.activeReqs { break } @@ -251,7 +259,9 @@ func (s *Server) checkConnIdentityLocked(ci *actor) error { if username, err := active.Username(); err == nil { fmt.Fprintf(&b, " by %s", username) } - fmt.Fprintf(&b, ", pid %d", active.pid()) + if active, ok := active.(*actor); ok { + fmt.Fprintf(&b, ", pid %d", active.pid()) + } return inUseOtherUserError{errors.New(b.String())} } } @@ -267,7 +277,7 @@ func (s *Server) checkConnIdentityLocked(ci *actor) error { // // This is primarily used for the Windows GUI, to block until one user's done // controlling the tailscaled process. -func (s *Server) blockWhileIdentityInUse(ctx context.Context, actor *actor) error { +func (s *Server) blockWhileIdentityInUse(ctx context.Context, actor ipnauth.Actor) error { inUse := func() bool { s.mu.Lock() defer s.mu.Unlock() @@ -277,7 +287,18 @@ func (s *Server) blockWhileIdentityInUse(ctx context.Context, actor *actor) erro for inUse() { // Check whenever the connection count drops down to zero. ready, cleanup := s.zeroReqWaiter.add(&s.mu, ctx) - <-ready + if inUse() { + // If the server was in use at the time of the initial check, + // but disconnected and was removed from the activeReqs map + // by the time we registered a waiter, the ready channel + // will never be closed, resulting in a deadlock. To avoid + // this, we can check again after registering the waiter. + // + // This method is planned for complete removal as part of the + // multi-user improvements in tailscale/corp#18342, + // and this approach should be fine as a temporary solution. + <-ready + } cleanup() if err := ctx.Err(); err != nil { return err @@ -291,6 +312,13 @@ func (s *Server) blockWhileIdentityInUse(ctx context.Context, actor *actor) erro // Unix-like platforms and specifies the ID of a local user // (in the os/user.User.Uid string form) who is allowed // to operate tailscaled without being root or using sudo. +// +// Sandboxed macos clients must directly supply, or be able to read, +// an explicit token. Permission is inferred by validating that +// token. Sandboxed macos clients also don't use ipnserver.actor at all +// (and prior to that, they didn't use ipnauth.ConnIdentity) +// +// See safesocket and safesocket_darwin. func (a *actor) Permissions(operatorUID string) (read, write bool) { switch envknob.GOOS() { case "windows": @@ -303,7 +331,7 @@ func (a *actor) Permissions(operatorUID string) (read, write bool) { // checks here. Note that this permission model is being changed in // tailscale/corp#18342. return true, true - case "js": + case "js", "plan9": return true, true } if a.ci.IsUnixSock() { @@ -361,23 +389,13 @@ func (a *actor) CanFetchCerts() bool { // The returned error may be of type [inUseOtherUserError]. // // onDone must be called when the HTTP request is done. -func (s *Server) addActiveHTTPRequest(req *http.Request, actor *actor) (onDone func(), err error) { +func (s *Server) addActiveHTTPRequest(req *http.Request, actor ipnauth.Actor) (onDone func(), err error) { if actor == nil { return nil, errors.New("internal error: nil actor") } lb := s.mustBackend() - // If the connected user changes, reset the backend server state to make - // sure node keys don't leak between users. - var doReset bool - defer func() { - if doReset { - s.logf("identity changed; resetting server") - lb.ResetForClientDisconnect() - } - }() - s.mu.Lock() defer s.mu.Unlock() @@ -392,40 +410,25 @@ func (s *Server) addActiveHTTPRequest(req *http.Request, actor *actor) (onDone f // Tell the LocalBackend about the identity we're now running as, // unless its the SYSTEM user. That user is not a real account and // doesn't have a home directory. - uid, err := lb.SetCurrentUser(actor) - if err != nil { - return nil, err - } - if s.lastUserID != uid { - if s.lastUserID != "" { - doReset = true - } - s.lastUserID = uid - } + lb.SetCurrentUser(actor) } } onDone = func() { s.mu.Lock() + defer s.mu.Unlock() delete(s.activeReqs, req) - remain := len(s.activeReqs) - s.mu.Unlock() - - if remain == 0 && s.resetOnZero { - if lb.InServerMode() { - s.logf("client disconnected; staying alive in server mode") - } else { - s.logf("client disconnected; stopping server") - lb.ResetForClientDisconnect() - } + if len(s.activeReqs) != 0 { + // The server is not idle yet. + return } - // Wake up callers waiting for the server to be idle: - if remain == 0 { - s.mu.Lock() - s.zeroReqWaiter.wakeAll() - s.mu.Unlock() + if envknob.GOOS() == "windows" && !actor.IsLocalSystem() { + lb.SetCurrentUser(nil) } + + // Wake up callers waiting for the server to be idle: + s.zeroReqWaiter.wakeAll() } return onDone, nil @@ -445,7 +448,6 @@ func New(logf logger.Logf, logID logid.PublicID, netMon *netmon.Monitor) *Server backendLogID: logID, logf: logf, netMon: netMon, - resetOnZero: envknob.GOOS() == "windows", } } diff --git a/ipn/ipnserver/server_fortest.go b/ipn/ipnserver/server_fortest.go new file mode 100644 index 0000000000000..9aab3b276d31f --- /dev/null +++ b/ipn/ipnserver/server_fortest.go @@ -0,0 +1,42 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnserver + +import ( + "context" + "net/http" + + "tailscale.com/ipn/ipnauth" +) + +// BlockWhileInUseByOtherForTest blocks while the actor can't connect to the server because +// the server is in use by a different actor. It is used in tests only. +func (s *Server) BlockWhileInUseByOtherForTest(ctx context.Context, actor ipnauth.Actor) error { + return s.blockWhileIdentityInUse(ctx, actor) +} + +// BlockWhileInUseForTest blocks until the server becomes idle (no active requests), +// or the specified context is done. It returns the context's error if it is done. +// It is used in tests only. +func (s *Server) BlockWhileInUseForTest(ctx context.Context) error { + ready, cleanup := s.zeroReqWaiter.add(&s.mu, ctx) + + s.mu.Lock() + busy := len(s.activeReqs) != 0 + s.mu.Unlock() + + if busy { + <-ready + } + cleanup() + return ctx.Err() +} + +// ServeHTTPForTest responds to a single LocalAPI HTTP request. +// The request's context carries the actor that made the request +// and can be created with [NewContextWithActorForTest]. +// It is used in tests only. +func (s *Server) ServeHTTPForTest(w http.ResponseWriter, r *http.Request) { + s.serveHTTP(w, r) +} diff --git a/ipn/ipnserver/server_test.go b/ipn/ipnserver/server_test.go index b7d5ea144c408..903cb6b738331 100644 --- a/ipn/ipnserver/server_test.go +++ b/ipn/ipnserver/server_test.go @@ -1,46 +1,269 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -package ipnserver +package ipnserver_test import ( "context" + "runtime" + "strconv" "sync" "testing" + + "tailscale.com/client/local" + "tailscale.com/envknob" + "tailscale.com/ipn" + "tailscale.com/ipn/lapitest" + "tailscale.com/types/ptr" ) -func TestWaiterSet(t *testing.T) { - var s waiterSet +func TestUserConnectDisconnectNonWindows(t *testing.T) { + enableLogging := false + if runtime.GOOS == "windows" { + setGOOSForTest(t, "linux") + } + + ctx := context.Background() + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) + + // UserA connects and starts watching the IPN bus. + clientA := server.ClientWithName("UserA") + watcherA, _ := clientA.WatchIPNBus(ctx, 0) + + // The concept of "current user" is only relevant on Windows + // and it should not be set on non-Windows platforms. + server.CheckCurrentUser(nil) + + // Additionally, a different user should be able to connect and use the LocalAPI. + clientB := server.ClientWithName("UserB") + if _, gotErr := clientB.Status(ctx); gotErr != nil { + t.Fatalf("Status(%q): want nil; got %v", clientB.Username(), gotErr) + } + + // Watching the IPN bus should also work for UserB. + watcherB, _ := clientB.WatchIPNBus(ctx, 0) + + // And if we send a notification, both users should receive it. + wantErrMessage := "test error" + testNotify := ipn.Notify{ErrMessage: ptr.To(wantErrMessage)} + server.Backend().DebugNotify(testNotify) + + if n, err := watcherA.Next(); err != nil { + t.Fatalf("IPNBusWatcher.Next(%q): %v", clientA.Username(), err) + } else if gotErrMessage := n.ErrMessage; gotErrMessage == nil || *gotErrMessage != wantErrMessage { + t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientA.Username(), wantErrMessage, gotErrMessage) + } + + if n, err := watcherB.Next(); err != nil { + t.Fatalf("IPNBusWatcher.Next(%q): %v", clientB.Username(), err) + } else if gotErrMessage := n.ErrMessage; gotErrMessage == nil || *gotErrMessage != wantErrMessage { + t.Fatalf("IPNBusWatcher.Next(%q): want %v; got %v", clientB.Username(), wantErrMessage, gotErrMessage) + } +} + +func TestUserConnectDisconnectOnWindows(t *testing.T) { + enableLogging := false + setGOOSForTest(t, "windows") + + ctx := context.Background() + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) + + client := server.ClientWithName("User") + _, cancelWatcher := client.WatchIPNBus(ctx, 0) + + // On Windows, however, the current user should be set to the user that connected. + server.CheckCurrentUser(client.Actor) + + // Cancel the IPN bus watcher request and wait for the server to unblock. + cancelWatcher() + server.BlockWhileInUse(ctx) + + // The current user should not be set after a disconnect, as no one is + // currently using the server. + server.CheckCurrentUser(nil) +} + +func TestIPNAlreadyInUseOnWindows(t *testing.T) { + enableLogging := false + setGOOSForTest(t, "windows") + + ctx := context.Background() + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) + + // UserA connects and starts watching the IPN bus. + clientA := server.ClientWithName("UserA") + clientA.WatchIPNBus(ctx, 0) + + // While UserA is connected, UserB should not be able to connect. + clientB := server.ClientWithName("UserB") + if _, gotErr := clientB.Status(ctx); gotErr == nil { + t.Fatalf("Status(%q): want error; got nil", clientB.Username()) + } else if wantError := "401 Unauthorized: Tailscale already in use by UserA"; gotErr.Error() != wantError { + t.Fatalf("Status(%q): want %q; got %q", clientB.Username(), wantError, gotErr.Error()) + } + + // Current user should still be UserA. + server.CheckCurrentUser(clientA.Actor) +} + +func TestSequentialOSUserSwitchingOnWindows(t *testing.T) { + enableLogging := false + setGOOSForTest(t, "windows") + + ctx := context.Background() + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) + + connectDisconnectAsUser := func(name string) { + // User connects and starts watching the IPN bus. + client := server.ClientWithName(name) + watcher, cancelWatcher := client.WatchIPNBus(ctx, 0) + defer cancelWatcher() + go pumpIPNBus(watcher) - wantLen := func(want int, when string) { - t.Helper() - if got := len(s); got != want { - t.Errorf("%s: len = %v; want %v", when, got, want) + // It should be the current user from the LocalBackend's perspective... + server.CheckCurrentUser(client.Actor) + // until it disconnects. + cancelWatcher() + server.BlockWhileInUse(ctx) + // Now, the current user should be unset. + server.CheckCurrentUser(nil) + } + + // UserA logs in, uses Tailscale for a bit, then logs out. + connectDisconnectAsUser("UserA") + // Same for UserB. + connectDisconnectAsUser("UserB") +} + +func TestConcurrentOSUserSwitchingOnWindows(t *testing.T) { + enableLogging := false + setGOOSForTest(t, "windows") + + ctx := context.Background() + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) + + connectDisconnectAsUser := func(name string) { + // User connects and starts watching the IPN bus. + client := server.ClientWithName(name) + watcher, cancelWatcher := client.WatchIPNBus(ctx, ipn.NotifyInitialState) + defer cancelWatcher() + + runtime.Gosched() + + // Get the current user from the LocalBackend's perspective + // as soon as we're connected. + gotUID, gotActor := server.Backend().CurrentUserForTest() + + // Wait for the first notification to arrive. + // It will either be the initial state we've requested via [ipn.NotifyInitialState], + // returned by an actual handler, or a "fake" notification sent by the server + // itself to indicate that it is being used by someone else. + n, err := watcher.Next() + if err != nil { + t.Fatal(err) + } + + // If our user lost the race and the IPN is in use by another user, + // we should just return. For the sake of this test, we're not + // interested in waiting for the server to become idle. + if n.State != nil && *n.State == ipn.InUseOtherUser { + return + } + + // Otherwise, our user should have been the current user since the time we connected. + if gotUID != client.Actor.UserID() { + t.Errorf("CurrentUser(Initial): got UID %q; want %q", gotUID, client.Actor.UserID()) + return + } + if hasActor := gotActor != nil; !hasActor || gotActor != client.Actor { + t.Errorf("CurrentUser(Initial): got %v; want %v", gotActor, client.Actor) + return + } + + // And should still be the current user (as they're still connected)... + server.CheckCurrentUser(client.Actor) + } + + numIterations := 10 + for range numIterations { + numGoRoutines := 100 + var wg sync.WaitGroup + wg.Add(numGoRoutines) + for i := range numGoRoutines { + // User logs in, uses Tailscale for a bit, then logs out + // in parallel with other users doing the same. + go func() { + defer wg.Done() + connectDisconnectAsUser("User-" + strconv.Itoa(i)) + }() + } + wg.Wait() + + if err := server.BlockWhileInUse(ctx); err != nil { + t.Fatalf("BlockUntilIdle: %v", err) } + + server.CheckCurrentUser(nil) } - wantLen(0, "initial") - var mu sync.Mutex - ctx, cancel := context.WithCancel(context.Background()) +} + +func TestBlockWhileIdentityInUse(t *testing.T) { + enableLogging := false + setGOOSForTest(t, "windows") - ready, cleanup := s.add(&mu, ctx) - wantLen(1, "after add") + ctx := context.Background() + server := lapitest.NewServer(t, lapitest.WithLogging(enableLogging)) - select { - case <-ready: - t.Fatal("should not be ready") - default: + // connectWaitDisconnectAsUser connects as a user with the specified name + // and keeps the IPN bus watcher alive until the context is canceled. + // It returns a channel that is closed when done. + connectWaitDisconnectAsUser := func(ctx context.Context, name string) <-chan struct{} { + client := server.ClientWithName(name) + watcher, cancelWatcher := client.WatchIPNBus(ctx, 0) + + done := make(chan struct{}) + go func() { + defer cancelWatcher() + defer close(done) + for { + _, err := watcher.Next() + if err != nil { + // There's either an error or the request has been canceled. + break + } + } + }() + return done } - s.wakeAll() - <-ready - wantLen(1, "after fire") - cleanup() - wantLen(0, "after cleanup") + for range 100 { + // Connect as UserA, and keep the connection alive + // until disconnectUserA is called. + userAContext, disconnectUserA := context.WithCancel(ctx) + userADone := connectWaitDisconnectAsUser(userAContext, "UserA") + disconnectUserA() + // Check if userB can connect. Calling it directly increases + // the likelihood of triggering a deadlock due to a race condition + // in blockWhileIdentityInUse. But the issue also occurs during + // the normal execution path when UserB connects to the IPN server + // while UserA is disconnecting. + userB := server.MakeTestActor("UserB", "ClientB") + server.BlockWhileInUseByOther(ctx, userB) + <-userADone + } +} - // And again but on an already-expired ctx. - cancel() - ready, cleanup = s.add(&mu, ctx) - <-ready // shouldn't block - cleanup() - wantLen(0, "at end") +func setGOOSForTest(tb testing.TB, goos string) { + tb.Helper() + envknob.Setenv("TS_DEBUG_FAKE_GOOS", goos) + tb.Cleanup(func() { envknob.Setenv("TS_DEBUG_FAKE_GOOS", "") }) +} + +func pumpIPNBus(watcher *local.IPNBusWatcher) { + for { + _, err := watcher.Next() + if err != nil { + break + } + } } diff --git a/ipn/ipnserver/waiterset_test.go b/ipn/ipnserver/waiterset_test.go new file mode 100644 index 0000000000000..b7d5ea144c408 --- /dev/null +++ b/ipn/ipnserver/waiterset_test.go @@ -0,0 +1,46 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package ipnserver + +import ( + "context" + "sync" + "testing" +) + +func TestWaiterSet(t *testing.T) { + var s waiterSet + + wantLen := func(want int, when string) { + t.Helper() + if got := len(s); got != want { + t.Errorf("%s: len = %v; want %v", when, got, want) + } + } + wantLen(0, "initial") + var mu sync.Mutex + ctx, cancel := context.WithCancel(context.Background()) + + ready, cleanup := s.add(&mu, ctx) + wantLen(1, "after add") + + select { + case <-ready: + t.Fatal("should not be ready") + default: + } + s.wakeAll() + <-ready + + wantLen(1, "after fire") + cleanup() + wantLen(0, "after cleanup") + + // And again but on an already-expired ctx. + cancel() + ready, cleanup = s.add(&mu, ctx) + <-ready // shouldn't block + cleanup() + wantLen(0, "at end") +} diff --git a/ipn/ipnstate/ipnstate.go b/ipn/ipnstate/ipnstate.go index 9f8bd34f61033..89c6d7e24dbc5 100644 --- a/ipn/ipnstate/ipnstate.go +++ b/ipn/ipnstate/ipnstate.go @@ -216,6 +216,11 @@ type PeerStatusLite struct { } // PeerStatus describes a peer node and its current state. +// WARNING: The fields in PeerStatus are merged by the AddPeer method in the StatusBuilder. +// When adding a new field to PeerStatus, you must update AddPeer to handle merging +// the new field. The AddPeer function is responsible for combining multiple updates +// to the same peer, and any new field that is not merged properly may lead to +// inconsistencies or lost data in the peer status. type PeerStatus struct { ID tailcfg.StableNodeID PublicKey key.NodePublic @@ -270,6 +275,12 @@ type PeerStatus struct { // PeerAPIURL are the URLs of the node's PeerAPI servers. PeerAPIURL []string + // TaildropTargetStatus represents the node's eligibility to have files shared to it. + TaildropTarget TaildropTargetStatus + + // Reason why this peer cannot receive files. Empty if CanReceiveFiles=true + NoFileSharingReason string + // Capabilities are capabilities that the node has. // They're free-form strings, but should be in the form of URLs/URIs // such as: @@ -318,6 +329,21 @@ type PeerStatus struct { Location *tailcfg.Location `json:",omitempty"` } +type TaildropTargetStatus int + +const ( + TaildropTargetUnknown TaildropTargetStatus = iota + TaildropTargetAvailable + TaildropTargetNoNetmapAvailable + TaildropTargetIpnStateNotRunning + TaildropTargetMissingCap + TaildropTargetOffline + TaildropTargetNoPeerInfo + TaildropTargetUnsupportedOS + TaildropTargetNoPeerAPI + TaildropTargetOwnedByOtherUser +) + // HasCap reports whether ps has the given capability. func (ps *PeerStatus) HasCap(cap tailcfg.NodeCapability) bool { return ps.CapMap.Contains(cap) @@ -367,7 +393,7 @@ func (sb *StatusBuilder) MutateSelfStatus(f func(*PeerStatus)) { } // AddUser adds a user profile to the status. -func (sb *StatusBuilder) AddUser(id tailcfg.UserID, up tailcfg.UserProfile) { +func (sb *StatusBuilder) AddUser(id tailcfg.UserID, up tailcfg.UserProfileView) { if sb.locked { log.Printf("[unexpected] ipnstate: AddUser after Locked") return @@ -377,7 +403,7 @@ func (sb *StatusBuilder) AddUser(id tailcfg.UserID, up tailcfg.UserProfile) { sb.st.User = make(map[tailcfg.UserID]tailcfg.UserProfile) } - sb.st.User[id] = up + sb.st.User[id] = *up.AsStruct() } // AddIP adds a Tailscale IP address to the status. @@ -512,6 +538,9 @@ func (sb *StatusBuilder) AddPeer(peer key.NodePublic, st *PeerStatus) { if v := st.Capabilities; v != nil { e.Capabilities = v } + if v := st.TaildropTarget; v != TaildropTargetUnknown { + e.TaildropTarget = v + } e.Location = st.Location } @@ -650,6 +679,8 @@ func osEmoji(os string) string { return "🐡" case "illumos": return "â˜€ī¸" + case "solaris": + return "đŸŒ¤ī¸" } return "đŸ‘Ŋ" } diff --git a/ipn/lapitest/backend.go b/ipn/lapitest/backend.go new file mode 100644 index 0000000000000..ddf48fb2893d8 --- /dev/null +++ b/ipn/lapitest/backend.go @@ -0,0 +1,63 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lapitest + +import ( + "testing" + + "tailscale.com/control/controlclient" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/store/mem" + "tailscale.com/types/logid" + "tailscale.com/wgengine" +) + +// NewBackend returns a new [ipnlocal.LocalBackend] for testing purposes. +// It fails the test if the specified options are invalid or if the backend cannot be created. +func NewBackend(tb testing.TB, opts ...Option) *ipnlocal.LocalBackend { + tb.Helper() + options, err := newOptions(tb, opts...) + if err != nil { + tb.Fatalf("NewBackend: %v", err) + } + return newBackend(options) +} + +func newBackend(opts *options) *ipnlocal.LocalBackend { + tb := opts.TB() + tb.Helper() + + sys := opts.Sys() + if _, ok := sys.StateStore.GetOK(); !ok { + sys.Set(&mem.Store{}) + } + + e, err := wgengine.NewFakeUserspaceEngine(opts.Logf(), sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry(), sys.Bus.Get()) + if err != nil { + opts.tb.Fatalf("NewFakeUserspaceEngine: %v", err) + } + tb.Cleanup(e.Close) + sys.Set(e) + + b, err := ipnlocal.NewLocalBackend(opts.Logf(), logid.PublicID{}, sys, 0) + if err != nil { + tb.Fatalf("NewLocalBackend: %v", err) + } + tb.Cleanup(b.Shutdown) + b.DisablePortMapperForTest() + b.SetControlClientGetterForTesting(opts.MakeControlClient) + return b +} + +// NewUnreachableControlClient is a [NewControlFn] that creates +// a new [controlclient.Client] for an unreachable control server. +func NewUnreachableControlClient(tb testing.TB, opts controlclient.Options) (controlclient.Client, error) { + tb.Helper() + opts.ServerURL = "https://127.0.0.1:1" + cc, err := controlclient.New(opts) + if err != nil { + tb.Fatal(err) + } + return cc, nil +} diff --git a/ipn/lapitest/client.go b/ipn/lapitest/client.go new file mode 100644 index 0000000000000..6d22e938b210e --- /dev/null +++ b/ipn/lapitest/client.go @@ -0,0 +1,71 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lapitest + +import ( + "context" + "testing" + + "tailscale.com/client/local" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" +) + +// Client wraps a [local.Client] for testing purposes. +// It can be created using [Server.Client], [Server.ClientWithName], +// or [Server.ClientFor] and sends requests as the specified actor +// to the associated [Server]. +type Client struct { + tb testing.TB + // Client is the underlying [local.Client] wrapped by the test client. + // It is configured to send requests to the test server on behalf of the actor. + *local.Client + // Actor represents the user on whose behalf this client is making requests. + // The server uses it to determine the client's identity and permissions. + // The test can mutate the user to alter the actor's identity or permissions + // before making a new request. It is typically an [ipnauth.TestActor], + // unless the [Client] was created with s specific actor using [Server.ClientFor]. + Actor ipnauth.Actor +} + +// Username returns username of the client's owner. +func (c *Client) Username() string { + c.tb.Helper() + name, err := c.Actor.Username() + if err != nil { + c.tb.Fatalf("Client.Username: %v", err) + } + return name +} + +// WatchIPNBus is like [local.Client.WatchIPNBus] but returns a [local.IPNBusWatcher] +// that is closed when the test ends and a cancel function that stops the watcher. +// It fails the test if the underlying WatchIPNBus returns an error. +func (c *Client) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (*local.IPNBusWatcher, context.CancelFunc) { + c.tb.Helper() + ctx, cancelWatcher := context.WithCancel(ctx) + c.tb.Cleanup(cancelWatcher) + watcher, err := c.Client.WatchIPNBus(ctx, mask) + name, _ := c.Actor.Username() + if err != nil { + c.tb.Fatalf("Client.WatchIPNBus(%q): %v", name, err) + } + c.tb.Cleanup(func() { watcher.Close() }) + return watcher, cancelWatcher +} + +// generateSequentialName generates a unique sequential name based on the given prefix and number n. +// It uses a base-26 encoding to create names like "User-A", "User-B", ..., "User-Z", "User-AA", etc. +func generateSequentialName(prefix string, n int) string { + n++ + name := "" + const numLetters = 'Z' - 'A' + 1 + for n > 0 { + n-- + remainder := byte(n % numLetters) + name = string([]byte{'A' + remainder}) + name + n = n / numLetters + } + return prefix + "-" + name +} diff --git a/ipn/lapitest/example_test.go b/ipn/lapitest/example_test.go new file mode 100644 index 0000000000000..57479199a8123 --- /dev/null +++ b/ipn/lapitest/example_test.go @@ -0,0 +1,80 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lapitest + +import ( + "context" + "testing" + + "tailscale.com/ipn" +) + +func TestClientServer(t *testing.T) { + t.Parallel() + + ctx := context.Background() + + // Create a server and two clients. + // Both clients represent the same user to make this work across platforms. + // On Windows we've been restricting the API usage to a single user at a time. + // While we're planning on changing this once a better permission model is in place, + // this test is currently limited to a single user (but more than one client is fine). + // Alternatively, we could override GOOS via envknobs to test as if we're + // on a different platform, but that would make the test depend on global state, etc. + s := NewServer(t, WithLogging(false)) + c1 := s.ClientWithName("User-A") + c2 := s.ClientWithName("User-A") + + // Start watching the IPN bus as the second client. + w2, _ := c2.WatchIPNBus(context.Background(), ipn.NotifyInitialPrefs) + + // We're supposed to get a notification about the initial prefs, + // and WantRunning should be false. + n, err := w2.Next() + for ; err == nil; n, err = w2.Next() { + if n.Prefs == nil { + // Ignore non-prefs notifications. + continue + } + if n.Prefs.WantRunning() { + t.Errorf("WantRunning(initial): got %v, want false", n.Prefs.WantRunning()) + } + break + } + if err != nil { + t.Fatalf("IPNBusWatcher.Next failed: %v", err) + } + + // Now send an EditPrefs request from the first client to set WantRunning to true. + change := &ipn.MaskedPrefs{Prefs: ipn.Prefs{WantRunning: true}, WantRunningSet: true} + gotPrefs, err := c1.EditPrefs(ctx, change) + if err != nil { + t.Fatalf("EditPrefs failed: %v", err) + } + if !gotPrefs.WantRunning { + t.Fatalf("EditPrefs.WantRunning: got %v, want true", gotPrefs.WantRunning) + } + + // We can check the backend directly to see if the prefs were set correctly. + if gotWantRunning := s.Backend().Prefs().WantRunning(); !gotWantRunning { + t.Fatalf("Backend.Prefs.WantRunning: got %v, want true", gotWantRunning) + } + + // And can also wait for the second client with an IPN bus watcher to receive the notification + // about the prefs change. + n, err = w2.Next() + for ; err == nil; n, err = w2.Next() { + if n.Prefs == nil { + // Ignore non-prefs notifications. + continue + } + if !n.Prefs.WantRunning() { + t.Fatalf("WantRunning(changed): got %v, want true", n.Prefs.WantRunning()) + } + break + } + if err != nil { + t.Fatalf("IPNBusWatcher.Next failed: %v", err) + } +} diff --git a/ipn/lapitest/opts.go b/ipn/lapitest/opts.go new file mode 100644 index 0000000000000..6eb1594da2607 --- /dev/null +++ b/ipn/lapitest/opts.go @@ -0,0 +1,170 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lapitest + +import ( + "context" + "errors" + "fmt" + "testing" + + "tailscale.com/control/controlclient" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/tsd" + "tailscale.com/tstest" + "tailscale.com/types/lazy" + "tailscale.com/types/logger" +) + +// Option is any optional configuration that can be passed to [NewServer] or [NewBackend]. +type Option interface { + apply(*options) error +} + +// options is the merged result of all applied [Option]s. +type options struct { + tb testing.TB + ctx lazy.SyncValue[context.Context] + logf lazy.SyncValue[logger.Logf] + sys lazy.SyncValue[*tsd.System] + newCC lazy.SyncValue[NewControlFn] + backend lazy.SyncValue[*ipnlocal.LocalBackend] +} + +// newOptions returns a new [options] struct with the specified [Option]s applied. +func newOptions(tb testing.TB, opts ...Option) (*options, error) { + options := &options{tb: tb} + for _, opt := range opts { + if err := opt.apply(options); err != nil { + return nil, fmt.Errorf("lapitest: %w", err) + } + } + return options, nil +} + +// TB returns the owning [*testing.T] or [*testing.B]. +func (o *options) TB() testing.TB { + return o.tb +} + +// Context returns the base context to be used by the server. +func (o *options) Context() context.Context { + return o.ctx.Get(context.Background) +} + +// Logf returns the [logger.Logf] to be used for logging. +func (o *options) Logf() logger.Logf { + return o.logf.Get(func() logger.Logf { return logger.Discard }) +} + +// Sys returns the [tsd.System] that contains subsystems to be used +// when creating a new [ipnlocal.LocalBackend]. +func (o *options) Sys() *tsd.System { + return o.sys.Get(func() *tsd.System { return tsd.NewSystem() }) +} + +// Backend returns the [ipnlocal.LocalBackend] to be used by the server. +// If a backend is provided via [WithBackend], it is used as-is. +// Otherwise, a new backend is created with the the [options] in o. +func (o *options) Backend() *ipnlocal.LocalBackend { + return o.backend.Get(func() *ipnlocal.LocalBackend { return newBackend(o) }) +} + +// MakeControlClient returns a new [controlclient.Client] to be used by newly +// created [ipnlocal.LocalBackend]s. It is only used if no backend is provided +// via [WithBackend]. +func (o *options) MakeControlClient(opts controlclient.Options) (controlclient.Client, error) { + newCC := o.newCC.Get(func() NewControlFn { return NewUnreachableControlClient }) + return newCC(o.tb, opts) +} + +type loggingOption struct{ enableLogging bool } + +// WithLogging returns an [Option] that enables or disables logging. +func WithLogging(enableLogging bool) Option { + return loggingOption{enableLogging: enableLogging} +} + +func (o loggingOption) apply(opts *options) error { + var logf logger.Logf + if o.enableLogging { + logf = tstest.WhileTestRunningLogger(opts.tb) + } else { + logf = logger.Discard + } + if !opts.logf.Set(logf) { + return errors.New("logging already configured") + } + return nil +} + +type contextOption struct{ ctx context.Context } + +// WithContext returns an [Option] that sets the base context to be used by the [Server]. +func WithContext(ctx context.Context) Option { + return contextOption{ctx: ctx} +} + +func (o contextOption) apply(opts *options) error { + if !opts.ctx.Set(o.ctx) { + return errors.New("context already configured") + } + return nil +} + +type sysOption struct{ sys *tsd.System } + +// WithSys returns an [Option] that sets the [tsd.System] to be used +// when creating a new [ipnlocal.LocalBackend]. +func WithSys(sys *tsd.System) Option { + return sysOption{sys: sys} +} + +func (o sysOption) apply(opts *options) error { + if !opts.sys.Set(o.sys) { + return errors.New("tsd.System already configured") + } + return nil +} + +type backendOption struct{ backend *ipnlocal.LocalBackend } + +// WithBackend returns an [Option] that configures the server to use the specified +// [ipnlocal.LocalBackend] instead of creating a new one. +// It is mutually exclusive with [WithControlClient]. +func WithBackend(backend *ipnlocal.LocalBackend) Option { + return backendOption{backend: backend} +} + +func (o backendOption) apply(opts *options) error { + if _, ok := opts.backend.Peek(); ok { + return errors.New("backend cannot be set when control client is already set") + } + if !opts.backend.Set(o.backend) { + return errors.New("backend already set") + } + return nil +} + +// NewControlFn is any function that creates a new [controlclient.Client] +// with the specified options. +type NewControlFn func(tb testing.TB, opts controlclient.Options) (controlclient.Client, error) + +// WithControlClient returns an option that specifies a function to be used +// by the [ipnlocal.LocalBackend] when creating a new [controlclient.Client]. +// It is mutually exclusive with [WithBackend] and is only used if no backend +// has been provided. +func WithControlClient(newControl NewControlFn) Option { + return newControl +} + +func (fn NewControlFn) apply(opts *options) error { + if _, ok := opts.backend.Peek(); ok { + return errors.New("control client cannot be set when backend is already set") + } + if !opts.newCC.Set(fn) { + return errors.New("control client already set") + } + return nil +} diff --git a/ipn/lapitest/server.go b/ipn/lapitest/server.go new file mode 100644 index 0000000000000..d477dc1828549 --- /dev/null +++ b/ipn/lapitest/server.go @@ -0,0 +1,324 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package lapitest provides utilities for black-box testing of LocalAPI ([ipnserver]). +package lapitest + +import ( + "context" + "fmt" + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" + + "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/envknob" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/ipnserver" + "tailscale.com/types/logger" + "tailscale.com/types/logid" + "tailscale.com/types/ptr" + "tailscale.com/util/mak" + "tailscale.com/util/rands" +) + +// A Server is an in-process LocalAPI server that can be used in end-to-end tests. +type Server struct { + tb testing.TB + + ctx context.Context + cancelCtx context.CancelFunc + + lb *ipnlocal.LocalBackend + ipnServer *ipnserver.Server + + // mu protects the following fields. + mu sync.Mutex + started bool + httpServer *httptest.Server + actorsByName map[string]*ipnauth.TestActor + lastClientID int +} + +// NewUnstartedServer returns a new [Server] with the specified options without starting it. +func NewUnstartedServer(tb testing.TB, opts ...Option) *Server { + tb.Helper() + options, err := newOptions(tb, opts...) + if err != nil { + tb.Fatalf("invalid options: %v", err) + } + + s := &Server{tb: tb, lb: options.Backend()} + s.ctx, s.cancelCtx = context.WithCancel(options.Context()) + s.ipnServer = newUnstartedIPNServer(options) + s.httpServer = httptest.NewUnstartedServer(http.HandlerFunc(s.serveHTTP)) + s.httpServer.Config.Addr = "http://" + apitype.LocalAPIHost + s.httpServer.Config.BaseContext = func(_ net.Listener) context.Context { return s.ctx } + s.httpServer.Config.ErrorLog = logger.StdLogger(logger.WithPrefix(options.Logf(), "lapitest: ")) + tb.Cleanup(s.Close) + return s +} + +// NewServer starts and returns a new [Server] with the specified options. +func NewServer(tb testing.TB, opts ...Option) *Server { + tb.Helper() + server := NewUnstartedServer(tb, opts...) + server.Start() + return server +} + +// Start starts the server from [NewUnstartedServer]. +func (s *Server) Start() { + s.tb.Helper() + s.mu.Lock() + defer s.mu.Unlock() + if !s.started && s.httpServer != nil { + s.httpServer.Start() + s.started = true + } +} + +// Backend returns the underlying [ipnlocal.LocalBackend]. +func (s *Server) Backend() *ipnlocal.LocalBackend { + s.tb.Helper() + return s.lb +} + +// Client returns a new [Client] configured for making requests to the server +// as a new [ipnauth.TestActor] with a unique username and [ipnauth.ClientID]. +func (s *Server) Client() *Client { + s.tb.Helper() + user := s.MakeTestActor("", "") // generate a unique username and client ID + return s.ClientFor(user) +} + +// ClientWithName returns a new [Client] configured for making requests to the server +// as a new [ipnauth.TestActor] with the specified name and a unique [ipnauth.ClientID]. +func (s *Server) ClientWithName(name string) *Client { + s.tb.Helper() + user := s.MakeTestActor(name, "") // generate a unique client ID + return s.ClientFor(user) +} + +// ClientFor returns a new [Client] configured for making requests to the server +// as the specified actor. +func (s *Server) ClientFor(actor ipnauth.Actor) *Client { + s.tb.Helper() + client := &Client{ + tb: s.tb, + Actor: actor, + } + client.Client = &local.Client{Transport: newRoundTripper(client, s.httpServer)} + return client +} + +// MakeTestActor returns a new [ipnauth.TestActor] with the specified name and client ID. +// If the name is empty, a unique sequential name is generated. Likewise, +// if clientID is empty, a unique sequential client ID is generated. +func (s *Server) MakeTestActor(name string, clientID string) *ipnauth.TestActor { + s.tb.Helper() + + s.mu.Lock() + defer s.mu.Unlock() + + // Generate a unique sequential name if the provided name is empty. + if name == "" { + n := len(s.actorsByName) + name = generateSequentialName("User", n) + } + + if clientID == "" { + s.lastClientID += 1 + clientID = fmt.Sprintf("Client-%d", s.lastClientID) + } + + // Create a new base actor if one doesn't already exist for the given name. + baseActor := s.actorsByName[name] + if baseActor == nil { + baseActor = &ipnauth.TestActor{Name: name} + if envknob.GOOS() == "windows" { + // Historically, as of 2025-04-15, IPN does not distinguish between + // different users on non-Windows devices. Therefore, the UID, which is + // an [ipn.WindowsUserID], should only be populated when the actual or + // fake GOOS is Windows. + baseActor.UID = ipn.WindowsUserID(fmt.Sprintf("S-1-5-21-1-0-0-%d", 1001+len(s.actorsByName))) + } + mak.Set(&s.actorsByName, name, baseActor) + s.tb.Cleanup(func() { delete(s.actorsByName, name) }) + } + + // Create a shallow copy of the base actor and assign it the new client ID. + actor := ptr.To(*baseActor) + actor.CID = ipnauth.ClientIDFrom(clientID) + return actor +} + +// BlockWhileInUse blocks until the server becomes idle (no active requests), +// or the context is done. It returns the context's error if it is done. +// It is used in tests only. +func (s *Server) BlockWhileInUse(ctx context.Context) error { + s.tb.Helper() + s.mu.Lock() + defer s.mu.Unlock() + if s.httpServer == nil { + return nil + } + return s.ipnServer.BlockWhileInUseForTest(ctx) +} + +// BlockWhileInUseByOther blocks while the specified actor can't connect to the server +// due to another actor being connected. +// It is used in tests only. +func (s *Server) BlockWhileInUseByOther(ctx context.Context, actor ipnauth.Actor) error { + s.tb.Helper() + s.mu.Lock() + defer s.mu.Unlock() + if s.httpServer == nil { + return nil + } + return s.ipnServer.BlockWhileInUseByOtherForTest(ctx, actor) +} + +// CheckCurrentUser fails the test if the current user does not match the expected user. +// It is only used on Windows and will be removed as we progress on tailscale/corp#18342. +func (s *Server) CheckCurrentUser(want ipnauth.Actor) { + s.tb.Helper() + var wantUID ipn.WindowsUserID + if want != nil { + wantUID = want.UserID() + } + lb := s.Backend() + if lb == nil { + s.tb.Fatalf("Backend: nil") + } + gotUID, gotActor := lb.CurrentUserForTest() + if gotUID != wantUID { + s.tb.Errorf("CurrentUser: got UID %q; want %q", gotUID, wantUID) + } + if hasActor := gotActor != nil; hasActor != (want != nil) || (want != nil && gotActor != want) { + s.tb.Errorf("CurrentUser: got %v; want %v", gotActor, want) + } +} + +func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) { + actor, err := getActorForRequest(r) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + s.tb.Errorf("getActorForRequest: %v", err) + return + } + ctx := ipnserver.NewContextWithActorForTest(r.Context(), actor) + s.ipnServer.ServeHTTPForTest(w, r.Clone(ctx)) +} + +// Close shuts down the server and blocks until all outstanding requests on this server have completed. +func (s *Server) Close() { + s.tb.Helper() + s.mu.Lock() + server := s.httpServer + s.httpServer = nil + s.mu.Unlock() + + if server != nil { + server.Close() + } + s.cancelCtx() +} + +// newUnstartedIPNServer returns a new [ipnserver.Server] that exposes +// the specified [ipnlocal.LocalBackend] via LocalAPI, but does not start it. +// The opts carry additional configuration options. +func newUnstartedIPNServer(opts *options) *ipnserver.Server { + opts.TB().Helper() + lb := opts.Backend() + server := ipnserver.New(opts.Logf(), logid.PublicID{}, lb.NetMon()) + server.SetLocalBackend(lb) + return server +} + +// roundTripper is a [http.RoundTripper] that sends requests to a [Server] +// on behalf of the [Client] who owns it. +type roundTripper struct { + client *Client + transport http.RoundTripper +} + +// newRoundTripper returns a new [http.RoundTripper] that sends requests +// to the specified server as the specified client. +func newRoundTripper(client *Client, server *httptest.Server) http.RoundTripper { + return &roundTripper{ + client: client, + transport: &http.Transport{DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + var std net.Dialer + return std.DialContext(ctx, network, server.Listener.Addr().(*net.TCPAddr).String()) + }}, + } +} + +// requestIDHeaderName is the name of the header used to pass request IDs +// between the client and server. It is used to associate requests with their actors. +const requestIDHeaderName = "TS-Request-ID" + +// RoundTrip implements [http.RoundTripper] by sending the request to the [ipnserver.Server] +// on behalf of the owning [Client]. It registers each request for the duration +// of the call and associates it with the actor sending the request. +func (rt *roundTripper) RoundTrip(r *http.Request) (*http.Response, error) { + reqID, unregister := registerRequest(rt.client.Actor) + defer unregister() + r = r.Clone(r.Context()) + r.Header.Set(requestIDHeaderName, reqID) + return rt.transport.RoundTrip(r) +} + +// getActorForRequest returns the actor for a given request. +// It returns an error if the request is not associated with an actor, +// such as when it wasn't sent by a [roundTripper]. +func getActorForRequest(r *http.Request) (ipnauth.Actor, error) { + reqID := r.Header.Get(requestIDHeaderName) + if reqID == "" { + return nil, fmt.Errorf("missing %s header", requestIDHeaderName) + } + actor, ok := getActorByRequestID(reqID) + if !ok { + return nil, fmt.Errorf("unknown request: %s", reqID) + } + return actor, nil +} + +var ( + inFlightRequestsMu sync.Mutex + inFlightRequests map[string]ipnauth.Actor +) + +// registerRequest associates a request with the specified actor and returns a unique request ID +// which can be used to retrieve the actor later. The returned function unregisters the request. +func registerRequest(actor ipnauth.Actor) (requestID string, unregister func()) { + inFlightRequestsMu.Lock() + defer inFlightRequestsMu.Unlock() + for { + requestID = rands.HexString(16) + if _, ok := inFlightRequests[requestID]; !ok { + break + } + } + mak.Set(&inFlightRequests, requestID, actor) + return requestID, func() { + inFlightRequestsMu.Lock() + defer inFlightRequestsMu.Unlock() + delete(inFlightRequests, requestID) + } +} + +// getActorByRequestID returns the actor associated with the specified request ID. +// It returns the actor and true if found, or nil and false if not. +func getActorByRequestID(requestID string) (ipnauth.Actor, bool) { + inFlightRequestsMu.Lock() + defer inFlightRequestsMu.Unlock() + actor, ok := inFlightRequests[requestID] + return actor, ok +} diff --git a/ipn/localapi/debugderp.go b/ipn/localapi/debugderp.go index 85eb031e6fd0a..6636fd2535e4f 100644 --- a/ipn/localapi/debugderp.go +++ b/ipn/localapi/debugderp.go @@ -4,6 +4,7 @@ package localapi import ( + "cmp" "context" "crypto/tls" "encoding/json" @@ -81,7 +82,7 @@ func (h *Handler) serveDebugDERPRegion(w http.ResponseWriter, r *http.Request) { client *http.Client = http.DefaultClient ) checkConn := func(derpNode *tailcfg.DERPNode) bool { - port := firstNonzero(derpNode.DERPPort, 443) + port := cmp.Or(derpNode.DERPPort, 443) var ( hasIPv4 bool @@ -89,7 +90,7 @@ func (h *Handler) serveDebugDERPRegion(w http.ResponseWriter, r *http.Request) { ) // Check IPv4 first - addr := net.JoinHostPort(firstNonzero(derpNode.IPv4, derpNode.HostName), strconv.Itoa(port)) + addr := net.JoinHostPort(cmp.Or(derpNode.IPv4, derpNode.HostName), strconv.Itoa(port)) conn, err := dialer.DialContext(ctx, "tcp4", addr) if err != nil { st.Errors = append(st.Errors, fmt.Sprintf("Error connecting to node %q @ %q over IPv4: %v", derpNode.HostName, addr, err)) @@ -98,7 +99,7 @@ func (h *Handler) serveDebugDERPRegion(w http.ResponseWriter, r *http.Request) { // Upgrade to TLS and verify that works properly. tlsConn := tls.Client(conn, &tls.Config{ - ServerName: firstNonzero(derpNode.CertName, derpNode.HostName), + ServerName: cmp.Or(derpNode.CertName, derpNode.HostName), }) if err := tlsConn.HandshakeContext(ctx); err != nil { st.Errors = append(st.Errors, fmt.Sprintf("Error upgrading connection to node %q @ %q to TLS over IPv4: %v", derpNode.HostName, addr, err)) @@ -108,7 +109,7 @@ func (h *Handler) serveDebugDERPRegion(w http.ResponseWriter, r *http.Request) { } // Check IPv6 - addr = net.JoinHostPort(firstNonzero(derpNode.IPv6, derpNode.HostName), strconv.Itoa(port)) + addr = net.JoinHostPort(cmp.Or(derpNode.IPv6, derpNode.HostName), strconv.Itoa(port)) conn, err = dialer.DialContext(ctx, "tcp6", addr) if err != nil { st.Errors = append(st.Errors, fmt.Sprintf("Error connecting to node %q @ %q over IPv6: %v", derpNode.HostName, addr, err)) @@ -117,7 +118,7 @@ func (h *Handler) serveDebugDERPRegion(w http.ResponseWriter, r *http.Request) { // Upgrade to TLS and verify that works properly. tlsConn := tls.Client(conn, &tls.Config{ - ServerName: firstNonzero(derpNode.CertName, derpNode.HostName), + ServerName: cmp.Or(derpNode.CertName, derpNode.HostName), // TODO(andrew-d): we should print more // detailed failure information on if/why TLS // verification fails @@ -166,7 +167,7 @@ func (h *Handler) serveDebugDERPRegion(w http.ResponseWriter, r *http.Request) { addr = addrs[0] } - addrPort := netip.AddrPortFrom(addr, uint16(firstNonzero(derpNode.STUNPort, 3478))) + addrPort := netip.AddrPortFrom(addr, uint16(cmp.Or(derpNode.STUNPort, 3478))) txID := stun.NewTxID() req := stun.Request(txID) @@ -230,8 +231,14 @@ func (h *Handler) serveDebugDERPRegion(w http.ResponseWriter, r *http.Request) { connSuccess := checkConn(derpNode) // Verify that the /generate_204 endpoint works - captivePortalURL := "http://" + derpNode.HostName + "/generate_204" - resp, err := client.Get(captivePortalURL) + captivePortalURL := fmt.Sprintf("http://%s/generate_204?t=%d", derpNode.HostName, time.Now().Unix()) + req, err := http.NewRequest("GET", captivePortalURL, nil) + if err != nil { + st.Warnings = append(st.Warnings, fmt.Sprintf("Internal error creating request for captive portal check: %v", err)) + continue + } + req.Header.Set("Cache-Control", "no-cache, no-store, must-revalidate, no-transform, max-age=0") + resp, err := client.Do(req) if err != nil { st.Warnings = append(st.Warnings, fmt.Sprintf("Error making request to the captive portal check %q; is port 80 blocked?", captivePortalURL)) } else { @@ -292,13 +299,3 @@ func (h *Handler) serveDebugDERPRegion(w http.ResponseWriter, r *http.Request) { // issued in the first place, tell them specifically that the // cert is bad not just that the connection failed. } - -func firstNonzero[T comparable](items ...T) T { - var zero T - for _, item := range items { - if item != zero { - return item - } - } - return zero -} diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index 528304bab77d4..99cb7c95b4dff 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -14,12 +14,8 @@ import ( "errors" "fmt" "io" - "maps" - "mime" - "mime/multipart" "net" "net/http" - "net/http/httputil" "net/netip" "net/url" "os" @@ -46,7 +42,6 @@ import ( "tailscale.com/net/netutil" "tailscale.com/net/portmapper" "tailscale.com/tailcfg" - "tailscale.com/taildrop" "tailscale.com/tka" "tailscale.com/tstime" "tailscale.com/types/dnstype" @@ -56,38 +51,44 @@ import ( "tailscale.com/types/ptr" "tailscale.com/types/tkatype" "tailscale.com/util/clientmetric" - "tailscale.com/util/httphdr" + "tailscale.com/util/eventbus" "tailscale.com/util/httpm" "tailscale.com/util/mak" "tailscale.com/util/osdiag" - "tailscale.com/util/progresstracking" "tailscale.com/util/rands" - "tailscale.com/util/testenv" + "tailscale.com/util/syspolicy/rsop" + "tailscale.com/util/syspolicy/setting" "tailscale.com/version" "tailscale.com/wgengine/magicsock" ) -type localAPIHandler func(*Handler, http.ResponseWriter, *http.Request) +var ( + metricInvalidRequests = clientmetric.NewCounter("localapi_invalid_requests") + metricDebugMetricsCalls = clientmetric.NewCounter("localapi_debugmetric_requests") + metricUserMetricsCalls = clientmetric.NewCounter("localapi_usermetric_requests") + metricBugReportRequests = clientmetric.NewCounter("localapi_bugreport_requests") +) + +type LocalAPIHandler func(*Handler, http.ResponseWriter, *http.Request) // handler is the set of LocalAPI handlers, keyed by the part of the // Request.URL.Path after "/localapi/v0/". If the key ends with a trailing slash // then it's a prefix match. -var handler = map[string]localAPIHandler{ +var handler = map[string]LocalAPIHandler{ // The prefix match handlers end with a slash: "cert/": (*Handler).serveCert, - "file-put/": (*Handler).serveFilePut, - "files/": (*Handler).serveFiles, + "policy/": (*Handler).servePolicy, "profiles/": (*Handler).serveProfiles, // The other /localapi/v0/NAME handlers are exact matches and contain only NAME // without a trailing slash: + "alpha-set-device-attrs": (*Handler).serveSetDeviceAttrs, // see tailscale/corp#24690 "bugreport": (*Handler).serveBugReport, "check-ip-forwarding": (*Handler).serveCheckIPForwarding, "check-prefs": (*Handler).serveCheckPrefs, "check-udp-gro-forwarding": (*Handler).serveCheckUDPGROForwarding, "component-debug-logging": (*Handler).serveComponentDebugLogging, "debug": (*Handler).serveDebug, - "debug-capture": (*Handler).serveDebugCapture, "debug-derp-region": (*Handler).serveDebugDERPRegion, "debug-dial-types": (*Handler).serveDebugDialTypes, "debug-log": (*Handler).serveDebugLog, @@ -98,11 +99,11 @@ var handler = map[string]localAPIHandler{ "derpmap": (*Handler).serveDERPMap, "dev-set-state-store": (*Handler).serveDevSetStateStore, "dial": (*Handler).serveDial, + "disconnect-control": (*Handler).disconnectControl, "dns-osconfig": (*Handler).serveDNSOSConfig, "dns-query": (*Handler).serveDNSQuery, "drive/fileserver-address": (*Handler).serveDriveServerAddr, "drive/shares": (*Handler).serveShares, - "file-targets": (*Handler).serveFileTargets, "goroutines": (*Handler).serveGoroutines, "handle-push-message": (*Handler).serveHandlePushMessage, "id-token": (*Handler).serveIDToken, @@ -148,6 +149,14 @@ var handler = map[string]localAPIHandler{ "whois": (*Handler).serveWhoIs, } +// Register registers a new LocalAPI handler for the given name. +func Register(name string, fn LocalAPIHandler) { + if _, ok := handler[name]; ok { + panic("duplicate LocalAPI handler registration: " + name) + } + handler[name] = fn +} + var ( // The clientmetrics package is stateful, but we want to expose a simple // imperative API to local clients, so we need to keep track of @@ -158,10 +167,9 @@ var ( metrics = map[string]*clientmetric.Metric{} ) -// NewHandler creates a new LocalAPI HTTP handler. All parameters except netMon -// are required (if non-nil it's used to do faster interface lookups). -func NewHandler(b *ipnlocal.LocalBackend, logf logger.Logf, logID logid.PublicID) *Handler { - return &Handler{b: b, logf: logf, backendLogID: logID, clock: tstime.StdClock{}} +// NewHandler creates a new LocalAPI HTTP handler. All parameters are required. +func NewHandler(actor ipnauth.Actor, b *ipnlocal.LocalBackend, logf logger.Logf, logID logid.PublicID) *Handler { + return &Handler{Actor: actor, b: b, logf: logf, backendLogID: logID, clock: tstime.StdClock{}} } type Handler struct { @@ -192,6 +200,14 @@ type Handler struct { clock tstime.Clock } +func (h *Handler) Logf(format string, args ...any) { + h.logf(format, args...) +} + +func (h *Handler) LocalBackend() *ipnlocal.LocalBackend { + return h.b +} + func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if h.b == nil { http.Error(w, "server has no local backend", http.StatusInternalServerError) @@ -256,7 +272,7 @@ func (h *Handler) validHost(hostname string) bool { // handlerForPath returns the LocalAPI handler for the provided Request.URI.Path. // (the path doesn't include any query parameters) -func handlerForPath(urlPath string) (h localAPIHandler, ok bool) { +func handlerForPath(urlPath string) (h LocalAPIHandler, ok bool) { if urlPath == "/" { return (*Handler).serveLocalAPIRoot, true } @@ -381,6 +397,15 @@ func (h *Handler) serveBugReport(w http.ResponseWriter, r *http.Request) { // OS-specific details h.logf.JSON(1, "UserBugReportOS", osdiag.SupportInfo(osdiag.LogSupportInfoReasonBugReport)) + // Tailnet lock details + st := h.b.NetworkLockStatus() + if st.Enabled { + h.logf.JSON(1, "UserBugReportTailnetLockStatus", st) + if st.NodeKeySignature != nil { + h.logf("user bugreport tailnet lock signature: %s", st.NodeKeySignature.String()) + } + } + if defBool(r.URL.Query().Get("diagnose"), false) { h.b.Doctor(r.Context(), logger.WithPrefix(h.logf, "diag: ")) } @@ -415,6 +440,8 @@ func (h *Handler) serveBugReport(w http.ResponseWriter, r *http.Request) { // NOTE(andrew): if we have anything else we want to do while recording // a bugreport, we can add it here. + metricBugReportRequests.Add(1) + // Read from the client; this will also return when the client closes // the connection. var buf [1]byte @@ -443,6 +470,33 @@ func (h *Handler) serveWhoIs(w http.ResponseWriter, r *http.Request) { h.serveWhoIsWithBackend(w, r, h.b) } +// serveSetDeviceAttrs is (as of 2024-12-30) an experimental LocalAPI handler to +// set device attributes via the control plane. +// +// See tailscale/corp#24690. +func (h *Handler) serveSetDeviceAttrs(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !h.PermitWrite { + http.Error(w, "set-device-attrs access denied", http.StatusForbidden) + return + } + if r.Method != "PATCH" { + http.Error(w, "only PATCH allowed", http.StatusMethodNotAllowed) + return + } + var req map[string]any + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := h.b.SetDeviceAttrs(ctx, req); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, "{}\n") +} + // localBackendWhoIsMethods is the subset of ipn.LocalBackend as needed // by the localapi WhoIs method. type localBackendWhoIsMethods interface { @@ -560,6 +614,7 @@ func (h *Handler) serveLogTap(w http.ResponseWriter, r *http.Request) { } func (h *Handler) serveMetrics(w http.ResponseWriter, r *http.Request) { + metricDebugMetricsCalls.Add(1) // Require write access out of paranoia that the metrics // might contain something sensitive. if !h.PermitWrite { @@ -570,15 +625,10 @@ func (h *Handler) serveMetrics(w http.ResponseWriter, r *http.Request) { clientmetric.WritePrometheusExpositionFormat(w) } -// TODO(kradalby): Remove this once we have landed on a final set of -// metrics to export to clients and consider the metrics stable. -var debugUsermetricsEndpoint = envknob.RegisterBool("TS_DEBUG_USER_METRICS") - +// serveUserMetrics returns user-facing metrics in Prometheus text +// exposition format. func (h *Handler) serveUserMetrics(w http.ResponseWriter, r *http.Request) { - if !testenv.InTest() && !debugUsermetricsEndpoint() { - http.Error(w, "usermetrics debug flag not enabled", http.StatusForbidden) - return - } + metricUserMetricsCalls.Add(1) h.b.UserMetricsRegistry().Handler(w, r) } @@ -635,6 +685,13 @@ func (h *Handler) serveDebug(w http.ResponseWriter, r *http.Request) { } case "pick-new-derp": err = h.b.DebugPickNewDERP() + case "force-prefer-derp": + var n int + err = json.NewDecoder(r.Body).Decode(&n) + if err != nil { + break + } + h.b.DebugForcePreferDERP(n) case "": err = fmt.Errorf("missing parameter 'action'") default: @@ -774,23 +831,31 @@ func (h *Handler) serveDebugPortmap(w http.ResponseWriter, r *http.Request) { done := make(chan bool, 1) var c *portmapper.Client - c = portmapper.NewClient(logger.WithPrefix(logf, "portmapper: "), h.b.NetMon(), debugKnobs, h.b.ControlKnobs(), func() { - logf("portmapping changed.") - logf("have mapping: %v", c.HaveMapping()) - - if ext, ok := c.GetCachedMappingOrStartCreatingOne(); ok { - logf("cb: mapping: %v", ext) - select { - case done <- true: - default: + c = portmapper.NewClient(portmapper.Config{ + Logf: logger.WithPrefix(logf, "portmapper: "), + NetMon: h.b.NetMon(), + DebugKnobs: debugKnobs, + ControlKnobs: h.b.ControlKnobs(), + OnChange: func() { + logf("portmapping changed.") + logf("have mapping: %v", c.HaveMapping()) + + if ext, ok := c.GetCachedMappingOrStartCreatingOne(); ok { + logf("cb: mapping: %v", ext) + select { + case done <- true: + default: + } + return } - return - } - logf("cb: no mapping") + logf("cb: no mapping") + }, }) defer c.Close() - netMon, err := netmon.New(logger.WithPrefix(logf, "monitor: ")) + bus := eventbus.New() + defer bus.Close() + netMon, err := netmon.New(bus, logger.WithPrefix(logf, "monitor: ")) if err != nil { logf("error creating monitor: %v", err) return @@ -956,6 +1021,22 @@ func (h *Handler) servePprof(w http.ResponseWriter, r *http.Request) { servePprofFunc(w, r) } +// disconnectControl is the handler for local API /disconnect-control endpoint that shuts down control client, so that +// node no longer communicates with control. Doing this makes control consider this node inactive. This can be used +// before shutting down a replica of HA subnet router or app connector deployments to ensure that control tells the +// peers to switch over to another replica whilst still maintaining th existing peer connections. +func (h *Handler) disconnectControl(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "use POST", http.StatusMethodNotAllowed) + return + } + h.b.DisconnectControl() +} + func (h *Handler) reloadConfig(w http.ResponseWriter, r *http.Request) { if !h.PermitWrite { http.Error(w, "access denied", http.StatusForbidden) @@ -1018,7 +1099,7 @@ func (h *Handler) serveServeConfig(w http.ResponseWriter, r *http.Request) { } configIn := new(ipn.ServeConfig) if err := json.NewDecoder(r.Body).Decode(configIn); err != nil { - writeErrorJSON(w, fmt.Errorf("decoding config: %w", err)) + WriteErrorJSON(w, fmt.Errorf("decoding config: %w", err)) return } @@ -1036,7 +1117,7 @@ func (h *Handler) serveServeConfig(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusPreconditionFailed) return } - writeErrorJSON(w, fmt.Errorf("updating config: %w", err)) + WriteErrorJSON(w, fmt.Errorf("updating config: %w", err)) return } w.WriteHeader(http.StatusOK) @@ -1047,7 +1128,7 @@ func (h *Handler) serveServeConfig(w http.ResponseWriter, r *http.Request) { func authorizeServeConfigForGOOSAndUserContext(goos string, configIn *ipn.ServeConfig, h *Handler) error { switch goos { - case "windows", "linux", "darwin": + case "windows", "linux", "darwin", "illumos", "solaris": default: return nil } @@ -1067,7 +1148,7 @@ func authorizeServeConfigForGOOSAndUserContext(goos string, configIn *ipn.ServeC switch goos { case "windows": return errors.New("must be a Windows local admin to serve a path") - case "linux", "darwin": + case "linux", "darwin", "illumos", "solaris": return errors.New("must be root, or be an operator and able to run 'sudo tailscale' to serve a path") default: // We filter goos at the start of the func, this default case @@ -1231,7 +1312,7 @@ func (h *Handler) serveWatchIPNBus(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") ctx := r.Context() enc := json.NewEncoder(w) - h.b.WatchNotifications(ctx, mask, f.Flush, func(roNotify *ipn.Notify) (keepGoing bool) { + h.b.WatchNotificationsAs(ctx, h.Actor, mask, f.Flush, func(roNotify *ipn.Notify) (keepGoing bool) { err := enc.Encode(roNotify) if err != nil { h.logf("json.Encode: %v", err) @@ -1251,7 +1332,7 @@ func (h *Handler) serveLoginInteractive(w http.ResponseWriter, r *http.Request) http.Error(w, "want POST", http.StatusBadRequest) return } - h.b.StartLoginInteractive(r.Context()) + h.b.StartLoginInteractiveAs(r.Context(), h.Actor) w.WriteHeader(http.StatusNoContent) return } @@ -1320,7 +1401,7 @@ func (h *Handler) servePrefs(w http.ResponseWriter, r *http.Request) { return } var err error - prefs, err = h.b.EditPrefs(mp) + prefs, err = h.b.EditPrefsAs(mp, h.Actor) if err != nil { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusBadRequest) @@ -1339,426 +1420,93 @@ func (h *Handler) servePrefs(w http.ResponseWriter, r *http.Request) { e.Encode(prefs) } -type resJSON struct { - Error string `json:",omitempty"` -} - -func (h *Handler) serveCheckPrefs(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "checkprefs access denied", http.StatusForbidden) - return - } - if r.Method != "POST" { - http.Error(w, "unsupported method", http.StatusMethodNotAllowed) - return - } - p := new(ipn.Prefs) - if err := json.NewDecoder(r.Body).Decode(p); err != nil { - http.Error(w, "invalid JSON body", http.StatusBadRequest) +func (h *Handler) servePolicy(w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "policy access denied", http.StatusForbidden) return } - err := h.b.CheckPrefs(p) - var res resJSON - if err != nil { - res.Error = err.Error() - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(res) -} -func (h *Handler) serveFiles(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "file access denied", http.StatusForbidden) - return - } - suffix, ok := strings.CutPrefix(r.URL.EscapedPath(), "/localapi/v0/files/") + suffix, ok := strings.CutPrefix(r.URL.EscapedPath(), "/localapi/v0/policy/") if !ok { http.Error(w, "misconfigured", http.StatusInternalServerError) return } + + var scope setting.PolicyScope if suffix == "" { - if r.Method != "GET" { - http.Error(w, "want GET to list files", http.StatusBadRequest) - return - } - ctx := r.Context() - if s := r.FormValue("waitsec"); s != "" && s != "0" { - d, err := strconv.Atoi(s) - if err != nil { - http.Error(w, "invalid waitsec", http.StatusBadRequest) - return - } - deadline := time.Now().Add(time.Duration(d) * time.Second) - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(ctx, deadline) - defer cancel() - } - wfs, err := h.b.AwaitWaitingFiles(ctx) - if err != nil && ctx.Err() == nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(wfs) + scope = setting.DefaultScope() + } else if err := scope.UnmarshalText([]byte(suffix)); err != nil { + http.Error(w, fmt.Sprintf("%q is not a valid scope", suffix), http.StatusBadRequest) return } - name, err := url.PathUnescape(suffix) + + policy, err := rsop.PolicyFor(scope) if err != nil { - http.Error(w, "bad filename", http.StatusBadRequest) + http.Error(w, err.Error(), http.StatusInternalServerError) return } - if r.Method == "DELETE" { - if err := h.b.DeleteFile(name); err != nil { + + var effectivePolicy *setting.Snapshot + switch r.Method { + case "GET": + effectivePolicy = policy.Get() + case "POST": + effectivePolicy, err = policy.Reload() + if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - w.WriteHeader(http.StatusNoContent) - return - } - rc, size, err := h.b.OpenFile(name) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + default: + http.Error(w, "unsupported method", http.StatusMethodNotAllowed) return } - defer rc.Close() - w.Header().Set("Content-Length", fmt.Sprint(size)) - w.Header().Set("Content-Type", "application/octet-stream") - io.Copy(w, rc) -} -func writeErrorJSON(w http.ResponseWriter, err error) { - if err == nil { - err = errors.New("unexpected nil error") - } w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusInternalServerError) - type E struct { - Error string `json:"error"` - } - json.NewEncoder(w).Encode(E{err.Error()}) + e := json.NewEncoder(w) + e.SetIndent("", "\t") + e.Encode(effectivePolicy) } -func (h *Handler) serveFileTargets(w http.ResponseWriter, r *http.Request) { - if !h.PermitRead { - http.Error(w, "access denied", http.StatusForbidden) - return - } - if r.Method != "GET" { - http.Error(w, "want GET to list targets", http.StatusBadRequest) - return - } - fts, err := h.b.FileTargets() - if err != nil { - writeErrorJSON(w, err) - return - } - mak.NonNilSliceForJSON(&fts) - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(fts) +type resJSON struct { + Error string `json:",omitempty"` } -// serveFilePut sends a file to another node. -// -// It's sometimes possible for clients to do this themselves, without -// tailscaled, except in the case of tailscaled running in -// userspace-networking ("netstack") mode, in which case tailscaled -// needs to a do a netstack dial out. -// -// Instead, the CLI also goes through tailscaled so it doesn't need to be -// aware of the network mode in use. -// -// macOS/iOS have always used this localapi method to simplify the GUI -// clients. -// -// The Windows client currently (2021-11-30) uses the peerapi (/v0/put/) -// directly, as the Windows GUI always runs in tun mode anyway. -// -// In addition to single file PUTs, this endpoint accepts multipart file -// POSTS encoded as multipart/form-data.The first part should be an -// application/json file that contains a manifest consisting of a JSON array of -// OutgoingFiles which wecan use for tracking progress even before reading the -// file parts. -// -// URL format: -// -// - PUT /localapi/v0/file-put/:stableID/:escaped-filename -// - POST /localapi/v0/file-put/:stableID -func (h *Handler) serveFilePut(w http.ResponseWriter, r *http.Request) { - metricFilePutCalls.Add(1) - +func (h *Handler) serveCheckPrefs(w http.ResponseWriter, r *http.Request) { if !h.PermitWrite { - http.Error(w, "file access denied", http.StatusForbidden) - return - } - - if r.Method != "PUT" && r.Method != "POST" { - http.Error(w, "want PUT to put file", http.StatusBadRequest) - return - } - - fts, err := h.b.FileTargets() - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - upath, ok := strings.CutPrefix(r.URL.EscapedPath(), "/localapi/v0/file-put/") - if !ok { - http.Error(w, "misconfigured", http.StatusInternalServerError) - return - } - var peerIDStr, filenameEscaped string - if r.Method == "PUT" { - ok := false - peerIDStr, filenameEscaped, ok = strings.Cut(upath, "/") - if !ok { - http.Error(w, "bogus URL", http.StatusBadRequest) - return - } - } else { - peerIDStr = upath - } - peerID := tailcfg.StableNodeID(peerIDStr) - - var ft *apitype.FileTarget - for _, x := range fts { - if x.Node.StableID == peerID { - ft = x - break - } - } - if ft == nil { - http.Error(w, "node not found", http.StatusNotFound) + http.Error(w, "checkprefs access denied", http.StatusForbidden) return } - dstURL, err := url.Parse(ft.PeerAPIURL) - if err != nil { - http.Error(w, "bogus peer URL", http.StatusInternalServerError) + if r.Method != "POST" { + http.Error(w, "unsupported method", http.StatusMethodNotAllowed) return } - - // Periodically report progress of outgoing files. - outgoingFiles := make(map[string]*ipn.OutgoingFile) - t := time.NewTicker(1 * time.Second) - progressUpdates := make(chan ipn.OutgoingFile) - defer close(progressUpdates) - - go func() { - defer t.Stop() - defer h.b.UpdateOutgoingFiles(outgoingFiles) - for { - select { - case u, ok := <-progressUpdates: - if !ok { - return - } - outgoingFiles[u.ID] = &u - case <-t.C: - h.b.UpdateOutgoingFiles(outgoingFiles) - } - } - }() - - switch r.Method { - case "PUT": - file := ipn.OutgoingFile{ - ID: rands.HexString(30), - PeerID: peerID, - Name: filenameEscaped, - DeclaredSize: r.ContentLength, - } - h.singleFilePut(r.Context(), progressUpdates, w, r.Body, dstURL, file) - case "POST": - h.multiFilePost(progressUpdates, w, r, peerID, dstURL) - default: - http.Error(w, "want PUT to put file", http.StatusBadRequest) + p := new(ipn.Prefs) + if err := json.NewDecoder(r.Body).Decode(p); err != nil { + http.Error(w, "invalid JSON body", http.StatusBadRequest) return } -} - -func (h *Handler) multiFilePost(progressUpdates chan (ipn.OutgoingFile), w http.ResponseWriter, r *http.Request, peerID tailcfg.StableNodeID, dstURL *url.URL) { - _, params, err := mime.ParseMediaType(r.Header.Get("Content-Type")) + err := h.b.CheckPrefs(p) + var res resJSON if err != nil { - http.Error(w, fmt.Sprintf("invalid Content-Type for multipart POST: %s", err), http.StatusBadRequest) - return - } - - ww := &multiFilePostResponseWriter{} - defer func() { - if err := ww.Flush(w); err != nil { - h.logf("error: multiFilePostResponseWriter.Flush(): %s", err) - } - }() - - outgoingFilesByName := make(map[string]ipn.OutgoingFile) - first := true - mr := multipart.NewReader(r.Body, params["boundary"]) - for { - part, err := mr.NextPart() - if err == io.EOF { - // No more parts. - return - } else if err != nil { - http.Error(ww, fmt.Sprintf("failed to decode multipart/form-data: %s", err), http.StatusBadRequest) - return - } - - if first { - first = false - if part.Header.Get("Content-Type") != "application/json" { - http.Error(ww, "first MIME part must be a JSON map of filename -> size", http.StatusBadRequest) - return - } - - var manifest []ipn.OutgoingFile - err := json.NewDecoder(part).Decode(&manifest) - if err != nil { - http.Error(ww, fmt.Sprintf("invalid manifest: %s", err), http.StatusBadRequest) - return - } - - for _, file := range manifest { - outgoingFilesByName[file.Name] = file - progressUpdates <- file - } - - continue - } - - if !h.singleFilePut(r.Context(), progressUpdates, ww, part, dstURL, outgoingFilesByName[part.FileName()]) { - return - } - - if ww.statusCode >= 400 { - // put failed, stop immediately - h.logf("error: singleFilePut: failed with status %d", ww.statusCode) - return - } - } -} - -// multiFilePostResponseWriter is a buffering http.ResponseWriter that can be -// reused across multiple singleFilePut calls and then flushed to the client -// when all files have been PUT. -type multiFilePostResponseWriter struct { - header http.Header - statusCode int - body *bytes.Buffer -} - -func (ww *multiFilePostResponseWriter) Header() http.Header { - if ww.header == nil { - ww.header = make(http.Header) - } - return ww.header -} - -func (ww *multiFilePostResponseWriter) WriteHeader(statusCode int) { - ww.statusCode = statusCode -} - -func (ww *multiFilePostResponseWriter) Write(p []byte) (int, error) { - if ww.body == nil { - ww.body = bytes.NewBuffer(nil) - } - return ww.body.Write(p) -} - -func (ww *multiFilePostResponseWriter) Flush(w http.ResponseWriter) error { - if ww.header != nil { - maps.Copy(w.Header(), ww.header) - } - if ww.statusCode > 0 { - w.WriteHeader(ww.statusCode) - } - if ww.body != nil { - _, err := io.Copy(w, ww.body) - return err + res.Error = err.Error() } - return nil + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(res) } -func (h *Handler) singleFilePut( - ctx context.Context, - progressUpdates chan (ipn.OutgoingFile), - w http.ResponseWriter, - body io.Reader, - dstURL *url.URL, - outgoingFile ipn.OutgoingFile, -) bool { - outgoingFile.Started = time.Now() - body = progresstracking.NewReader(body, 1*time.Second, func(n int, err error) { - outgoingFile.Sent = int64(n) - progressUpdates <- outgoingFile - }) - - fail := func() { - outgoingFile.Finished = true - outgoingFile.Succeeded = false - progressUpdates <- outgoingFile - } - - // Before we PUT a file we check to see if there are any existing partial file and if so, - // we resume the upload from where we left off by sending the remaining file instead of - // the full file. - var offset int64 - var resumeDuration time.Duration - remainingBody := io.Reader(body) - client := &http.Client{ - Transport: h.b.Dialer().PeerAPITransport(), - Timeout: 10 * time.Second, - } - req, err := http.NewRequestWithContext(ctx, "GET", dstURL.String()+"/v0/put/"+outgoingFile.Name, nil) - if err != nil { - http.Error(w, "bogus peer URL", http.StatusInternalServerError) - fail() - return false - } - switch resp, err := client.Do(req); { - case err != nil: - h.logf("could not fetch remote hashes: %v", err) - case resp.StatusCode == http.StatusMethodNotAllowed || resp.StatusCode == http.StatusNotFound: - // noop; implies older peerapi without resume support - case resp.StatusCode != http.StatusOK: - h.logf("fetch remote hashes status code: %d", resp.StatusCode) - default: - resumeStart := time.Now() - dec := json.NewDecoder(resp.Body) - offset, remainingBody, err = taildrop.ResumeReader(body, func() (out taildrop.BlockChecksum, err error) { - err = dec.Decode(&out) - return out, err - }) - if err != nil { - h.logf("reader could not be fully resumed: %v", err) - } - resumeDuration = time.Since(resumeStart).Round(time.Millisecond) - } - - outReq, err := http.NewRequestWithContext(ctx, "PUT", "http://peer/v0/put/"+outgoingFile.Name, remainingBody) - if err != nil { - http.Error(w, "bogus outreq", http.StatusInternalServerError) - fail() - return false +// WriteErrorJSON writes a JSON object (with a single "error" string field) to w +// with the given error. If err is nil, "unexpected nil error" is used for the +// stringification instead. +func WriteErrorJSON(w http.ResponseWriter, err error) { + if err == nil { + err = errors.New("unexpected nil error") } - outReq.ContentLength = outgoingFile.DeclaredSize - if offset > 0 { - h.logf("resuming put at offset %d after %v", offset, resumeDuration) - rangeHdr, _ := httphdr.FormatRange([]httphdr.Range{{Start: offset, Length: 0}}) - outReq.Header.Set("Range", rangeHdr) - if outReq.ContentLength >= 0 { - outReq.ContentLength -= offset - } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + type E struct { + Error string `json:"error"` } - - rp := httputil.NewSingleHostReverseProxy(dstURL) - rp.Transport = h.b.Dialer().PeerAPITransport() - rp.ServeHTTP(w, outReq) - - outgoingFile.Finished = true - outgoingFile.Succeeded = true - progressUpdates <- outgoingFile - - return true + json.NewEncoder(w).Encode(E{err.Error()}) } func (h *Handler) serveSetDNS(w http.ResponseWriter, r *http.Request) { @@ -1773,7 +1521,7 @@ func (h *Handler) serveSetDNS(w http.ResponseWriter, r *http.Request) { ctx := r.Context() err := h.b.SetDNS(ctx, r.FormValue("name"), r.FormValue("value")) if err != nil { - writeErrorJSON(w, err) + WriteErrorJSON(w, err) return } w.Header().Set("Content-Type", "application/json") @@ -1864,7 +1612,7 @@ func (h *Handler) servePing(w http.ResponseWriter, r *http.Request) { } res, err := h.b.Ping(ctx, ip, tailcfg.PingType(pingTypeStr), size) if err != nil { - writeErrorJSON(w, err) + WriteErrorJSON(w, err) return } w.Header().Set("Content-Type", "application/json") @@ -2493,8 +2241,8 @@ func (h *Handler) serveProfiles(w http.ResponseWriter, r *http.Request) { switch r.Method { case httpm.GET: profiles := h.b.ListProfiles() - profileIndex := slices.IndexFunc(profiles, func(p ipn.LoginProfile) bool { - return p.ID == profileID + profileIndex := slices.IndexFunc(profiles, func(p ipn.LoginProfileView) bool { + return p.ID() == profileID }) if profileIndex == -1 { http.Error(w, "Profile not found", http.StatusNotFound) @@ -2592,21 +2340,6 @@ func defBool(a string, def bool) bool { return v } -func (h *Handler) serveDebugCapture(w http.ResponseWriter, r *http.Request) { - if !h.PermitWrite { - http.Error(w, "debug access denied", http.StatusForbidden) - return - } - if r.Method != "POST" { - http.Error(w, "POST required", http.StatusMethodNotAllowed) - return - } - - w.WriteHeader(http.StatusOK) - w.(http.Flusher).Flush() - h.b.StreamDebugCapture(r.Context(), w) -} - func (h *Handler) serveDebugLog(w http.ResponseWriter, r *http.Request) { if !h.PermitRead { http.Error(w, "debug-log access denied", http.StatusForbidden) @@ -2908,13 +2641,6 @@ func (h *Handler) serveShares(w http.ResponseWriter, r *http.Request) { } } -var ( - metricInvalidRequests = clientmetric.NewCounter("localapi_invalid_requests") - - // User-visible LocalAPI endpoints. - metricFilePutCalls = clientmetric.NewCounter("localapi_file_put") -) - // serveSuggestExitNode serves a POST endpoint for returning a suggested exit node. func (h *Handler) serveSuggestExitNode(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { @@ -2923,7 +2649,7 @@ func (h *Handler) serveSuggestExitNode(w http.ResponseWriter, r *http.Request) { } res, err := h.b.SuggestExitNode() if err != nil { - writeErrorJSON(w, err) + WriteErrorJSON(w, err) return } w.Header().Set("Content-Type", "application/json") diff --git a/ipn/localapi/localapi_test.go b/ipn/localapi/localapi_test.go index fa54a1e756a7e..970f798d05005 100644 --- a/ipn/localapi/localapi_test.go +++ b/ipn/localapi/localapi_test.go @@ -39,23 +39,6 @@ import ( "tailscale.com/wgengine" ) -var _ ipnauth.Actor = (*testActor)(nil) - -type testActor struct { - uid ipn.WindowsUserID - name string - isLocalSystem bool - isLocalAdmin bool -} - -func (u *testActor) UserID() ipn.WindowsUserID { return u.uid } - -func (u *testActor) Username() (string, error) { return u.name, nil } - -func (u *testActor) IsLocalSystem() bool { return u.isLocalSystem } - -func (u *testActor) IsLocalAdmin(operatorUID string) bool { return u.isLocalAdmin } - func TestValidHost(t *testing.T) { tests := []struct { host string @@ -207,7 +190,7 @@ func TestWhoIsArgTypes(t *testing.T) { func TestShouldDenyServeConfigForGOOSAndUserContext(t *testing.T) { newHandler := func(connIsLocalAdmin bool) *Handler { - return &Handler{Actor: &testActor{isLocalAdmin: connIsLocalAdmin}, b: newTestLocalBackend(t)} + return &Handler{Actor: &ipnauth.TestActor{LocalAdmin: connIsLocalAdmin}, b: newTestLocalBackend(t)} } tests := []struct { name string @@ -254,7 +237,7 @@ func TestShouldDenyServeConfigForGOOSAndUserContext(t *testing.T) { } for _, tt := range tests { - for _, goos := range []string{"linux", "windows", "darwin"} { + for _, goos := range []string{"linux", "windows", "darwin", "illumos", "solaris"} { t.Run(goos+"-"+tt.name, func(t *testing.T) { err := authorizeServeConfigForGOOSAndUserContext(goos, tt.configIn, tt.h) gotErr := err != nil @@ -353,10 +336,10 @@ func TestServeWatchIPNBus(t *testing.T) { func newTestLocalBackend(t testing.TB) *ipnlocal.LocalBackend { var logf logger.Logf = logger.Discard - sys := new(tsd.System) + sys := tsd.NewSystem() store := new(mem.Store) sys.Set(store) - eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) + eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry(), sys.Bus.Get()) if err != nil { t.Fatalf("NewFakeUserspaceEngine: %v", err) } @@ -366,6 +349,7 @@ func newTestLocalBackend(t testing.TB) *ipnlocal.LocalBackend { if err != nil { t.Fatalf("NewLocalBackend: %v", err) } + t.Cleanup(lb.Shutdown) return lb } diff --git a/ipn/prefs.go b/ipn/prefs.go index 5d61f0119cd23..01275a7e25bdc 100644 --- a/ipn/prefs.go +++ b/ipn/prefs.go @@ -29,6 +29,7 @@ import ( "tailscale.com/types/views" "tailscale.com/util/dnsname" "tailscale.com/util/syspolicy" + "tailscale.com/version" ) // DefaultControlURL is the URL base of the control plane @@ -179,6 +180,12 @@ type Prefs struct { // node. AdvertiseRoutes []netip.Prefix + // AdvertiseServices specifies the list of services that this + // node can serve as a destination for. Note that an advertised + // service must still go through the approval process from the + // control server. + AdvertiseServices []string + // NoSNAT specifies whether to source NAT traffic going to // destinations in AdvertiseRoutes. The default is to apply source // NAT, which makes the traffic appear to come from the router @@ -228,6 +235,11 @@ type Prefs struct { // PostureChecking enables the collection of information used for device // posture checks. + // + // Note: this should be named ReportPosture, but it was shipped as + // PostureChecking in some early releases and this JSON field is written to + // disk, so we just keep its old name. (akin to CorpDNS which is an internal + // pref name that doesn't match the public interface) PostureChecking bool // NetfilterKind specifies what netfilter implementation to use. @@ -239,6 +251,14 @@ type Prefs struct { // by name. DriveShares []*drive.Share + // RelayServerPort is the UDP port number for the relay server to bind to, + // on all interfaces. A non-nil zero value signifies a random unused port + // should be used. A nil value signifies relay server functionality + // should be disabled. This field is currently experimental, and therefore + // no guarantees are made about its current naming and functionality when + // non-nil/enabled. + RelayServerPort *int `json:",omitempty"` + // AllowSingleHosts was a legacy field that was always true // for the past 4.5 years. It controlled whether Tailscale // peers got /32 or /127 routes for each other. @@ -319,6 +339,7 @@ type MaskedPrefs struct { ForceDaemonSet bool `json:",omitempty"` EggSet bool `json:",omitempty"` AdvertiseRoutesSet bool `json:",omitempty"` + AdvertiseServicesSet bool `json:",omitempty"` NoSNATSet bool `json:",omitempty"` NoStatefulFilteringSet bool `json:",omitempty"` NetfilterModeSet bool `json:",omitempty"` @@ -329,6 +350,7 @@ type MaskedPrefs struct { PostureCheckingSet bool `json:",omitempty"` NetfilterKindSet bool `json:",omitempty"` DriveSharesSet bool `json:",omitempty"` + RelayServerPortSet bool `json:",omitempty"` } // SetsInternal reports whether mp has any of the Internal*Set field bools set @@ -527,6 +549,9 @@ func (p *Prefs) pretty(goos string) string { if len(p.AdvertiseTags) > 0 { fmt.Fprintf(&sb, "tags=%s ", strings.Join(p.AdvertiseTags, ",")) } + if len(p.AdvertiseServices) > 0 { + fmt.Fprintf(&sb, "services=%s ", strings.Join(p.AdvertiseServices, ",")) + } if goos == "linux" { fmt.Fprintf(&sb, "nf=%v ", p.NetfilterMode) } @@ -544,6 +569,9 @@ func (p *Prefs) pretty(goos string) string { } sb.WriteString(p.AutoUpdate.Pretty()) sb.WriteString(p.AppConnector.Pretty()) + if p.RelayServerPort != nil { + fmt.Fprintf(&sb, "relayServerPort=%d ", *p.RelayServerPort) + } if p.Persist != nil { sb.WriteString(p.Persist.Pretty()) } else { @@ -570,7 +598,7 @@ func (p PrefsView) Equals(p2 PrefsView) bool { } func (p *Prefs) Equals(p2 *Prefs) bool { - if p == nil && p2 == nil { + if p == p2 { return true } if p == nil || p2 == nil { @@ -596,15 +624,17 @@ func (p *Prefs) Equals(p2 *Prefs) bool { p.OperatorUser == p2.OperatorUser && p.Hostname == p2.Hostname && p.ForceDaemon == p2.ForceDaemon && - compareIPNets(p.AdvertiseRoutes, p2.AdvertiseRoutes) && - compareStrings(p.AdvertiseTags, p2.AdvertiseTags) && + slices.Equal(p.AdvertiseRoutes, p2.AdvertiseRoutes) && + slices.Equal(p.AdvertiseTags, p2.AdvertiseTags) && + slices.Equal(p.AdvertiseServices, p2.AdvertiseServices) && p.Persist.Equals(p2.Persist) && p.ProfileName == p2.ProfileName && p.AutoUpdate.Equals(p2.AutoUpdate) && p.AppConnector == p2.AppConnector && p.PostureChecking == p2.PostureChecking && slices.EqualFunc(p.DriveShares, p2.DriveShares, drive.SharesEqual) && - p.NetfilterKind == p2.NetfilterKind + p.NetfilterKind == p2.NetfilterKind && + compareIntPtrs(p.RelayServerPort, p2.RelayServerPort) } func (au AutoUpdatePrefs) Pretty() string { @@ -624,28 +654,14 @@ func (ap AppConnectorPrefs) Pretty() string { return "" } -func compareIPNets(a, b []netip.Prefix) bool { - if len(a) != len(b) { +func compareIntPtrs(a, b *int) bool { + if (a == nil) != (b == nil) { return false } - for i := range a { - if a[i] != b[i] { - return false - } - } - return true -} - -func compareStrings(a, b []string) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] != b[i] { - return false - } + if a == nil { + return true } - return true + return *a == *b } // NewPrefs returns the default preferences to use. @@ -653,7 +669,7 @@ func NewPrefs() *Prefs { // Provide default values for options which might be missing // from the json data for any reason. The json can still // override them to false. - return &Prefs{ + p := &Prefs{ // ControlURL is explicitly not set to signal that // it's not yet configured, which relaxes the CLI "up" // safety net features. It will get set to DefaultControlURL @@ -661,7 +677,6 @@ func NewPrefs() *Prefs { // later anyway. ControlURL: "", - RouteAll: true, CorpDNS: true, WantRunning: false, NetfilterMode: preftype.NetfilterOn, @@ -671,6 +686,8 @@ func NewPrefs() *Prefs { Apply: opt.Bool("unset"), }, } + p.RouteAll = p.DefaultRouteAll(runtime.GOOS) + return p } // ControlURLOrDefault returns the coordination server's URL base. @@ -700,6 +717,20 @@ func (p *Prefs) ControlURLOrDefault() string { return DefaultControlURL } +// DefaultRouteAll returns the default value of [Prefs.RouteAll] as a function +// of the platform it's running on. +func (p *Prefs) DefaultRouteAll(goos string) bool { + switch goos { + case "windows", "android", "ios": + return true + case "darwin": + // Only true for macAppStore and macsys, false for darwin tailscaled. + return version.IsSandboxedMacOS() + default: + return false + } +} + // AdminPageURL returns the admin web site URL for the current ControlURL. func (p PrefsView) AdminPageURL() string { return p.Đļ.AdminPageURL() } @@ -989,3 +1020,26 @@ type LoginProfile struct { // into. ControlURL string } + +// Equals reports whether p and p2 are equal. +func (p LoginProfileView) Equals(p2 LoginProfileView) bool { + return p.Đļ.Equals(p2.Đļ) +} + +// Equals reports whether p and p2 are equal. +func (p *LoginProfile) Equals(p2 *LoginProfile) bool { + if p == p2 { + return true + } + if p == nil || p2 == nil { + return false + } + return p.ID == p2.ID && + p.Name == p2.Name && + p.NetworkProfile == p2.NetworkProfile && + p.Key == p2.Key && + p.UserProfile.Equal(&p2.UserProfile) && + p.NodeID == p2.NodeID && + p.LocalUserID == p2.LocalUserID && + p.ControlURL == p2.ControlURL +} diff --git a/ipn/prefs_test.go b/ipn/prefs_test.go index dcb999ef56a64..d28d161db422e 100644 --- a/ipn/prefs_test.go +++ b/ipn/prefs_test.go @@ -54,6 +54,7 @@ func TestPrefsEqual(t *testing.T) { "ForceDaemon", "Egg", "AdvertiseRoutes", + "AdvertiseServices", "NoSNAT", "NoStatefulFiltering", "NetfilterMode", @@ -64,6 +65,7 @@ func TestPrefsEqual(t *testing.T) { "PostureChecking", "NetfilterKind", "DriveShares", + "RelayServerPort", "AllowSingleHosts", "Persist", } @@ -72,6 +74,9 @@ func TestPrefsEqual(t *testing.T) { have, prefsHandles) } + relayServerPort := func(port int) *int { + return &port + } nets := func(strs ...string) (ns []netip.Prefix) { for _, s := range strs { n, err := netip.ParsePrefix(s) @@ -330,6 +335,26 @@ func TestPrefsEqual(t *testing.T) { &Prefs{NetfilterKind: ""}, false, }, + { + &Prefs{AdvertiseServices: []string{"svc:tux", "svc:xenia"}}, + &Prefs{AdvertiseServices: []string{"svc:tux", "svc:xenia"}}, + true, + }, + { + &Prefs{AdvertiseServices: []string{"svc:tux", "svc:xenia"}}, + &Prefs{AdvertiseServices: []string{"svc:tux", "svc:amelie"}}, + false, + }, + { + &Prefs{RelayServerPort: relayServerPort(0)}, + &Prefs{RelayServerPort: nil}, + false, + }, + { + &Prefs{RelayServerPort: relayServerPort(0)}, + &Prefs{RelayServerPort: relayServerPort(1)}, + false, + }, } for i, tt := range tests { got := tt.a.Equals(tt.b) @@ -456,13 +481,6 @@ func TestPrefsPretty(t *testing.T) { "darwin", `Prefs{ra=false dns=false want=true tags=tag:foo,tag:bar url="http://localhost:1234" update=off Persist=nil}`, }, - { - Prefs{ - Persist: &persist.Persist{}, - }, - "linux", - `Prefs{ra=false dns=false want=false routes=[] nf=off update=off Persist{lm=, o=, n= u=""}}`, - }, { Prefs{ Persist: &persist.Persist{ @@ -470,7 +488,7 @@ func TestPrefsPretty(t *testing.T) { }, }, "linux", - `Prefs{ra=false dns=false want=false routes=[] nf=off update=off Persist{lm=, o=, n=[B1VKl] u=""}}`, + `Prefs{ra=false dns=false want=false routes=[] nf=off update=off Persist{o=, n=[B1VKl] u=""}}`, }, { Prefs{ diff --git a/ipn/serve.go b/ipn/serve.go index 5c0a97ed3ffa9..ac92287bdc08f 100644 --- a/ipn/serve.go +++ b/ipn/serve.go @@ -6,6 +6,7 @@ package ipn import ( "errors" "fmt" + "iter" "net" "net/netip" "net/url" @@ -15,7 +16,9 @@ import ( "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" + "tailscale.com/types/ipproto" "tailscale.com/util/mak" + "tailscale.com/util/set" ) // ServeConfigKey returns a StateKey that stores the @@ -24,6 +27,23 @@ func ServeConfigKey(profileID ProfileID) StateKey { return StateKey("_serve/" + profileID) } +// ServiceConfig contains the config information for a single service. +// it contains a bool to indicate if the service is in Tun mode (L3 forwarding). +// If the service is not in Tun mode, the service is configured by the L4 forwarding +// (TCP ports) and/or the L7 forwarding (http handlers) information. +type ServiceConfig struct { + // TCP are the list of TCP port numbers that tailscaled should handle for + // the Tailscale IP addresses. (not subnet routers, etc) + TCP map[uint16]*TCPPortHandler `json:",omitempty"` + + // Web maps from "$SNI_NAME:$PORT" to a set of HTTP handlers + // keyed by mount point ("/", "/foo", etc) + Web map[HostPort]*WebServerConfig `json:",omitempty"` + + // Tun determines if the service should be using L3 forwarding (Tun mode). + Tun bool `json:",omitempty"` +} + // ServeConfig is the JSON type stored in the StateStore for // StateKey "_serve/$PROFILE_ID" as returned by ServeConfigKey. type ServeConfig struct { @@ -35,16 +55,20 @@ type ServeConfig struct { // keyed by mount point ("/", "/foo", etc) Web map[HostPort]*WebServerConfig `json:",omitempty"` + // Services maps from service name (in the form "svc:dns-label") to a ServiceConfig. + // Which describes the L3, L4, and L7 forwarding information for the service. + Services map[tailcfg.ServiceName]*ServiceConfig `json:",omitempty"` + // AllowFunnel is the set of SNI:port values for which funnel // traffic is allowed, from trusted ingress peers. AllowFunnel map[HostPort]bool `json:",omitempty"` - // Foreground is a map of an IPN Bus session ID to an alternate foreground - // serve config that's valid for the life of that WatchIPNBus session ID. - // This. This allows the config to specify ephemeral configs that are - // used in the CLI's foreground mode to ensure ungraceful shutdowns - // of either the client or the LocalBackend does not expose ports - // that users are not aware of. + // Foreground is a map of an IPN Bus session ID to an alternate foreground serve config that's valid for the + // life of that WatchIPNBus session ID. This allows the config to specify ephemeral configs that are used + // in the CLI's foreground mode to ensure ungraceful shutdowns of either the client or the LocalBackend does not + // expose ports that users are not aware of. In practice this contains any serve config set via 'tailscale + // serve' command run without the '--bg' flag. ServeConfig contained by Foreground is not expected itself to contain + // another Foreground block. Foreground map[string]*ServeConfig `json:",omitempty"` // ETag is the checksum of the serve config that's populated @@ -365,8 +389,7 @@ func (sc *ServeConfig) RemoveTCPForwarding(port uint16) { // View version of ServeConfig.IsFunnelOn. func (v ServeConfigView) IsFunnelOn() bool { return v.Đļ.IsFunnelOn() } -// IsFunnelOn reports whether if ServeConfig is currently allowing funnel -// traffic for any host:port. +// IsFunnelOn reports whether any funnel endpoint is currently enabled for this node. func (sc *ServeConfig) IsFunnelOn() bool { if sc == nil { return false @@ -376,6 +399,11 @@ func (sc *ServeConfig) IsFunnelOn() bool { return true } } + for _, conf := range sc.Foreground { + if conf.IsFunnelOn() { + return true + } + } return false } @@ -543,58 +571,78 @@ func ExpandProxyTargetValue(target string, supportedSchemes []string, defaultSch return u.String(), nil } -// RangeOverTCPs ranges over both background and foreground TCPs. -// If the returned bool from the given f is false, then this function stops -// iterating immediately and does not check other foreground configs. -func (v ServeConfigView) RangeOverTCPs(f func(port uint16, _ TCPPortHandlerView) bool) { - parentCont := true - v.TCP().Range(func(k uint16, v TCPPortHandlerView) (cont bool) { - parentCont = f(k, v) - return parentCont - }) - v.Foreground().Range(func(k string, v ServeConfigView) (cont bool) { - if !parentCont { - return false - } - v.TCP().Range(func(k uint16, v TCPPortHandlerView) (cont bool) { - parentCont = f(k, v) - return parentCont - }) - return parentCont - }) -} - -// RangeOverWebs ranges over both background and foreground Webs. -// If the returned bool from the given f is false, then this function stops -// iterating immediately and does not check other foreground configs. -func (v ServeConfigView) RangeOverWebs(f func(_ HostPort, conf WebServerConfigView) bool) { - parentCont := true - v.Web().Range(func(k HostPort, v WebServerConfigView) (cont bool) { - parentCont = f(k, v) - return parentCont - }) - v.Foreground().Range(func(k string, v ServeConfigView) (cont bool) { - if !parentCont { - return false - } - v.Web().Range(func(k HostPort, v WebServerConfigView) (cont bool) { - parentCont = f(k, v) - return parentCont - }) - return parentCont - }) +// TCPs returns an iterator over both background and foreground TCP +// listeners. +// +// The key is the port number. +func (v ServeConfigView) TCPs() iter.Seq2[uint16, TCPPortHandlerView] { + return func(yield func(uint16, TCPPortHandlerView) bool) { + for k, v := range v.TCP().All() { + if !yield(k, v) { + return + } + } + for _, conf := range v.Foreground().All() { + for k, v := range conf.TCP().All() { + if !yield(k, v) { + return + } + } + } + } +} + +// Webs returns an iterator over both background and foreground Web configurations. +func (v ServeConfigView) Webs() iter.Seq2[HostPort, WebServerConfigView] { + return func(yield func(HostPort, WebServerConfigView) bool) { + for k, v := range v.Web().All() { + if !yield(k, v) { + return + } + } + for _, conf := range v.Foreground().All() { + for k, v := range conf.Web().All() { + if !yield(k, v) { + return + } + } + } + for _, service := range v.Services().All() { + for k, v := range service.Web().All() { + if !yield(k, v) { + return + } + } + } + } +} + +// FindServiceTCP return the TCPPortHandlerView for the given service name and port. +func (v ServeConfigView) FindServiceTCP(svcName tailcfg.ServiceName, port uint16) (res TCPPortHandlerView, ok bool) { + svcCfg, ok := v.Services().GetOk(svcName) + if !ok { + return res, ok + } + return svcCfg.TCP().GetOk(port) +} + +func (v ServeConfigView) FindServiceWeb(svcName tailcfg.ServiceName, hp HostPort) (res WebServerConfigView, ok bool) { + if svcCfg, ok := v.Services().GetOk(svcName); ok { + if res, ok := svcCfg.Web().GetOk(hp); ok { + return res, ok + } + } + return res, ok } // FindTCP returns the first TCP that matches with the given port. It // prefers a foreground match first followed by a background search if none // existed. func (v ServeConfigView) FindTCP(port uint16) (res TCPPortHandlerView, ok bool) { - v.Foreground().Range(func(_ string, v ServeConfigView) (cont bool) { - res, ok = v.TCP().GetOk(port) - return !ok - }) - if ok { - return res, ok + for _, conf := range v.Foreground().All() { + if res, ok := conf.TCP().GetOk(port); ok { + return res, ok + } } return v.TCP().GetOk(port) } @@ -603,12 +651,10 @@ func (v ServeConfigView) FindTCP(port uint16) (res TCPPortHandlerView, ok bool) // prefers a foreground match first followed by a background search if none // existed. func (v ServeConfigView) FindWeb(hp HostPort) (res WebServerConfigView, ok bool) { - v.Foreground().Range(func(_ string, v ServeConfigView) (cont bool) { - res, ok = v.Web().GetOk(hp) - return !ok - }) - if ok { - return res, ok + for _, conf := range v.Foreground().All() { + if res, ok := conf.Web().GetOk(hp); ok { + return res, ok + } } return v.Web().GetOk(hp) } @@ -616,14 +662,15 @@ func (v ServeConfigView) FindWeb(hp HostPort) (res WebServerConfigView, ok bool) // HasAllowFunnel returns whether this config has at least one AllowFunnel // set in the background or foreground configs. func (v ServeConfigView) HasAllowFunnel() bool { - return v.AllowFunnel().Len() > 0 || func() bool { - var exists bool - v.Foreground().Range(func(k string, v ServeConfigView) (cont bool) { - exists = v.AllowFunnel().Len() > 0 - return !exists - }) - return exists - }() + if v.AllowFunnel().Len() > 0 { + return true + } + for _, conf := range v.Foreground().All() { + if conf.AllowFunnel().Len() > 0 { + return true + } + } + return false } // FindFunnel reports whether target exists in either the background AllowFunnel @@ -632,12 +679,73 @@ func (v ServeConfigView) HasFunnelForTarget(target HostPort) bool { if v.AllowFunnel().Get(target) { return true } - var exists bool - v.Foreground().Range(func(_ string, v ServeConfigView) (cont bool) { - if exists = v.AllowFunnel().Get(target); exists { - return false + for _, conf := range v.Foreground().All() { + if conf.AllowFunnel().Get(target) { + return true } - return true - }) - return exists + } + return false +} + +// CheckValidServicesConfig reports whether the ServeConfig has +// invalid service configurations. +func (sc *ServeConfig) CheckValidServicesConfig() error { + for svcName, service := range sc.Services { + if err := service.checkValidConfig(); err != nil { + return fmt.Errorf("invalid service configuration for %q: %w", svcName, err) + } + } + return nil +} + +// ServicePortRange returns the list of tailcfg.ProtoPortRange that represents +// the proto/ports pairs that are being served by the service. +// +// Right now Tun mode is the only thing supports UDP, otherwise serve only supports TCP. +func (v ServiceConfigView) ServicePortRange() []tailcfg.ProtoPortRange { + if v.Tun() { + // If the service is in Tun mode, means service accept TCP/UDP on all ports. + return []tailcfg.ProtoPortRange{{Ports: tailcfg.PortRangeAny}} + } + tcp := int(ipproto.TCP) + + // Deduplicate the ports. + servePorts := make(set.Set[uint16]) + for port := range v.TCP().All() { + if port > 0 { + servePorts.Add(uint16(port)) + } + } + dedupedServePorts := servePorts.Slice() + slices.Sort(dedupedServePorts) + + var ranges []tailcfg.ProtoPortRange + for _, p := range dedupedServePorts { + if n := len(ranges); n > 0 && p == ranges[n-1].Ports.Last+1 { + ranges[n-1].Ports.Last = p + continue + } + ranges = append(ranges, tailcfg.ProtoPortRange{ + Proto: tcp, + Ports: tailcfg.PortRange{ + First: p, + Last: p, + }, + }) + } + return ranges +} + +// ErrServiceConfigHasBothTCPAndTun signals that a service +// in Tun mode cannot also has TCP or Web handlers set. +var ErrServiceConfigHasBothTCPAndTun = errors.New("the VIP Service configuration can not set TUN at the same time as TCP or Web") + +// checkValidConfig checks if the service configuration is valid. +// Currently, the only invalid configuration is when the service is in Tun mode +// and has TCP or Web handlers. +func (v *ServiceConfig) checkValidConfig() error { + if v.Tun && (len(v.TCP) > 0 || len(v.Web) > 0) { + return ErrServiceConfigHasBothTCPAndTun + } + return nil } diff --git a/ipn/serve_test.go b/ipn/serve_test.go index e9d8e8f322075..ae1d56eef6b09 100644 --- a/ipn/serve_test.go +++ b/ipn/serve_test.go @@ -182,3 +182,88 @@ func TestExpandProxyTargetDev(t *testing.T) { }) } } + +func TestIsFunnelOn(t *testing.T) { + tests := []struct { + name string + sc *ServeConfig + want bool + }{ + { + name: "nil_config", + }, + { + name: "empty_config", + sc: &ServeConfig{}, + }, + { + name: "funnel_enabled_in_background", + sc: &ServeConfig{ + AllowFunnel: map[HostPort]bool{ + "tailnet.xyz:443": true, + }, + }, + want: true, + }, + { + name: "funnel_disabled_in_background", + sc: &ServeConfig{ + AllowFunnel: map[HostPort]bool{ + "tailnet.xyz:443": false, + }, + }, + }, + { + name: "funnel_enabled_in_foreground", + sc: &ServeConfig{ + Foreground: map[string]*ServeConfig{ + "abc123": { + AllowFunnel: map[HostPort]bool{ + "tailnet.xyz:443": true, + }, + }, + }, + }, + want: true, + }, + { + name: "funnel_disabled_in_both", + sc: &ServeConfig{ + AllowFunnel: map[HostPort]bool{ + "tailnet.xyz:443": false, + }, + Foreground: map[string]*ServeConfig{ + "abc123": { + AllowFunnel: map[HostPort]bool{ + "tailnet.xyz:8443": false, + }, + }, + }, + }, + }, + { + name: "funnel_enabled_in_both", + sc: &ServeConfig{ + AllowFunnel: map[HostPort]bool{ + "tailnet.xyz:443": true, + }, + Foreground: map[string]*ServeConfig{ + "abc123": { + AllowFunnel: map[HostPort]bool{ + "tailnet.xyz:8443": true, + }, + }, + }, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.sc.IsFunnelOn(); got != tt.want { + t.Errorf("ServeConfig.IsFunnelOn() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/ipn/store/awsstore/store_aws.go b/ipn/store/awsstore/store_aws.go index 0fb78d45a6a53..40bbbf0370822 100644 --- a/ipn/store/awsstore/store_aws.go +++ b/ipn/store/awsstore/store_aws.go @@ -10,7 +10,9 @@ import ( "context" "errors" "fmt" + "net/url" "regexp" + "strings" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/aws/arn" @@ -28,6 +30,14 @@ const ( var parameterNameRx = regexp.MustCompile(parameterNameRxStr) +// Option defines a functional option type for configuring awsStore. +type Option func(*storeOptions) + +// storeOptions holds optional settings for creating a new awsStore. +type storeOptions struct { + kmsKey string +} + // awsSSMClient is an interface allowing us to mock the couple of // API calls we are leveraging with the AWSStore provider type awsSSMClient interface { @@ -46,6 +56,10 @@ type awsStore struct { ssmClient awsSSMClient ssmARN arn.ARN + // kmsKey is optional. If empty, the parameter is stored in plaintext. + // If non-empty, the parameter is encrypted with this KMS key. + kmsKey string + memory mem.Store } @@ -57,30 +71,80 @@ type awsStore struct { // Tailscaled to only only store new state in-memory and // restarting Tailscaled can fail until you delete your state // from the AWS Parameter Store. -func New(_ logger.Logf, ssmARN string) (ipn.StateStore, error) { - return newStore(ssmARN, nil) +// +// If you want to specify an optional KMS key, +// pass one or more Option objects, e.g. awsstore.WithKeyID("alias/my-key"). +func New(_ logger.Logf, ssmARN string, opts ...Option) (ipn.StateStore, error) { + // Apply all options to an empty storeOptions + var so storeOptions + for _, opt := range opts { + opt(&so) + } + + return newStore(ssmARN, so, nil) +} + +// WithKeyID sets the KMS key to be used for encryption. It can be +// a KeyID, an alias ("alias/my-key"), or a full ARN. +// +// If kmsKey is empty, the Option is a no-op. +func WithKeyID(kmsKey string) Option { + return func(o *storeOptions) { + o.kmsKey = kmsKey + } +} + +// ParseARNAndOpts parses an ARN and optional URL-encoded parameters +// from arg. +func ParseARNAndOpts(arg string) (ssmARN string, opts []Option, err error) { + ssmARN = arg + + // Support optional ?url-encoded-parameters. + if s, q, ok := strings.Cut(arg, "?"); ok { + ssmARN = s + q, err := url.ParseQuery(q) + if err != nil { + return "", nil, err + } + + for k := range q { + switch k { + default: + return "", nil, fmt.Errorf("unknown arn option parameter %q", k) + case "kmsKey": + // We allow an ARN, a key ID, or an alias name for kmsKeyID. + // If it doesn't look like an ARN and doesn't have a '/', + // prepend "alias/" for KMS alias references. + kmsKey := q.Get(k) + if kmsKey != "" && + !strings.Contains(kmsKey, "/") && + !strings.HasPrefix(kmsKey, "arn:") { + kmsKey = "alias/" + kmsKey + } + if kmsKey != "" { + opts = append(opts, WithKeyID(kmsKey)) + } + } + } + } + return ssmARN, opts, nil } // newStore is NewStore, but for tests. If client is non-nil, it's // used instead of making one. -func newStore(ssmARN string, client awsSSMClient) (ipn.StateStore, error) { +func newStore(ssmARN string, so storeOptions, client awsSSMClient) (ipn.StateStore, error) { s := &awsStore{ ssmClient: client, + kmsKey: so.kmsKey, } var err error - - // Parse the ARN if s.ssmARN, err = arn.Parse(ssmARN); err != nil { return nil, fmt.Errorf("unable to parse the ARN correctly: %v", err) } - - // Validate the ARN corresponds to the SSM service if s.ssmARN.Service != "ssm" { return nil, fmt.Errorf("invalid service %q, expected 'ssm'", s.ssmARN.Service) } - - // Validate the ARN corresponds to a parameter store resource if !parameterNameRx.MatchString(s.ssmARN.Resource) { return nil, fmt.Errorf("invalid resource %q, expected to match %v", s.ssmARN.Resource, parameterNameRxStr) } @@ -96,12 +160,11 @@ func newStore(ssmARN string, client awsSSMClient) (ipn.StateStore, error) { s.ssmClient = ssm.NewFromConfig(cfg) } - // Hydrate cache with the potentially current state + // Preload existing state, if any if err := s.LoadState(); err != nil { return nil, err } return s, nil - } // LoadState attempts to read the state from AWS SSM parameter store key. @@ -172,15 +235,21 @@ func (s *awsStore) persistState() error { // which is free. However, if it exceeds 4kb it switches the parameter to advanced tiering // doubling the capacity to 8kb per the following docs: // https://aws.amazon.com/about-aws/whats-new/2019/08/aws-systems-manager-parameter-store-announces-intelligent-tiering-to-enable-automatic-parameter-tier-selection/ - _, err = s.ssmClient.PutParameter( - context.TODO(), - &ssm.PutParameterInput{ - Name: aws.String(s.ParameterName()), - Value: aws.String(string(bs)), - Overwrite: aws.Bool(true), - Tier: ssmTypes.ParameterTierIntelligentTiering, - Type: ssmTypes.ParameterTypeSecureString, - }, - ) + in := &ssm.PutParameterInput{ + Name: aws.String(s.ParameterName()), + Value: aws.String(string(bs)), + Overwrite: aws.Bool(true), + Tier: ssmTypes.ParameterTierIntelligentTiering, + Type: ssmTypes.ParameterTypeSecureString, + } + + // If kmsKey is specified, encrypt with that key + // NOTE: this input allows any alias, keyID or ARN + // If this isn't specified, AWS will use the default KMS key + if s.kmsKey != "" { + in.KeyId = aws.String(s.kmsKey) + } + + _, err = s.ssmClient.PutParameter(context.TODO(), in) return err } diff --git a/ipn/store/awsstore/store_aws_stub.go b/ipn/store/awsstore/store_aws_stub.go deleted file mode 100644 index 8d2156ce948d5..0000000000000 --- a/ipn/store/awsstore/store_aws_stub.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux || ts_omit_aws - -package awsstore - -import ( - "fmt" - "runtime" - - "tailscale.com/ipn" - "tailscale.com/types/logger" -) - -func New(logger.Logf, string) (ipn.StateStore, error) { - return nil, fmt.Errorf("AWS store is not supported on %v", runtime.GOOS) -} diff --git a/ipn/store/awsstore/store_aws_test.go b/ipn/store/awsstore/store_aws_test.go index f6c8fedb32dc9..3382635a7d333 100644 --- a/ipn/store/awsstore/store_aws_test.go +++ b/ipn/store/awsstore/store_aws_test.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux +//go:build linux && !ts_omit_aws package awsstore @@ -65,7 +65,11 @@ func TestNewAWSStore(t *testing.T) { Resource: "parameter/foo", } - s, err := newStore(storeParameterARN.String(), mc) + opts := storeOptions{ + kmsKey: "arn:aws:kms:eu-west-1:123456789:key/MyCustomKey", + } + + s, err := newStore(storeParameterARN.String(), opts, mc) if err != nil { t.Fatalf("creating aws store failed: %v", err) } @@ -73,7 +77,7 @@ func TestNewAWSStore(t *testing.T) { // Build a brand new file store and check that both IDs written // above are still there. - s2, err := newStore(storeParameterARN.String(), mc) + s2, err := newStore(storeParameterARN.String(), opts, mc) if err != nil { t.Fatalf("creating second aws store failed: %v", err) } @@ -162,3 +166,54 @@ func testStoreSemantics(t *testing.T, store ipn.StateStore) { } } } + +func TestParseARNAndOpts(t *testing.T) { + tests := []struct { + name string + arg string + wantARN string + wantKey string + }{ + { + name: "no-key", + arg: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam", + wantARN: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam", + }, + { + name: "custom-key", + arg: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam?kmsKey=alias/MyCustomKey", + wantARN: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam", + wantKey: "alias/MyCustomKey", + }, + { + name: "bare-name", + arg: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam?kmsKey=Bare", + wantARN: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam", + wantKey: "alias/Bare", + }, + { + name: "arn-arg", + arg: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam?kmsKey=arn:foo", + wantARN: "arn:aws:ssm:us-east-1:123456789012:parameter/myTailscaleParam", + wantKey: "arn:foo", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + arn, opts, err := ParseARNAndOpts(tt.arg) + if err != nil { + t.Fatalf("New: %v", err) + } + if arn != tt.wantARN { + t.Errorf("ARN = %q; want %q", arn, tt.wantARN) + } + var got storeOptions + for _, opt := range opts { + opt(&got) + } + if got.kmsKey != tt.wantKey { + t.Errorf("kmsKey = %q; want %q", got.kmsKey, tt.wantKey) + } + }) + } +} diff --git a/ipn/store/kubestore/store_kube.go b/ipn/store/kubestore/store_kube.go index 00950bd3b2394..14025bbb4150a 100644 --- a/ipn/store/kubestore/store_kube.go +++ b/ipn/store/kubestore/store_kube.go @@ -7,27 +7,66 @@ package kubestore import ( "context" "fmt" + "log" "net" "os" "strings" "time" + "tailscale.com/envknob" "tailscale.com/ipn" + "tailscale.com/ipn/store/mem" "tailscale.com/kube/kubeapi" "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" "tailscale.com/types/logger" + "tailscale.com/util/dnsname" + "tailscale.com/util/mak" +) + +const ( + // timeout is the timeout for a single state update that includes calls to the API server to write or read a + // state Secret and emit an Event. + timeout = 30 * time.Second + + reasonTailscaleStateUpdated = "TailscaledStateUpdated" + reasonTailscaleStateLoaded = "TailscaleStateLoaded" + reasonTailscaleStateUpdateFailed = "TailscaleStateUpdateFailed" + reasonTailscaleStateLoadFailed = "TailscaleStateLoadFailed" + eventTypeWarning = "Warning" + eventTypeNormal = "Normal" + + keyTLSCert = "tls.crt" + keyTLSKey = "tls.key" ) // Store is an ipn.StateStore that uses a Kubernetes Secret for persistence. type Store struct { - client kubeclient.Client - canPatch bool - secretName string + client kubeclient.Client + canPatch bool + secretName string // state Secret + certShareMode string // 'ro', 'rw', or empty + podName string + + // memory holds the latest tailscale state. Writes write state to a kube + // Secret and memory, Reads read from memory. + memory mem.Store } -// New returns a new Store that persists to the named secret. -func New(_ logger.Logf, secretName string) (*Store, error) { - c, err := kubeclient.New() +// New returns a new Store that persists state to Kubernets Secret(s). +// Tailscale state is stored in a Secret named by the secretName parameter. +// TLS certs are stored and retrieved from state Secret or separate Secrets +// named after TLS endpoints if running in cert share mode. +func New(logf logger.Logf, secretName string) (*Store, error) { + c, err := newClient() + if err != nil { + return nil, err + } + return newWithClient(logf, c, secretName) +} + +func newClient() (kubeclient.Client, error) { + c, err := kubeclient.New("tailscale-state-store") if err != nil { return nil, err } @@ -35,15 +74,43 @@ func New(_ logger.Logf, secretName string) (*Store, error) { // Derive the API server address from the environment variables c.SetURL(fmt.Sprintf("https://%s:%s", os.Getenv("KUBERNETES_SERVICE_HOST"), os.Getenv("KUBERNETES_SERVICE_PORT_HTTPS"))) } + return c, nil +} + +func newWithClient(logf logger.Logf, c kubeclient.Client, secretName string) (*Store, error) { canPatch, _, err := c.CheckSecretPermissions(context.Background(), secretName) if err != nil { return nil, err } - return &Store{ + s := &Store{ client: c, canPatch: canPatch, secretName: secretName, - }, nil + podName: os.Getenv("POD_NAME"), + } + if envknob.IsCertShareReadWriteMode() { + s.certShareMode = "rw" + } else if envknob.IsCertShareReadOnlyMode() { + s.certShareMode = "ro" + } + + // Load latest state from kube Secret if it already exists. + if err := s.loadState(); err != nil && err != ipn.ErrStateNotExist { + return nil, fmt.Errorf("error loading state from kube Secret: %w", err) + } + // If we are in cert share mode, pre-load existing shared certs. + if s.certShareMode == "rw" || s.certShareMode == "ro" { + sel := s.certSecretSelector() + if err := s.loadCerts(context.Background(), sel); err != nil { + // We will attempt to again retrieve the certs from Secrets when a request for an HTTPS endpoint + // is received. + log.Printf("[unexpected] error loading TLS certs: %v", err) + } + } + if s.certShareMode == "ro" { + go s.runCertReload(context.Background(), logf) + } + return s, nil } func (s *Store) SetDialer(d func(ctx context.Context, network, address string) (net.Conn, error)) { @@ -54,86 +121,310 @@ func (s *Store) String() string { return "kube.Store" } // ReadState implements the StateStore interface. func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + return s.memory.ReadState(ipn.StateKey(sanitizeKey(id))) +} - secret, err := s.client.GetSecret(ctx, s.secretName) - if err != nil { - if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 { - return nil, ipn.ErrStateNotExist +// WriteState implements the StateStore interface. +func (s *Store) WriteState(id ipn.StateKey, bs []byte) (err error) { + defer func() { + if err == nil { + s.memory.WriteState(ipn.StateKey(sanitizeKey(id)), bs) } - return nil, err + }() + return s.updateSecret(map[string][]byte{string(id): bs}, s.secretName) +} + +// WriteTLSCertAndKey writes a TLS cert and key to domain.crt, domain.key fields +// of a Tailscale Kubernetes node's state Secret. +func (s *Store) WriteTLSCertAndKey(domain string, cert, key []byte) (err error) { + if s.certShareMode == "ro" { + log.Printf("[unexpected] TLS cert and key write in read-only mode") + } + if err := dnsname.ValidHostname(domain); err != nil { + return fmt.Errorf("invalid domain name %q: %w", domain, err) + } + secretName := s.secretName + data := map[string][]byte{ + domain + ".crt": cert, + domain + ".key": key, + } + // If we run in cert share mode, cert and key for a DNS name are written + // to a separate Secret. + if s.certShareMode == "rw" { + secretName = domain + data = map[string][]byte{ + keyTLSCert: cert, + keyTLSKey: key, + } + } + if err := s.updateSecret(data, secretName); err != nil { + return fmt.Errorf("error writing TLS cert and key to Secret: %w", err) } - b, ok := secret.Data[sanitizeKey(id)] - if !ok { - return nil, ipn.ErrStateNotExist + // TODO(irbekrm): certs for write replicas are currently not + // written to memory to avoid out of sync memory state after + // Ingress resources have been recreated. This means that TLS + // certs for write replicas are retrieved from the Secret on + // each HTTPS request. This is a temporary solution till we + // implement a Secret watch. + if s.certShareMode != "rw" { + s.memory.WriteState(ipn.StateKey(domain+".crt"), cert) + s.memory.WriteState(ipn.StateKey(domain+".key"), key) } - return b, nil + return nil } -func sanitizeKey(k ipn.StateKey) string { - // The only valid characters in a Kubernetes secret key are alphanumeric, -, - // _, and . - return strings.Map(func(r rune) rune { - if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-' || r == '_' || r == '.' { - return r +// ReadTLSCertAndKey reads a TLS cert and key from memory or from a +// domain-specific Secret. It first checks the in-memory store, if not found in +// memory and running cert store in read-only mode, looks up a Secret. +// Note that write replicas of HA Ingress always retrieve TLS certs from Secrets. +func (s *Store) ReadTLSCertAndKey(domain string) (cert, key []byte, err error) { + if err := dnsname.ValidHostname(domain); err != nil { + return nil, nil, fmt.Errorf("invalid domain name %q: %w", domain, err) + } + certKey := domain + ".crt" + keyKey := domain + ".key" + cert, err = s.memory.ReadState(ipn.StateKey(certKey)) + if err == nil { + key, err = s.memory.ReadState(ipn.StateKey(keyKey)) + if err == nil { + return cert, key, nil } - return '_' - }, string(k)) -} + } + if s.certShareMode == "" { + return nil, nil, ipn.ErrStateNotExist + } -// WriteState implements the StateStore interface. -func (s *Store) WriteState(id ipn.StateKey, bs []byte) error { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - - secret, err := s.client.GetSecret(ctx, s.secretName) + secret, err := s.client.GetSecret(ctx, domain) if err != nil { if kubeclient.IsNotFoundErr(err) { + // TODO(irbekrm): we should return a more specific error + // that wraps ipn.ErrStateNotExist here. + return nil, nil, ipn.ErrStateNotExist + } + return nil, nil, fmt.Errorf("getting TLS Secret %q: %w", domain, err) + } + cert = secret.Data[keyTLSCert] + key = secret.Data[keyTLSKey] + if len(cert) == 0 || len(key) == 0 { + return nil, nil, ipn.ErrStateNotExist + } + // TODO(irbekrm): a read between these two separate writes would + // get a mismatched cert and key. Allow writing both cert and + // key to the memory store in a single, lock-protected operation. + // + // TODO(irbekrm): currently certs for write replicas of HA Ingress get + // retrieved from the cluster Secret on each HTTPS request to avoid a + // situation when after Ingress recreation stale certs are read from + // memory. + // Fix this by watching Secrets to ensure that memory store gets updated + // when Secrets are deleted. + if s.certShareMode == "ro" { + s.memory.WriteState(ipn.StateKey(certKey), cert) + s.memory.WriteState(ipn.StateKey(keyKey), key) + } + return cert, key, nil +} + +func (s *Store) updateSecret(data map[string][]byte, secretName string) (err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer func() { + if err != nil { + if err := s.client.Event(ctx, eventTypeWarning, reasonTailscaleStateUpdateFailed, err.Error()); err != nil { + log.Printf("kubestore: error creating tailscaled state update Event: %v", err) + } + } else { + if err := s.client.Event(ctx, eventTypeNormal, reasonTailscaleStateUpdated, "Successfully updated tailscaled state Secret"); err != nil { + log.Printf("kubestore: error creating tailscaled state Event: %v", err) + } + } + cancel() + }() + secret, err := s.client.GetSecret(ctx, secretName) + if err != nil { + // If the Secret does not exist, create it with the required data. + if kubeclient.IsNotFoundErr(err) && s.canCreateSecret(secretName) { return s.client.CreateSecret(ctx, &kubeapi.Secret{ TypeMeta: kubeapi.TypeMeta{ APIVersion: "v1", Kind: "Secret", }, ObjectMeta: kubeapi.ObjectMeta{ - Name: s.secretName, - }, - Data: map[string][]byte{ - sanitizeKey(id): bs, + Name: secretName, }, + Data: func(m map[string][]byte) map[string][]byte { + d := make(map[string][]byte, len(m)) + for key, val := range m { + d[sanitizeKey(key)] = val + } + return d + }(data), }) } - return err + return fmt.Errorf("error getting Secret %s: %w", secretName, err) } - if s.canPatch { - if len(secret.Data) == 0 { // if user has pre-created a blank Secret - m := []kubeclient.JSONPatch{ + if s.canPatchSecret(secretName) { + var m []kubeclient.JSONPatch + // If the user has pre-created a Secret with no data, we need to ensure the top level /data field. + if len(secret.Data) == 0 { + m = []kubeclient.JSONPatch{ { - Op: "add", - Path: "/data", - Value: map[string][]byte{sanitizeKey(id): bs}, + Op: "add", + Path: "/data", + Value: func(m map[string][]byte) map[string][]byte { + d := make(map[string][]byte, len(m)) + for key, val := range m { + d[sanitizeKey(key)] = val + } + return d + }(data), }, } - if err := s.client.JSONPatchSecret(ctx, s.secretName, m); err != nil { - return fmt.Errorf("error patching Secret %s with a /data field: %v", s.secretName, err) + // If the Secret has data, patch it with the new data. + } else { + for key, val := range data { + m = append(m, kubeclient.JSONPatch{ + Op: "add", + Path: "/data/" + sanitizeKey(key), + Value: val, + }) } - return nil } - m := []kubeclient.JSONPatch{ - { - Op: "add", - Path: "/data/" + sanitizeKey(id), - Value: bs, - }, - } - if err := s.client.JSONPatchSecret(ctx, s.secretName, m); err != nil { - return fmt.Errorf("error patching Secret %s with /data/%s field", s.secretName, sanitizeKey(id)) + if err := s.client.JSONPatchResource(ctx, secretName, kubeclient.TypeSecrets, m); err != nil { + return fmt.Errorf("error patching Secret %s: %w", secretName, err) } return nil } - secret.Data[sanitizeKey(id)] = bs + // No patch permissions, use UPDATE instead. + for key, val := range data { + mak.Set(&secret.Data, sanitizeKey(key), val) + } if err := s.client.UpdateSecret(ctx, secret); err != nil { + return fmt.Errorf("error updating Secret %s: %w", s.secretName, err) + } + return nil +} + +func (s *Store) loadState() (err error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + secret, err := s.client.GetSecret(ctx, s.secretName) + if err != nil { + if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 { + return ipn.ErrStateNotExist + } + if err := s.client.Event(ctx, eventTypeWarning, reasonTailscaleStateLoadFailed, err.Error()); err != nil { + log.Printf("kubestore: error creating Event: %v", err) + } return err } - return err + if err := s.client.Event(ctx, eventTypeNormal, reasonTailscaleStateLoaded, "Successfully loaded tailscaled state from Secret"); err != nil { + log.Printf("kubestore: error creating Event: %v", err) + } + s.memory.LoadFromMap(secret.Data) + return nil +} + +// runCertReload relists and reloads all TLS certs for endpoints shared by this +// node from Secrets other than the state Secret to ensure that renewed certs get eventually loaded. +// It is not critical to reload a cert immediately after +// renewal, so a daily check is acceptable. +// Currently (3/2025) this is only used for the shared HA Ingress certs on 'read' replicas. +// Note that if shared certs are not found in memory on an HTTPS request, we +// do a Secret lookup, so this mechanism does not need to ensure that newly +// added Ingresses' certs get loaded. +func (s *Store) runCertReload(ctx context.Context, logf logger.Logf) { + ticker := time.NewTicker(time.Hour * 24) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + sel := s.certSecretSelector() + if err := s.loadCerts(ctx, sel); err != nil { + logf("[unexpected] error reloading TLS certs: %v", err) + } + } + } +} + +// loadCerts lists all Secrets matching the provided selector and loads TLS +// certs and keys from those. +func (s *Store) loadCerts(ctx context.Context, sel map[string]string) error { + ss, err := s.client.ListSecrets(ctx, sel) + if err != nil { + return fmt.Errorf("error listing TLS Secrets: %w", err) + } + for _, secret := range ss.Items { + if !hasTLSData(&secret) { + continue + } + // Only load secrets that have valid domain names (ending in .ts.net) + if !strings.HasSuffix(secret.Name, ".ts.net") { + continue + } + s.memory.WriteState(ipn.StateKey(secret.Name)+".crt", secret.Data[keyTLSCert]) + s.memory.WriteState(ipn.StateKey(secret.Name)+".key", secret.Data[keyTLSKey]) + } + return nil +} + +// canCreateSecret returns true if this node should be allowed to create the given +// Secret in its namespace. +func (s *Store) canCreateSecret(secret string) bool { + // Only allow creating the state Secret (and not TLS Secrets). + return secret == s.secretName +} + +// canPatchSecret returns true if this node should be allowed to patch the given +// Secret. +func (s *Store) canPatchSecret(secret string) bool { + // For backwards compatibility reasons, setups where the proxies are not + // given PATCH permissions for state Secrets are allowed. For TLS + // Secrets, we should always have PATCH permissions. + if secret == s.secretName { + return s.canPatch + } + return true +} + +// certSecretSelector returns a label selector that can be used to list all +// Secrets that aren't Tailscale state Secrets and contain TLS certificates for +// HTTPS endpoints that this node serves. +// Currently (3/2025) this only applies to the Kubernetes Operator's ingress +// ProxyGroup. +func (s *Store) certSecretSelector() map[string]string { + if s.podName == "" { + return map[string]string{} + } + p := strings.LastIndex(s.podName, "-") + if p == -1 { + return map[string]string{} + } + pgName := s.podName[:p] + return map[string]string{ + kubetypes.LabelSecretType: "certs", + kubetypes.LabelManaged: "true", + "tailscale.com/proxy-group": pgName, + } +} + +// hasTLSData returns true if the provided Secret contains non-empty TLS cert and key. +func hasTLSData(s *kubeapi.Secret) bool { + return len(s.Data[keyTLSCert]) != 0 && len(s.Data[keyTLSKey]) != 0 +} + +// sanitizeKey converts any value that can be converted to a string into a valid Kubernetes Secret key. +// Valid characters are alphanumeric, -, _, and . +// https://kubernetes.io/docs/concepts/configuration/secret/#restriction-names-data. +func sanitizeKey[T ~string](k T) string { + return strings.Map(func(r rune) rune { + if r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' || r == '-' || r == '_' || r == '.' { + return r + } + return '_' + }, string(k)) } diff --git a/ipn/store/kubestore/store_kube_test.go b/ipn/store/kubestore/store_kube_test.go new file mode 100644 index 0000000000000..0d709264e5c08 --- /dev/null +++ b/ipn/store/kubestore/store_kube_test.go @@ -0,0 +1,726 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package kubestore + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/envknob" + "tailscale.com/ipn" + "tailscale.com/ipn/store/mem" + "tailscale.com/kube/kubeapi" + "tailscale.com/kube/kubeclient" +) + +func TestWriteState(t *testing.T) { + tests := []struct { + name string + initial map[string][]byte + key ipn.StateKey + value []byte + wantData map[string][]byte + allowPatch bool + }{ + { + name: "basic_write", + initial: map[string][]byte{ + "existing": []byte("old"), + }, + key: "foo", + value: []byte("bar"), + wantData: map[string][]byte{ + "existing": []byte("old"), + "foo": []byte("bar"), + }, + allowPatch: true, + }, + { + name: "update_existing", + initial: map[string][]byte{ + "foo": []byte("old"), + }, + key: "foo", + value: []byte("new"), + wantData: map[string][]byte{ + "foo": []byte("new"), + }, + allowPatch: true, + }, + { + name: "create_new_secret", + key: "foo", + value: []byte("bar"), + wantData: map[string][]byte{ + "foo": []byte("bar"), + }, + allowPatch: true, + }, + { + name: "patch_denied", + initial: map[string][]byte{ + "foo": []byte("old"), + }, + key: "foo", + value: []byte("new"), + wantData: map[string][]byte{ + "foo": []byte("new"), + }, + allowPatch: false, + }, + { + name: "sanitize_key", + initial: map[string][]byte{ + "clean-key": []byte("old"), + }, + key: "dirty@key", + value: []byte("new"), + wantData: map[string][]byte{ + "clean-key": []byte("old"), + "dirty_key": []byte("new"), + }, + allowPatch: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + secret := tt.initial // track current state + client := &kubeclient.FakeClient{ + GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) { + if secret == nil { + return nil, &kubeapi.Status{Code: 404} + } + return &kubeapi.Secret{Data: secret}, nil + }, + CheckSecretPermissionsImpl: func(ctx context.Context, name string) (bool, bool, error) { + return tt.allowPatch, true, nil + }, + CreateSecretImpl: func(ctx context.Context, s *kubeapi.Secret) error { + secret = s.Data + return nil + }, + UpdateSecretImpl: func(ctx context.Context, s *kubeapi.Secret) error { + secret = s.Data + return nil + }, + JSONPatchResourceImpl: func(ctx context.Context, name, resourceType string, patches []kubeclient.JSONPatch) error { + if !tt.allowPatch { + return &kubeapi.Status{Reason: "Forbidden"} + } + if secret == nil { + secret = make(map[string][]byte) + } + for _, p := range patches { + if p.Op == "add" && p.Path == "/data" { + secret = p.Value.(map[string][]byte) + } else if p.Op == "add" && strings.HasPrefix(p.Path, "/data/") { + key := strings.TrimPrefix(p.Path, "/data/") + secret[key] = p.Value.([]byte) + } + } + return nil + }, + } + + s := &Store{ + client: client, + canPatch: tt.allowPatch, + secretName: "ts-state", + memory: mem.Store{}, + } + + err := s.WriteState(tt.key, tt.value) + if err != nil { + t.Errorf("WriteState() error = %v", err) + return + } + + // Verify secret data + if diff := cmp.Diff(secret, tt.wantData); diff != "" { + t.Errorf("secret data mismatch (-got +want):\n%s", diff) + } + + // Verify memory store was updated + got, err := s.memory.ReadState(ipn.StateKey(sanitizeKey(string(tt.key)))) + if err != nil { + t.Errorf("reading from memory store: %v", err) + } + if !cmp.Equal(got, tt.value) { + t.Errorf("memory store key %q = %v, want %v", tt.key, got, tt.value) + } + }) + } +} + +func TestWriteTLSCertAndKey(t *testing.T) { + const ( + testDomain = "my-app.tailnetxyz.ts.net" + testCert = "fake-cert" + testKey = "fake-key" + ) + + tests := []struct { + name string + initial map[string][]byte // pre-existing cert and key + certShareMode string + allowPatch bool // whether client can patch the Secret + wantSecretName string // name of the Secret where cert and key should be written + wantSecretData map[string][]byte + wantMemoryStore map[ipn.StateKey][]byte + }{ + { + name: "basic_write", + initial: map[string][]byte{ + "existing": []byte("old"), + }, + allowPatch: true, + wantSecretName: "ts-state", + wantSecretData: map[string][]byte{ + "existing": []byte("old"), + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + wantMemoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + }, + { + name: "cert_share_mode_write", + certShareMode: "rw", + allowPatch: true, + wantSecretName: "my-app.tailnetxyz.ts.net", + wantSecretData: map[string][]byte{ + "tls.crt": []byte(testCert), + "tls.key": []byte(testKey), + }, + }, + { + name: "cert_share_mode_write_update_existing", + initial: map[string][]byte{ + "tls.crt": []byte("old-cert"), + "tls.key": []byte("old-key"), + }, + certShareMode: "rw", + allowPatch: true, + wantSecretName: "my-app.tailnetxyz.ts.net", + wantSecretData: map[string][]byte{ + "tls.crt": []byte(testCert), + "tls.key": []byte(testKey), + }, + }, + { + name: "update_existing", + initial: map[string][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte("old-cert"), + "my-app.tailnetxyz.ts.net.key": []byte("old-key"), + }, + certShareMode: "", + allowPatch: true, + wantSecretName: "ts-state", + wantSecretData: map[string][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + wantMemoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + }, + { + name: "patch_denied", + certShareMode: "", + allowPatch: false, + wantSecretName: "ts-state", + wantSecretData: map[string][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + wantMemoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + // Set POD_NAME for testing selectors + envknob.Setenv("POD_NAME", "ingress-proxies-1") + defer envknob.Setenv("POD_NAME", "") + + secret := tt.initial // track current state + client := &kubeclient.FakeClient{ + GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) { + if secret == nil { + return nil, &kubeapi.Status{Code: 404} + } + return &kubeapi.Secret{Data: secret}, nil + }, + CheckSecretPermissionsImpl: func(ctx context.Context, name string) (bool, bool, error) { + return tt.allowPatch, true, nil + }, + CreateSecretImpl: func(ctx context.Context, s *kubeapi.Secret) error { + if s.Name != tt.wantSecretName { + t.Errorf("CreateSecret called with wrong name, got %q, want %q", s.Name, tt.wantSecretName) + } + secret = s.Data + return nil + }, + UpdateSecretImpl: func(ctx context.Context, s *kubeapi.Secret) error { + if s.Name != tt.wantSecretName { + t.Errorf("UpdateSecret called with wrong name, got %q, want %q", s.Name, tt.wantSecretName) + } + secret = s.Data + return nil + }, + JSONPatchResourceImpl: func(ctx context.Context, name, resourceType string, patches []kubeclient.JSONPatch) error { + if !tt.allowPatch { + return &kubeapi.Status{Reason: "Forbidden"} + } + if name != tt.wantSecretName { + t.Errorf("JSONPatchResource called with wrong name, got %q, want %q", name, tt.wantSecretName) + } + if secret == nil { + secret = make(map[string][]byte) + } + for _, p := range patches { + if p.Op == "add" && p.Path == "/data" { + secret = p.Value.(map[string][]byte) + } else if p.Op == "add" && strings.HasPrefix(p.Path, "/data/") { + key := strings.TrimPrefix(p.Path, "/data/") + secret[key] = p.Value.([]byte) + } + } + return nil + }, + } + + s := &Store{ + client: client, + canPatch: tt.allowPatch, + secretName: tt.wantSecretName, + certShareMode: tt.certShareMode, + memory: mem.Store{}, + } + + err := s.WriteTLSCertAndKey(testDomain, []byte(testCert), []byte(testKey)) + if err != nil { + t.Errorf("WriteTLSCertAndKey() error = '%v'", err) + return + } + + // Verify secret data + if diff := cmp.Diff(secret, tt.wantSecretData); diff != "" { + t.Errorf("secret data mismatch (-got +want):\n%s", diff) + } + + // Verify memory store was updated + for key, want := range tt.wantMemoryStore { + got, err := s.memory.ReadState(key) + if err != nil { + t.Errorf("reading from memory store: %v", err) + continue + } + if !cmp.Equal(got, want) { + t.Errorf("memory store key %q = %v, want %v", key, got, want) + } + } + }) + } +} + +func TestReadTLSCertAndKey(t *testing.T) { + const ( + testDomain = "my-app.tailnetxyz.ts.net" + testCert = "fake-cert" + testKey = "fake-key" + ) + + tests := []struct { + name string + memoryStore map[ipn.StateKey][]byte // pre-existing memory store state + certShareMode string + domain string + secretData map[string][]byte // data to return from mock GetSecret + secretGetErr error // error to return from mock GetSecret + wantCert []byte + wantKey []byte + wantErr error + // what should end up in memory store after the store is created + wantMemoryStore map[ipn.StateKey][]byte + }{ + { + name: "found_in_memory", + memoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + domain: testDomain, + wantCert: []byte(testCert), + wantKey: []byte(testKey), + wantMemoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + }, + { + name: "not_found_in_memory", + domain: testDomain, + wantErr: ipn.ErrStateNotExist, + }, + { + name: "cert_share_ro_mode_found_in_secret", + certShareMode: "ro", + domain: testDomain, + secretData: map[string][]byte{ + "tls.crt": []byte(testCert), + "tls.key": []byte(testKey), + }, + wantCert: []byte(testCert), + wantKey: []byte(testKey), + wantMemoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + }, + { + name: "cert_share_rw_mode_found_in_secret", + certShareMode: "rw", + domain: testDomain, + secretData: map[string][]byte{ + "tls.crt": []byte(testCert), + "tls.key": []byte(testKey), + }, + wantCert: []byte(testCert), + wantKey: []byte(testKey), + }, + { + name: "cert_share_ro_mode_found_in_memory", + certShareMode: "ro", + memoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + domain: testDomain, + wantCert: []byte(testCert), + wantKey: []byte(testKey), + wantMemoryStore: map[ipn.StateKey][]byte{ + "my-app.tailnetxyz.ts.net.crt": []byte(testCert), + "my-app.tailnetxyz.ts.net.key": []byte(testKey), + }, + }, + { + name: "cert_share_ro_mode_not_found", + certShareMode: "ro", + domain: testDomain, + secretGetErr: &kubeapi.Status{Code: 404}, + wantErr: ipn.ErrStateNotExist, + }, + { + name: "cert_share_ro_mode_empty_cert_in_secret", + certShareMode: "ro", + domain: testDomain, + secretData: map[string][]byte{ + "tls.crt": {}, + "tls.key": []byte(testKey), + }, + wantErr: ipn.ErrStateNotExist, + }, + { + name: "cert_share_ro_mode_kube_api_error", + certShareMode: "ro", + domain: testDomain, + secretGetErr: fmt.Errorf("api error"), + wantErr: fmt.Errorf("getting TLS Secret %q: api error", sanitizeKey(testDomain)), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + client := &kubeclient.FakeClient{ + GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) { + if tt.secretGetErr != nil { + return nil, tt.secretGetErr + } + return &kubeapi.Secret{Data: tt.secretData}, nil + }, + } + + s := &Store{ + client: client, + secretName: "ts-state", + certShareMode: tt.certShareMode, + memory: mem.Store{}, + } + + // Initialize memory store + for k, v := range tt.memoryStore { + s.memory.WriteState(k, v) + } + + gotCert, gotKey, err := s.ReadTLSCertAndKey(tt.domain) + if tt.wantErr != nil { + if err == nil { + t.Errorf("ReadTLSCertAndKey() error = nil, want error containing %v", tt.wantErr) + return + } + if !strings.Contains(err.Error(), tt.wantErr.Error()) { + t.Errorf("ReadTLSCertAndKey() error = %v, want error containing %v", err, tt.wantErr) + } + return + } + if err != nil { + t.Errorf("ReadTLSCertAndKey() unexpected error: %v", err) + return + } + + if !bytes.Equal(gotCert, tt.wantCert) { + t.Errorf("ReadTLSCertAndKey() gotCert = %v, want %v", gotCert, tt.wantCert) + } + if !bytes.Equal(gotKey, tt.wantKey) { + t.Errorf("ReadTLSCertAndKey() gotKey = %v, want %v", gotKey, tt.wantKey) + } + + // Verify memory store contents after operation + if tt.wantMemoryStore != nil { + for key, want := range tt.wantMemoryStore { + got, err := s.memory.ReadState(key) + if err != nil { + t.Errorf("reading from memory store: %v", err) + continue + } + if !bytes.Equal(got, want) { + t.Errorf("memory store key %q = %v, want %v", key, got, want) + } + } + } + }) + } +} + +func TestNewWithClient(t *testing.T) { + const ( + secretName = "ts-state" + testCert = "fake-cert" + testKey = "fake-key" + ) + + certSecretsLabels := map[string]string{ + "tailscale.com/secret-type": "certs", + "tailscale.com/managed": "true", + "tailscale.com/proxy-group": "ingress-proxies", + } + + // Helper function to create Secret objects for testing + makeSecret := func(name string, labels map[string]string, certSuffix string) kubeapi.Secret { + return kubeapi.Secret{ + ObjectMeta: kubeapi.ObjectMeta{ + Name: name, + Labels: labels, + }, + Data: map[string][]byte{ + "tls.crt": []byte(testCert + certSuffix), + "tls.key": []byte(testKey + certSuffix), + }, + } + } + + tests := []struct { + name string + stateSecretContents map[string][]byte // data in state Secret + TLSSecrets []kubeapi.Secret // list of TLS cert Secrets + certMode string + secretGetErr error // error to return from GetSecret + secretsListErr error // error to return from ListSecrets + wantMemoryStoreContents map[ipn.StateKey][]byte + wantErr error + }{ + { + name: "empty_state_secret", + stateSecretContents: map[string][]byte{}, + wantMemoryStoreContents: map[ipn.StateKey][]byte{}, + }, + { + name: "state_secret_not_found", + secretGetErr: &kubeapi.Status{Code: 404}, + wantMemoryStoreContents: map[ipn.StateKey][]byte{}, + }, + { + name: "state_secret_get_error", + secretGetErr: fmt.Errorf("some error"), + wantErr: fmt.Errorf("error loading state from kube Secret: some error"), + }, + { + name: "load_existing_state", + stateSecretContents: map[string][]byte{ + "foo": []byte("bar"), + "baz": []byte("qux"), + }, + wantMemoryStoreContents: map[ipn.StateKey][]byte{ + "foo": []byte("bar"), + "baz": []byte("qux"), + }, + }, + { + name: "load_select_certs_in_read_only_mode", + certMode: "ro", + stateSecretContents: map[string][]byte{ + "foo": []byte("bar"), + }, + TLSSecrets: []kubeapi.Secret{ + makeSecret("app1.tailnetxyz.ts.net", certSecretsLabels, "1"), + makeSecret("app2.tailnetxyz.ts.net", certSecretsLabels, "2"), + makeSecret("some-other-secret", nil, "3"), + makeSecret("app3.other-proxies.ts.net", map[string]string{ + "tailscale.com/secret-type": "certs", + "tailscale.com/managed": "true", + "tailscale.com/proxy-group": "some-other-proxygroup", + }, "4"), + }, + wantMemoryStoreContents: map[ipn.StateKey][]byte{ + "foo": []byte("bar"), + "app1.tailnetxyz.ts.net.crt": []byte(testCert + "1"), + "app1.tailnetxyz.ts.net.key": []byte(testKey + "1"), + "app2.tailnetxyz.ts.net.crt": []byte(testCert + "2"), + "app2.tailnetxyz.ts.net.key": []byte(testKey + "2"), + }, + }, + { + name: "load_select_certs_in_read_write_mode", + certMode: "rw", + stateSecretContents: map[string][]byte{ + "foo": []byte("bar"), + }, + TLSSecrets: []kubeapi.Secret{ + makeSecret("app1.tailnetxyz.ts.net", certSecretsLabels, "1"), + makeSecret("app2.tailnetxyz.ts.net", certSecretsLabels, "2"), + makeSecret("some-other-secret", nil, "3"), + makeSecret("app3.other-proxies.ts.net", map[string]string{ + "tailscale.com/secret-type": "certs", + "tailscale.com/managed": "true", + "tailscale.com/proxy-group": "some-other-proxygroup", + }, "4"), + }, + wantMemoryStoreContents: map[ipn.StateKey][]byte{ + "foo": []byte("bar"), + "app1.tailnetxyz.ts.net.crt": []byte(testCert + "1"), + "app1.tailnetxyz.ts.net.key": []byte(testKey + "1"), + "app2.tailnetxyz.ts.net.crt": []byte(testCert + "2"), + "app2.tailnetxyz.ts.net.key": []byte(testKey + "2"), + }, + }, + { + name: "list_cert_secrets_fails", + certMode: "ro", + stateSecretContents: map[string][]byte{ + "foo": []byte("bar"), + }, + secretsListErr: fmt.Errorf("list error"), + // The error is logged but not returned, and state is still loaded + wantMemoryStoreContents: map[ipn.StateKey][]byte{ + "foo": []byte("bar"), + }, + }, + { + name: "cert_secrets_not_loaded_when_not_in_share_mode", + certMode: "", + stateSecretContents: map[string][]byte{ + "foo": []byte("bar"), + }, + TLSSecrets: []kubeapi.Secret{ + makeSecret("app1.tailnetxyz.ts.net", certSecretsLabels, "1"), + }, + wantMemoryStoreContents: map[ipn.StateKey][]byte{ + "foo": []byte("bar"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + envknob.Setenv("TS_CERT_SHARE_MODE", tt.certMode) + + t.Setenv("POD_NAME", "ingress-proxies-1") + + client := &kubeclient.FakeClient{ + GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) { + if tt.secretGetErr != nil { + return nil, tt.secretGetErr + } + if name == secretName { + return &kubeapi.Secret{Data: tt.stateSecretContents}, nil + } + return nil, &kubeapi.Status{Code: 404} + }, + CheckSecretPermissionsImpl: func(ctx context.Context, name string) (bool, bool, error) { + return true, true, nil + }, + ListSecretsImpl: func(ctx context.Context, selector map[string]string) (*kubeapi.SecretList, error) { + if tt.secretsListErr != nil { + return nil, tt.secretsListErr + } + var matchingSecrets []kubeapi.Secret + for _, secret := range tt.TLSSecrets { + matches := true + for k, v := range selector { + if secret.Labels[k] != v { + matches = false + break + } + } + if matches { + matchingSecrets = append(matchingSecrets, secret) + } + } + return &kubeapi.SecretList{Items: matchingSecrets}, nil + }, + } + + s, err := newWithClient(t.Logf, client, secretName) + if tt.wantErr != nil { + if err == nil { + t.Errorf("NewWithClient() error = nil, want error containing %v", tt.wantErr) + return + } + if !strings.Contains(err.Error(), tt.wantErr.Error()) { + t.Errorf("NewWithClient() error = %v, want error containing %v", err, tt.wantErr) + } + return + } + + if err != nil { + t.Errorf("NewWithClient() unexpected error: %v", err) + return + } + + // Verify memory store contents + gotJSON, err := s.memory.ExportToJSON() + if err != nil { + t.Errorf("ExportToJSON failed: %v", err) + return + } + var got map[ipn.StateKey][]byte + if err := json.Unmarshal(gotJSON, &got); err != nil { + t.Errorf("failed to unmarshal memory store JSON: %v", err) + return + } + want := tt.wantMemoryStoreContents + if want == nil { + want = map[ipn.StateKey][]byte{} + } + if diff := cmp.Diff(got, want); diff != "" { + t.Errorf("memory store contents mismatch (-got +want):\n%s", diff) + } + }) + } +} diff --git a/ipn/store/mem/store_mem.go b/ipn/store/mem/store_mem.go index f3a308ae5dc4f..6f474ce993b43 100644 --- a/ipn/store/mem/store_mem.go +++ b/ipn/store/mem/store_mem.go @@ -9,8 +9,10 @@ import ( "encoding/json" "sync" + xmaps "golang.org/x/exp/maps" "tailscale.com/ipn" "tailscale.com/types/logger" + "tailscale.com/util/mak" ) // New returns a new Store. @@ -28,6 +30,7 @@ type Store struct { func (s *Store) String() string { return "mem.Store" } // ReadState implements the StateStore interface. +// It returns ipn.ErrStateNotExist if the state does not exist. func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) { s.mu.Lock() defer s.mu.Unlock() @@ -39,6 +42,7 @@ func (s *Store) ReadState(id ipn.StateKey) ([]byte, error) { } // WriteState implements the StateStore interface. +// It never returns an error. func (s *Store) WriteState(id ipn.StateKey, bs []byte) error { s.mu.Lock() defer s.mu.Unlock() @@ -49,6 +53,19 @@ func (s *Store) WriteState(id ipn.StateKey, bs []byte) error { return nil } +// LoadFromMap loads the in-memory cache from the provided map. +// Any existing content is cleared, and the provided map is +// copied into the cache. +func (s *Store) LoadFromMap(m map[string][]byte) { + s.mu.Lock() + defer s.mu.Unlock() + xmaps.Clear(s.cache) + for k, v := range m { + mak.Set(&s.cache, ipn.StateKey(k), v) + } + return +} + // LoadFromJSON attempts to unmarshal json content into the // in-memory cache. func (s *Store) LoadFromJSON(data []byte) error { diff --git a/ipn/store/store_aws.go b/ipn/store/store_aws.go index e164f9de741b0..834b657d34df0 100644 --- a/ipn/store/store_aws.go +++ b/ipn/store/store_aws.go @@ -1,18 +1,22 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build (ts_aws || (linux && (arm64 || amd64))) && !ts_omit_aws +//go:build (ts_aws || (linux && (arm64 || amd64) && !android)) && !ts_omit_aws package store import ( + "tailscale.com/ipn" "tailscale.com/ipn/store/awsstore" + "tailscale.com/types/logger" ) func init() { - registerAvailableExternalStores = append(registerAvailableExternalStores, registerAWSStore) -} - -func registerAWSStore() { - Register("arn:", awsstore.New) + Register("arn:", func(logf logger.Logf, arg string) (ipn.StateStore, error) { + ssmARN, opts, err := awsstore.ParseARNAndOpts(arg) + if err != nil { + return nil, err + } + return awsstore.New(logf, ssmARN, opts...) + }) } diff --git a/ipn/store/store_kube.go b/ipn/store/store_kube.go index 8941620f6649d..7eac75c196990 100644 --- a/ipn/store/store_kube.go +++ b/ipn/store/store_kube.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build (ts_kube || (linux && (arm64 || amd64))) && !ts_omit_kube +//go:build (ts_kube || (linux && (arm64 || amd64) && !android)) && !ts_omit_kube package store @@ -14,10 +14,6 @@ import ( ) func init() { - registerAvailableExternalStores = append(registerAvailableExternalStores, registerKubeStore) -} - -func registerKubeStore() { Register("kube:", func(logf logger.Logf, path string) (ipn.StateStore, error) { secretName := strings.TrimPrefix(path, "kube:") return kubestore.New(logf, secretName) diff --git a/ipn/store/stores.go b/ipn/store/stores.go index 1a87fc548026a..1f98891bff248 100644 --- a/ipn/store/stores.go +++ b/ipn/store/stores.go @@ -26,16 +26,8 @@ import ( // The arg is of the form "prefix:rest", where prefix was previously registered with Register. type Provider func(logf logger.Logf, arg string) (ipn.StateStore, error) -var regOnce sync.Once - -var registerAvailableExternalStores []func() - -func registerDefaultStores() { +func init() { Register("mem:", mem.New) - - for _, f := range registerAvailableExternalStores { - f() - } } var knownStores map[string]Provider @@ -55,7 +47,6 @@ var knownStores map[string]Provider // the suffix is a Kubernetes secret name // - In all other cases, the path is treated as a filepath. func New(logf logger.Logf, path string) (ipn.StateStore, error) { - regOnce.Do(registerDefaultStores) for prefix, sf := range knownStores { if strings.HasPrefix(path, prefix) { // We can't strip the prefix here as some NewStoreFunc (like arn:) diff --git a/ipn/store/stores_test.go b/ipn/store/stores_test.go index ea09e6ea63ae4..1f0fc0fef1bff 100644 --- a/ipn/store/stores_test.go +++ b/ipn/store/stores_test.go @@ -4,6 +4,7 @@ package store import ( + "maps" "path/filepath" "testing" @@ -14,10 +15,9 @@ import ( ) func TestNewStore(t *testing.T) { - regOnce.Do(registerDefaultStores) + oldKnownStores := maps.Clone(knownStores) t.Cleanup(func() { - knownStores = map[string]Provider{} - registerDefaultStores() + knownStores = oldKnownStores }) knownStores = map[string]Provider{} diff --git a/k8s-operator/api-proxy/doc.go b/k8s-operator/api-proxy/doc.go new file mode 100644 index 0000000000000..89d8909595fd3 --- /dev/null +++ b/k8s-operator/api-proxy/doc.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// Package apiproxy contains the Kubernetes API Proxy implementation used by +// k8s-operator and k8s-proxy. +package apiproxy diff --git a/k8s-operator/api-proxy/env.go b/k8s-operator/api-proxy/env.go new file mode 100644 index 0000000000000..c0640ab1e16bf --- /dev/null +++ b/k8s-operator/api-proxy/env.go @@ -0,0 +1,29 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package apiproxy + +import ( + "os" + + "tailscale.com/types/opt" +) + +func defaultBool(envName string, defVal bool) bool { + vs := os.Getenv(envName) + if vs == "" { + return defVal + } + v, _ := opt.Bool(vs).Get() + return v +} + +func defaultEnv(envName, defVal string) string { + v := os.Getenv(envName) + if v == "" { + return defVal + } + return v +} diff --git a/cmd/k8s-operator/proxy.go b/k8s-operator/api-proxy/proxy.go similarity index 92% rename from cmd/k8s-operator/proxy.go rename to k8s-operator/api-proxy/proxy.go index 672f07b1f1608..7c7260b94af39 100644 --- a/cmd/k8s-operator/proxy.go +++ b/k8s-operator/api-proxy/proxy.go @@ -3,7 +3,7 @@ //go:build !plan9 -package main +package apiproxy import ( "crypto/tls" @@ -20,7 +20,7 @@ import ( "go.uber.org/zap" "k8s.io/client-go/rest" "k8s.io/client-go/transport" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/client/tailscale/apitype" ksr "tailscale.com/k8s-operator/sessionrecording" "tailscale.com/kube/kubetypes" @@ -37,15 +37,15 @@ var ( whoIsKey = ctxkey.New("", (*apitype.WhoIsResponse)(nil)) ) -type apiServerProxyMode int +type APIServerProxyMode int -func (a apiServerProxyMode) String() string { +func (a APIServerProxyMode) String() string { switch a { - case apiserverProxyModeDisabled: + case APIServerProxyModeDisabled: return "disabled" - case apiserverProxyModeEnabled: + case APIServerProxyModeEnabled: return "auth" - case apiserverProxyModeNoAuth: + case APIServerProxyModeNoAuth: return "noauth" default: return "unknown" @@ -53,12 +53,12 @@ func (a apiServerProxyMode) String() string { } const ( - apiserverProxyModeDisabled apiServerProxyMode = iota - apiserverProxyModeEnabled - apiserverProxyModeNoAuth + APIServerProxyModeDisabled APIServerProxyMode = iota + APIServerProxyModeEnabled + APIServerProxyModeNoAuth ) -func parseAPIProxyMode() apiServerProxyMode { +func ParseAPIProxyMode() APIServerProxyMode { haveAuthProxyEnv := os.Getenv("AUTH_PROXY") != "" haveAPIProxyEnv := os.Getenv("APISERVER_PROXY") != "" switch { @@ -67,34 +67,34 @@ func parseAPIProxyMode() apiServerProxyMode { case haveAuthProxyEnv: var authProxyEnv = defaultBool("AUTH_PROXY", false) // deprecated if authProxyEnv { - return apiserverProxyModeEnabled + return APIServerProxyModeEnabled } - return apiserverProxyModeDisabled + return APIServerProxyModeDisabled case haveAPIProxyEnv: var apiProxyEnv = defaultEnv("APISERVER_PROXY", "") // true, false or "noauth" switch apiProxyEnv { case "true": - return apiserverProxyModeEnabled + return APIServerProxyModeEnabled case "false", "": - return apiserverProxyModeDisabled + return APIServerProxyModeDisabled case "noauth": - return apiserverProxyModeNoAuth + return APIServerProxyModeNoAuth default: panic(fmt.Sprintf("unknown APISERVER_PROXY value %q", apiProxyEnv)) } } - return apiserverProxyModeDisabled + return APIServerProxyModeDisabled } // maybeLaunchAPIServerProxy launches the auth proxy, which is a small HTTP server // that authenticates requests using the Tailscale LocalAPI and then proxies // them to the kube-apiserver. -func maybeLaunchAPIServerProxy(zlog *zap.SugaredLogger, restConfig *rest.Config, s *tsnet.Server, mode apiServerProxyMode) { - if mode == apiserverProxyModeDisabled { +func MaybeLaunchAPIServerProxy(zlog *zap.SugaredLogger, restConfig *rest.Config, s *tsnet.Server, mode APIServerProxyMode) { + if mode == APIServerProxyModeDisabled { return } startlog := zlog.Named("launchAPIProxy") - if mode == apiserverProxyModeNoAuth { + if mode == APIServerProxyModeNoAuth { restConfig = rest.AnonymousClientConfig(restConfig) } cfg, err := restConfig.TransportConfig() @@ -132,8 +132,8 @@ func maybeLaunchAPIServerProxy(zlog *zap.SugaredLogger, restConfig *rest.Config, // are passed through to the Kubernetes API. // // It never returns. -func runAPIServerProxy(ts *tsnet.Server, rt http.RoundTripper, log *zap.SugaredLogger, mode apiServerProxyMode, host string) { - if mode == apiserverProxyModeDisabled { +func runAPIServerProxy(ts *tsnet.Server, rt http.RoundTripper, log *zap.SugaredLogger, mode APIServerProxyMode, host string) { + if mode == APIServerProxyModeDisabled { return } ln, err := ts.Listen("tcp", ":443") @@ -189,10 +189,10 @@ func runAPIServerProxy(ts *tsnet.Server, rt http.RoundTripper, log *zap.SugaredL // LocalAPI and then proxies them to the Kubernetes API. type apiserverProxy struct { log *zap.SugaredLogger - lc *tailscale.LocalClient + lc *local.Client rp *httputil.ReverseProxy - mode apiServerProxyMode + mode APIServerProxyMode ts *tsnet.Server upstreamURL *url.URL } @@ -285,7 +285,7 @@ func (ap *apiserverProxy) execForProto(w http.ResponseWriter, r *http.Request, p func (h *apiserverProxy) addImpersonationHeadersAsRequired(r *http.Request) { r.URL.Scheme = h.upstreamURL.Scheme r.URL.Host = h.upstreamURL.Host - if h.mode == apiserverProxyModeNoAuth { + if h.mode == APIServerProxyModeNoAuth { // If we are not providing authentication, then we are just // proxying to the Kubernetes API, so we don't need to do // anything else. @@ -311,7 +311,7 @@ func (h *apiserverProxy) addImpersonationHeadersAsRequired(r *http.Request) { // Now add the impersonation headers that we want. if err := addImpersonationHeaders(r, h.log); err != nil { - log.Printf("failed to add impersonation headers: " + err.Error()) + log.Print("failed to add impersonation headers: ", err.Error()) } } diff --git a/cmd/k8s-operator/proxy_test.go b/k8s-operator/api-proxy/proxy_test.go similarity index 99% rename from cmd/k8s-operator/proxy_test.go rename to k8s-operator/api-proxy/proxy_test.go index d1d5733e7f49f..71bf65648931c 100644 --- a/cmd/k8s-operator/proxy_test.go +++ b/k8s-operator/api-proxy/proxy_test.go @@ -3,7 +3,7 @@ //go:build !plan9 -package main +package apiproxy import ( "net/http" diff --git a/k8s-operator/api.md b/k8s-operator/api.md index e8a6e248a2934..39e1a97c091d1 100644 --- a/k8s-operator/api.md +++ b/k8s-operator/api.md @@ -21,6 +21,22 @@ +#### AppConnector + + + +AppConnector defines a Tailscale app connector node configured via Connector. + + + +_Appears in:_ +- [ConnectorSpec](#connectorspec) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `routes` _[Routes](#routes)_ | Routes are optional preconfigured routes for the domains routed via the app connector.
If not set, routes for the domains will be discovered dynamically.
If set, the app connector will immediately be able to route traffic using the preconfigured routes, but may
also dynamically discover other routes.
https://tailscale.com/kb/1332/apps-best-practices#preconfiguration | | Format: cidr
MinItems: 1
Type: string
| + + #### Connector @@ -86,8 +102,9 @@ _Appears in:_ | `tags` _[Tags](#tags)_ | Tags that the Tailscale node will be tagged with.
Defaults to [tag:k8s].
To autoapprove the subnet routes or exit node defined by a Connector,
you can configure Tailscale ACLs to give these tags the necessary
permissions.
See https://tailscale.com/kb/1337/acl-syntax#autoapprovers.
If you specify custom tags here, you must also make the operator an owner of these tags.
See https://tailscale.com/kb/1236/kubernetes-operator/#setting-up-the-kubernetes-operator.
Tags cannot be changed once a Connector node has been created.
Tag values must be in form ^tag:[a-zA-Z][a-zA-Z0-9-]*$. | | Pattern: `^tag:[a-zA-Z][a-zA-Z0-9-]*$`
Type: string
| | `hostname` _[Hostname](#hostname)_ | Hostname is the tailnet hostname that should be assigned to the
Connector node. If unset, hostname defaults to name>-connector. Hostname can contain lower case letters, numbers and
dashes, it must not start or end with a dash and must be between 2
and 63 characters long. | | Pattern: `^[a-z0-9][a-z0-9-]{0,61}[a-z0-9]$`
Type: string
| | `proxyClass` _string_ | ProxyClass is the name of the ProxyClass custom resource that
contains configuration options that should be applied to the
resources created for this Connector. If unset, the operator will
create resources with the default configuration. | | | -| `subnetRouter` _[SubnetRouter](#subnetrouter)_ | SubnetRouter defines subnet routes that the Connector node should
expose to tailnet. If unset, none are exposed.
https://tailscale.com/kb/1019/subnets/ | | | -| `exitNode` _boolean_ | ExitNode defines whether the Connector node should act as a
Tailscale exit node. Defaults to false.
https://tailscale.com/kb/1103/exit-nodes | | | +| `subnetRouter` _[SubnetRouter](#subnetrouter)_ | SubnetRouter defines subnet routes that the Connector device should
expose to tailnet as a Tailscale subnet router.
https://tailscale.com/kb/1019/subnets/
If this field is unset, the device does not get configured as a Tailscale subnet router.
This field is mutually exclusive with the appConnector field. | | | +| `appConnector` _[AppConnector](#appconnector)_ | AppConnector defines whether the Connector device should act as a Tailscale app connector. A Connector that is
configured as an app connector cannot be a subnet router or an exit node. If this field is unset, the
Connector does not act as an app connector.
Note that you will need to manually configure the permissions and the domains for the app connector via the
Admin panel.
Note also that the main tested and supported use case of this config option is to deploy an app connector on
Kubernetes to access SaaS applications available on the public internet. Using the app connector to expose
cluster workloads or other internal workloads to tailnet might work, but this is not a use case that we have
tested or optimised for.
If you are using the app connector to access SaaS applications because you need a predictable egress IP that
can be whitelisted, it is also your responsibility to ensure that cluster traffic from the connector flows
via that predictable IP, for example by enforcing that cluster egress traffic is routed via an egress NAT
device with a static IP address.
https://tailscale.com/kb/1281/app-connectors | | | +| `exitNode` _boolean_ | ExitNode defines whether the Connector device should act as a Tailscale exit node. Defaults to false.
This field is mutually exclusive with the appConnector field.
https://tailscale.com/kb/1103/exit-nodes | | | #### ConnectorStatus @@ -106,6 +123,7 @@ _Appears in:_ | `conditions` _[Condition](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#condition-v1-meta) array_ | List of status conditions to indicate the status of the Connector.
Known condition types are `ConnectorReady`. | | | | `subnetRoutes` _string_ | SubnetRoutes are the routes currently exposed to tailnet via this
Connector instance. | | | | `isExitNode` _boolean_ | IsExitNode is set to true if the Connector acts as an exit node. | | | +| `isAppConnector` _boolean_ | IsAppConnector is set to true if the Connector acts as an app connector. | | | | `tailnetIPs` _string array_ | TailnetIPs is the set of tailnet IP addresses (both IPv4 and IPv6)
assigned to the Connector node. | | | | `hostname` _string_ | Hostname is the fully qualified domain name of the Connector node.
If MagicDNS is enabled in your tailnet, it is the MagicDNS name of the
node. | | | @@ -127,7 +145,8 @@ _Appears in:_ | `image` _string_ | Container image name. By default images are pulled from
docker.io/tailscale/tailscale, but the official images are also
available at ghcr.io/tailscale/tailscale. Specifying image name here
will override any proxy image values specified via the Kubernetes
operator's Helm chart values or PROXY_IMAGE env var in the operator
Deployment.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#image | | | | `imagePullPolicy` _[PullPolicy](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#pullpolicy-v1-core)_ | Image pull policy. One of Always, Never, IfNotPresent. Defaults to Always.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#image | | Enum: [Always Never IfNotPresent]
| | `resources` _[ResourceRequirements](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#resourcerequirements-v1-core)_ | Container resource requirements.
By default Tailscale Kubernetes operator does not apply any resource
requirements. The amount of resources required wil depend on the
amount of resources the operator needs to parse, usage patterns and
cluster size.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#resources | | | -| `securityContext` _[SecurityContext](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#securitycontext-v1-core)_ | Container security context.
Security context specified here will override the security context by the operator.
By default the operator:
- sets 'privileged: true' for the init container
- set NET_ADMIN capability for tailscale container for proxies that
are created for Services or Connector.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context | | | +| `securityContext` _[SecurityContext](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#securitycontext-v1-core)_ | Container security context.
Security context specified here will override the security context set by the operator.
By default the operator sets the Tailscale container and the Tailscale init container to privileged
for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup.
You can reduce the permissions of the Tailscale container to cap NET_ADMIN by
installing device plugin in your cluster and configuring the proxies tun device to be created
by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context | | | +| `debug` _[Debug](#debug)_ | Configuration for enabling extra debug information in the container.
Not recommended for production use. | | | #### DNSConfig @@ -230,6 +249,22 @@ _Appears in:_ | `nameserver` _[NameserverStatus](#nameserverstatus)_ | Nameserver describes the status of nameserver cluster resources. | | | +#### Debug + + + + + + + +_Appears in:_ +- [Container](#container) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `enable` _boolean_ | Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/
and internal debug metrics endpoint at :9001/debug/metrics, where
9001 is a container port named "debug". The endpoints and their responses
may change in backwards incompatible ways in the future, and should not
be considered stable.
In 1.78.x and 1.80.x, this setting will default to the value of
.spec.metrics.enable, and requests to the "metrics" port matching the
mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x,
this setting will default to false, and no requests will be proxied. | | | + + #### Env @@ -278,6 +313,37 @@ _Appears in:_ +#### LabelValue + +_Underlying type:_ _string_ + + + +_Validation:_ +- MaxLength: 63 +- Pattern: `^(([a-zA-Z0-9][-._a-zA-Z0-9]*)?[a-zA-Z0-9])?$` +- Type: string + +_Appears in:_ +- [Labels](#labels) + + + +#### Labels + +_Underlying type:_ _[map[string]LabelValue](#map[string]labelvalue)_ + + + + + +_Appears in:_ +- [Pod](#pod) +- [ServiceMonitor](#servicemonitor) +- [StatefulSet](#statefulset) + + + #### Metrics @@ -291,7 +357,8 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `enable` _boolean_ | Setting enable to true will make the proxy serve Tailscale metrics
at :9001/debug/metrics.
Defaults to false. | | | +| `enable` _boolean_ | Setting enable to true will make the proxy serve Tailscale metrics
at :9002/metrics.
A metrics Service named -metrics will also be created in the operator's namespace and will
serve the metrics at :9002/metrics.
In 1.78.x and 1.80.x, this field also serves as the default value for
.spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both
fields will independently default to false.
Defaults to false. | | | +| `serviceMonitor` _[ServiceMonitor](#servicemonitor)_ | Enable to create a Prometheus ServiceMonitor for scraping the proxy's Tailscale metrics.
The ServiceMonitor will select the metrics Service that gets created when metrics are enabled.
The ingested metrics for each Service monitor will have labels to identify the proxy:
ts_proxy_type: ingress_service\|ingress_resource\|connector\|proxygroup
ts_proxy_parent_name: name of the parent resource (i.e name of the Connector, Tailscale Ingress, Tailscale Service or ProxyGroup)
ts_proxy_parent_namespace: namespace of the parent resource (if the parent resource is not cluster scoped)
job: ts__[]_ | | | #### Name @@ -323,6 +390,9 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | | `image` _[NameserverImage](#nameserverimage)_ | Nameserver image. Defaults to tailscale/k8s-nameserver:unstable. | | | +| `cmd` _string array_ | Cmd can be used to overwrite the command used when running the nameserver image. | | | +| `env` _[EnvVar](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#envvar-v1-core) array_ | Env can be used to pass environment variables to the nameserver
container. | | | +| `podLabels` _object (keys:string, values:string)_ | PodLabels are the labels which will be attached to the nameserver
pod. They can be used to define network policies. | | | #### NameserverImage @@ -371,7 +441,7 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `labels` _object (keys:string, values:string)_ | Labels that will be added to the proxy Pod.
Any labels specified here will be merged with the default labels
applied to the Pod by the Tailscale Kubernetes operator.
Label keys and values must be valid Kubernetes label keys and values.
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set | | | +| `labels` _[Labels](#labels)_ | Labels that will be added to the proxy Pod.
Any labels specified here will be merged with the default labels
applied to the Pod by the Tailscale Kubernetes operator.
Label keys and values must be valid Kubernetes label keys and values.
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set | | | | `annotations` _object (keys:string, values:string)_ | Annotations that will be added to the proxy Pod.
Any annotations specified here will be merged with the default
annotations applied to the Pod by the Tailscale Kubernetes operator.
Annotations must be valid Kubernetes annotations.
https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/#syntax-and-character-set | | | | `affinity` _[Affinity](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#affinity-v1-core)_ | Proxy Pod's affinity rules.
By default, the Tailscale Kubernetes operator does not apply any affinity rules.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#affinity | | | | `tailscaleContainer` _[Container](#container)_ | Configuration for the proxy container running tailscale. | | | @@ -381,6 +451,7 @@ _Appears in:_ | `nodeName` _string_ | Proxy Pod's node name.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | | | `nodeSelector` _object (keys:string, values:string)_ | Proxy Pod's node selector.
By default Tailscale Kubernetes operator does not apply any node
selector.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | | | `tolerations` _[Toleration](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#toleration-v1-core) array_ | Proxy Pod's tolerations.
By default Tailscale Kubernetes operator does not apply any
tolerations.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | | +| `topologySpreadConstraints` _[TopologySpreadConstraint](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#topologyspreadconstraint-v1-core) array_ | Proxy Pod's topology spread constraints.
By default Tailscale Kubernetes operator does not apply any topology spread constraints.
https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/ | | | #### ProxyClass @@ -449,6 +520,7 @@ _Appears in:_ | `statefulSet` _[StatefulSet](#statefulset)_ | Configuration parameters for the proxy's StatefulSet. Tailscale
Kubernetes operator deploys a StatefulSet for each of the user
configured proxies (Tailscale Ingress, Tailscale Service, Connector). | | | | `metrics` _[Metrics](#metrics)_ | Configuration for proxy metrics. Metrics are currently not supported
for egress proxies and for Ingress proxies that have been configured
with tailscale.com/experimental-forward-cluster-traffic-via-ingress
annotation. Note that the metrics are currently considered unstable
and will likely change in breaking ways in the future - we only
recommend that you use those for debugging purposes. | | | | `tailscale` _[TailscaleConfig](#tailscaleconfig)_ | TailscaleConfig contains options to configure the tailscale-specific
parameters of proxies. | | | +| `useLetsEncryptStagingEnvironment` _boolean_ | Set UseLetsEncryptStagingEnvironment to true to issue TLS
certificates for any HTTPS endpoints exposed to the tailnet from
LetsEncrypt's staging environment.
https://letsencrypt.org/docs/staging-environment/
This setting only affects Tailscale Ingress resources.
By default Ingress TLS certificates are issued from LetsEncrypt's
production environment.
Changing this setting true -> false, will result in any
existing certs being re-issued from the production environment.
Changing this setting false (default) -> true, when certs have already
been provisioned from production environment will NOT result in certs
being re-issued from the staging environment before they need to be
renewed. | | | #### ProxyClassStatus @@ -471,7 +543,16 @@ _Appears in:_ +ProxyGroup defines a set of Tailscale devices that will act as proxies. +Currently only egress ProxyGroups are supported. + +Use the tailscale.com/proxy-group annotation on a Service to specify that +the egress proxy should be implemented by a ProxyGroup instead of a single +dedicated proxy. In addition to running a highly available set of proxies, +ProxyGroup also allows for serving many annotated Services from a single +set of proxies to minimise resource consumption. +More info: https://tailscale.com/kb/1438/kubernetes-operator-cluster-egress @@ -522,9 +603,9 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `type` _[ProxyGroupType](#proxygrouptype)_ | Type of the ProxyGroup proxies. Currently the only supported type is egress. | | Enum: [egress]
Type: string
| +| `type` _[ProxyGroupType](#proxygrouptype)_ | Type of the ProxyGroup proxies. Supported types are egress and ingress.
Type is immutable once a ProxyGroup is created. | | Enum: [egress ingress]
Type: string
| | `tags` _[Tags](#tags)_ | Tags that the Tailscale devices will be tagged with. Defaults to [tag:k8s].
If you specify custom tags here, make sure you also make the operator
an owner of these tags.
See https://tailscale.com/kb/1236/kubernetes-operator/#setting-up-the-kubernetes-operator.
Tags cannot be changed once a ProxyGroup device has been created.
Tag values must be in form ^tag:[a-zA-Z][a-zA-Z0-9-]*$. | | Pattern: `^tag:[a-zA-Z][a-zA-Z0-9-]*$`
Type: string
| -| `replicas` _integer_ | Replicas specifies how many replicas to create the StatefulSet with.
Defaults to 2. | | | +| `replicas` _integer_ | Replicas specifies how many replicas to create the StatefulSet with.
Defaults to 2. | | Minimum: 0
| | `hostnamePrefix` _[HostnamePrefix](#hostnameprefix)_ | HostnamePrefix is the hostname prefix to use for tailnet devices created
by the ProxyGroup. Each device will have the integer number from its
StatefulSet pod appended to this prefix to form the full hostname.
HostnamePrefix can contain lower case letters, numbers and dashes, it
must not start with a dash and must be between 1 and 62 characters long. | | Pattern: `^[a-z0-9][a-z0-9-]{0,61}$`
Type: string
| | `proxyClass` _string_ | ProxyClass is the name of the ProxyClass custom resource that contains
configuration options that should be applied to the resources created
for this ProxyGroup. If unset, and there is no default ProxyClass
configured, the operator will create resources with the default
configuration. | | | @@ -553,7 +634,7 @@ _Underlying type:_ _string_ _Validation:_ -- Enum: [egress] +- Enum: [egress ingress] - Type: string _Appears in:_ @@ -565,7 +646,11 @@ _Appears in:_ +Recorder defines a tsrecorder device for recording SSH sessions. By default, +it will store recordings in a local ephemeral volume. If you want to persist +recordings, you can configure an S3-compatible API for storage. +More info: https://tailscale.com/kb/1484/kubernetes-operator-deploying-tsrecorder @@ -644,6 +729,24 @@ _Appears in:_ | `imagePullSecrets` _[LocalObjectReference](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#localobjectreference-v1-core) array_ | Image pull Secrets for Recorder Pods.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#PodSpec | | | | `nodeSelector` _object (keys:string, values:string)_ | Node selector rules for Recorder Pods. By default, the operator does
not apply any node selector rules.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | | | `tolerations` _[Toleration](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#toleration-v1-core) array_ | Tolerations for Recorder Pods. By default, the operator does not apply
any tolerations.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling | | | +| `serviceAccount` _[RecorderServiceAccount](#recorderserviceaccount)_ | Config for the ServiceAccount to create for the Recorder's StatefulSet.
By default, the operator will create a ServiceAccount with the same
name as the Recorder resource.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#service-account | | | + + +#### RecorderServiceAccount + + + + + + + +_Appears in:_ +- [RecorderPod](#recorderpod) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `name` _string_ | Name of the ServiceAccount to create. Defaults to the name of the
Recorder resource.
https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#service-account | | MaxLength: 253
Pattern: `^[a-z0-9]([a-z0-9-.]{0,61}[a-z0-9])?$`
Type: string
| +| `annotations` _object (keys:string, values:string)_ | Annotations to add to the ServiceAccount.
https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/#syntax-and-character-set
You can use this to add IAM roles to the ServiceAccount (IRSA) instead of
providing static S3 credentials in a Secret.
https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html
For example:
eks.amazonaws.com/role-arn: arn:aws:iam:::role/ | | | #### RecorderSpec @@ -745,6 +848,7 @@ _Validation:_ - Type: string _Appears in:_ +- [AppConnector](#appconnector) - [SubnetRouter](#subnetrouter) @@ -799,6 +903,23 @@ _Appears in:_ | `name` _string_ | The name of a Kubernetes Secret in the operator's namespace that contains
credentials for writing to the configured bucket. Each key-value pair
from the secret's data will be mounted as an environment variable. It
should include keys for AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY if
using a static access key. | | | +#### ServiceMonitor + + + + + + + +_Appears in:_ +- [Metrics](#metrics) + +| Field | Description | Default | Validation | +| --- | --- | --- | --- | +| `enable` _boolean_ | If Enable is set to true, a Prometheus ServiceMonitor will be created. Enable can only be set to true if metrics are enabled. | | | +| `labels` _[Labels](#labels)_ | Labels to add to the ServiceMonitor.
Labels must be valid Kubernetes labels.
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set | | | + + #### StatefulSet @@ -812,7 +933,7 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | -| `labels` _object (keys:string, values:string)_ | Labels that will be added to the StatefulSet created for the proxy.
Any labels specified here will be merged with the default labels
applied to the StatefulSet by the Tailscale Kubernetes operator as
well as any other labels that might have been applied by other
actors.
Label keys and values must be valid Kubernetes label keys and values.
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set | | | +| `labels` _[Labels](#labels)_ | Labels that will be added to the StatefulSet created for the proxy.
Any labels specified here will be merged with the default labels
applied to the StatefulSet by the Tailscale Kubernetes operator as
well as any other labels that might have been applied by other
actors.
Label keys and values must be valid Kubernetes label keys and values.
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set | | | | `annotations` _object (keys:string, values:string)_ | Annotations that will be added to the StatefulSet created for the proxy.
Any Annotations specified here will be merged with the default annotations
applied to the StatefulSet by the Tailscale Kubernetes operator as
well as any other annotations that might have been applied by other
actors.
Annotations must be valid Kubernetes annotations.
https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/#syntax-and-character-set | | | | `pod` _[Pod](#pod)_ | Configuration for the proxy Pod. | | | diff --git a/k8s-operator/apis/v1alpha1/register.go b/k8s-operator/apis/v1alpha1/register.go index 70b411d120994..0880ac975732e 100644 --- a/k8s-operator/apis/v1alpha1/register.go +++ b/k8s-operator/apis/v1alpha1/register.go @@ -10,6 +10,7 @@ import ( "tailscale.com/k8s-operator/apis" + apiextensionsv1 "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" @@ -39,12 +40,18 @@ func init() { localSchemeBuilder.Register(addKnownTypes) GlobalScheme = runtime.NewScheme() + // Add core types if err := scheme.AddToScheme(GlobalScheme); err != nil { panic(fmt.Sprintf("failed to add k8s.io scheme: %s", err)) } + // Add tailscale.com types if err := AddToScheme(GlobalScheme); err != nil { panic(fmt.Sprintf("failed to add tailscale.com scheme: %s", err)) } + // Add apiextensions types (CustomResourceDefinitions/CustomResourceDefinitionLists) + if err := apiextensionsv1.AddToScheme(GlobalScheme); err != nil { + panic(fmt.Sprintf("failed to add apiextensions.k8s.io scheme: %s", err)) + } } // Adds the list of known types to api.Scheme. diff --git a/k8s-operator/apis/v1alpha1/types_connector.go b/k8s-operator/apis/v1alpha1/types_connector.go index 27afd0838a388..b8b7a935e3a3f 100644 --- a/k8s-operator/apis/v1alpha1/types_connector.go +++ b/k8s-operator/apis/v1alpha1/types_connector.go @@ -22,7 +22,9 @@ var ConnectorKind = "Connector" // +kubebuilder:resource:scope=Cluster,shortName=cn // +kubebuilder:printcolumn:name="SubnetRoutes",type="string",JSONPath=`.status.subnetRoutes`,description="CIDR ranges exposed to tailnet by a subnet router defined via this Connector instance." // +kubebuilder:printcolumn:name="IsExitNode",type="string",JSONPath=`.status.isExitNode`,description="Whether this Connector instance defines an exit node." +// +kubebuilder:printcolumn:name="IsAppConnector",type="string",JSONPath=`.status.isAppConnector`,description="Whether this Connector instance is an app connector." // +kubebuilder:printcolumn:name="Status",type="string",JSONPath=`.status.conditions[?(@.type == "ConnectorReady")].reason`,description="Status of the deployed Connector resources." +// +kubebuilder:printcolumn:name="Age",type="date",JSONPath=".metadata.creationTimestamp" // Connector defines a Tailscale node that will be deployed in the cluster. The // node can be configured to act as a Tailscale subnet router and/or a Tailscale @@ -55,7 +57,8 @@ type ConnectorList struct { } // ConnectorSpec describes a Tailscale node to be deployed in the cluster. -// +kubebuilder:validation:XValidation:rule="has(self.subnetRouter) || self.exitNode == true",message="A Connector needs to be either an exit node or a subnet router, or both." +// +kubebuilder:validation:XValidation:rule="has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true) || has(self.appConnector)",message="A Connector needs to have at least one of exit node, subnet router or app connector configured." +// +kubebuilder:validation:XValidation:rule="!((has(self.subnetRouter) || (has(self.exitNode) && self.exitNode == true)) && has(self.appConnector))",message="The appConnector field is mutually exclusive with exitNode and subnetRouter fields." type ConnectorSpec struct { // Tags that the Tailscale node will be tagged with. // Defaults to [tag:k8s]. @@ -82,13 +85,31 @@ type ConnectorSpec struct { // create resources with the default configuration. // +optional ProxyClass string `json:"proxyClass,omitempty"` - // SubnetRouter defines subnet routes that the Connector node should - // expose to tailnet. If unset, none are exposed. + // SubnetRouter defines subnet routes that the Connector device should + // expose to tailnet as a Tailscale subnet router. // https://tailscale.com/kb/1019/subnets/ + // If this field is unset, the device does not get configured as a Tailscale subnet router. + // This field is mutually exclusive with the appConnector field. // +optional - SubnetRouter *SubnetRouter `json:"subnetRouter"` - // ExitNode defines whether the Connector node should act as a - // Tailscale exit node. Defaults to false. + SubnetRouter *SubnetRouter `json:"subnetRouter,omitempty"` + // AppConnector defines whether the Connector device should act as a Tailscale app connector. A Connector that is + // configured as an app connector cannot be a subnet router or an exit node. If this field is unset, the + // Connector does not act as an app connector. + // Note that you will need to manually configure the permissions and the domains for the app connector via the + // Admin panel. + // Note also that the main tested and supported use case of this config option is to deploy an app connector on + // Kubernetes to access SaaS applications available on the public internet. Using the app connector to expose + // cluster workloads or other internal workloads to tailnet might work, but this is not a use case that we have + // tested or optimised for. + // If you are using the app connector to access SaaS applications because you need a predictable egress IP that + // can be whitelisted, it is also your responsibility to ensure that cluster traffic from the connector flows + // via that predictable IP, for example by enforcing that cluster egress traffic is routed via an egress NAT + // device with a static IP address. + // https://tailscale.com/kb/1281/app-connectors + // +optional + AppConnector *AppConnector `json:"appConnector,omitempty"` + // ExitNode defines whether the Connector device should act as a Tailscale exit node. Defaults to false. + // This field is mutually exclusive with the appConnector field. // https://tailscale.com/kb/1103/exit-nodes // +optional ExitNode bool `json:"exitNode"` @@ -104,6 +125,17 @@ type SubnetRouter struct { AdvertiseRoutes Routes `json:"advertiseRoutes"` } +// AppConnector defines a Tailscale app connector node configured via Connector. +type AppConnector struct { + // Routes are optional preconfigured routes for the domains routed via the app connector. + // If not set, routes for the domains will be discovered dynamically. + // If set, the app connector will immediately be able to route traffic using the preconfigured routes, but may + // also dynamically discover other routes. + // https://tailscale.com/kb/1332/apps-best-practices#preconfiguration + // +optional + Routes Routes `json:"routes"` +} + type Tags []Tag func (tags Tags) Stringify() []string { @@ -156,6 +188,9 @@ type ConnectorStatus struct { // IsExitNode is set to true if the Connector acts as an exit node. // +optional IsExitNode bool `json:"isExitNode"` + // IsAppConnector is set to true if the Connector acts as an app connector. + // +optional + IsAppConnector bool `json:"isAppConnector"` // TailnetIPs is the set of tailnet IP addresses (both IPv4 and IPv6) // assigned to the Connector node. // +optional @@ -187,4 +222,7 @@ const ( // on a ProxyGroup. // Set to true if the service is ready to route cluster traffic. EgressSvcReady ConditionType = `TailscaleEgressSvcReady` + + IngressSvcValid ConditionType = `TailscaleIngressSvcValid` + IngressSvcConfigured ConditionType = `TailscaleIngressSvcConfigured` ) diff --git a/k8s-operator/apis/v1alpha1/types_proxyclass.go b/k8s-operator/apis/v1alpha1/types_proxyclass.go index 7f415bc340bd7..899abf096bb86 100644 --- a/k8s-operator/apis/v1alpha1/types_proxyclass.go +++ b/k8s-operator/apis/v1alpha1/types_proxyclass.go @@ -16,6 +16,7 @@ var ProxyClassKind = "ProxyClass" // +kubebuilder:subresource:status // +kubebuilder:resource:scope=Cluster // +kubebuilder:printcolumn:name="Status",type="string",JSONPath=`.status.conditions[?(@.type == "ProxyClassReady")].reason`,description="Status of the ProxyClass." +// +kubebuilder:printcolumn:name="Age",type="date",JSONPath=".metadata.creationTimestamp" // ProxyClass describes a set of configuration parameters that can be applied to // proxy resources created by the Tailscale Kubernetes operator. @@ -66,6 +67,21 @@ type ProxyClassSpec struct { // parameters of proxies. // +optional TailscaleConfig *TailscaleConfig `json:"tailscale,omitempty"` + // Set UseLetsEncryptStagingEnvironment to true to issue TLS + // certificates for any HTTPS endpoints exposed to the tailnet from + // LetsEncrypt's staging environment. + // https://letsencrypt.org/docs/staging-environment/ + // This setting only affects Tailscale Ingress resources. + // By default Ingress TLS certificates are issued from LetsEncrypt's + // production environment. + // Changing this setting true -> false, will result in any + // existing certs being re-issued from the production environment. + // Changing this setting false (default) -> true, when certs have already + // been provisioned from production environment will NOT result in certs + // being re-issued from the staging environment before they need to be + // renewed. + // +optional + UseLetsEncryptStagingEnvironment bool `json:"useLetsEncryptStagingEnvironment,omitempty"` } type TailscaleConfig struct { @@ -87,7 +103,7 @@ type StatefulSet struct { // Label keys and values must be valid Kubernetes label keys and values. // https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set // +optional - Labels map[string]string `json:"labels,omitempty"` + Labels Labels `json:"labels,omitempty"` // Annotations that will be added to the StatefulSet created for the proxy. // Any Annotations specified here will be merged with the default annotations // applied to the StatefulSet by the Tailscale Kubernetes operator as @@ -109,7 +125,7 @@ type Pod struct { // Label keys and values must be valid Kubernetes label keys and values. // https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set // +optional - Labels map[string]string `json:"labels,omitempty"` + Labels Labels `json:"labels,omitempty"` // Annotations that will be added to the proxy Pod. // Any annotations specified here will be merged with the default // annotations applied to the Pod by the Tailscale Kubernetes operator. @@ -154,16 +170,68 @@ type Pod struct { // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling // +optional Tolerations []corev1.Toleration `json:"tolerations,omitempty"` + // Proxy Pod's topology spread constraints. + // By default Tailscale Kubernetes operator does not apply any topology spread constraints. + // https://kubernetes.io/docs/concepts/scheduling-eviction/topology-spread-constraints/ // +optional + TopologySpreadConstraints []corev1.TopologySpreadConstraint `json:"topologySpreadConstraints,omitempty"` } +// +kubebuilder:validation:XValidation:rule="!(has(self.serviceMonitor) && self.serviceMonitor.enable && !self.enable)",message="ServiceMonitor can only be enabled if metrics are enabled" type Metrics struct { // Setting enable to true will make the proxy serve Tailscale metrics - // at :9001/debug/metrics. + // at :9002/metrics. + // A metrics Service named -metrics will also be created in the operator's namespace and will + // serve the metrics at :9002/metrics. + // + // In 1.78.x and 1.80.x, this field also serves as the default value for + // .spec.statefulSet.pod.tailscaleContainer.debug.enable. From 1.82.0, both + // fields will independently default to false. + // // Defaults to false. Enable bool `json:"enable"` + // Enable to create a Prometheus ServiceMonitor for scraping the proxy's Tailscale metrics. + // The ServiceMonitor will select the metrics Service that gets created when metrics are enabled. + // The ingested metrics for each Service monitor will have labels to identify the proxy: + // ts_proxy_type: ingress_service|ingress_resource|connector|proxygroup + // ts_proxy_parent_name: name of the parent resource (i.e name of the Connector, Tailscale Ingress, Tailscale Service or ProxyGroup) + // ts_proxy_parent_namespace: namespace of the parent resource (if the parent resource is not cluster scoped) + // job: ts__[]_ + // +optional + ServiceMonitor *ServiceMonitor `json:"serviceMonitor"` +} + +type ServiceMonitor struct { + // If Enable is set to true, a Prometheus ServiceMonitor will be created. Enable can only be set to true if metrics are enabled. + Enable bool `json:"enable"` + // Labels to add to the ServiceMonitor. + // Labels must be valid Kubernetes labels. + // https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set + // +optional + Labels Labels `json:"labels"` +} + +type Labels map[string]LabelValue + +func (l Labels) Parse() map[string]string { + if l == nil { + return nil + } + m := make(map[string]string, len(l)) + for k, v := range l { + m[k] = string(v) + } + return m } +// We do not validate the values of the label keys here - it is done by the ProxyClass +// reconciler because the validation rules are too complex for a CRD validation markers regex. + +// +kubebuilder:validation:Type=string +// +kubebuilder:validation:Pattern=`^(([a-zA-Z0-9][-._a-zA-Z0-9]*)?[a-zA-Z0-9])?$` +// +kubebuilder:validation:MaxLength=63 +type LabelValue string + type Container struct { // List of environment variables to set in the container. // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#environment-variables @@ -197,14 +265,35 @@ type Container struct { // +optional Resources corev1.ResourceRequirements `json:"resources,omitempty"` // Container security context. - // Security context specified here will override the security context by the operator. - // By default the operator: - // - sets 'privileged: true' for the init container - // - set NET_ADMIN capability for tailscale container for proxies that - // are created for Services or Connector. + // Security context specified here will override the security context set by the operator. + // By default the operator sets the Tailscale container and the Tailscale init container to privileged + // for proxies created for Tailscale ingress and egress Service, Connector and ProxyGroup. + // You can reduce the permissions of the Tailscale container to cap NET_ADMIN by + // installing device plugin in your cluster and configuring the proxies tun device to be created + // by the device plugin, see https://github.com/tailscale/tailscale/issues/10814#issuecomment-2479977752 // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#security-context // +optional SecurityContext *corev1.SecurityContext `json:"securityContext,omitempty"` + // Configuration for enabling extra debug information in the container. + // Not recommended for production use. + // +optional + Debug *Debug `json:"debug,omitempty"` +} + +type Debug struct { + // Enable tailscaled's HTTP pprof endpoints at :9001/debug/pprof/ + // and internal debug metrics endpoint at :9001/debug/metrics, where + // 9001 is a container port named "debug". The endpoints and their responses + // may change in backwards incompatible ways in the future, and should not + // be considered stable. + // + // In 1.78.x and 1.80.x, this setting will default to the value of + // .spec.metrics.enable, and requests to the "metrics" port matching the + // mux pattern /debug/ will be forwarded to the "debug" port. In 1.82.x, + // this setting will default to false, and no requests will be proxied. + // + // +optional + Enable bool `json:"enable"` } type Env struct { diff --git a/k8s-operator/apis/v1alpha1/types_proxygroup.go b/k8s-operator/apis/v1alpha1/types_proxygroup.go index 7e5515ba9d66c..ac87cc6caf892 100644 --- a/k8s-operator/apis/v1alpha1/types_proxygroup.go +++ b/k8s-operator/apis/v1alpha1/types_proxygroup.go @@ -13,7 +13,19 @@ import ( // +kubebuilder:subresource:status // +kubebuilder:resource:scope=Cluster,shortName=pg // +kubebuilder:printcolumn:name="Status",type="string",JSONPath=`.status.conditions[?(@.type == "ProxyGroupReady")].reason`,description="Status of the deployed ProxyGroup resources." - +// +kubebuilder:printcolumn:name="Type",type="string",JSONPath=`.spec.type`,description="ProxyGroup type." +// +kubebuilder:printcolumn:name="Age",type="date",JSONPath=".metadata.creationTimestamp" + +// ProxyGroup defines a set of Tailscale devices that will act as proxies. +// Currently only egress ProxyGroups are supported. +// +// Use the tailscale.com/proxy-group annotation on a Service to specify that +// the egress proxy should be implemented by a ProxyGroup instead of a single +// dedicated proxy. In addition to running a highly available set of proxies, +// ProxyGroup also allows for serving many annotated Services from a single +// set of proxies to minimise resource consumption. +// +// More info: https://tailscale.com/kb/1438/kubernetes-operator-cluster-egress type ProxyGroup struct { metav1.TypeMeta `json:",inline"` metav1.ObjectMeta `json:"metadata,omitempty"` @@ -37,7 +49,9 @@ type ProxyGroupList struct { } type ProxyGroupSpec struct { - // Type of the ProxyGroup proxies. Currently the only supported type is egress. + // Type of the ProxyGroup proxies. Supported types are egress and ingress. + // Type is immutable once a ProxyGroup is created. + // +kubebuilder:validation:XValidation:rule="self == oldSelf",message="ProxyGroup type is immutable" Type ProxyGroupType `json:"type"` // Tags that the Tailscale devices will be tagged with. Defaults to [tag:k8s]. @@ -52,6 +66,7 @@ type ProxyGroupSpec struct { // Replicas specifies how many replicas to create the StatefulSet with. // Defaults to 2. // +optional + // +kubebuilder:validation:Minimum=0 Replicas *int32 `json:"replicas,omitempty"` // HostnamePrefix is the hostname prefix to use for tailnet devices created @@ -99,11 +114,12 @@ type TailnetDevice struct { } // +kubebuilder:validation:Type=string -// +kubebuilder:validation:Enum=egress +// +kubebuilder:validation:Enum=egress;ingress type ProxyGroupType string const ( - ProxyGroupTypeEgress ProxyGroupType = "egress" + ProxyGroupTypeEgress ProxyGroupType = "egress" + ProxyGroupTypeIngress ProxyGroupType = "ingress" ) // +kubebuilder:validation:Type=string diff --git a/k8s-operator/apis/v1alpha1/types_recorder.go b/k8s-operator/apis/v1alpha1/types_recorder.go index 3728154b45170..16a610b26d179 100644 --- a/k8s-operator/apis/v1alpha1/types_recorder.go +++ b/k8s-operator/apis/v1alpha1/types_recorder.go @@ -15,7 +15,13 @@ import ( // +kubebuilder:resource:scope=Cluster,shortName=rec // +kubebuilder:printcolumn:name="Status",type="string",JSONPath=`.status.conditions[?(@.type == "RecorderReady")].reason`,description="Status of the deployed Recorder resources." // +kubebuilder:printcolumn:name="URL",type="string",JSONPath=`.status.devices[?(@.url != "")].url`,description="URL on which the UI is exposed if enabled." +// +kubebuilder:printcolumn:name="Age",type="date",JSONPath=".metadata.creationTimestamp" +// Recorder defines a tsrecorder device for recording SSH sessions. By default, +// it will store recordings in a local ephemeral volume. If you want to persist +// recordings, you can configure an S3-compatible API for storage. +// +// More info: https://tailscale.com/kb/1484/kubernetes-operator-deploying-tsrecorder type Recorder struct { metav1.TypeMeta `json:",inline"` metav1.ObjectMeta `json:"metadata,omitempty"` @@ -136,6 +142,36 @@ type RecorderPod struct { // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#scheduling // +optional Tolerations []corev1.Toleration `json:"tolerations,omitempty"` + + // Config for the ServiceAccount to create for the Recorder's StatefulSet. + // By default, the operator will create a ServiceAccount with the same + // name as the Recorder resource. + // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#service-account + // +optional + ServiceAccount RecorderServiceAccount `json:"serviceAccount,omitempty"` +} + +type RecorderServiceAccount struct { + // Name of the ServiceAccount to create. Defaults to the name of the + // Recorder resource. + // https://kubernetes.io/docs/reference/kubernetes-api/workload-resources/pod-v1/#service-account + // +kubebuilder:validation:Type=string + // +kubebuilder:validation:Pattern=`^[a-z0-9]([a-z0-9-.]{0,61}[a-z0-9])?$` + // +kubebuilder:validation:MaxLength=253 + // +optional + Name string `json:"name,omitempty"` + + // Annotations to add to the ServiceAccount. + // https://kubernetes.io/docs/concepts/overview/working-with-objects/annotations/#syntax-and-character-set + // + // You can use this to add IAM roles to the ServiceAccount (IRSA) instead of + // providing static S3 credentials in a Secret. + // https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html + // + // For example: + // eks.amazonaws.com/role-arn: arn:aws:iam:::role/ + // +optional + Annotations map[string]string `json:"annotations,omitempty"` } type RecorderContainer struct { diff --git a/k8s-operator/apis/v1alpha1/types_tsdnsconfig.go b/k8s-operator/apis/v1alpha1/types_tsdnsconfig.go index 60d212279f4f5..1eef25d6d41fb 100644 --- a/k8s-operator/apis/v1alpha1/types_tsdnsconfig.go +++ b/k8s-operator/apis/v1alpha1/types_tsdnsconfig.go @@ -6,6 +6,7 @@ package v1alpha1 import ( + corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) @@ -18,6 +19,7 @@ var DNSConfigKind = "DNSConfig" // +kubebuilder:subresource:status // +kubebuilder:resource:scope=Cluster,shortName=dc // +kubebuilder:printcolumn:name="NameserverIP",type="string",JSONPath=`.status.nameserver.ip`,description="Service IP address of the nameserver" +// +kubebuilder:printcolumn:name="Age",type="date",JSONPath=".metadata.creationTimestamp" // DNSConfig can be deployed to cluster to make a subset of Tailscale MagicDNS // names resolvable by cluster workloads. Use this if: A) you need to refer to @@ -81,6 +83,17 @@ type Nameserver struct { // Nameserver image. Defaults to tailscale/k8s-nameserver:unstable. // +optional Image *NameserverImage `json:"image,omitempty"` + // Cmd can be used to overwrite the command used when running the nameserver image. + // +optional + Cmd []string `json:"cmd,omitempty"` + // Env can be used to pass environment variables to the nameserver + // container. + // +optional + Env []corev1.EnvVar `json:"env,omitempty"` + // PodLabels are the labels which will be attached to the nameserver + // pod. They can be used to define network policies. + // +optional + PodLabels map[string]string `json:"podLabels,omitempty"` } type NameserverImage struct { diff --git a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go index ba4ff40e46dd5..265894791256c 100644 --- a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go +++ b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go @@ -13,6 +13,26 @@ import ( "k8s.io/apimachinery/pkg/runtime" ) +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *AppConnector) DeepCopyInto(out *AppConnector) { + *out = *in + if in.Routes != nil { + in, out := &in.Routes, &out.Routes + *out = make(Routes, len(*in)) + copy(*out, *in) + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new AppConnector. +func (in *AppConnector) DeepCopy() *AppConnector { + if in == nil { + return nil + } + out := new(AppConnector) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Connector) DeepCopyInto(out *Connector) { *out = *in @@ -85,6 +105,11 @@ func (in *ConnectorSpec) DeepCopyInto(out *ConnectorSpec) { *out = new(SubnetRouter) (*in).DeepCopyInto(*out) } + if in.AppConnector != nil { + in, out := &in.AppConnector, &out.AppConnector + *out = new(AppConnector) + (*in).DeepCopyInto(*out) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ConnectorSpec. @@ -138,6 +163,11 @@ func (in *Container) DeepCopyInto(out *Container) { *out = new(corev1.SecurityContext) (*in).DeepCopyInto(*out) } + if in.Debug != nil { + in, out := &in.Debug, &out.Debug + *out = new(Debug) + **out = **in + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Container. @@ -256,6 +286,21 @@ func (in *DNSConfigStatus) DeepCopy() *DNSConfigStatus { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Debug) DeepCopyInto(out *Debug) { + *out = *in +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Debug. +func (in *Debug) DeepCopy() *Debug { + if in == nil { + return nil + } + out := new(Debug) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Env) DeepCopyInto(out *Env) { *out = *in @@ -271,9 +316,35 @@ func (in *Env) DeepCopy() *Env { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in Labels) DeepCopyInto(out *Labels) { + { + in := &in + *out = make(Labels, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Labels. +func (in Labels) DeepCopy() Labels { + if in == nil { + return nil + } + out := new(Labels) + in.DeepCopyInto(out) + return *out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *Metrics) DeepCopyInto(out *Metrics) { *out = *in + if in.ServiceMonitor != nil { + in, out := &in.ServiceMonitor, &out.ServiceMonitor + *out = new(ServiceMonitor) + (*in).DeepCopyInto(*out) + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Metrics. @@ -294,6 +365,25 @@ func (in *Nameserver) DeepCopyInto(out *Nameserver) { *out = new(NameserverImage) **out = **in } + if in.Cmd != nil { + in, out := &in.Cmd, &out.Cmd + *out = make([]string, len(*in)) + copy(*out, *in) + } + if in.Env != nil { + in, out := &in.Env, &out.Env + *out = make([]corev1.EnvVar, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } + if in.PodLabels != nil { + in, out := &in.PodLabels, &out.PodLabels + *out = make(map[string]string, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Nameserver. @@ -341,7 +431,7 @@ func (in *Pod) DeepCopyInto(out *Pod) { *out = *in if in.Labels != nil { in, out := &in.Labels, &out.Labels - *out = make(map[string]string, len(*in)) + *out = make(Labels, len(*in)) for key, val := range *in { (*out)[key] = val } @@ -392,6 +482,13 @@ func (in *Pod) DeepCopyInto(out *Pod) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + if in.TopologySpreadConstraints != nil { + in, out := &in.TopologySpreadConstraints, &out.TopologySpreadConstraints + *out = make([]corev1.TopologySpreadConstraint, len(*in)) + for i := range *in { + (*in)[i].DeepCopyInto(&(*out)[i]) + } + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Pod. @@ -474,7 +571,7 @@ func (in *ProxyClassSpec) DeepCopyInto(out *ProxyClassSpec) { if in.Metrics != nil { in, out := &in.Metrics, &out.Metrics *out = new(Metrics) - **out = **in + (*in).DeepCopyInto(*out) } if in.TailscaleConfig != nil { in, out := &in.TailscaleConfig, &out.TailscaleConfig @@ -760,6 +857,7 @@ func (in *RecorderPod) DeepCopyInto(out *RecorderPod) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + in.ServiceAccount.DeepCopyInto(&out.ServiceAccount) } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RecorderPod. @@ -772,6 +870,28 @@ func (in *RecorderPod) DeepCopy() *RecorderPod { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *RecorderServiceAccount) DeepCopyInto(out *RecorderServiceAccount) { + *out = *in + if in.Annotations != nil { + in, out := &in.Annotations, &out.Annotations + *out = make(map[string]string, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new RecorderServiceAccount. +func (in *RecorderServiceAccount) DeepCopy() *RecorderServiceAccount { + if in == nil { + return nil + } + out := new(RecorderServiceAccount) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *RecorderSpec) DeepCopyInto(out *RecorderSpec) { *out = *in @@ -939,12 +1059,34 @@ func (in *S3Secret) DeepCopy() *S3Secret { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *ServiceMonitor) DeepCopyInto(out *ServiceMonitor) { + *out = *in + if in.Labels != nil { + in, out := &in.Labels, &out.Labels + *out = make(Labels, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new ServiceMonitor. +func (in *ServiceMonitor) DeepCopy() *ServiceMonitor { + if in == nil { + return nil + } + out := new(ServiceMonitor) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *StatefulSet) DeepCopyInto(out *StatefulSet) { *out = *in if in.Labels != nil { in, out := &in.Labels, &out.Labels - *out = make(map[string]string, len(*in)) + *out = make(Labels, len(*in)) for key, val := range *in { (*out)[key] = val } diff --git a/k8s-operator/conditions.go b/k8s-operator/conditions.go index ace0fb7e33a75..abe8f7f9cc6fa 100644 --- a/k8s-operator/conditions.go +++ b/k8s-operator/conditions.go @@ -75,16 +75,6 @@ func RemoveServiceCondition(svc *corev1.Service, conditionType tsapi.ConditionTy }) } -func EgressServiceIsValidAndConfigured(svc *corev1.Service) bool { - for _, typ := range []tsapi.ConditionType{tsapi.EgressSvcValid, tsapi.EgressSvcConfigured} { - cond := GetServiceCondition(svc, typ) - if cond == nil || cond.Status != metav1.ConditionTrue { - return false - } - } - return true -} - // SetRecorderCondition ensures that Recorder status has a condition with the // given attributes. LastTransitionTime gets set every time condition's status // changes. @@ -167,3 +157,14 @@ func DNSCfgIsReady(cfg *tsapi.DNSConfig) bool { cond := cfg.Status.Conditions[idx] return cond.Status == metav1.ConditionTrue && cond.ObservedGeneration == cfg.Generation } + +func SvcIsReady(svc *corev1.Service) bool { + idx := xslices.IndexFunc(svc.Status.Conditions, func(cond metav1.Condition) bool { + return cond.Type == string(tsapi.ProxyReady) + }) + if idx == -1 { + return false + } + cond := svc.Status.Conditions[idx] + return cond.Status == metav1.ConditionTrue +} diff --git a/k8s-operator/sessionrecording/hijacker.go b/k8s-operator/sessionrecording/hijacker.go index f8ef951d415f0..a9ed658964787 100644 --- a/k8s-operator/sessionrecording/hijacker.go +++ b/k8s-operator/sessionrecording/hijacker.go @@ -25,6 +25,7 @@ import ( "tailscale.com/k8s-operator/sessionrecording/spdy" "tailscale.com/k8s-operator/sessionrecording/tsrecorder" "tailscale.com/k8s-operator/sessionrecording/ws" + "tailscale.com/net/netx" "tailscale.com/sessionrecording" "tailscale.com/tailcfg" "tailscale.com/tsnet" @@ -102,7 +103,7 @@ type Hijacker struct { // connection succeeds. In case of success, returns a list with a single // successful recording attempt and an error channel. If the connection errors // after having been established, an error is sent down the channel. -type RecorderDialFn func(context.Context, []netip.AddrPort, func(context.Context, string, string) (net.Conn, error)) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) +type RecorderDialFn func(context.Context, []netip.AddrPort, netx.DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) // Hijack hijacks a 'kubectl exec' session and configures for the session // contents to be sent to a recorder. diff --git a/k8s-operator/sessionrecording/hijacker_test.go b/k8s-operator/sessionrecording/hijacker_test.go index 440d9c94294c9..880015b22c2d0 100644 --- a/k8s-operator/sessionrecording/hijacker_test.go +++ b/k8s-operator/sessionrecording/hijacker_test.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "net" "net/http" "net/netip" "net/url" @@ -20,6 +19,7 @@ import ( "go.uber.org/zap" "tailscale.com/client/tailscale/apitype" "tailscale.com/k8s-operator/sessionrecording/fakes" + "tailscale.com/net/netx" "tailscale.com/tailcfg" "tailscale.com/tsnet" "tailscale.com/tstest" @@ -80,7 +80,7 @@ func Test_Hijacker(t *testing.T) { h := &Hijacker{ connectToRecorder: func(context.Context, []netip.AddrPort, - func(context.Context, string, string) (net.Conn, error), + netx.DialFunc, ) (wc io.WriteCloser, rec []*tailcfg.SSHRecordingAttempt, _ <-chan error, err error) { if tt.failRecorderConnect { err = errors.New("test") diff --git a/k8s-operator/utils.go b/k8s-operator/utils.go index a1f225fe601c8..420d7e49c7ec2 100644 --- a/k8s-operator/utils.go +++ b/k8s-operator/utils.go @@ -32,9 +32,6 @@ type Records struct { // TailscaledConfigFileName returns a tailscaled config file name in // format expected by containerboot for the given CapVer. func TailscaledConfigFileName(cap tailcfg.CapabilityVersion) string { - if cap < 95 { - return "tailscaled" - } return fmt.Sprintf("cap-%v.hujson", cap) } diff --git a/kube/egressservices/egressservices.go b/kube/egressservices/egressservices.go index 04a1c362b00c4..2515f1bf3a476 100644 --- a/kube/egressservices/egressservices.go +++ b/kube/egressservices/egressservices.go @@ -13,9 +13,15 @@ import ( "net/netip" ) -// KeyEgressServices is name of the proxy state Secret field that contains the -// currently applied egress proxy config. -const KeyEgressServices = "egress-services" +const ( + // KeyEgressServices is name of the proxy state Secret field that contains the + // currently applied egress proxy config. + KeyEgressServices = "egress-services" + + // KeyHEPPings is the number of times an egress service health check endpoint needs to be pinged to ensure that + // each currently configured backend is hit. In practice, it depends on the number of ProxyGroup replicas. + KeyHEPPings = "hep-pings" +) // Configs contains the desired configuration for egress services keyed by // service name. @@ -24,6 +30,7 @@ type Configs map[string]Config // Config is an egress service configuration. // TODO(irbekrm): version this? type Config struct { + HealthCheckEndpoint string `json:"healthCheckEndpoint"` // TailnetTarget is the target to which cluster traffic for this service // should be proxied. TailnetTarget TailnetTarget `json:"tailnetTarget"` diff --git a/kube/egressservices/egressservices_test.go b/kube/egressservices/egressservices_test.go index d6f952ea0a463..806ad91be61cd 100644 --- a/kube/egressservices/egressservices_test.go +++ b/kube/egressservices/egressservices_test.go @@ -55,7 +55,7 @@ func Test_jsonMarshalConfig(t *testing.T) { protocol: "tcp", matchPort: 4003, targetPort: 80, - wantsBs: []byte(`{"tailnetTarget":{"ip":"","fqdn":""},"ports":[{"protocol":"tcp","matchPort":4003,"targetPort":80}]}`), + wantsBs: []byte(`{"healthCheckEndpoint":"","tailnetTarget":{"ip":"","fqdn":""},"ports":[{"protocol":"tcp","matchPort":4003,"targetPort":80}]}`), }, } for _, tt := range tests { diff --git a/kube/ingressservices/ingressservices.go b/kube/ingressservices/ingressservices.go new file mode 100644 index 0000000000000..f79410761af02 --- /dev/null +++ b/kube/ingressservices/ingressservices.go @@ -0,0 +1,53 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package ingressservices contains shared types for exposing Kubernetes Services to tailnet. +// These are split into a separate package for consumption of +// non-Kubernetes shared libraries and binaries. Be mindful of not increasing +// dependency size for those consumers when adding anything new here. +package ingressservices + +import "net/netip" + +// IngressConfigKey is the key at which both the desired ingress firewall +// configuration is stored in the ingress proxies' ConfigMap and at which the +// recorded firewall configuration status is stored in the proxies' state +// Secrets. +const IngressConfigKey = "ingress-config.json" + +// Configs contains the desired configuration for ingress proxies firewall. Map +// keys are Tailscale Service names. +type Configs map[string]Config + +// GetConfig returns the desired configuration for the given Tailscale Service name. +func (cfgs *Configs) GetConfig(name string) *Config { + if cfgs == nil { + return nil + } + if cfg, ok := (*cfgs)[name]; ok { + return &cfg + } + return nil +} + +// Status contains the recorded firewall configuration status for a specific +// ingress proxy Pod. +// Pod IPs are used to identify the ingress proxy Pod. +type Status struct { + Configs Configs `json:"configs,omitempty"` + PodIPv4 string `json:"podIPv4,omitempty"` + PodIPv6 string `json:"podIPv6,omitempty"` +} + +// Config is an ingress service configuration. +type Config struct { + IPv4Mapping *Mapping `json:"IPv4Mapping,omitempty"` + IPv6Mapping *Mapping `json:"IPv6Mapping,omitempty"` +} + +// Mapping describes a rule that forwards traffic from Tailscale Service IP to a +// Kubernetes Service IP. +type Mapping struct { + TailscaleServiceIP netip.Addr `json:"TailscaleServiceIP"` + ClusterIP netip.Addr `json:"ClusterIP"` +} diff --git a/kube/kubeapi/api.go b/kube/kubeapi/api.go index 0e42437a69a2a..e62bd6e2b2eb1 100644 --- a/kube/kubeapi/api.go +++ b/kube/kubeapi/api.go @@ -7,7 +7,9 @@ // dependency size for those consumers when adding anything new here. package kubeapi -import "time" +import ( + "time" +) // Note: The API types are copied from k8s.io/api{,machinery} to not introduce a // module dependency on the Kubernetes API as it pulls in many more dependencies. @@ -151,6 +153,65 @@ type Secret struct { Data map[string][]byte `json:"data,omitempty"` } +// SecretList is a list of Secret objects. +type SecretList struct { + TypeMeta `json:",inline"` + ObjectMeta `json:"metadata"` + + Items []Secret `json:"items,omitempty"` +} + +// Event contains a subset of fields from corev1.Event. +// https://github.com/kubernetes/api/blob/6cc44b8953ae704d6d9ec2adf32e7ae19199ea9f/core/v1/types.go#L7034 +// It is copied here to avoid having to import kube libraries. +type Event struct { + TypeMeta `json:",inline"` + ObjectMeta `json:"metadata"` + Message string `json:"message,omitempty"` + Reason string `json:"reason,omitempty"` + Source EventSource `json:"source,omitempty"` // who is emitting this Event + Type string `json:"type,omitempty"` // Normal or Warning + // InvolvedObject is the subject of the Event. `kubectl describe` will, for most object types, display any + // currently present cluster Events matching the object (but you probably want to set UID for this to work). + InvolvedObject ObjectReference `json:"involvedObject"` + Count int32 `json:"count,omitempty"` // how many times Event was observed + FirstTimestamp time.Time `json:"firstTimestamp,omitempty"` + LastTimestamp time.Time `json:"lastTimestamp,omitempty"` +} + +// EventSource includes a subset of fields from corev1.EventSource. +// https://github.com/kubernetes/api/blob/6cc44b8953ae704d6d9ec2adf32e7ae19199ea9f/core/v1/types.go#L7007 +// It is copied here to avoid having to import kube libraries. +type EventSource struct { + // Component is the name of the component that is emitting the Event. + Component string `json:"component,omitempty"` +} + +// ObjectReference contains a subset of fields from corev1.ObjectReference. +// https://github.com/kubernetes/api/blob/6cc44b8953ae704d6d9ec2adf32e7ae19199ea9f/core/v1/types.go#L6902 +// It is copied here to avoid having to import kube libraries. +type ObjectReference struct { + // Kind of the referent. + // More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds + // +optional + Kind string `json:"kind,omitempty"` + // Namespace of the referent. + // More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/namespaces/ + // +optional + Namespace string `json:"namespace,omitempty"` + // Name of the referent. + // More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names + // +optional + Name string `json:"name,omitempty"` + // UID of the referent. + // More info: https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#uids + // +optional + UID string `json:"uid,omitempty"` + // API version of the referent. + // +optional + APIVersion string `json:"apiVersion,omitempty"` +} + // Status is a return value for calls that don't return other objects. type Status struct { TypeMeta `json:",inline"` @@ -186,6 +247,6 @@ type Status struct { Code int `json:"code,omitempty"` } -func (s *Status) Error() string { +func (s Status) Error() string { return s.Message } diff --git a/kube/kubeclient/client.go b/kube/kubeclient/client.go index e8ddec75d1584..332b21106ecfb 100644 --- a/kube/kubeclient/client.go +++ b/kube/kubeclient/client.go @@ -23,16 +23,21 @@ import ( "net/url" "os" "path/filepath" + "strings" "sync" "time" "tailscale.com/kube/kubeapi" + "tailscale.com/tstime" "tailscale.com/util/multierr" ) const ( saPath = "/var/run/secrets/kubernetes.io/serviceaccount" defaultURL = "https://kubernetes.default.svc" + + TypeSecrets = "secrets" + typeEvents = "events" ) // rootPathForTests is set by tests to override the root path to the @@ -55,10 +60,16 @@ func readFile(n string) ([]byte, error) { // It expects to be run inside a cluster. type Client interface { GetSecret(context.Context, string) (*kubeapi.Secret, error) + ListSecrets(context.Context, map[string]string) (*kubeapi.SecretList, error) UpdateSecret(context.Context, *kubeapi.Secret) error CreateSecret(context.Context, *kubeapi.Secret) error + // Event attempts to ensure an event with the specified options associated with the Pod in which we are + // currently running. This is best effort - if the client is not able to create events, this operation will be a + // no-op. If there is already an Event with the given reason for the current Pod, it will get updated (only + // count and timestamp are expected to change), else a new event will be created. + Event(_ context.Context, typ, reason, msg string) error StrategicMergePatchSecret(context.Context, string, *kubeapi.Secret, string) error - JSONPatchSecret(context.Context, string, []JSONPatch) error + JSONPatchResource(_ context.Context, resourceName string, resourceType string, patches []JSONPatch) error CheckSecretPermissions(context.Context, string) (bool, bool, error) SetDialer(dialer func(context.Context, string, string) (net.Conn, error)) SetURL(string) @@ -66,15 +77,24 @@ type Client interface { type client struct { mu sync.Mutex + name string url string - ns string + podName string + podUID string + ns string // Pod namespace client *http.Client token string tokenExpiry time.Time + cl tstime.Clock + // hasEventsPerms is true if client can emit Events for the Pod in which it runs. If it is set to false any + // calls to Events() will be a no-op. + hasEventsPerms bool + // kubeAPIRequest sends a request to the kube API server. It can set to a fake in tests. + kubeAPIRequest kubeAPIRequestFunc } // New returns a new client -func New() (Client, error) { +func New(name string) (Client, error) { ns, err := readFile("namespace") if err != nil { return nil, err @@ -87,9 +107,11 @@ func New() (Client, error) { if ok := cp.AppendCertsFromPEM(caCert); !ok { return nil, fmt.Errorf("kube: error in creating root cert pool") } - return &client{ - url: defaultURL, - ns: string(ns), + c := &client{ + url: defaultURL, + ns: string(ns), + name: name, + cl: tstime.DefaultClock{}, client: &http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{ @@ -97,7 +119,10 @@ func New() (Client, error) { }, }, }, - }, nil + } + c.kubeAPIRequest = newKubeAPIRequest(c) + c.setEventPerms() + return c, nil } // SetURL sets the URL to use for the Kubernetes API. @@ -115,14 +140,14 @@ func (c *client) SetDialer(dialer func(ctx context.Context, network, addr string func (c *client) expireToken() { c.mu.Lock() defer c.mu.Unlock() - c.tokenExpiry = time.Now() + c.tokenExpiry = c.cl.Now() } func (c *client) getOrRenewToken() (string, error) { c.mu.Lock() defer c.mu.Unlock() tk, te := c.token, c.tokenExpiry - if time.Now().Before(te) { + if c.cl.Now().Before(te) { return tk, nil } @@ -131,17 +156,10 @@ func (c *client) getOrRenewToken() (string, error) { return "", err } c.token = string(tkb) - c.tokenExpiry = time.Now().Add(30 * time.Minute) + c.tokenExpiry = c.cl.Now().Add(30 * time.Minute) return c.token, nil } -func (c *client) secretURL(name string) string { - if name == "" { - return fmt.Sprintf("%s/api/v1/namespaces/%s/secrets", c.url, c.ns) - } - return fmt.Sprintf("%s/api/v1/namespaces/%s/secrets/%s", c.url, c.ns, name) -} - func getError(resp *http.Response) error { if resp.StatusCode == 200 || resp.StatusCode == 201 { // These are the only success codes returned by the Kubernetes API. @@ -161,36 +179,41 @@ func setHeader(key, value string) func(*http.Request) { } } -// doRequest performs an HTTP request to the Kubernetes API. -// If in is not nil, it is expected to be a JSON-encodable object and will be -// sent as the request body. -// If out is not nil, it is expected to be a pointer to an object that can be -// decoded from JSON. -// If the request fails with a 401, the token is expired and a new one is -// requested. -func (c *client) doRequest(ctx context.Context, method, url string, in, out any, opts ...func(*http.Request)) error { - req, err := c.newRequest(ctx, method, url, in) - if err != nil { - return err - } - for _, opt := range opts { - opt(req) - } - resp, err := c.client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if err := getError(resp); err != nil { - if st, ok := err.(*kubeapi.Status); ok && st.Code == 401 { - c.expireToken() +type kubeAPIRequestFunc func(ctx context.Context, method, url string, in, out any, opts ...func(*http.Request)) error + +// newKubeAPIRequest returns a function that can perform an HTTP request to the Kubernetes API. +func newKubeAPIRequest(c *client) kubeAPIRequestFunc { + // If in is not nil, it is expected to be a JSON-encodable object and will be + // sent as the request body. + // If out is not nil, it is expected to be a pointer to an object that can be + // decoded from JSON. + // If the request fails with a 401, the token is expired and a new one is + // requested. + f := func(ctx context.Context, method, url string, in, out any, opts ...func(*http.Request)) error { + req, err := c.newRequest(ctx, method, url, in) + if err != nil { + return err } - return err - } - if out != nil { - return json.NewDecoder(resp.Body).Decode(out) + for _, opt := range opts { + opt(req) + } + resp, err := c.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if err := getError(resp); err != nil { + if st, ok := err.(*kubeapi.Status); ok && st.Code == 401 { + c.expireToken() + } + return err + } + if out != nil { + return json.NewDecoder(resp.Body).Decode(out) + } + return nil } - return nil + return f } func (c *client) newRequest(ctx context.Context, method, url string, in any) (*http.Request, error) { @@ -226,25 +249,39 @@ func (c *client) newRequest(ctx context.Context, method, url string, in any) (*h // GetSecret fetches the secret from the Kubernetes API. func (c *client) GetSecret(ctx context.Context, name string) (*kubeapi.Secret, error) { s := &kubeapi.Secret{Data: make(map[string][]byte)} - if err := c.doRequest(ctx, "GET", c.secretURL(name), nil, s); err != nil { + if err := c.kubeAPIRequest(ctx, "GET", c.resourceURL(name, TypeSecrets, ""), nil, s); err != nil { return nil, err } return s, nil } +// ListSecrets fetches the secret from the Kubernetes API. +func (c *client) ListSecrets(ctx context.Context, selector map[string]string) (*kubeapi.SecretList, error) { + sl := new(kubeapi.SecretList) + s := make([]string, 0, len(selector)) + for key, val := range selector { + s = append(s, key+"="+url.QueryEscape(val)) + } + ss := strings.Join(s, ",") + if err := c.kubeAPIRequest(ctx, "GET", c.resourceURL("", TypeSecrets, ss), nil, sl); err != nil { + return nil, err + } + return sl, nil +} + // CreateSecret creates a secret in the Kubernetes API. func (c *client) CreateSecret(ctx context.Context, s *kubeapi.Secret) error { s.Namespace = c.ns - return c.doRequest(ctx, "POST", c.secretURL(""), s, nil) + return c.kubeAPIRequest(ctx, "POST", c.resourceURL("", TypeSecrets, ""), s, nil) } // UpdateSecret updates a secret in the Kubernetes API. func (c *client) UpdateSecret(ctx context.Context, s *kubeapi.Secret) error { - return c.doRequest(ctx, "PUT", c.secretURL(s.Name), s, nil) + return c.kubeAPIRequest(ctx, "PUT", c.resourceURL(s.Name, TypeSecrets, ""), s, nil) } // JSONPatch is a JSON patch operation. -// It currently (2023-03-02) only supports "add" and "remove" operations. +// It currently (2024-11-15) only supports "add", "remove" and "replace" operations. // // https://tools.ietf.org/html/rfc6902 type JSONPatch struct { @@ -253,22 +290,22 @@ type JSONPatch struct { Value any `json:"value,omitempty"` } -// JSONPatchSecret updates a secret in the Kubernetes API using a JSON patch. -// It currently (2023-03-02) only supports "add" and "remove" operations. -func (c *client) JSONPatchSecret(ctx context.Context, name string, patch []JSONPatch) error { - for _, p := range patch { +// JSONPatchResource updates a resource in the Kubernetes API using a JSON patch. +// It currently (2024-11-15) only supports "add", "remove" and "replace" operations. +func (c *client) JSONPatchResource(ctx context.Context, name, typ string, patches []JSONPatch) error { + for _, p := range patches { if p.Op != "remove" && p.Op != "add" && p.Op != "replace" { return fmt.Errorf("unsupported JSON patch operation: %q", p.Op) } } - return c.doRequest(ctx, "PATCH", c.secretURL(name), patch, nil, setHeader("Content-Type", "application/json-patch+json")) + return c.kubeAPIRequest(ctx, "PATCH", c.resourceURL(name, typ, ""), patches, nil, setHeader("Content-Type", "application/json-patch+json")) } // StrategicMergePatchSecret updates a secret in the Kubernetes API using a // strategic merge patch. // If a fieldManager is provided, it will be used to track the patch. func (c *client) StrategicMergePatchSecret(ctx context.Context, name string, s *kubeapi.Secret, fieldManager string) error { - surl := c.secretURL(name) + surl := c.resourceURL(name, TypeSecrets, "") if fieldManager != "" { uv := url.Values{ "fieldManager": {fieldManager}, @@ -277,7 +314,66 @@ func (c *client) StrategicMergePatchSecret(ctx context.Context, name string, s * } s.Namespace = c.ns s.Name = name - return c.doRequest(ctx, "PATCH", surl, s, nil, setHeader("Content-Type", "application/strategic-merge-patch+json")) + return c.kubeAPIRequest(ctx, "PATCH", surl, s, nil, setHeader("Content-Type", "application/strategic-merge-patch+json")) +} + +// Event tries to ensure an Event associated with the Pod in which we are running. It is best effort - the event will be +// created if the kube client on startup was able to determine the name and UID of this Pod from POD_NAME,POD_UID env +// vars and if permissions check for event creation succeeded. Events are keyed on opts.Reason- if an Event for the +// current Pod with that reason already exists, its count and first timestamp will be updated, else a new Event will be +// created. +func (c *client) Event(ctx context.Context, typ, reason, msg string) error { + if !c.hasEventsPerms { + return nil + } + name := c.nameForEvent(reason) + ev, err := c.getEvent(ctx, name) + now := c.cl.Now() + if err != nil { + if !IsNotFoundErr(err) { + return err + } + // Event not found - create it + ev := kubeapi.Event{ + ObjectMeta: kubeapi.ObjectMeta{ + Name: name, + Namespace: c.ns, + }, + Type: typ, + Reason: reason, + Message: msg, + Source: kubeapi.EventSource{ + Component: c.name, + }, + InvolvedObject: kubeapi.ObjectReference{ + Name: c.podName, + Namespace: c.ns, + UID: c.podUID, + Kind: "Pod", + APIVersion: "v1", + }, + + FirstTimestamp: now, + LastTimestamp: now, + Count: 1, + } + return c.kubeAPIRequest(ctx, "POST", c.resourceURL("", typeEvents, ""), &ev, nil) + } + // If the Event already exists, we patch its count and last timestamp. This ensures that when users run 'kubectl + // describe pod...', they see the event just once (but with a message of how many times it has appeared over + // last timestamp - first timestamp period of time). + count := ev.Count + 1 + countPatch := JSONPatch{ + Op: "replace", + Value: count, + Path: "/count", + } + tsPatch := JSONPatch{ + Op: "replace", + Value: now, + Path: "/lastTimestamp", + } + return c.JSONPatchResource(ctx, name, typeEvents, []JSONPatch{countPatch, tsPatch}) } // CheckSecretPermissions checks the secret access permissions of the current @@ -293,7 +389,7 @@ func (c *client) StrategicMergePatchSecret(ctx context.Context, name string, s * func (c *client) CheckSecretPermissions(ctx context.Context, secretName string) (canPatch, canCreate bool, err error) { var errs []error for _, verb := range []string{"get", "update"} { - ok, err := c.checkPermission(ctx, verb, secretName) + ok, err := c.checkPermission(ctx, verb, TypeSecrets, secretName) if err != nil { log.Printf("error checking %s permission on secret %s: %v", verb, secretName, err) } else if !ok { @@ -303,12 +399,12 @@ func (c *client) CheckSecretPermissions(ctx context.Context, secretName string) if len(errs) > 0 { return false, false, multierr.New(errs...) } - canPatch, err = c.checkPermission(ctx, "patch", secretName) + canPatch, err = c.checkPermission(ctx, "patch", TypeSecrets, secretName) if err != nil { log.Printf("error checking patch permission on secret %s: %v", secretName, err) return false, false, nil } - canCreate, err = c.checkPermission(ctx, "create", secretName) + canCreate, err = c.checkPermission(ctx, "create", TypeSecrets, secretName) if err != nil { log.Printf("error checking create permission on secret %s: %v", secretName, err) return false, false, nil @@ -316,19 +412,64 @@ func (c *client) CheckSecretPermissions(ctx context.Context, secretName string) return canPatch, canCreate, nil } -// checkPermission reports whether the current pod has permission to use the -// given verb (e.g. get, update, patch, create) on secretName. -func (c *client) checkPermission(ctx context.Context, verb, secretName string) (bool, error) { +func IsNotFoundErr(err error) bool { + if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 { + return true + } + return false +} + +// setEventPerms checks whether this client will be able to write tailscaled Events to its Pod and updates the state +// accordingly. If it determines that the client can not write Events, any subsequent calls to client.Event will be a +// no-op. +func (c *client) setEventPerms() { + name := os.Getenv("POD_NAME") + uid := os.Getenv("POD_UID") + hasPerms := false + defer func() { + c.podName = name + c.podUID = uid + c.hasEventsPerms = hasPerms + if !hasPerms { + log.Printf(`kubeclient: this client is not able to write tailscaled Events to the Pod in which it is running. + To help with future debugging you can make it able write Events by giving it get,create,patch permissions for Events in the Pod namespace + and setting POD_NAME, POD_UID env vars for the Pod.`) + } + }() + if name == "" || uid == "" { + return + } + for _, verb := range []string{"get", "create", "patch"} { + can, err := c.checkPermission(context.Background(), verb, typeEvents, "") + if err != nil { + log.Printf("kubeclient: error checking Events permissions: %v", err) + return + } + if !can { + return + } + } + hasPerms = true + return +} + +// checkPermission reports whether the current pod has permission to use the given verb (e.g. get, update, patch, +// create) on the given resource type. If name is not an empty string, will check the check will be for resource with +// the given name only. +func (c *client) checkPermission(ctx context.Context, verb, typ, name string) (bool, error) { + ra := map[string]any{ + "namespace": c.ns, + "verb": verb, + "resource": typ, + } + if name != "" { + ra["name"] = name + } sar := map[string]any{ "apiVersion": "authorization.k8s.io/v1", "kind": "SelfSubjectAccessReview", "spec": map[string]any{ - "resourceAttributes": map[string]any{ - "namespace": c.ns, - "verb": verb, - "resource": "secrets", - "name": secretName, - }, + "resourceAttributes": ra, }, } var res struct { @@ -337,15 +478,36 @@ func (c *client) checkPermission(ctx context.Context, verb, secretName string) ( } `json:"status"` } url := c.url + "/apis/authorization.k8s.io/v1/selfsubjectaccessreviews" - if err := c.doRequest(ctx, "POST", url, sar, &res); err != nil { + if err := c.kubeAPIRequest(ctx, "POST", url, sar, &res); err != nil { return false, err } return res.Status.Allowed, nil } -func IsNotFoundErr(err error) bool { - if st, ok := err.(*kubeapi.Status); ok && st.Code == 404 { - return true +// resourceURL returns a URL that can be used to interact with the given resource type and, if name is not empty string, +// the named resource of that type. +// Note that this only works for core/v1 resource types. +func (c *client) resourceURL(name, typ, sel string) string { + if name == "" { + url := fmt.Sprintf("%s/api/v1/namespaces/%s/%s", c.url, c.ns, typ) + if sel != "" { + url += "?labelSelector=" + sel + } + return url } - return false + return fmt.Sprintf("%s/api/v1/namespaces/%s/%s/%s", c.url, c.ns, typ, name) +} + +// nameForEvent returns a name for the Event that uniquely identifies Event with that reason for the current Pod. +func (c *client) nameForEvent(reason string) string { + return fmt.Sprintf("%s.%s.%s", c.podName, c.podUID, strings.ToLower(reason)) +} + +// getEvent fetches the event from the Kubernetes API. +func (c *client) getEvent(ctx context.Context, name string) (*kubeapi.Event, error) { + e := &kubeapi.Event{} + if err := c.kubeAPIRequest(ctx, "GET", c.resourceURL(name, typeEvents, ""), nil, e); err != nil { + return nil, err + } + return e, nil } diff --git a/kube/kubeclient/client_test.go b/kube/kubeclient/client_test.go new file mode 100644 index 0000000000000..31878befe4106 --- /dev/null +++ b/kube/kubeclient/client_test.go @@ -0,0 +1,151 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package kubeclient + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/kube/kubeapi" + "tailscale.com/tstest" +) + +func Test_client_Event(t *testing.T) { + cl := &tstest.Clock{} + tests := []struct { + name string + typ string + reason string + msg string + argSets []args + wantErr bool + }{ + { + name: "new_event_gets_created", + typ: "Normal", + reason: "TestReason", + msg: "TestMessage", + argSets: []args{ + { // request to GET event returns not found + wantsMethod: "GET", + wantsURL: "test-apiserver/api/v1/namespaces/test-ns/events/test-pod.test-uid.testreason", + setErr: &kubeapi.Status{Code: 404}, + }, + { // sends POST request to create event + wantsMethod: "POST", + wantsURL: "test-apiserver/api/v1/namespaces/test-ns/events", + wantsIn: &kubeapi.Event{ + ObjectMeta: kubeapi.ObjectMeta{ + Name: "test-pod.test-uid.testreason", + Namespace: "test-ns", + }, + Type: "Normal", + Reason: "TestReason", + Message: "TestMessage", + Source: kubeapi.EventSource{ + Component: "test-client", + }, + InvolvedObject: kubeapi.ObjectReference{ + Name: "test-pod", + UID: "test-uid", + Namespace: "test-ns", + APIVersion: "v1", + Kind: "Pod", + }, + FirstTimestamp: cl.Now(), + LastTimestamp: cl.Now(), + Count: 1, + }, + }, + }, + }, + { + name: "existing_event_gets_patched", + typ: "Warning", + reason: "TestReason", + msg: "TestMsg", + argSets: []args{ + { // request to GET event does not error - this is enough to assume that event exists + wantsMethod: "GET", + wantsURL: "test-apiserver/api/v1/namespaces/test-ns/events/test-pod.test-uid.testreason", + setOut: []byte(`{"count":2}`), + }, + { // sends PATCH request to update the event + wantsMethod: "PATCH", + wantsURL: "test-apiserver/api/v1/namespaces/test-ns/events/test-pod.test-uid.testreason", + wantsIn: []JSONPatch{ + {Op: "replace", Path: "/count", Value: int32(3)}, + {Op: "replace", Path: "/lastTimestamp", Value: cl.Now()}, + }, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &client{ + cl: cl, + name: "test-client", + podName: "test-pod", + podUID: "test-uid", + url: "test-apiserver", + ns: "test-ns", + kubeAPIRequest: fakeKubeAPIRequest(t, tt.argSets), + hasEventsPerms: true, + } + if err := c.Event(context.Background(), tt.typ, tt.reason, tt.msg); (err != nil) != tt.wantErr { + t.Errorf("client.Event() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +// args is a set of values for testing a single call to client.kubeAPIRequest. +type args struct { + // wantsMethod is the expected value of 'method' arg. + wantsMethod string + // wantsURL is the expected value of 'url' arg. + wantsURL string + // wantsIn is the expected value of 'in' arg. + wantsIn any + // setOut can be set to a byte slice representing valid JSON. If set 'out' arg will get set to the unmarshalled + // JSON object. + setOut []byte + // setErr is the error that kubeAPIRequest will return. + setErr error +} + +// fakeKubeAPIRequest can be used to test that a series of calls to client.kubeAPIRequest gets called with expected +// values and to set these calls to return preconfigured values. 'argSets' should be set to a slice of expected +// arguments and should-be return values of a series of kubeAPIRequest calls. +func fakeKubeAPIRequest(t *testing.T, argSets []args) kubeAPIRequestFunc { + count := 0 + f := func(ctx context.Context, gotMethod, gotUrl string, gotIn, gotOut any, opts ...func(*http.Request)) error { + t.Helper() + if count >= len(argSets) { + t.Fatalf("unexpected call to client.kubeAPIRequest, expected %d calls, but got a %dth call", len(argSets), count+1) + } + a := argSets[count] + if gotMethod != a.wantsMethod { + t.Errorf("[%d] got method %q, wants method %q", count, gotMethod, a.wantsMethod) + } + if gotUrl != a.wantsURL { + t.Errorf("[%d] got URL %q, wants URL %q", count, gotUrl, a.wantsURL) + } + if d := cmp.Diff(gotIn, a.wantsIn); d != "" { + t.Errorf("[%d] unexpected payload (-want + got):\n%s", count, d) + } + if len(a.setOut) != 0 { + if err := json.Unmarshal(a.setOut, gotOut); err != nil { + t.Fatalf("[%d] error unmarshalling output: %v", count, err) + } + } + count++ + return a.setErr + } + return f +} diff --git a/kube/kubeclient/fake_client.go b/kube/kubeclient/fake_client.go index 3cef3d27ee0df..c21dc2bf89e61 100644 --- a/kube/kubeclient/fake_client.go +++ b/kube/kubeclient/fake_client.go @@ -15,6 +15,10 @@ var _ Client = &FakeClient{} type FakeClient struct { GetSecretImpl func(context.Context, string) (*kubeapi.Secret, error) CheckSecretPermissionsImpl func(ctx context.Context, name string) (bool, bool, error) + CreateSecretImpl func(context.Context, *kubeapi.Secret) error + UpdateSecretImpl func(context.Context, *kubeapi.Secret) error + JSONPatchResourceImpl func(context.Context, string, string, []JSONPatch) error + ListSecretsImpl func(context.Context, map[string]string) (*kubeapi.SecretList, error) } func (fc *FakeClient) CheckSecretPermissions(ctx context.Context, name string) (bool, bool, error) { @@ -29,8 +33,22 @@ func (fc *FakeClient) SetDialer(dialer func(ctx context.Context, network, addr s func (fc *FakeClient) StrategicMergePatchSecret(context.Context, string, *kubeapi.Secret, string) error { return nil } -func (fc *FakeClient) JSONPatchSecret(context.Context, string, []JSONPatch) error { +func (fc *FakeClient) Event(context.Context, string, string, string) error { return nil } -func (fc *FakeClient) UpdateSecret(context.Context, *kubeapi.Secret) error { return nil } -func (fc *FakeClient) CreateSecret(context.Context, *kubeapi.Secret) error { return nil } + +func (fc *FakeClient) JSONPatchResource(ctx context.Context, resource, name string, patches []JSONPatch) error { + return fc.JSONPatchResourceImpl(ctx, resource, name, patches) +} +func (fc *FakeClient) UpdateSecret(ctx context.Context, secret *kubeapi.Secret) error { + return fc.UpdateSecretImpl(ctx, secret) +} +func (fc *FakeClient) CreateSecret(ctx context.Context, secret *kubeapi.Secret) error { + return fc.CreateSecretImpl(ctx, secret) +} +func (fc *FakeClient) ListSecrets(ctx context.Context, selector map[string]string) (*kubeapi.SecretList, error) { + if fc.ListSecretsImpl != nil { + return fc.ListSecretsImpl(ctx, selector) + } + return nil, nil +} diff --git a/kube/kubetypes/metrics.go b/kube/kubetypes/metrics.go deleted file mode 100644 index b183f1f6f79f7..0000000000000 --- a/kube/kubetypes/metrics.go +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package kubetypes - -const ( - // Hostinfo App values for the Tailscale Kubernetes Operator components. - AppOperator = "k8s-operator" - AppAPIServerProxy = "k8s-operator-proxy" - AppIngressProxy = "k8s-operator-ingress-proxy" - AppIngressResource = "k8s-operator-ingress-resource" - AppEgressProxy = "k8s-operator-egress-proxy" - AppConnector = "k8s-operator-connector-resource" - - // Clientmetrics for Tailscale Kubernetes Operator components - MetricIngressProxyCount = "k8s_ingress_proxies" // L3 - MetricIngressResourceCount = "k8s_ingress_resources" // L7 - MetricEgressProxyCount = "k8s_egress_proxies" - MetricConnectorResourceCount = "k8s_connector_resources" - MetricConnectorWithSubnetRouterCount = "k8s_connector_subnetrouter_resources" - MetricConnectorWithExitNodeCount = "k8s_connector_exitnode_resources" - MetricNameserverCount = "k8s_nameserver_resources" - MetricRecorderCount = "k8s_recorder_resources" - MetricEgressServiceCount = "k8s_egress_service_resources" - MetricProxyGroupCount = "k8s_proxygroup_resources" -) diff --git a/kube/kubetypes/types.go b/kube/kubetypes/types.go new file mode 100644 index 0000000000000..6f96875dddd0f --- /dev/null +++ b/kube/kubetypes/types.go @@ -0,0 +1,55 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package kubetypes + +const ( + // Hostinfo App values for the Tailscale Kubernetes Operator components. + AppOperator = "k8s-operator" + AppAPIServerProxy = "k8s-operator-proxy" + AppIngressProxy = "k8s-operator-ingress-proxy" + AppIngressResource = "k8s-operator-ingress-resource" + AppEgressProxy = "k8s-operator-egress-proxy" + AppConnector = "k8s-operator-connector-resource" + AppProxyGroupEgress = "k8s-operator-proxygroup-egress" + AppProxyGroupIngress = "k8s-operator-proxygroup-ingress" + + // Clientmetrics for Tailscale Kubernetes Operator components + MetricIngressProxyCount = "k8s_ingress_proxies" // L3 + MetricIngressResourceCount = "k8s_ingress_resources" // L7 + MetricIngressPGResourceCount = "k8s_ingress_pg_resources" // L7 on ProxyGroup + MetricServicePGResourceCount = "k8s_service_pg_resources" // L3 on ProxyGroup + MetricEgressProxyCount = "k8s_egress_proxies" + MetricConnectorResourceCount = "k8s_connector_resources" + MetricConnectorWithSubnetRouterCount = "k8s_connector_subnetrouter_resources" + MetricConnectorWithExitNodeCount = "k8s_connector_exitnode_resources" + MetricConnectorWithAppConnectorCount = "k8s_connector_appconnector_resources" + MetricNameserverCount = "k8s_nameserver_resources" + MetricRecorderCount = "k8s_recorder_resources" + MetricEgressServiceCount = "k8s_egress_service_resources" + MetricProxyGroupEgressCount = "k8s_proxygroup_egress_resources" + MetricProxyGroupIngressCount = "k8s_proxygroup_ingress_resources" + + // Keys that containerboot writes to state file that can be used to determine its state. + // fields set in Tailscale state Secret. These are mostly used by the Tailscale Kubernetes operator to determine + // the state of this tailscale device. + KeyDeviceID string = "device_id" // node stable ID of the device + KeyDeviceFQDN string = "device_fqdn" // device's tailnet hostname + KeyDeviceIPs string = "device_ips" // device's tailnet IPs + KeyPodUID string = "pod_uid" // Pod UID + // KeyCapVer contains Tailscale capability version of this proxy instance. + KeyCapVer string = "tailscale_capver" + // KeyHTTPSEndpoint is a name of a field that can be set to the value of any HTTPS endpoint currently exposed by + // this device to the tailnet. This is used by the Kubernetes operator Ingress proxy to communicate to the operator + // that cluster workloads behind the Ingress can now be accessed via the given DNS name over HTTPS. + KeyHTTPSEndpoint string = "https_endpoint" + ValueNoHTTPS string = "no-https" + + // Pod's IPv4 address header key as returned by containerboot health check endpoint. + PodIPv4Header string = "Pod-IPv4" + + EgessServicesPreshutdownEP = "/internal-egress-services-preshutdown" + + LabelManaged = "tailscale.com/managed" + LabelSecretType = "tailscale.com/secret-type" // "config", "state" "certs" +) diff --git a/licenses/README.md b/licenses/README.md new file mode 100644 index 0000000000000..46fe8b77f050c --- /dev/null +++ b/licenses/README.md @@ -0,0 +1,35 @@ +# Licenses + +This directory contains a list of dependencies, and their licenses, that are included in the Tailscale clients. +These lists are generated using the [go-licenses] tool to analyze all Go packages in the Tailscale binaries, +as well as a set of custom output templates that includes any additional non-Go dependencies. +For example, the clients for macOS and iOS include some additional Swift libraries. + +These lists are updated roughly every week, so it is possible to see the dependencies in a given release by looking at the release tag. +For example, the dependences for the 1.80.0 release of the macOS client can be seen at +. + +[go-licenses]: https://github.com/google/go-licenses + +## Other formats + +The go-licenses tool can output other formats like CSV, but that wouldn't include the non-Go dependencies. +We can generate a CSV file if that's really needed by running a regex over the markdown files: + +```sh +cat apple.md | grep "^ -" | sed -E "s/- \[(.*)\]\(.*?\) \(\[(.*)\]\((.*)\)\)/\1,\2,\3/" +``` + +## Reviewer instructions + +The majority of changes in this directory are from updating dependency versions. +In that case, only the URL for the license file will change to reflect the new version. +Occasionally, a dependency is added or removed, or the import path is changed. + +New dependencies require the closest review to ensure the license is acceptable. +Because we generate the license reports **after** dependencies are changed, +the new dependency would have already gone through one review when it was initially added. +This is just a secondary review to double-check the license. If in doubt, ask legal. + +Always do a normal GitHub code review on the license PR with a brief summary of what changed. +For example, see #13936 or #14064. Then approve and merge the PR. diff --git a/licenses/android.md b/licenses/android.md index 94aeb3fc0615f..37961b74c44fe 100644 --- a/licenses/android.md +++ b/licenses/android.md @@ -9,77 +9,72 @@ Client][]. See also the dependencies in the [Tailscale CLI][]. - [filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519) ([BSD-3-Clause](https://github.com/FiloSottile/edwards25519/blob/v1.1.0/LICENSE)) - - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.24.1/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.26.5/config/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.16.16/credentials/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.14.11/feature/ec2/imds/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.2.10/internal/configsources/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.5.10/internal/endpoints/v2/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.7.2/internal/ini/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.24.1/internal/sync/singleflight/LICENSE)) - - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.10.4/service/internal/accept-encoding/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.10.10/service/internal/presigned-url/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.36.0/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.29.5/config/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.17.58/credentials/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.16.27/feature/ec2/imds/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.3.31/internal/configsources/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.6.31/internal/endpoints/v2/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.8.2/internal/ini/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.36.0/internal/sync/singleflight/LICENSE)) + - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.12.2/service/internal/accept-encoding/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.12.12/service/internal/presigned-url/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/ssm](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssm) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssm/v1.44.7/service/ssm/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.18.7/service/sso/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.21.7/service/ssooidc/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.26.7/service/sts/LICENSE.txt)) - - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.19.0/LICENSE)) - - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.19.0/internal/sync/singleflight/LICENSE)) - - [github.com/bits-and-blooms/bitset](https://pkg.go.dev/github.com/bits-and-blooms/bitset) ([BSD-3-Clause](https://github.com/bits-and-blooms/bitset/blob/v1.13.0/LICENSE)) - - [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) ([ISC](https://github.com/coder/websocket/blob/v1.8.12/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.24.14/service/sso/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.28.13/service/ssooidc/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.33.13/service/sts/LICENSE.txt)) + - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.22.2/LICENSE)) + - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.22.2/internal/sync/singleflight/LICENSE)) - [github.com/coreos/go-iptables/iptables](https://pkg.go.dev/github.com/coreos/go-iptables/iptables) ([Apache-2.0](https://github.com/coreos/go-iptables/blob/65c67c9f46e6/LICENSE)) - [github.com/djherbis/times](https://pkg.go.dev/github.com/djherbis/times) ([MIT](https://github.com/djherbis/times/blob/v1.6.0/LICENSE)) - - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.6.0/LICENSE)) - - [github.com/gaissmai/bart](https://pkg.go.dev/github.com/gaissmai/bart) ([MIT](https://github.com/gaissmai/bart/blob/v0.11.1/LICENSE)) - - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/2e55bd4e08b0/LICENSE)) + - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.7.0/LICENSE)) + - [github.com/gaissmai/bart](https://pkg.go.dev/github.com/gaissmai/bart) ([MIT](https://github.com/gaissmai/bart/blob/v0.18.0/LICENSE)) + - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/d3c622f1b874/LICENSE)) - [github.com/godbus/dbus/v5](https://pkg.go.dev/github.com/godbus/dbus/v5) ([BSD-2-Clause](https://github.com/godbus/dbus/blob/76236955d466/LICENSE)) - [github.com/golang/groupcache/lru](https://pkg.go.dev/github.com/golang/groupcache/lru) ([Apache-2.0](https://github.com/golang/groupcache/blob/41bb18bfe9da/LICENSE)) - [github.com/google/btree](https://pkg.go.dev/github.com/google/btree) ([Apache-2.0](https://github.com/google/btree/blob/v1.1.2/LICENSE)) - [github.com/google/nftables](https://pkg.go.dev/github.com/google/nftables) ([Apache-2.0](https://github.com/google/nftables/blob/5e242ec57806/LICENSE)) - [github.com/hdevalence/ed25519consensus](https://pkg.go.dev/github.com/hdevalence/ed25519consensus) ([BSD-3-Clause](https://github.com/hdevalence/ed25519consensus/blob/v0.2.0/LICENSE)) - - [github.com/illarion/gonotify/v2](https://pkg.go.dev/github.com/illarion/gonotify/v2) ([MIT](https://github.com/illarion/gonotify/blob/v2.0.3/LICENSE)) + - [github.com/illarion/gonotify/v3](https://pkg.go.dev/github.com/illarion/gonotify/v3) ([MIT](https://github.com/illarion/gonotify/blob/v3.0.2/LICENSE)) - [github.com/insomniacslk/dhcp](https://pkg.go.dev/github.com/insomniacslk/dhcp) ([BSD-3-Clause](https://github.com/insomniacslk/dhcp/blob/8c70d406f6d2/LICENSE)) - [github.com/jellydator/ttlcache/v3](https://pkg.go.dev/github.com/jellydator/ttlcache/v3) ([MIT](https://github.com/jellydator/ttlcache/blob/v3.1.0/LICENSE)) - [github.com/jmespath/go-jmespath](https://pkg.go.dev/github.com/jmespath/go-jmespath) ([Apache-2.0](https://github.com/jmespath/go-jmespath/blob/v0.4.0/LICENSE)) - - [github.com/josharian/native](https://pkg.go.dev/github.com/josharian/native) ([MIT](https://github.com/josharian/native/blob/5c7d0dd6ab86/license)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.4/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.4/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.4/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.11/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.11/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.11/zstd/internal/xxhash/LICENSE.txt)) - [github.com/kortschak/wol](https://pkg.go.dev/github.com/kortschak/wol) ([BSD-3-Clause](https://github.com/kortschak/wol/blob/da482cc4850a/LICENSE)) - [github.com/mdlayher/genetlink](https://pkg.go.dev/github.com/mdlayher/genetlink) ([MIT](https://github.com/mdlayher/genetlink/blob/v1.3.2/LICENSE.md)) - - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/v1.7.2/LICENSE.md)) + - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/fbb4dce95f42/LICENSE.md)) - [github.com/mdlayher/sdnotify](https://pkg.go.dev/github.com/mdlayher/sdnotify) ([MIT](https://github.com/mdlayher/sdnotify/blob/v1.0.0/LICENSE.md)) - [github.com/mdlayher/socket](https://pkg.go.dev/github.com/mdlayher/socket) ([MIT](https://github.com/mdlayher/socket/blob/v0.5.0/LICENSE.md)) - [github.com/miekg/dns](https://pkg.go.dev/github.com/miekg/dns) ([BSD-3-Clause](https://github.com/miekg/dns/blob/v1.1.58/LICENSE)) - [github.com/mitchellh/go-ps](https://pkg.go.dev/github.com/mitchellh/go-ps) ([MIT](https://github.com/mitchellh/go-ps/blob/v1.0.0/LICENSE.md)) - [github.com/pierrec/lz4/v4](https://pkg.go.dev/github.com/pierrec/lz4/v4) ([BSD-3-Clause](https://github.com/pierrec/lz4/blob/v4.1.21/LICENSE)) - [github.com/safchain/ethtool](https://pkg.go.dev/github.com/safchain/ethtool) ([Apache-2.0](https://github.com/safchain/ethtool/blob/v0.3.0/LICENSE)) - - [github.com/tailscale/golang-x-crypto](https://pkg.go.dev/github.com/tailscale/golang-x-crypto) ([BSD-3-Clause](https://github.com/tailscale/golang-x-crypto/blob/3fde5e568aa4/LICENSE)) - [github.com/tailscale/goupnp](https://pkg.go.dev/github.com/tailscale/goupnp) ([BSD-2-Clause](https://github.com/tailscale/goupnp/blob/c64d0f06ea05/LICENSE)) - [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/4d49adab4de7/LICENSE)) - - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/b535050b2aa4/LICENSE)) + - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/35a0c7bd7edc/LICENSE)) - [github.com/tailscale/tailscale-android/libtailscale](https://pkg.go.dev/github.com/tailscale/tailscale-android/libtailscale) ([BSD-3-Clause](https://github.com/tailscale/tailscale-android/blob/HEAD/LICENSE)) - - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/799c1978fafc/LICENSE)) + - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/0b8b35511f19/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - - [github.com/tcnksm/go-httpstat](https://pkg.go.dev/github.com/tcnksm/go-httpstat) ([MIT](https://github.com/tcnksm/go-httpstat/blob/v0.2.0/LICENSE)) - - [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/a3c409a6018e/LICENSE)) + - [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/d2acac8f3701/LICENSE)) - [github.com/vishvananda/netns](https://pkg.go.dev/github.com/vishvananda/netns) ([Apache-2.0](https://github.com/vishvananda/netns/blob/v0.0.4/LICENSE)) - [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE)) - [go4.org/intern](https://pkg.go.dev/go4.org/intern) ([BSD-3-Clause](https://github.com/go4org/intern/blob/ae77deb06f29/LICENSE)) - - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/4f986261bf13/LICENSE)) + - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - [go4.org/unsafe/assume-no-moving-gc](https://pkg.go.dev/go4.org/unsafe/assume-no-moving-gc) ([BSD-3-Clause](https://github.com/go4org/unsafe-assume-no-moving-gc/blob/e7c30c78aeb2/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.26.0:LICENSE)) - - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/1b970713:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.35.0:LICENSE)) + - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/939b2ce7:LICENSE)) - [golang.org/x/mobile](https://pkg.go.dev/golang.org/x/mobile) ([BSD-3-Clause](https://cs.opensource.google/go/x/mobile/+/81131f64:LICENSE)) - - [golang.org/x/mod/semver](https://pkg.go.dev/golang.org/x/mod/semver) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.20.0:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.28.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.8.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.23.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.23.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.17.0:LICENSE)) - - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.5.0:LICENSE)) - - [golang.org/x/tools](https://pkg.go.dev/golang.org/x/tools) ([BSD-3-Clause](https://cs.opensource.google/go/x/tools/+/v0.24.0:LICENSE)) - - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/64c016c92987/LICENSE)) + - [golang.org/x/mod/semver](https://pkg.go.dev/golang.org/x/mod/semver) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.23.0:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.36.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.11.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.30.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.29.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.22.0:LICENSE)) + - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.10.0:LICENSE)) + - [golang.org/x/tools](https://pkg.go.dev/golang.org/x/tools) ([BSD-3-Clause](https://cs.opensource.google/go/x/tools/+/v0.30.0:LICENSE)) + - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/9414b50a5633/LICENSE)) - [inet.af/netaddr](https://pkg.go.dev/inet.af/netaddr) ([BSD-3-Clause](Unknown)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) diff --git a/licenses/apple.md b/licenses/apple.md index 751082d5b220f..5a017076e9f38 100644 --- a/licenses/apple.md +++ b/licenses/apple.md @@ -12,48 +12,45 @@ See also the dependencies in the [Tailscale CLI][]. - [filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519) ([BSD-3-Clause](https://github.com/FiloSottile/edwards25519/blob/v1.1.0/LICENSE)) - - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.30.4/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.27.28/config/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.17.28/credentials/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.16.12/feature/ec2/imds/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.3.16/internal/configsources/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.6.16/internal/endpoints/v2/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.8.1/internal/ini/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.30.4/internal/sync/singleflight/LICENSE)) - - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.11.4/service/internal/accept-encoding/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.11.18/service/internal/presigned-url/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.36.0/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.29.5/config/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.17.58/credentials/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.16.27/feature/ec2/imds/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.3.31/internal/configsources/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.6.31/internal/endpoints/v2/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.8.2/internal/ini/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.36.0/internal/sync/singleflight/LICENSE)) + - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.12.2/service/internal/accept-encoding/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.12.12/service/internal/presigned-url/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/ssm](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssm) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssm/v1.45.0/service/ssm/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.22.5/service/sso/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.26.5/service/ssooidc/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.30.4/service/sts/LICENSE.txt)) - - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.20.4/LICENSE)) - - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.20.4/internal/sync/singleflight/LICENSE)) - - [github.com/bits-and-blooms/bitset](https://pkg.go.dev/github.com/bits-and-blooms/bitset) ([BSD-3-Clause](https://github.com/bits-and-blooms/bitset/blob/v1.13.0/LICENSE)) - - [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) ([ISC](https://github.com/coder/websocket/blob/v1.8.12/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.24.14/service/sso/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.28.13/service/ssooidc/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.33.13/service/sts/LICENSE.txt)) + - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.22.2/LICENSE)) + - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.22.2/internal/sync/singleflight/LICENSE)) - [github.com/coreos/go-iptables/iptables](https://pkg.go.dev/github.com/coreos/go-iptables/iptables) ([Apache-2.0](https://github.com/coreos/go-iptables/blob/65c67c9f46e6/LICENSE)) - [github.com/digitalocean/go-smbios/smbios](https://pkg.go.dev/github.com/digitalocean/go-smbios/smbios) ([Apache-2.0](https://github.com/digitalocean/go-smbios/blob/390a4f403a8e/LICENSE.md)) - [github.com/djherbis/times](https://pkg.go.dev/github.com/djherbis/times) ([MIT](https://github.com/djherbis/times/blob/v1.6.0/LICENSE)) - - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.6.0/LICENSE)) - - [github.com/gaissmai/bart](https://pkg.go.dev/github.com/gaissmai/bart) ([MIT](https://github.com/gaissmai/bart/blob/v0.11.1/LICENSE)) - - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/2e55bd4e08b0/LICENSE)) + - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.7.0/LICENSE)) + - [github.com/gaissmai/bart](https://pkg.go.dev/github.com/gaissmai/bart) ([MIT](https://github.com/gaissmai/bart/blob/v0.18.0/LICENSE)) + - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/d3c622f1b874/LICENSE)) - [github.com/godbus/dbus/v5](https://pkg.go.dev/github.com/godbus/dbus/v5) ([BSD-2-Clause](https://github.com/godbus/dbus/blob/76236955d466/LICENSE)) - [github.com/golang/groupcache/lru](https://pkg.go.dev/github.com/golang/groupcache/lru) ([Apache-2.0](https://github.com/golang/groupcache/blob/41bb18bfe9da/LICENSE)) - [github.com/google/btree](https://pkg.go.dev/github.com/google/btree) ([Apache-2.0](https://github.com/google/btree/blob/v1.1.2/LICENSE)) - [github.com/google/nftables](https://pkg.go.dev/github.com/google/nftables) ([Apache-2.0](https://github.com/google/nftables/blob/5e242ec57806/LICENSE)) - [github.com/google/uuid](https://pkg.go.dev/github.com/google/uuid) ([BSD-3-Clause](https://github.com/google/uuid/blob/v1.6.0/LICENSE)) - [github.com/hdevalence/ed25519consensus](https://pkg.go.dev/github.com/hdevalence/ed25519consensus) ([BSD-3-Clause](https://github.com/hdevalence/ed25519consensus/blob/v0.2.0/LICENSE)) - - [github.com/illarion/gonotify/v2](https://pkg.go.dev/github.com/illarion/gonotify/v2) ([MIT](https://github.com/illarion/gonotify/blob/v2.0.3/LICENSE)) + - [github.com/illarion/gonotify/v3](https://pkg.go.dev/github.com/illarion/gonotify/v3) ([MIT](https://github.com/illarion/gonotify/blob/v3.0.2/LICENSE)) - [github.com/insomniacslk/dhcp](https://pkg.go.dev/github.com/insomniacslk/dhcp) ([BSD-3-Clause](https://github.com/insomniacslk/dhcp/blob/15c9b8791914/LICENSE)) - [github.com/jellydator/ttlcache/v3](https://pkg.go.dev/github.com/jellydator/ttlcache/v3) ([MIT](https://github.com/jellydator/ttlcache/blob/v3.1.0/LICENSE)) - [github.com/jmespath/go-jmespath](https://pkg.go.dev/github.com/jmespath/go-jmespath) ([Apache-2.0](https://github.com/jmespath/go-jmespath/blob/v0.4.0/LICENSE)) - - [github.com/josharian/native](https://pkg.go.dev/github.com/josharian/native) ([MIT](https://github.com/josharian/native/blob/5c7d0dd6ab86/license)) - [github.com/jsimonetti/rtnetlink](https://pkg.go.dev/github.com/jsimonetti/rtnetlink) ([MIT](https://github.com/jsimonetti/rtnetlink/blob/v1.4.1/LICENSE.md)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.8/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.8/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.8/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.0/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.0/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.0/zstd/internal/xxhash/LICENSE.txt)) - [github.com/kortschak/wol](https://pkg.go.dev/github.com/kortschak/wol) ([BSD-3-Clause](https://github.com/kortschak/wol/blob/da482cc4850a/LICENSE)) - [github.com/mdlayher/genetlink](https://pkg.go.dev/github.com/mdlayher/genetlink) ([MIT](https://github.com/mdlayher/genetlink/blob/v1.3.2/LICENSE.md)) - - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/v1.7.2/LICENSE.md)) + - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/fbb4dce95f42/LICENSE.md)) - [github.com/mdlayher/sdnotify](https://pkg.go.dev/github.com/mdlayher/sdnotify) ([MIT](https://github.com/mdlayher/sdnotify/blob/v1.0.0/LICENSE.md)) - [github.com/mdlayher/socket](https://pkg.go.dev/github.com/mdlayher/socket) ([MIT](https://github.com/mdlayher/socket/blob/v0.5.0/LICENSE.md)) - [github.com/miekg/dns](https://pkg.go.dev/github.com/miekg/dns) ([BSD-3-Clause](https://github.com/miekg/dns/blob/v1.1.58/LICENSE)) @@ -61,27 +58,25 @@ See also the dependencies in the [Tailscale CLI][]. - [github.com/pierrec/lz4/v4](https://pkg.go.dev/github.com/pierrec/lz4/v4) ([BSD-3-Clause](https://github.com/pierrec/lz4/blob/v4.1.21/LICENSE)) - [github.com/prometheus-community/pro-bing](https://pkg.go.dev/github.com/prometheus-community/pro-bing) ([MIT](https://github.com/prometheus-community/pro-bing/blob/v0.4.0/LICENSE)) - [github.com/safchain/ethtool](https://pkg.go.dev/github.com/safchain/ethtool) ([Apache-2.0](https://github.com/safchain/ethtool/blob/v0.3.0/LICENSE)) - - [github.com/tailscale/golang-x-crypto](https://pkg.go.dev/github.com/tailscale/golang-x-crypto) ([BSD-3-Clause](https://github.com/tailscale/golang-x-crypto/blob/3fde5e568aa4/LICENSE)) - [github.com/tailscale/goupnp](https://pkg.go.dev/github.com/tailscale/goupnp) ([BSD-2-Clause](https://github.com/tailscale/goupnp/blob/c64d0f06ea05/LICENSE)) - [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/4d49adab4de7/LICENSE)) - - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/b535050b2aa4/LICENSE)) - - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/799c1978fafc/LICENSE)) + - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/35a0c7bd7edc/LICENSE)) + - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/91a0587fb251/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - - [github.com/tcnksm/go-httpstat](https://pkg.go.dev/github.com/tcnksm/go-httpstat) ([MIT](https://github.com/tcnksm/go-httpstat/blob/v0.2.0/LICENSE)) - - [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/a3c409a6018e/LICENSE)) + - [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/d2acac8f3701/LICENSE)) - [github.com/vishvananda/netns](https://pkg.go.dev/github.com/vishvananda/netns) ([Apache-2.0](https://github.com/vishvananda/netns/blob/v0.0.4/LICENSE)) - [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE)) - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE)) - - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fe59bbe5:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.8.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE)) - - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.5.0:LICENSE)) - - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/64c016c92987/LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.35.0:LICENSE)) + - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/939b2ce7:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.36.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.12.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.31.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.30.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.22.0:LICENSE)) + - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.10.0:LICENSE)) + - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/9414b50a5633/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) ## Additional Dependencies diff --git a/licenses/tailscale.md b/licenses/tailscale.md index b1303d2a6dd8e..206734fb41f47 100644 --- a/licenses/tailscale.md +++ b/licenses/tailscale.md @@ -17,56 +17,50 @@ Some packages may only be included on certain architectures or operating systems - [github.com/akutz/memconn](https://pkg.go.dev/github.com/akutz/memconn) ([Apache-2.0](https://github.com/akutz/memconn/blob/v0.1.0/LICENSE)) - [github.com/alexbrainman/sspi](https://pkg.go.dev/github.com/alexbrainman/sspi) ([BSD-3-Clause](https://github.com/alexbrainman/sspi/blob/1a75b4708caa/LICENSE)) - [github.com/anmitsu/go-shlex](https://pkg.go.dev/github.com/anmitsu/go-shlex) ([MIT](https://github.com/anmitsu/go-shlex/blob/38f4b401e2be/LICENSE)) - - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.24.1/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.26.5/config/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.16.16/credentials/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.14.11/feature/ec2/imds/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.2.10/internal/configsources/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.5.10/internal/endpoints/v2/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.7.2/internal/ini/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.24.1/internal/sync/singleflight/LICENSE)) - - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.10.4/service/internal/accept-encoding/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.10.10/service/internal/presigned-url/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.36.0/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.29.5/config/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.17.58/credentials/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.16.27/feature/ec2/imds/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.3.31/internal/configsources/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.6.31/internal/endpoints/v2/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.8.2/internal/ini/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.36.0/internal/sync/singleflight/LICENSE)) + - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.12.2/service/internal/accept-encoding/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.12.12/service/internal/presigned-url/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/ssm](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssm) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssm/v1.44.7/service/ssm/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.18.7/service/sso/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.21.7/service/ssooidc/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.26.7/service/sts/LICENSE.txt)) - - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.19.0/LICENSE)) - - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.19.0/internal/sync/singleflight/LICENSE)) - - [github.com/bits-and-blooms/bitset](https://pkg.go.dev/github.com/bits-and-blooms/bitset) ([BSD-3-Clause](https://github.com/bits-and-blooms/bitset/blob/v1.13.0/LICENSE)) + - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.24.14/service/sso/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.28.13/service/ssooidc/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.33.13/service/sts/LICENSE.txt)) + - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.22.2/LICENSE)) + - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.22.2/internal/sync/singleflight/LICENSE)) - [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) ([ISC](https://github.com/coder/websocket/blob/v1.8.12/LICENSE.txt)) - - [github.com/coreos/go-iptables/iptables](https://pkg.go.dev/github.com/coreos/go-iptables/iptables) ([Apache-2.0](https://github.com/coreos/go-iptables/blob/65c67c9f46e6/LICENSE)) - [github.com/creack/pty](https://pkg.go.dev/github.com/creack/pty) ([MIT](https://github.com/creack/pty/blob/v1.1.23/LICENSE)) - [github.com/dblohm7/wingoes](https://pkg.go.dev/github.com/dblohm7/wingoes) ([BSD-3-Clause](https://github.com/dblohm7/wingoes/blob/a09d6be7affa/LICENSE)) - [github.com/digitalocean/go-smbios/smbios](https://pkg.go.dev/github.com/digitalocean/go-smbios/smbios) ([Apache-2.0](https://github.com/digitalocean/go-smbios/blob/390a4f403a8e/LICENSE.md)) - [github.com/djherbis/times](https://pkg.go.dev/github.com/djherbis/times) ([MIT](https://github.com/djherbis/times/blob/v1.6.0/LICENSE)) - - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.6.0/LICENSE)) - - [github.com/gaissmai/bart](https://pkg.go.dev/github.com/gaissmai/bart) ([MIT](https://github.com/gaissmai/bart/blob/v0.11.1/LICENSE)) - - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/2e55bd4e08b0/LICENSE)) + - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.7.0/LICENSE)) + - [github.com/gaissmai/bart](https://pkg.go.dev/github.com/gaissmai/bart) ([MIT](https://github.com/gaissmai/bart/blob/v0.18.0/LICENSE)) + - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/d3c622f1b874/LICENSE)) - [github.com/go-ole/go-ole](https://pkg.go.dev/github.com/go-ole/go-ole) ([MIT](https://github.com/go-ole/go-ole/blob/v1.3.0/LICENSE)) - [github.com/godbus/dbus/v5](https://pkg.go.dev/github.com/godbus/dbus/v5) ([BSD-2-Clause](https://github.com/godbus/dbus/blob/76236955d466/LICENSE)) - [github.com/golang/groupcache/lru](https://pkg.go.dev/github.com/golang/groupcache/lru) ([Apache-2.0](https://github.com/golang/groupcache/blob/41bb18bfe9da/LICENSE)) - [github.com/google/btree](https://pkg.go.dev/github.com/google/btree) ([Apache-2.0](https://github.com/google/btree/blob/v1.1.2/LICENSE)) - - [github.com/google/nftables](https://pkg.go.dev/github.com/google/nftables) ([Apache-2.0](https://github.com/google/nftables/blob/5e242ec57806/LICENSE)) - [github.com/google/uuid](https://pkg.go.dev/github.com/google/uuid) ([BSD-3-Clause](https://github.com/google/uuid/blob/v1.6.0/LICENSE)) - - [github.com/gorilla/csrf](https://pkg.go.dev/github.com/gorilla/csrf) ([BSD-3-Clause](https://github.com/gorilla/csrf/blob/v1.7.2/LICENSE)) + - [github.com/gorilla/csrf](https://pkg.go.dev/github.com/gorilla/csrf) ([BSD-3-Clause](https://github.com/gorilla/csrf/blob/9dd6af1f6d30/LICENSE)) - [github.com/gorilla/securecookie](https://pkg.go.dev/github.com/gorilla/securecookie) ([BSD-3-Clause](https://github.com/gorilla/securecookie/blob/v1.1.2/LICENSE)) - [github.com/hdevalence/ed25519consensus](https://pkg.go.dev/github.com/hdevalence/ed25519consensus) ([BSD-3-Clause](https://github.com/hdevalence/ed25519consensus/blob/v0.2.0/LICENSE)) - - [github.com/illarion/gonotify/v2](https://pkg.go.dev/github.com/illarion/gonotify/v2) ([MIT](https://github.com/illarion/gonotify/blob/v2.0.3/LICENSE)) + - [github.com/illarion/gonotify/v3](https://pkg.go.dev/github.com/illarion/gonotify/v3) ([MIT](https://github.com/illarion/gonotify/blob/v3.0.2/LICENSE)) - [github.com/insomniacslk/dhcp](https://pkg.go.dev/github.com/insomniacslk/dhcp) ([BSD-3-Clause](https://github.com/insomniacslk/dhcp/blob/8c70d406f6d2/LICENSE)) - [github.com/jellydator/ttlcache/v3](https://pkg.go.dev/github.com/jellydator/ttlcache/v3) ([MIT](https://github.com/jellydator/ttlcache/blob/v3.1.0/LICENSE)) - [github.com/jmespath/go-jmespath](https://pkg.go.dev/github.com/jmespath/go-jmespath) ([Apache-2.0](https://github.com/jmespath/go-jmespath/blob/v0.4.0/LICENSE)) - - [github.com/josharian/native](https://pkg.go.dev/github.com/josharian/native) ([MIT](https://github.com/josharian/native/blob/5c7d0dd6ab86/license)) - [github.com/kballard/go-shellquote](https://pkg.go.dev/github.com/kballard/go-shellquote) ([MIT](https://github.com/kballard/go-shellquote/blob/95032a82bc51/LICENSE)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.4/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.4/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.4/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.11/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.11/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.11/zstd/internal/xxhash/LICENSE.txt)) - [github.com/kortschak/wol](https://pkg.go.dev/github.com/kortschak/wol) ([BSD-3-Clause](https://github.com/kortschak/wol/blob/da482cc4850a/LICENSE)) - [github.com/kr/fs](https://pkg.go.dev/github.com/kr/fs) ([BSD-3-Clause](https://github.com/kr/fs/blob/v0.1.0/LICENSE)) - [github.com/mattn/go-colorable](https://pkg.go.dev/github.com/mattn/go-colorable) ([MIT](https://github.com/mattn/go-colorable/blob/v0.1.13/LICENSE)) - [github.com/mattn/go-isatty](https://pkg.go.dev/github.com/mattn/go-isatty) ([MIT](https://github.com/mattn/go-isatty/blob/v0.0.20/LICENSE)) - - [github.com/mdlayher/genetlink](https://pkg.go.dev/github.com/mdlayher/genetlink) ([MIT](https://github.com/mdlayher/genetlink/blob/v1.3.2/LICENSE.md)) - - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/v1.7.2/LICENSE.md)) - [github.com/mdlayher/sdnotify](https://pkg.go.dev/github.com/mdlayher/sdnotify) ([MIT](https://github.com/mdlayher/sdnotify/blob/v1.0.0/LICENSE.md)) - [github.com/mdlayher/socket](https://pkg.go.dev/github.com/mdlayher/socket) ([MIT](https://github.com/mdlayher/socket/blob/v0.5.0/LICENSE.md)) - [github.com/miekg/dns](https://pkg.go.dev/github.com/miekg/dns) ([BSD-3-Clause](https://github.com/miekg/dns/blob/v1.1.58/LICENSE)) @@ -79,34 +73,30 @@ Some packages may only be included on certain architectures or operating systems - [github.com/skip2/go-qrcode](https://pkg.go.dev/github.com/skip2/go-qrcode) ([MIT](https://github.com/skip2/go-qrcode/blob/da1b6568686e/LICENSE)) - [github.com/tailscale/certstore](https://pkg.go.dev/github.com/tailscale/certstore) ([MIT](https://github.com/tailscale/certstore/blob/d3fa0460f47e/LICENSE.md)) - [github.com/tailscale/go-winio](https://pkg.go.dev/github.com/tailscale/go-winio) ([MIT](https://github.com/tailscale/go-winio/blob/c4f33415bf55/LICENSE)) - - [github.com/tailscale/golang-x-crypto](https://pkg.go.dev/github.com/tailscale/golang-x-crypto) ([BSD-3-Clause](https://github.com/tailscale/golang-x-crypto/blob/3fde5e568aa4/LICENSE)) - - [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/4d49adab4de7/LICENSE)) - - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/b535050b2aa4/LICENSE)) - - [github.com/tailscale/web-client-prebuilt](https://pkg.go.dev/github.com/tailscale/web-client-prebuilt) ([BSD-3-Clause](https://github.com/tailscale/web-client-prebuilt/blob/5db17b287bf1/LICENSE)) + - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/35a0c7bd7edc/LICENSE)) + - [github.com/tailscale/web-client-prebuilt](https://pkg.go.dev/github.com/tailscale/web-client-prebuilt) ([BSD-3-Clause](https://github.com/tailscale/web-client-prebuilt/blob/d4cd19a26976/LICENSE)) - [github.com/tailscale/wf](https://pkg.go.dev/github.com/tailscale/wf) ([BSD-3-Clause](https://github.com/tailscale/wf/blob/6fbb0a674ee6/LICENSE)) - - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/799c1978fafc/LICENSE)) + - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/91a0587fb251/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - - [github.com/tcnksm/go-httpstat](https://pkg.go.dev/github.com/tcnksm/go-httpstat) ([MIT](https://github.com/tcnksm/go-httpstat/blob/v0.2.0/LICENSE)) - [github.com/toqueteos/webbrowser](https://pkg.go.dev/github.com/toqueteos/webbrowser) ([MIT](https://github.com/toqueteos/webbrowser/blob/v1.2.0/LICENSE.md)) - - [github.com/u-root/u-root/pkg/termios](https://pkg.go.dev/github.com/u-root/u-root/pkg/termios) ([BSD-3-Clause](https://github.com/u-root/u-root/blob/v0.12.0/LICENSE)) - - [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/a3c409a6018e/LICENSE)) - - [github.com/vishvananda/netns](https://pkg.go.dev/github.com/vishvananda/netns) ([Apache-2.0](https://github.com/vishvananda/netns/blob/v0.0.4/LICENSE)) + - [github.com/u-root/u-root/pkg/termios](https://pkg.go.dev/github.com/u-root/u-root/pkg/termios) ([BSD-3-Clause](https://github.com/u-root/u-root/blob/v0.14.0/LICENSE)) + - [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/d2acac8f3701/LICENSE)) - [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE)) - - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/4f986261bf13/LICENSE)) + - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE)) - - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/1b970713:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE)) - - [golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2) ([BSD-3-Clause](https://cs.opensource.google/go/x/oauth2/+/v0.16.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.7.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE)) - - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.5.0:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.35.0:LICENSE)) + - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/939b2ce7:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.36.0:LICENSE)) + - [golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2) ([BSD-3-Clause](https://cs.opensource.google/go/x/oauth2/+/v0.26.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.11.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.31.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.29.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.22.0:LICENSE)) + - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.10.0:LICENSE)) - [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=0fa3db229ce2)) - [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3)) - - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/64c016c92987/LICENSE)) - - [k8s.io/client-go/util/homedir](https://pkg.go.dev/k8s.io/client-go/util/homedir) ([Apache-2.0](https://github.com/kubernetes/client-go/blob/v0.30.3/LICENSE)) + - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/9414b50a5633/LICENSE)) + - [k8s.io/client-go/util/homedir](https://pkg.go.dev/k8s.io/client-go/util/homedir) ([Apache-2.0](https://github.com/kubernetes/client-go/blob/v0.32.0/LICENSE)) - [sigs.k8s.io/yaml](https://pkg.go.dev/sigs.k8s.io/yaml) ([Apache-2.0](https://github.com/kubernetes-sigs/yaml/blob/v1.4.0/LICENSE)) - [sigs.k8s.io/yaml/goyaml.v2](https://pkg.go.dev/sigs.k8s.io/yaml/goyaml.v2) ([Apache-2.0](https://github.com/kubernetes-sigs/yaml/blob/v1.4.0/goyaml.v2/LICENSE)) - [software.sslmate.com/src/go-pkcs12](https://pkg.go.dev/software.sslmate.com/src/go-pkcs12) ([BSD-3-Clause](https://github.com/SSLMate/go-pkcs12/blob/v0.4.0/LICENSE)) diff --git a/licenses/windows.md b/licenses/windows.md index 2a8e4e621a4a6..e47bc3227b3f9 100644 --- a/licenses/windows.md +++ b/licenses/windows.md @@ -13,70 +13,78 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/alexbrainman/sspi](https://pkg.go.dev/github.com/alexbrainman/sspi) ([BSD-3-Clause](https://github.com/alexbrainman/sspi/blob/1a75b4708caa/LICENSE)) - [github.com/apenwarr/fixconsole](https://pkg.go.dev/github.com/apenwarr/fixconsole) ([Apache-2.0](https://github.com/apenwarr/fixconsole/blob/5a9f6489cc29/LICENSE)) - [github.com/apenwarr/w32](https://pkg.go.dev/github.com/apenwarr/w32) ([BSD-3-Clause](https://github.com/apenwarr/w32/blob/aa00fece76ab/LICENSE)) - - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.30.4/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.27.28/config/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.17.28/credentials/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.16.12/feature/ec2/imds/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.3.16/internal/configsources/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.6.16/internal/endpoints/v2/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.8.1/internal/ini/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.30.4/internal/sync/singleflight/LICENSE)) - - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.11.4/service/internal/accept-encoding/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.11.18/service/internal/presigned-url/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.36.0/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.29.5/config/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.17.58/credentials/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.16.27/feature/ec2/imds/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.3.31/internal/configsources/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.6.31/internal/endpoints/v2/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.8.2/internal/ini/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.36.0/internal/sync/singleflight/LICENSE)) + - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.12.2/service/internal/accept-encoding/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.12.12/service/internal/presigned-url/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/ssm](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssm) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssm/v1.45.0/service/ssm/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.22.5/service/sso/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.26.5/service/ssooidc/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.30.4/service/sts/LICENSE.txt)) - - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.20.4/LICENSE)) - - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.20.4/internal/sync/singleflight/LICENSE)) + - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.24.14/service/sso/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.28.13/service/ssooidc/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.33.13/service/sts/LICENSE.txt)) + - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.22.2/LICENSE)) + - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.22.2/internal/sync/singleflight/LICENSE)) + - [github.com/beorn7/perks/quantile](https://pkg.go.dev/github.com/beorn7/perks/quantile) ([MIT](https://github.com/beorn7/perks/blob/v1.0.1/LICENSE)) + - [github.com/cespare/xxhash/v2](https://pkg.go.dev/github.com/cespare/xxhash/v2) ([MIT](https://github.com/cespare/xxhash/blob/v2.3.0/LICENSE.txt)) - [github.com/coreos/go-iptables/iptables](https://pkg.go.dev/github.com/coreos/go-iptables/iptables) ([Apache-2.0](https://github.com/coreos/go-iptables/blob/65c67c9f46e6/LICENSE)) - [github.com/dblohm7/wingoes](https://pkg.go.dev/github.com/dblohm7/wingoes) ([BSD-3-Clause](https://github.com/dblohm7/wingoes/blob/b75a8a7d7eb0/LICENSE)) - [github.com/djherbis/times](https://pkg.go.dev/github.com/djherbis/times) ([MIT](https://github.com/djherbis/times/blob/v1.6.0/LICENSE)) - - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.6.0/LICENSE)) - - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/2e55bd4e08b0/LICENSE)) + - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.7.0/LICENSE)) + - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/d3c622f1b874/LICENSE)) - [github.com/golang/groupcache/lru](https://pkg.go.dev/github.com/golang/groupcache/lru) ([Apache-2.0](https://github.com/golang/groupcache/blob/41bb18bfe9da/LICENSE)) - [github.com/google/btree](https://pkg.go.dev/github.com/google/btree) ([Apache-2.0](https://github.com/google/btree/blob/v1.1.2/LICENSE)) + - [github.com/google/go-cmp/cmp](https://pkg.go.dev/github.com/google/go-cmp/cmp) ([BSD-3-Clause](https://github.com/google/go-cmp/blob/v0.7.0/LICENSE)) - [github.com/google/nftables](https://pkg.go.dev/github.com/google/nftables) ([Apache-2.0](https://github.com/google/nftables/blob/5e242ec57806/LICENSE)) - [github.com/google/uuid](https://pkg.go.dev/github.com/google/uuid) ([BSD-3-Clause](https://github.com/google/uuid/blob/v1.6.0/LICENSE)) - [github.com/gregjones/httpcache](https://pkg.go.dev/github.com/gregjones/httpcache) ([MIT](https://github.com/gregjones/httpcache/blob/901d90724c79/LICENSE.txt)) - [github.com/hdevalence/ed25519consensus](https://pkg.go.dev/github.com/hdevalence/ed25519consensus) ([BSD-3-Clause](https://github.com/hdevalence/ed25519consensus/blob/v0.2.0/LICENSE)) - [github.com/jellydator/ttlcache/v3](https://pkg.go.dev/github.com/jellydator/ttlcache/v3) ([MIT](https://github.com/jellydator/ttlcache/blob/v3.1.0/LICENSE)) - [github.com/jmespath/go-jmespath](https://pkg.go.dev/github.com/jmespath/go-jmespath) ([Apache-2.0](https://github.com/jmespath/go-jmespath/blob/v0.4.0/LICENSE)) - - [github.com/josharian/native](https://pkg.go.dev/github.com/josharian/native) ([MIT](https://github.com/josharian/native/blob/5c7d0dd6ab86/license)) - [github.com/jsimonetti/rtnetlink](https://pkg.go.dev/github.com/jsimonetti/rtnetlink) ([MIT](https://github.com/jsimonetti/rtnetlink/blob/v1.4.1/LICENSE.md)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.17.8/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.17.8/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.17.8/zstd/internal/xxhash/LICENSE.txt)) - - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/v1.7.2/LICENSE.md)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.0/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.0/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.0/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/fbb4dce95f42/LICENSE.md)) - [github.com/mdlayher/socket](https://pkg.go.dev/github.com/mdlayher/socket) ([MIT](https://github.com/mdlayher/socket/blob/v0.5.0/LICENSE.md)) - [github.com/miekg/dns](https://pkg.go.dev/github.com/miekg/dns) ([BSD-3-Clause](https://github.com/miekg/dns/blob/v1.1.58/LICENSE)) - [github.com/mitchellh/go-ps](https://pkg.go.dev/github.com/mitchellh/go-ps) ([MIT](https://github.com/mitchellh/go-ps/blob/v1.0.0/LICENSE.md)) + - [github.com/munnerz/goautoneg](https://pkg.go.dev/github.com/munnerz/goautoneg) ([BSD-3-Clause](https://github.com/munnerz/goautoneg/blob/a7dc8b61c822/LICENSE)) - [github.com/nfnt/resize](https://pkg.go.dev/github.com/nfnt/resize) ([ISC](https://github.com/nfnt/resize/blob/83c6a9932646/LICENSE)) - [github.com/peterbourgon/diskv](https://pkg.go.dev/github.com/peterbourgon/diskv) ([MIT](https://github.com/peterbourgon/diskv/blob/v2.0.1/LICENSE)) + - [github.com/prometheus/client_golang/prometheus](https://pkg.go.dev/github.com/prometheus/client_golang/prometheus) ([Apache-2.0](https://github.com/prometheus/client_golang/blob/v1.19.1/LICENSE)) + - [github.com/prometheus/client_model/go](https://pkg.go.dev/github.com/prometheus/client_model/go) ([Apache-2.0](https://github.com/prometheus/client_model/blob/v0.6.1/LICENSE)) + - [github.com/prometheus/common](https://pkg.go.dev/github.com/prometheus/common) ([Apache-2.0](https://github.com/prometheus/common/blob/v0.55.0/LICENSE)) - [github.com/skip2/go-qrcode](https://pkg.go.dev/github.com/skip2/go-qrcode) ([MIT](https://github.com/skip2/go-qrcode/blob/da1b6568686e/LICENSE)) - [github.com/tailscale/go-winio](https://pkg.go.dev/github.com/tailscale/go-winio) ([MIT](https://github.com/tailscale/go-winio/blob/c4f33415bf55/LICENSE)) - - [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/20486734a56a/LICENSE)) + - [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/ec1d1c113d33/LICENSE)) - [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/4d49adab4de7/LICENSE)) - - [github.com/tailscale/walk](https://pkg.go.dev/github.com/tailscale/walk) ([BSD-3-Clause](https://github.com/tailscale/walk/blob/52804fd3056a/LICENSE)) - - [github.com/tailscale/win](https://pkg.go.dev/github.com/tailscale/win) ([BSD-3-Clause](https://github.com/tailscale/win/blob/6580b55d49ca/LICENSE)) + - [github.com/tailscale/walk](https://pkg.go.dev/github.com/tailscale/walk) ([BSD-3-Clause](https://github.com/tailscale/walk/blob/b2c15a420186/LICENSE)) + - [github.com/tailscale/win](https://pkg.go.dev/github.com/tailscale/win) ([BSD-3-Clause](https://github.com/tailscale/win/blob/5992cb43ca35/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - [github.com/tc-hib/winres](https://pkg.go.dev/github.com/tc-hib/winres) ([0BSD](https://github.com/tc-hib/winres/blob/v0.2.1/LICENSE)) - [github.com/vishvananda/netns](https://pkg.go.dev/github.com/vishvananda/netns) ([Apache-2.0](https://github.com/vishvananda/netns/blob/v0.0.4/LICENSE)) - [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE)) - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.25.0:LICENSE)) - - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/fe59bbe5:LICENSE)) - - [golang.org/x/image/bmp](https://pkg.go.dev/golang.org/x/image/bmp) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.18.0:LICENSE)) - - [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.19.0:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.27.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.8.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.22.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.22.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.16.0:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.35.0:LICENSE)) + - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/939b2ce7:LICENSE)) + - [golang.org/x/image/bmp](https://pkg.go.dev/golang.org/x/image/bmp) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.24.0:LICENSE)) + - [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.23.0:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.36.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.12.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.31.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.30.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.22.0:LICENSE)) - [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=0fa3db229ce2)) - [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3)) + - [google.golang.org/protobuf](https://pkg.go.dev/google.golang.org/protobuf) ([BSD-3-Clause](https://github.com/protocolbuffers/protobuf-go/blob/v1.35.1/LICENSE)) - [gopkg.in/Knetic/govaluate.v3](https://pkg.go.dev/gopkg.in/Knetic/govaluate.v3) ([MIT](https://github.com/Knetic/govaluate/blob/v3.0.0/LICENSE)) + - [gopkg.in/yaml.v3](https://pkg.go.dev/gopkg.in/yaml.v3) ([MIT](https://github.com/go-yaml/yaml/blob/v3.0.1/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) ## Additional Dependencies diff --git a/logpolicy/logpolicy.go b/logpolicy/logpolicy.go index 0d2af77f2d703..fc259a417197b 100644 --- a/logpolicy/logpolicy.go +++ b/logpolicy/logpolicy.go @@ -9,6 +9,7 @@ package logpolicy import ( "bufio" "bytes" + "cmp" "context" "crypto/tls" "encoding/json" @@ -41,6 +42,7 @@ import ( "tailscale.com/net/netknob" "tailscale.com/net/netmon" "tailscale.com/net/netns" + "tailscale.com/net/netx" "tailscale.com/net/tlsdial" "tailscale.com/net/tshttpproxy" "tailscale.com/paths" @@ -230,6 +232,9 @@ func LogsDir(logf logger.Logf) string { logf("logpolicy: using $STATE_DIRECTORY, %q", systemdStateDir) return systemdStateDir } + case "js": + logf("logpolicy: no logs directory in the browser") + return "" } // Default to e.g. /var/lib/tailscale or /var/db/tailscale on Unix. @@ -446,25 +451,76 @@ func tryFixLogStateLocation(dir, cmdname string, logf logger.Logf) { } } -// New returns a new log policy (a logger and its instance ID) for a given -// collection name. -// -// The netMon parameter is optional. It should be specified in environments where -// Tailscaled is manipulating the routing table. -// -// The logf parameter is optional; if non-nil, information logs (e.g. when -// migrating state) are sent to that logger, and global changes to the log -// package are avoided. If nil, logs will be printed using log.Printf. +// Deprecated: Use [Options.New] instead. func New(collection string, netMon *netmon.Monitor, health *health.Tracker, logf logger.Logf) *Policy { - return NewWithConfigPath(collection, "", "", netMon, health, logf) + return Options{ + Collection: collection, + NetMon: netMon, + Health: health, + Logf: logf, + }.New() } -// NewWithConfigPath is identical to New, but uses the specified directory and -// command name. If either is empty, it derives them automatically. -// -// The netMon parameter is optional. It should be specified in environments where -// Tailscaled is manipulating the routing table. +// Deprecated: Use [Options.New] instead. func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor, health *health.Tracker, logf logger.Logf) *Policy { + return Options{ + Collection: collection, + Dir: dir, + CmdName: cmdName, + NetMon: netMon, + Health: health, + Logf: logf, + }.New() +} + +// Options is used to construct a [Policy]. +type Options struct { + // Collection is a required collection to upload logs under. + // Collection is a namespace for the type logs. + // For example, logs for a node use "tailnode.log.tailscale.io". + Collection string + + // Dir is an optional directory to store the log configuration. + // If empty, [LogsDir] is used. + Dir string + + // CmdName is an optional name of the current binary. + // If empty, [version.CmdName] is used. + CmdName string + + // NetMon is an optional parameter for monitoring. + // If non-nil, it's used to do faster interface lookups. + NetMon *netmon.Monitor + + // Health is an optional parameter for health status. + // If non-nil, it's used to construct the default HTTP client. + Health *health.Tracker + + // Logf is an optional logger to use. + // If nil, [log.Printf] will be used instead. + Logf logger.Logf + + // HTTPC is an optional client to use upload logs. + // If nil, [TransportOptions.New] is used to construct a new client + // with that particular transport sending logs to the default logs server. + HTTPC *http.Client + + // MaxBufferSize is the maximum size of the log buffer. + // This controls the amount of logs that can be temporarily stored + // before the logs can be successfully upload. + // If zero, a default buffer size is chosen. + MaxBufferSize int + + // MaxUploadSize is the maximum size per upload. + // This should only be set by clients that have been authenticated + // with the logging service as having a higher upload limit. + // If zero, a default upload size is chosen. + MaxUploadSize int +} + +// init initializes the log policy and returns a logtail.Config and the +// Policy. +func (opts Options) init(disableLogging bool) (*logtail.Config, *Policy) { if hostinfo.IsNATLabGuestVM() { // In NATLab Gokrazy instances, tailscaled comes up concurently with // DHCP and the doesn't have DNS for a while. Wait for DHCP first. @@ -492,23 +548,23 @@ func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor, earlyErrBuf.WriteByte('\n') } - if dir == "" { - dir = LogsDir(earlyLogf) + if opts.Dir == "" { + opts.Dir = LogsDir(earlyLogf) } - if cmdName == "" { - cmdName = version.CmdName() + if opts.CmdName == "" { + opts.CmdName = version.CmdName() } - useStdLogger := logf == nil + useStdLogger := opts.Logf == nil if useStdLogger { - logf = log.Printf + opts.Logf = log.Printf } - tryFixLogStateLocation(dir, cmdName, logf) + tryFixLogStateLocation(opts.Dir, opts.CmdName, opts.Logf) - cfgPath := filepath.Join(dir, fmt.Sprintf("%s.log.conf", cmdName)) + cfgPath := filepath.Join(opts.Dir, fmt.Sprintf("%s.log.conf", opts.CmdName)) if runtime.GOOS == "windows" { - switch cmdName { + switch opts.CmdName { case "tailscaled": // Tailscale 1.14 and before stored state under %LocalAppData% // (usually "C:\WINDOWS\system32\config\systemprofile\AppData\Local" @@ -539,7 +595,7 @@ func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor, cfgPath = paths.TryConfigFileMigration(earlyLogf, oldPath, cfgPath) case "tailscale-ipn": for _, oldBase := range []string{"wg64.log.conf", "wg32.log.conf"} { - oldConf := filepath.Join(dir, oldBase) + oldConf := filepath.Join(opts.Dir, oldBase) if fi, err := os.Stat(oldConf); err == nil && fi.Mode().IsRegular() { cfgPath = paths.TryConfigFileMigration(earlyLogf, oldConf, cfgPath) break @@ -552,44 +608,54 @@ func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor, if err != nil { earlyLogf("logpolicy.ConfigFromFile %v: %v", cfgPath, err) } - if err := newc.Validate(collection); err != nil { + if err := newc.Validate(opts.Collection); err != nil { earlyLogf("logpolicy.Config.Validate for %v: %v", cfgPath, err) - newc = NewConfig(collection) + newc = NewConfig(opts.Collection) if err := newc.Save(cfgPath); err != nil { earlyLogf("logpolicy.Config.Save for %v: %v", cfgPath, err) } } conf := logtail.Config{ - Collection: newc.Collection, - PrivateID: newc.PrivateID, - Stderr: logWriter{console}, - CompressLogs: true, - HTTPC: &http.Client{Transport: NewLogtailTransport(logtail.DefaultHost, netMon, health, logf)}, + Collection: newc.Collection, + PrivateID: newc.PrivateID, + Stderr: logWriter{console}, + CompressLogs: true, + MaxUploadSize: opts.MaxUploadSize, } - if collection == logtail.CollectionNode { + if opts.Collection == logtail.CollectionNode { conf.MetricsDelta = clientmetric.EncodeLogTailMetricsDelta conf.IncludeProcID = true conf.IncludeProcSequence = true } - if envknob.NoLogsNoSupport() || testenv.InTest() { - logf("You have disabled logging. Tailscale will not be able to provide support.") + if disableLogging { + opts.Logf("You have disabled logging. Tailscale will not be able to provide support.") conf.HTTPC = &http.Client{Transport: noopPretendSuccessTransport{}} } else { // Only attach an on-disk filch buffer if we are going to be sending logs. // No reason to persist them locally just to drop them later. - attachFilchBuffer(&conf, dir, cmdName, logf) + attachFilchBuffer(&conf, opts.Dir, opts.CmdName, opts.MaxBufferSize, opts.Logf) + conf.HTTPC = opts.HTTPC + logHost := logtail.DefaultHost if val := getLogTarget(); val != "" { - logf("You have enabled a non-default log target. Doing without being told to by Tailscale staff or your network administrator will make getting support difficult.") + opts.Logf("You have enabled a non-default log target. Doing without being told to by Tailscale staff or your network administrator will make getting support difficult.") conf.BaseURL = val u, _ := url.Parse(val) - conf.HTTPC = &http.Client{Transport: NewLogtailTransport(u.Host, netMon, health, logf)} + logHost = u.Host } + if conf.HTTPC == nil { + conf.HTTPC = &http.Client{Transport: TransportOptions{ + Host: logHost, + NetMon: opts.NetMon, + Health: opts.Health, + Logf: opts.Logf, + }.New()} + } } - lw := logtail.NewLogger(conf, logf) + lw := logtail.NewLogger(conf, opts.Logf) var logOutput io.Writer = lw @@ -607,28 +673,36 @@ func NewWithConfigPath(collection, dir, cmdName string, netMon *netmon.Monitor, log.SetOutput(logOutput) } - logf("Program starting: v%v, Go %v: %#v", + opts.Logf("Program starting: v%v, Go %v: %#v", version.Long(), goVersion(), os.Args) - logf("LogID: %v", newc.PublicID) + opts.Logf("LogID: %v", newc.PublicID) if earlyErrBuf.Len() != 0 { - logf("%s", earlyErrBuf.Bytes()) + opts.Logf("%s", earlyErrBuf.Bytes()) } - return &Policy{ + return &conf, &Policy{ Logtail: lw, PublicID: newc.PublicID, - Logf: logf, + Logf: opts.Logf, } } +// New returns a new log policy (a logger and its instance ID). +func (opts Options) New() *Policy { + disableLogging := envknob.NoLogsNoSupport() || testenv.InTest() || runtime.GOOS == "plan9" + _, policy := opts.init(disableLogging) + return policy +} + // attachFilchBuffer creates an on-disk ring buffer using filch and attaches // it to the logtail config. Note that this is optional; if no buffer is set, // logtail will use an in-memory buffer. -func attachFilchBuffer(conf *logtail.Config, dir, cmdName string, logf logger.Logf) { +func attachFilchBuffer(conf *logtail.Config, dir, cmdName string, maxFileSize int, logf logger.Logf) { filchOptions := filch.Options{ ReplaceStderr: redirectStderrToLogPanics(), + MaxFileSize: maxFileSize, } filchPrefix := filepath.Join(dir, cmdName) @@ -705,7 +779,7 @@ func (p *Policy) Shutdown(ctx context.Context) error { // // The netMon parameter is optional. It should be specified in environments where // Tailscaled is manipulating the routing table. -func MakeDialFunc(netMon *netmon.Monitor, logf logger.Logf) func(ctx context.Context, netw, addr string) (net.Conn, error) { +func MakeDialFunc(netMon *netmon.Monitor, logf logger.Logf) netx.DialFunc { if netMon == nil { netMon = netmon.NewStatic() } @@ -760,23 +834,48 @@ func dialContext(ctx context.Context, netw, addr string, netMon *netmon.Monitor, return c, err } -// NewLogtailTransport returns an HTTP Transport particularly suited to uploading -// logs to the given host name. See DialContext for details on how it works. -// -// The netMon parameter is optional. It should be specified in environments where -// Tailscaled is manipulating the routing table. -// -// The logf parameter is optional; if non-nil, logs are printed using the -// provided function; if nil, log.Printf will be used instead. +// Deprecated: Use [TransportOptions.New] instead. func NewLogtailTransport(host string, netMon *netmon.Monitor, health *health.Tracker, logf logger.Logf) http.RoundTripper { + return TransportOptions{Host: host, NetMon: netMon, Health: health, Logf: logf}.New() +} + +// TransportOptions is used to construct an [http.RoundTripper]. +type TransportOptions struct { + // Host is the optional hostname of the logs server. + // If empty, then [logtail.DefaultHost] is used. + Host string + + // NetMon is an optional parameter for monitoring. + // If non-nil, it's used to do faster interface lookups. + NetMon *netmon.Monitor + + // Health is an optional parameter for health status. + // If non-nil, it's used to construct the default HTTP client. + Health *health.Tracker + + // Logf is an optional logger to use. + // If nil, [log.Printf] will be used instead. + Logf logger.Logf + + // TLSClientConfig is an optional TLS configuration to use. + // If non-nil, the configuration will be cloned. + TLSClientConfig *tls.Config +} + +// New returns an HTTP Transport particularly suited to uploading logs +// to the given host name. See [DialContext] for details on how it works. +func (opts TransportOptions) New() http.RoundTripper { if testenv.InTest() { return noopPretendSuccessTransport{} } - if netMon == nil { - netMon = netmon.NewStatic() + if opts.NetMon == nil { + opts.NetMon = netmon.NewStatic() } // Start with a copy of http.DefaultTransport and tweak it a bit. tr := http.DefaultTransport.(*http.Transport).Clone() + if opts.TLSClientConfig != nil { + tr.TLSClientConfig = opts.TLSClientConfig.Clone() + } tr.Proxy = tshttpproxy.ProxyFromEnvironment tshttpproxy.SetTransportGetProxyConnectHeader(tr) @@ -787,10 +886,10 @@ func NewLogtailTransport(host string, netMon *netmon.Monitor, health *health.Tra tr.DisableCompression = true // Log whenever we dial: - if logf == nil { - logf = log.Printf + if opts.Logf == nil { + opts.Logf = log.Printf } - tr.DialContext = MakeDialFunc(netMon, logf) + tr.DialContext = MakeDialFunc(opts.NetMon, opts.Logf) // We're uploading logs ideally infrequently, with specific timing that will // change over time. Try to keep the connection open, to avoid repeatedly @@ -812,8 +911,9 @@ func NewLogtailTransport(host string, netMon *netmon.Monitor, health *health.Tra tr.TLSNextProto = map[string]func(authority string, c *tls.Conn) http.RoundTripper{} } - tr.TLSClientConfig = tlsdial.Config(host, health, tr.TLSClientConfig) - // Force TLS 1.3 since we know log.tailscale.io supports it. + host := cmp.Or(opts.Host, logtail.DefaultHost) + tr.TLSClientConfig = tlsdial.Config(host, opts.Health, tr.TLSClientConfig) + // Force TLS 1.3 since we know log.tailscale.com supports it. tr.TLSClientConfig.MinVersion = tls.VersionTLS13 return tr diff --git a/logpolicy/logpolicy_test.go b/logpolicy/logpolicy_test.go index fdbfe4506e038..28f03448a225d 100644 --- a/logpolicy/logpolicy_test.go +++ b/logpolicy/logpolicy_test.go @@ -4,33 +4,83 @@ package logpolicy import ( + "net/http" "os" "reflect" "testing" + + "tailscale.com/logtail" ) -func TestLogHost(t *testing.T) { +func resetLogTarget() { + os.Unsetenv("TS_LOG_TARGET") v := reflect.ValueOf(&getLogTargetOnce).Elem() - reset := func() { - v.Set(reflect.Zero(v.Type())) - } - defer reset() + v.Set(reflect.Zero(v.Type())) +} + +func TestLogHost(t *testing.T) { + defer resetLogTarget() tests := []struct { env string want string }{ - {"", "log.tailscale.io"}, + {"", logtail.DefaultHost}, {"http://foo.com", "foo.com"}, {"https://foo.com", "foo.com"}, {"https://foo.com/", "foo.com"}, {"https://foo.com:123/", "foo.com"}, } for _, tt := range tests { - reset() + resetLogTarget() os.Setenv("TS_LOG_TARGET", tt.env) if got := LogHost(); got != tt.want { t.Errorf("for env %q, got %q, want %q", tt.env, got, tt.want) } } } +func TestOptions(t *testing.T) { + defer resetLogTarget() + + tests := []struct { + name string + opts func() Options + wantBaseURL string + }{ + { + name: "default", + opts: func() Options { return Options{} }, + wantBaseURL: "", + }, + { + name: "custom_baseurl", + opts: func() Options { + os.Setenv("TS_LOG_TARGET", "http://localhost:1234") + return Options{} + }, + wantBaseURL: "http://localhost:1234", + }, + { + name: "custom_httpc_and_baseurl", + opts: func() Options { + os.Setenv("TS_LOG_TARGET", "http://localhost:12345") + return Options{HTTPC: &http.Client{Transport: noopPretendSuccessTransport{}}} + }, + wantBaseURL: "http://localhost:12345", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resetLogTarget() + config, policy := tt.opts().init(false) + if policy == nil { + t.Fatal("unexpected nil policy") + } + if config.BaseURL != tt.wantBaseURL { + t.Errorf("got %q, want %q", config.BaseURL, tt.wantBaseURL) + } + policy.Close() + }) + } +} diff --git a/logtail/api.md b/logtail/api.md index 8ec0b69c0f331..20726e2096704 100644 --- a/logtail/api.md +++ b/logtail/api.md @@ -6,14 +6,14 @@ retrieving, and processing log entries. # Overview HTTP requests are received at the service **base URL** -[https://log.tailscale.io](https://log.tailscale.io), and return JSON-encoded +[https://log.tailscale.com](https://log.tailscale.com), and return JSON-encoded responses using standard HTTP response codes. Authorization for the configuration and retrieval APIs is done with a secret API key passed as the HTTP basic auth username. Secret keys are generated via the web UI at base URL. An example of using basic auth with curl: - curl -u : https://log.tailscale.io/collections + curl -u : https://log.tailscale.com/collections In the future, an HTTP header will allow using MessagePack instead of JSON. diff --git a/logtail/example/logadopt/logadopt.go b/logtail/example/logadopt/logadopt.go index 984a8a35adc7a..eba3f93112d62 100644 --- a/logtail/example/logadopt/logadopt.go +++ b/logtail/example/logadopt/logadopt.go @@ -25,7 +25,7 @@ func main() { } log.SetFlags(0) - req, err := http.NewRequest("POST", "https://log.tailscale.io/instances", strings.NewReader(url.Values{ + req, err := http.NewRequest("POST", "https://log.tailscale.com/instances", strings.NewReader(url.Values{ "collection": []string{*collection}, "instances": []string{*publicID}, "adopt": []string{"true"}, diff --git a/logtail/example/logreprocess/demo.sh b/logtail/example/logreprocess/demo.sh index 4ec819a67450d..583929c12b4fe 100755 --- a/logtail/example/logreprocess/demo.sh +++ b/logtail/example/logreprocess/demo.sh @@ -13,7 +13,7 @@ # # Then generate a LOGTAIL_API_KEY and two test collections by visiting: # -# https://log.tailscale.io +# https://log.tailscale.com # # Then set the three variables below. trap 'rv=$?; [ "$rv" = 0 ] || echo "-- exiting with code $rv"; exit $rv' EXIT diff --git a/logtail/example/logreprocess/logreprocess.go b/logtail/example/logreprocess/logreprocess.go index 5dbf765788165..aae65df9f1321 100644 --- a/logtail/example/logreprocess/logreprocess.go +++ b/logtail/example/logreprocess/logreprocess.go @@ -37,7 +37,7 @@ func main() { }() } - req, err := http.NewRequest("GET", "https://log.tailscale.io/c/"+*collection+"?stream=true", nil) + req, err := http.NewRequest("GET", "https://log.tailscale.com/c/"+*collection+"?stream=true", nil) if err != nil { log.Fatal(err) } diff --git a/logtail/logtail.go b/logtail/logtail.go index 9df164273d74c..b355addd20b82 100644 --- a/logtail/logtail.go +++ b/logtail/logtail.go @@ -1,11 +1,12 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -// Package logtail sends logs to log.tailscale.io. +// Package logtail sends logs to log.tailscale.com. package logtail import ( "bytes" + "cmp" "context" "crypto/rand" "encoding/binary" @@ -14,9 +15,7 @@ import ( "log" mrand "math/rand/v2" "net/http" - "net/netip" "os" - "regexp" "runtime" "slices" "strconv" @@ -28,7 +27,6 @@ import ( "tailscale.com/envknob" "tailscale.com/net/netmon" "tailscale.com/net/sockstats" - "tailscale.com/net/tsaddr" "tailscale.com/tstime" tslogger "tailscale.com/types/logger" "tailscale.com/types/logid" @@ -55,7 +53,7 @@ const bufferSize = 4 << 10 // DefaultHost is the default host name to upload logs to when // Config.BaseURL isn't provided. -const DefaultHost = "log.tailscale.io" +const DefaultHost = "log.tailscale.com" const defaultFlushDelay = 2 * time.Second @@ -69,7 +67,7 @@ type Config struct { Collection string // collection name, a domain name PrivateID logid.PrivateID // private ID for the primary log stream CopyPrivateID logid.PrivateID // private ID for a log stream that is a superset of this log stream - BaseURL string // if empty defaults to "https://log.tailscale.io" + BaseURL string // if empty defaults to "https://log.tailscale.com" HTTPC *http.Client // if empty defaults to http.DefaultClient SkipClientTime bool // if true, client_time is not written to logs LowMemory bool // if true, logtail minimizes memory use @@ -78,6 +76,7 @@ type Config struct { StderrLevel int // max verbosity level to write to stderr; 0 means the non-verbose messages only Buffer Buffer // temp storage, if nil a MemoryBuffer CompressLogs bool // whether to compress the log uploads + MaxUploadSize int // maximum upload size; 0 means using the default // MetricsDelta, if non-nil, is a func that returns an encoding // delta in clientmetrics to upload alongside existing logs. @@ -157,6 +156,7 @@ func NewLogger(cfg Config, logf tslogger.Logf) *Logger { url: cfg.BaseURL + "/c/" + cfg.Collection + "/" + cfg.PrivateID.String() + urlSuffix, lowMem: cfg.LowMemory, buffer: cfg.Buffer, + maxUploadSize: cfg.MaxUploadSize, skipClientTime: cfg.SkipClientTime, drainWake: make(chan struct{}, 1), sentinel: make(chan int32, 16), @@ -192,6 +192,7 @@ type Logger struct { skipClientTime bool netMonitor *netmon.Monitor buffer Buffer + maxUploadSize int drainWake chan struct{} // signal to speed up drain drainBuf []byte // owned by drainPending for reuse flushDelayFn func() time.Duration // negative or zero return value to upload aggressively, or >0 to batch at this delay @@ -213,6 +214,7 @@ type Logger struct { procSequence uint64 flushTimer tstime.TimerController // used when flushDelay is >0 writeBuf [bufferSize]byte // owned by Write for reuse + bytesBuf bytes.Buffer // owned by appendTextOrJSONLocked for reuse jsonDec jsontext.Decoder // owned by appendTextOrJSONLocked for reuse shutdownStartMu sync.Mutex // guards the closing of shutdownStart @@ -324,7 +326,7 @@ func (l *Logger) drainPending() (b []byte) { } }() - maxLen := maxSize + maxLen := cmp.Or(l.maxUploadSize, maxSize) if l.lowMem { // When operating in a low memory environment, it is better to upload // in multiple operations than it is to allocate a large body and OOM. @@ -506,7 +508,7 @@ func (l *Logger) upload(ctx context.Context, body []byte, origlen int) (retryAft } if runtime.GOOS == "js" { // We once advertised we'd accept optional client certs (for internal use) - // on log.tailscale.io but then Tailscale SSH js/wasm clients prompted + // on log.tailscale.com but then Tailscale SSH js/wasm clients prompted // users (on some browsers?) to pick a client cert. We'll fix the server's // TLS ServerHello, but we can also fix it client side for good measure. // @@ -725,9 +727,16 @@ func (l *Logger) appendTextOrJSONLocked(dst, src []byte, level int) []byte { // whether it contains the reserved "logtail" name at the top-level. var logtailKeyOffset, logtailValOffset, logtailValLength int validJSON := func() bool { - // TODO(dsnet): Avoid allocation of bytes.Buffer struct. + // The jsontext.NewDecoder API operates on an io.Reader, for which + // bytes.Buffer provides a means to convert a []byte into an io.Reader. + // However, bytes.NewBuffer normally allocates unless + // we immediately shallow copy it into a pre-allocated Buffer struct. + // See https://go.dev/issue/67004. + l.bytesBuf = *bytes.NewBuffer(src) + defer func() { l.bytesBuf = bytes.Buffer{} }() // avoid pinning src + dec := &l.jsonDec - dec.Reset(bytes.NewBuffer(src)) + dec.Reset(&l.bytesBuf) if tok, err := dec.ReadToken(); tok.Kind() != '{' || err != nil { return false } @@ -767,9 +776,10 @@ func (l *Logger) appendTextOrJSONLocked(dst, src []byte, level int) []byte { // That's okay as the Tailscale log service limit is actually 2*maxSize. // However, so long as logging applications aim to target the maxSize limit, // there should be no trouble eventually uploading logs. - if len(src) > maxSize { + maxLen := cmp.Or(l.maxUploadSize, maxSize) + if len(src) > maxLen { errDetail := fmt.Sprintf("entry too large: %d bytes", len(src)) - errData := appendTruncatedString(nil, src, maxSize/len(`\uffff`)) // escaping could increase size + errData := appendTruncatedString(nil, src, maxLen/len(`\uffff`)) // escaping could increase size dst = append(dst, '{') dst = l.appendMetadata(dst, l.skipClientTime, true, l.procID, l.procSequence, errDetail, errData, level) @@ -820,8 +830,6 @@ func (l *Logger) Logf(format string, args ...any) { fmt.Fprintf(l, format, args...) } -var obscureIPs = envknob.RegisterBool("TS_OBSCURE_LOGGED_IPS") - // Write logs an encoded JSON blob. // // If the []byte passed to Write is not an encoded JSON blob, @@ -846,10 +854,6 @@ func (l *Logger) Write(buf []byte) (int, error) { } } - if obscureIPs() { - buf = redactIPs(buf) - } - l.writeLock.Lock() defer l.writeLock.Unlock() @@ -858,40 +862,6 @@ func (l *Logger) Write(buf []byte) (int, error) { return inLen, err } -var ( - regexMatchesIPv6 = regexp.MustCompile(`([0-9a-fA-F]{1,4}):([0-9a-fA-F]{1,4}):([0-9a-fA-F:]{1,4})*`) - regexMatchesIPv4 = regexp.MustCompile(`(\d{1,3})\.(\d{1,3})\.\d{1,3}\.\d{1,3}`) -) - -// redactIPs is a helper function used in Write() to redact IPs (other than tailscale IPs). -// This function takes a log line as a byte slice and -// uses regex matching to parse and find IP addresses. Based on if the IP address is IPv4 or -// IPv6, it parses and replaces the end of the addresses with an "x". This function returns the -// log line with the IPs redacted. -func redactIPs(buf []byte) []byte { - out := regexMatchesIPv6.ReplaceAllFunc(buf, func(b []byte) []byte { - ip, err := netip.ParseAddr(string(b)) - if err != nil || tsaddr.IsTailscaleIP(ip) { - return b // don't change this one - } - - prefix := bytes.Split(b, []byte(":")) - return bytes.Join(append(prefix[:2], []byte("x")), []byte(":")) - }) - - out = regexMatchesIPv4.ReplaceAllFunc(out, func(b []byte) []byte { - ip, err := netip.ParseAddr(string(b)) - if err != nil || tsaddr.IsTailscaleIP(ip) { - return b // don't change this one - } - - prefix := bytes.Split(b, []byte(".")) - return bytes.Join(append(prefix[:2], []byte("x.x")), []byte(".")) - }) - - return []byte(out) -} - var ( openBracketV = []byte("[v") v1 = []byte("[v1] ") diff --git a/logtail/logtail_test.go b/logtail/logtail_test.go index 3ea6304067bfd..b8c46c44840bc 100644 --- a/logtail/logtail_test.go +++ b/logtail/logtail_test.go @@ -15,7 +15,6 @@ import ( "time" "github.com/go-json-experiment/json/jsontext" - "tailscale.com/envknob" "tailscale.com/tstest" "tailscale.com/tstime" "tailscale.com/util/must" @@ -316,85 +315,6 @@ func TestLoggerWriteResult(t *testing.T) { t.Errorf("mismatch.\n got: %#q\nwant: %#q", back, want) } } -func TestRedact(t *testing.T) { - envknob.Setenv("TS_OBSCURE_LOGGED_IPS", "true") - tests := []struct { - in string - want string - }{ - // tests for ipv4 addresses - { - "120.100.30.47", - "120.100.x.x", - }, - { - "192.167.0.1/65", - "192.167.x.x/65", - }, - { - "node [5Btdd] d:e89a3384f526d251 now using 10.0.0.222:41641 mtu=1360 tx=d81a8a35a0ce", - "node [5Btdd] d:e89a3384f526d251 now using 10.0.x.x:41641 mtu=1360 tx=d81a8a35a0ce", - }, - //tests for ipv6 addresses - { - "2001:0db8:85a3:0000:0000:8a2e:0370:7334", - "2001:0db8:x", - }, - { - "2345:0425:2CA1:0000:0000:0567:5673:23b5", - "2345:0425:x", - }, - { - "2601:645:8200:edf0::c9de/64", - "2601:645:x/64", - }, - { - "node [5Btdd] d:e89a3384f526d251 now using 2051:0000:140F::875B:131C mtu=1360 tx=d81a8a35a0ce", - "node [5Btdd] d:e89a3384f526d251 now using 2051:0000:x mtu=1360 tx=d81a8a35a0ce", - }, - { - "2601:645:8200:edf0::c9de/64 2601:645:8200:edf0:1ce9:b17d:71f5:f6a3/64", - "2601:645:x/64 2601:645:x/64", - }, - //tests for tailscale ip addresses - { - "100.64.5.6", - "100.64.5.6", - }, - { - "fd7a:115c:a1e0::/96", - "fd7a:115c:a1e0::/96", - }, - //tests for ipv6 and ipv4 together - { - "192.167.0.1 2001:0db8:85a3:0000:0000:8a2e:0370:7334", - "192.167.x.x 2001:0db8:x", - }, - { - "node [5Btdd] d:e89a3384f526d251 now using 10.0.0.222:41641 mtu=1360 tx=d81a8a35a0ce 2345:0425:2CA1::0567:5673:23b5", - "node [5Btdd] d:e89a3384f526d251 now using 10.0.x.x:41641 mtu=1360 tx=d81a8a35a0ce 2345:0425:x", - }, - { - "100.64.5.6 2091:0db8:85a3:0000:0000:8a2e:0370:7334", - "100.64.5.6 2091:0db8:x", - }, - { - "192.167.0.1 120.100.30.47 2041:0000:140F::875B:131B", - "192.167.x.x 120.100.x.x 2041:0000:x", - }, - { - "fd7a:115c:a1e0::/96 192.167.0.1 2001:0db8:85a3:0000:0000:8a2e:0370:7334", - "fd7a:115c:a1e0::/96 192.167.x.x 2001:0db8:x", - }, - } - - for _, tt := range tests { - gotBuf := redactIPs([]byte(tt.in)) - if string(gotBuf) != tt.want { - t.Errorf("for %q,\n got: %#q\nwant: %#q\n", tt.in, gotBuf, tt.want) - } - } -} func TestAppendMetadata(t *testing.T) { var l Logger diff --git a/maths/ewma.go b/maths/ewma.go new file mode 100644 index 0000000000000..0897b73e4727f --- /dev/null +++ b/maths/ewma.go @@ -0,0 +1,72 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package maths contains additional mathematical functions or structures not +// found in the standard library. +package maths + +import ( + "math" + "time" +) + +// EWMA is an exponentially weighted moving average supporting updates at +// irregular intervals with at most nanosecond resolution. +// The zero value will compute a half-life of 1 second. +// It is not safe for concurrent use. +// TODO(raggi): de-duplicate with tstime/rate.Value, which has a more complex +// and synchronized interface and does not provide direct access to the stable +// value. +type EWMA struct { + value float64 // current value of the average + lastTime int64 // time of last update in unix nanos + halfLife float64 // half-life in seconds +} + +// NewEWMA creates a new EWMA with the specified half-life. If halfLifeSeconds +// is 0, it defaults to 1. +func NewEWMA(halfLifeSeconds float64) *EWMA { + return &EWMA{ + halfLife: halfLifeSeconds, + } +} + +// Update adds a new sample to the average. If t is zero or precedes the last +// update, the update is ignored. +func (e *EWMA) Update(value float64, t time.Time) { + if t.IsZero() { + return + } + hl := e.halfLife + if hl == 0 { + hl = 1 + } + tn := t.UnixNano() + if e.lastTime == 0 { + e.value = value + e.lastTime = tn + return + } + + dt := (time.Duration(tn-e.lastTime) * time.Nanosecond).Seconds() + if dt < 0 { + // drop out of order updates + return + } + + // decay = 2^(-dt/halfLife) + decay := math.Exp2(-dt / hl) + e.value = e.value*decay + value*(1-decay) + e.lastTime = tn +} + +// Get returns the current value of the average +func (e *EWMA) Get() float64 { + return e.value +} + +// Reset clears the EWMA to its initial state +func (e *EWMA) Reset() { + e.value = 0 + e.lastTime = 0 +} diff --git a/maths/ewma_test.go b/maths/ewma_test.go new file mode 100644 index 0000000000000..307078a38ebdf --- /dev/null +++ b/maths/ewma_test.go @@ -0,0 +1,178 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package maths + +import ( + "slices" + "testing" + "time" +) + +// some real world latency samples. +var ( + latencyHistory1 = []int{ + 14, 12, 15, 6, 19, 12, 13, 13, 13, 16, 17, 11, 17, 11, 14, 15, 14, 15, + 16, 16, 17, 14, 12, 16, 18, 14, 14, 11, 15, 15, 25, 11, 15, 14, 12, 15, + 13, 12, 13, 15, 11, 13, 15, 14, 14, 15, 12, 15, 18, 12, 15, 22, 12, 13, + 10, 14, 16, 15, 16, 11, 14, 17, 18, 20, 16, 11, 16, 14, 5, 15, 17, 12, + 15, 11, 15, 20, 12, 17, 12, 17, 15, 12, 12, 11, 14, 15, 11, 20, 14, 13, + 11, 12, 13, 13, 11, 13, 11, 15, 13, 13, 14, 12, 11, 12, 12, 14, 11, 13, + 12, 12, 12, 19, 14, 13, 13, 14, 11, 12, 10, 11, 15, 12, 14, 11, 11, 14, + 14, 12, 12, 11, 14, 12, 11, 12, 14, 11, 12, 15, 12, 14, 12, 12, 21, 16, + 21, 12, 16, 9, 11, 16, 14, 13, 14, 12, 13, 16, + } + latencyHistory2 = []int{ + 18, 20, 21, 21, 20, 23, 18, 18, 20, 21, 20, 19, 22, 18, 20, 20, 19, 21, + 21, 22, 22, 19, 18, 22, 22, 19, 20, 17, 16, 11, 25, 16, 18, 21, 17, 22, + 19, 18, 22, 21, 20, 18, 22, 17, 17, 20, 19, 10, 19, 16, 19, 25, 17, 18, + 15, 20, 21, 20, 23, 22, 22, 22, 19, 22, 22, 17, 22, 20, 20, 19, 21, 22, + 20, 19, 17, 22, 16, 16, 20, 22, 17, 19, 21, 16, 20, 22, 19, 21, 20, 19, + 13, 14, 23, 19, 16, 10, 19, 15, 15, 17, 16, 18, 14, 16, 18, 22, 20, 18, + 18, 21, 15, 19, 18, 19, 18, 20, 17, 19, 21, 19, 20, 19, 20, 20, 17, 14, + 17, 17, 18, 21, 20, 18, 18, 17, 16, 17, 17, 20, 22, 19, 20, 21, 21, 20, + 21, 24, 20, 18, 12, 17, 18, 17, 19, 19, 19, + } +) + +func TestEWMALatencyHistory(t *testing.T) { + type result struct { + t time.Time + v float64 + s int + } + + for _, latencyHistory := range [][]int{latencyHistory1, latencyHistory2} { + startTime := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + halfLife := 30.0 + + ewma := NewEWMA(halfLife) + + var results []result + sum := 0.0 + for i, latency := range latencyHistory { + t := startTime.Add(time.Duration(i) * time.Second) + ewma.Update(float64(latency), t) + sum += float64(latency) + + results = append(results, result{t, ewma.Get(), latency}) + } + mean := sum / float64(len(latencyHistory)) + min := float64(slices.Min(latencyHistory)) + max := float64(slices.Max(latencyHistory)) + + t.Logf("EWMA Latency History (half-life: %.1f seconds):", halfLife) + t.Logf("Mean latency: %.2f ms", mean) + t.Logf("Range: [%.1f, %.1f]", min, max) + + t.Log("Samples: ") + sparkline := []rune("▁▂▃▄▅▆▇█") + var sampleLine []rune + for _, r := range results { + idx := int(((float64(r.s) - min) / (max - min)) * float64(len(sparkline)-1)) + if idx >= len(sparkline) { + idx = len(sparkline) - 1 + } + sampleLine = append(sampleLine, sparkline[idx]) + } + t.Log(string(sampleLine)) + + t.Log("EWMA: ") + var ewmaLine []rune + for _, r := range results { + idx := int(((r.v - min) / (max - min)) * float64(len(sparkline)-1)) + if idx >= len(sparkline) { + idx = len(sparkline) - 1 + } + ewmaLine = append(ewmaLine, sparkline[idx]) + } + t.Log(string(ewmaLine)) + t.Log("") + + t.Logf("Time | Sample | Value | Value - Sample") + t.Logf("") + + for _, result := range results { + t.Logf("%10s | % 6d | % 5.2f | % 5.2f", result.t.Format("15:04:05"), result.s, result.v, result.v-float64(result.s)) + } + + // check that all results are greater than the min, and less than the max of the input, + // and they're all close to the mean. + for _, result := range results { + if result.v < float64(min) || result.v > float64(max) { + t.Errorf("result %f out of range [%f, %f]", result.v, min, max) + } + + if result.v < mean*0.9 || result.v > mean*1.1 { + t.Errorf("result %f not close to mean %f", result.v, mean) + } + } + } +} + +func TestHalfLife(t *testing.T) { + start := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + + ewma := NewEWMA(30.0) + ewma.Update(10, start) + ewma.Update(0, start.Add(30*time.Second)) + + if ewma.Get() != 5 { + t.Errorf("expected 5, got %f", ewma.Get()) + } + + ewma.Update(10, start.Add(60*time.Second)) + if ewma.Get() != 7.5 { + t.Errorf("expected 7.5, got %f", ewma.Get()) + } + + ewma.Update(10, start.Add(90*time.Second)) + if ewma.Get() != 8.75 { + t.Errorf("expected 8.75, got %f", ewma.Get()) + } +} + +func TestZeroValue(t *testing.T) { + start := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + + var ewma EWMA + ewma.Update(10, start) + ewma.Update(0, start.Add(time.Second)) + + if ewma.Get() != 5 { + t.Errorf("expected 5, got %f", ewma.Get()) + } + + ewma.Update(10, start.Add(2*time.Second)) + if ewma.Get() != 7.5 { + t.Errorf("expected 7.5, got %f", ewma.Get()) + } + + ewma.Update(10, start.Add(3*time.Second)) + if ewma.Get() != 8.75 { + t.Errorf("expected 8.75, got %f", ewma.Get()) + } +} + +func TestReset(t *testing.T) { + start := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + + ewma := NewEWMA(30.0) + ewma.Update(10, start) + ewma.Update(0, start.Add(30*time.Second)) + + if ewma.Get() != 5 { + t.Errorf("expected 5, got %f", ewma.Get()) + } + + ewma.Reset() + + if ewma.Get() != 0 { + t.Errorf("expected 0, got %f", ewma.Get()) + } + + ewma.Update(10, start.Add(90*time.Second)) + if ewma.Get() != 10 { + t.Errorf("expected 10, got %f", ewma.Get()) + } +} diff --git a/metrics/metrics.go b/metrics/metrics.go index a07ddccae5107..d1b1c06c9dc2c 100644 --- a/metrics/metrics.go +++ b/metrics/metrics.go @@ -11,6 +11,9 @@ import ( "io" "slices" "strings" + "sync" + + "tailscale.com/syncs" ) // Set is a string-to-Var map variable that satisfies the expvar.Var @@ -37,6 +40,8 @@ type Set struct { type LabelMap struct { Label string expvar.Map + // shardedIntMu orders the initialization of new shardedint keys + shardedIntMu sync.Mutex } // SetInt64 sets the *Int value stored under the given map key. @@ -44,6 +49,19 @@ func (m *LabelMap) SetInt64(key string, v int64) { m.Get(key).Set(v) } +// Add adds delta to the any int-like value stored under the given map key. +func (m *LabelMap) Add(key string, delta int64) { + type intAdder interface { + Add(delta int64) + } + o := m.Map.Get(key) + if o == nil { + m.Map.Add(key, delta) + return + } + o.(intAdder).Add(delta) +} + // Get returns a direct pointer to the expvar.Int for key, creating it // if necessary. func (m *LabelMap) Get(key string) *expvar.Int { @@ -51,6 +69,23 @@ func (m *LabelMap) Get(key string) *expvar.Int { return m.Map.Get(key).(*expvar.Int) } +// GetShardedInt returns a direct pointer to the syncs.ShardedInt for key, +// creating it if necessary. +func (m *LabelMap) GetShardedInt(key string) *syncs.ShardedInt { + i := m.Map.Get(key) + if i == nil { + m.shardedIntMu.Lock() + defer m.shardedIntMu.Unlock() + i = m.Map.Get(key) + if i != nil { + return i.(*syncs.ShardedInt) + } + i = syncs.NewShardedInt() + m.Set(key, i) + } + return i.(*syncs.ShardedInt) +} + // GetIncrFunc returns a function that increments the expvar.Int named by key. // // Most callers should not need this; it exists to satisfy an diff --git a/metrics/metrics_test.go b/metrics/metrics_test.go index 45bf39e56efd2..a808d5a73eb3e 100644 --- a/metrics/metrics_test.go +++ b/metrics/metrics_test.go @@ -21,6 +21,15 @@ func TestLabelMap(t *testing.T) { if g, w := m.Get("bar").Value(), int64(2); g != w { t.Errorf("bar = %v; want %v", g, w) } + m.GetShardedInt("sharded").Add(5) + if g, w := m.GetShardedInt("sharded").Value(), int64(5); g != w { + t.Errorf("sharded = %v; want %v", g, w) + } + m.Add("sharded", 1) + if g, w := m.GetShardedInt("sharded").Value(), int64(6); g != w { + t.Errorf("sharded = %v; want %v", g, w) + } + m.Add("neverbefore", 1) } func TestCurrentFileDescriptors(t *testing.T) { diff --git a/net/bakedroots/bakedroots.go b/net/bakedroots/bakedroots.go new file mode 100644 index 0000000000000..42e70c0dd2abb --- /dev/null +++ b/net/bakedroots/bakedroots.go @@ -0,0 +1,149 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package bakedroots contains WebPKI CA roots we bake into the tailscaled binary, +// lest the system's CA roots be missing them (or entirely empty). +package bakedroots + +import ( + "crypto/x509" + "sync" + + "tailscale.com/util/testenv" +) + +// Get returns the baked-in roots. +// +// As of 2025-01-21, this includes only the LetsEncrypt ISRG Root X1 root. +func Get() *x509.CertPool { + roots.once.Do(func() { + roots.parsePEM(append( + []byte(letsEncryptX1), + letsEncryptX2..., + )) + }) + return roots.p +} + +// testingTB is a subset of testing.TB needed +// to verify the caller isn't in a parallel test. +type testingTB interface { + // Setenv panics if it's in a parallel test. + Setenv(k, v string) +} + +// ResetForTest resets the cached roots for testing, +// optionally setting them to caPEM if non-nil. +func ResetForTest(tb testingTB, caPEM []byte) { + if !testenv.InTest() { + panic("not in test") + } + tb.Setenv("ASSERT_NOT_PARALLEL_TEST", "1") // panics if tb's Parallel was called + + roots = rootsOnce{} + if caPEM != nil { + roots.once.Do(func() { roots.parsePEM(caPEM) }) + } +} + +var roots rootsOnce + +type rootsOnce struct { + once sync.Once + p *x509.CertPool +} + +func (r *rootsOnce) parsePEM(caPEM []byte) { + p := x509.NewCertPool() + if !p.AppendCertsFromPEM(caPEM) { + panic("bogus PEM") + } + r.p = p +} + +/* +letsEncryptX1 is the LetsEncrypt X1 root: + +Certificate: + + Data: + Version: 3 (0x2) + Serial Number: + 82:10:cf:b0:d2:40:e3:59:44:63:e0:bb:63:82:8b:00 + Signature Algorithm: sha256WithRSAEncryption + Issuer: C = US, O = Internet Security Research Group, CN = ISRG Root X1 + Validity + Not Before: Jun 4 11:04:38 2015 GMT + Not After : Jun 4 11:04:38 2035 GMT + Subject: C = US, O = Internet Security Research Group, CN = ISRG Root X1 + Subject Public Key Info: + Public Key Algorithm: rsaEncryption + RSA Public-Key: (4096 bit) + +We bake it into the binary as a fallback verification root, +in case the system we're running on doesn't have it. +(Tailscale runs on some ancient devices.) + +To test that this code is working on Debian/Ubuntu: + +$ sudo mv /usr/share/ca-certificates/mozilla/ISRG_Root_X1.crt{,.old} +$ sudo update-ca-certificates + +Then restart tailscaled. To also test dnsfallback's use of it, nuke +your /etc/resolv.conf and it should still start & run fine. +*/ +const letsEncryptX1 = ` +-----BEGIN CERTIFICATE----- +MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw +TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh +cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 +WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu +ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY +MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc +h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ +0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U +A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW +T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH +B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC +B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv +KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn +OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn +jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw +qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI +rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV +HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq +hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL +ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ +3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK +NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 +ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur +TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC +jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc +oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq +4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA +mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d +emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= +-----END CERTIFICATE----- +` + +// letsEncryptX2 is the ISRG Root X2. +// +// Subject: O = Internet Security Research Group, CN = ISRG Root X2 +// Key type: ECDSA P-384 +// Validity: until 2035-09-04 (generated 2020-09-04) +const letsEncryptX2 = ` +-----BEGIN CERTIFICATE----- +MIICGzCCAaGgAwIBAgIQQdKd0XLq7qeAwSxs6S+HUjAKBggqhkjOPQQDAzBPMQsw +CQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJuZXQgU2VjdXJpdHkgUmVzZWFyY2gg +R3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBYMjAeFw0yMDA5MDQwMDAwMDBaFw00 +MDA5MTcxNjAwMDBaME8xCzAJBgNVBAYTAlVTMSkwJwYDVQQKEyBJbnRlcm5ldCBT +ZWN1cml0eSBSZXNlYXJjaCBHcm91cDEVMBMGA1UEAxMMSVNSRyBSb290IFgyMHYw +EAYHKoZIzj0CAQYFK4EEACIDYgAEzZvVn4CDCuwJSvMWSj5cz3es3mcFDR0HttwW ++1qLFNvicWDEukWVEYmO6gbf9yoWHKS5xcUy4APgHoIYOIvXRdgKam7mAHf7AlF9 +ItgKbppbd9/w+kHsOdx1ymgHDB/qo0IwQDAOBgNVHQ8BAf8EBAMCAQYwDwYDVR0T +AQH/BAUwAwEB/zAdBgNVHQ4EFgQUfEKWrt5LSDv6kviejM9ti6lyN5UwCgYIKoZI +zj0EAwMDaAAwZQIwe3lORlCEwkSHRhtFcP9Ymd70/aTSVaYgLXTWNLxBo1BfASdW +tL4ndQavEi51mI38AjEAi/V3bNTIZargCyzuFJ0nN6T5U6VR5CmD1/iQMVtCnwr1 +/q4AaOeMSQ+2b1tbFfLn +-----END CERTIFICATE----- +` diff --git a/net/bakedroots/bakedroots_test.go b/net/bakedroots/bakedroots_test.go new file mode 100644 index 0000000000000..8ba502a7827e0 --- /dev/null +++ b/net/bakedroots/bakedroots_test.go @@ -0,0 +1,32 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package bakedroots + +import ( + "slices" + "testing" +) + +func TestBakedInRoots(t *testing.T) { + ResetForTest(t, nil) + p := Get() + got := p.Subjects() + if len(got) != 2 { + t.Errorf("subjects = %v; want 2", len(got)) + } + + // TODO(bradfitz): is there a way to easily make this test prettier without + // writing a DER decoder? I'm not seeing how. + var name []string + for _, der := range got { + name = append(name, string(der)) + } + want := []string{ + "0O1\v0\t\x06\x03U\x04\x06\x13\x02US1)0'\x06\x03U\x04\n\x13 Internet Security Research Group1\x150\x13\x06\x03U\x04\x03\x13\fISRG Root X1", + "0O1\v0\t\x06\x03U\x04\x06\x13\x02US1)0'\x06\x03U\x04\n\x13 Internet Security Research Group1\x150\x13\x06\x03U\x04\x03\x13\fISRG Root X2", + } + if !slices.Equal(name, want) { + t.Errorf("subjects = %q; want %q", name, want) + } +} diff --git a/net/captivedetection/captivedetection.go b/net/captivedetection/captivedetection.go index c6e8bca3a19a2..a06362a5b4d1d 100644 --- a/net/captivedetection/captivedetection.go +++ b/net/captivedetection/captivedetection.go @@ -11,6 +11,7 @@ import ( "net" "net/http" "runtime" + "strconv" "strings" "sync" "syscall" @@ -23,6 +24,7 @@ import ( // Detector checks whether the system is behind a captive portal. type Detector struct { + clock func() time.Time // httpClient is the HTTP client that is used for captive portal detection. It is configured // to not follow redirects, have a short timeout and no keep-alive. @@ -52,6 +54,13 @@ func NewDetector(logf logger.Logf) *Detector { return d } +func (d *Detector) Now() time.Time { + if d.clock != nil { + return d.clock() + } + return time.Now() +} + // Timeout is the timeout for captive portal detection requests. Because the captive portal intercepting our requests // is usually located on the LAN, this is a relatively short timeout. const Timeout = 3 * time.Second @@ -136,26 +145,31 @@ func interfaceNameDoesNotNeedCaptiveDetection(ifName string, goos string) bool { func (d *Detector) detectOnInterface(ctx context.Context, ifIndex int, endpoints []Endpoint) bool { defer d.httpClient.CloseIdleConnections() - d.logf("[v2] %d available captive portal detection endpoints: %v", len(endpoints), endpoints) + use := min(len(endpoints), 5) + endpoints = endpoints[:use] + d.logf("[v2] %d available captive portal detection endpoints; trying %v", len(endpoints), use) // We try to detect the captive portal more quickly by making requests to multiple endpoints concurrently. var wg sync.WaitGroup resultCh := make(chan bool, len(endpoints)) - for i, e := range endpoints { - if i >= 5 { - // Try a maximum of 5 endpoints, break out (returning false) if we run of attempts. - break - } + // Once any goroutine detects a captive portal, we shut down the others. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + for _, e := range endpoints { wg.Add(1) go func(endpoint Endpoint) { defer wg.Done() found, err := d.verifyCaptivePortalEndpoint(ctx, endpoint, ifIndex) if err != nil { - d.logf("[v1] checkCaptivePortalEndpoint failed with endpoint %v: %v", endpoint, err) + if ctx.Err() == nil { + d.logf("[v1] checkCaptivePortalEndpoint failed with endpoint %v: %v", endpoint, err) + } return } if found { + cancel() // one match is good enough resultCh <- true } }(e) @@ -182,10 +196,16 @@ func (d *Detector) verifyCaptivePortalEndpoint(ctx context.Context, e Endpoint, ctx, cancel := context.WithTimeout(ctx, Timeout) defer cancel() - req, err := http.NewRequestWithContext(ctx, "GET", e.URL.String(), nil) + u := *e.URL + v := u.Query() + v.Add("t", strconv.Itoa(int(d.Now().Unix()))) + u.RawQuery = v.Encode() + + req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) if err != nil { return false, err } + req.Header.Set("Cache-Control", "no-cache, no-store, must-revalidate, no-transform, max-age=0") // Attach the Tailscale challenge header if the endpoint supports it. Not all captive portal detection endpoints // support this, so we only attach it if the endpoint does. diff --git a/net/captivedetection/captivedetection_test.go b/net/captivedetection/captivedetection_test.go index e74273afd922e..064a86c8c35e5 100644 --- a/net/captivedetection/captivedetection_test.go +++ b/net/captivedetection/captivedetection_test.go @@ -5,12 +5,21 @@ package captivedetection import ( "context" + "net/http" + "net/http/httptest" + "net/url" "runtime" + "strconv" "sync" + "sync/atomic" "testing" + "time" - "tailscale.com/cmd/testwrapper/flakytest" + "tailscale.com/derp/derphttp" "tailscale.com/net/netmon" + "tailscale.com/syncs" + "tailscale.com/tstest/nettest" + "tailscale.com/util/must" ) func TestAvailableEndpointsAlwaysAtLeastTwo(t *testing.T) { @@ -36,25 +45,110 @@ func TestDetectCaptivePortalReturnsFalse(t *testing.T) { } } -func TestAllEndpointsAreUpAndReturnExpectedResponse(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/13019") +func TestEndpointsAreUpAndReturnExpectedResponse(t *testing.T) { + nettest.SkipIfNoNetwork(t) + d := NewDetector(t.Logf) endpoints := availableEndpoints(nil, 0, t.Logf, runtime.GOOS) + t.Logf("testing %d endpoints", len(endpoints)) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + var good atomic.Bool var wg sync.WaitGroup + sem := syncs.NewSemaphore(5) for _, e := range endpoints { wg.Add(1) go func(endpoint Endpoint) { defer wg.Done() - found, err := d.verifyCaptivePortalEndpoint(context.Background(), endpoint, 0) - if err != nil { - t.Errorf("verifyCaptivePortalEndpoint failed with endpoint %v: %v", endpoint, err) + + if !sem.AcquireContext(ctx) { + return + } + defer sem.Release() + + found, err := d.verifyCaptivePortalEndpoint(ctx, endpoint, 0) + if err != nil && ctx.Err() == nil { + t.Logf("verifyCaptivePortalEndpoint failed with endpoint %v: %v", endpoint, err) } if found { - t.Errorf("verifyCaptivePortalEndpoint with endpoint %v says we're behind a captive portal, but we aren't", endpoint) + t.Logf("verifyCaptivePortalEndpoint with endpoint %v says we're behind a captive portal, but we aren't", endpoint) + return } + good.Store(true) + t.Logf("endpoint good: %v", endpoint) + cancel() }(e) } wg.Wait() + + if !good.Load() { + t.Errorf("no good endpoints found") + } +} + +func TestCaptivePortalRequest(t *testing.T) { + d := NewDetector(t.Logf) + now := time.Now() + d.clock = func() time.Time { return now } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "GET" { + t.Errorf("expected GET, got %q", r.Method) + } + if r.URL.Path != "/generate_204" { + t.Errorf("expected /generate_204, got %q", r.URL.Path) + } + q := r.URL.Query() + if got, want := q.Get("t"), strconv.Itoa(int(now.Unix())); got != want { + t.Errorf("timestamp param; got %v, want %v", got, want) + } + w.Header().Set("X-Tailscale-Response", "response "+r.Header.Get("X-Tailscale-Challenge")) + + w.WriteHeader(http.StatusNoContent) + })) + defer s.Close() + + e := Endpoint{ + URL: must.Get(url.Parse(s.URL + "/generate_204")), + StatusCode: 204, + ExpectedContent: "", + SupportsTailscaleChallenge: true, + } + + found, err := d.verifyCaptivePortalEndpoint(ctx, e, 0) + if err != nil { + t.Fatalf("verifyCaptivePortalEndpoint = %v, %v", found, err) + } + if found { + t.Errorf("verifyCaptivePortalEndpoint = %v, want false", found) + } +} + +func TestAgainstDERPHandler(t *testing.T) { + d := NewDetector(t.Logf) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := httptest.NewServer(http.HandlerFunc(derphttp.ServeNoContent)) + defer s.Close() + e := Endpoint{ + URL: must.Get(url.Parse(s.URL + "/generate_204")), + StatusCode: 204, + ExpectedContent: "", + SupportsTailscaleChallenge: true, + } + found, err := d.verifyCaptivePortalEndpoint(ctx, e, 0) + if err != nil { + t.Fatalf("verifyCaptivePortalEndpoint = %v, %v", found, err) + } + if found { + t.Errorf("verifyCaptivePortalEndpoint = %v, want false", found) + } } diff --git a/net/captivedetection/endpoints.go b/net/captivedetection/endpoints.go index 450ed4a1cae4a..57b3e53351a1a 100644 --- a/net/captivedetection/endpoints.go +++ b/net/captivedetection/endpoints.go @@ -89,7 +89,7 @@ func availableEndpoints(derpMap *tailcfg.DERPMap, preferredDERPRegionID int, log // Use the DERP IPs as captive portal detection endpoints. Using IPs is better than hostnames // because they do not depend on DNS resolution. for _, region := range derpMap.Regions { - if region.Avoid { + if region.Avoid || region.NoMeasureNoHome { continue } for _, node := range region.Nodes { diff --git a/net/dns/config.go b/net/dns/config.go index 67d3d753c9a8d..b2f4e6dbd9dc2 100644 --- a/net/dns/config.go +++ b/net/dns/config.go @@ -10,6 +10,8 @@ import ( "net/netip" "sort" + "tailscale.com/control/controlknobs" + "tailscale.com/envknob" "tailscale.com/net/dns/publicdns" "tailscale.com/net/dns/resolver" "tailscale.com/net/tsaddr" @@ -47,11 +49,28 @@ type Config struct { OnlyIPv6 bool } -func (c *Config) serviceIP() netip.Addr { +var magicDNSDualStack = envknob.RegisterBool("TS_DEBUG_MAGIC_DNS_DUAL_STACK") + +// serviceIPs returns the list of service IPs where MagicDNS is reachable. +// +// The provided knobs may be nil. +func (c *Config) serviceIPs(knobs *controlknobs.Knobs) []netip.Addr { if c.OnlyIPv6 { - return tsaddr.TailscaleServiceIPv6() + return []netip.Addr{tsaddr.TailscaleServiceIPv6()} } - return tsaddr.TailscaleServiceIP() + + // TODO(bradfitz,mikeodr,raggi): include IPv6 here too; tailscale/tailscale#15404 + // And add a controlknobs knob to disable dual stack. + // + // For now, opt-in for testing. + if magicDNSDualStack() { + return []netip.Addr{ + tsaddr.TailscaleServiceIP(), + tsaddr.TailscaleServiceIPv6(), + } + } + + return []netip.Addr{tsaddr.TailscaleServiceIP()} } // WriteToBufioWriter write a debug version of c for logs to w, omitting diff --git a/net/dns/debian_resolvconf.go b/net/dns/debian_resolvconf.go index 3ffc796e06d1b..63fd80c1274e8 100644 --- a/net/dns/debian_resolvconf.go +++ b/net/dns/debian_resolvconf.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || freebsd || openbsd +//go:build (linux && !android) || freebsd || openbsd package dns diff --git a/net/dns/direct.go b/net/dns/direct.go index aaff18fcb7848..f23723d9a1515 100644 --- a/net/dns/direct.go +++ b/net/dns/direct.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !android && !ios + package dns import ( diff --git a/net/dns/direct_linux.go b/net/dns/direct_linux.go index bdeefb352498b..0558f0f51b253 100644 --- a/net/dns/direct_linux.go +++ b/net/dns/direct_linux.go @@ -1,26 +1,32 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build linux && !android + package dns import ( "bytes" "context" + "fmt" - "github.com/illarion/gonotify/v2" + "github.com/illarion/gonotify/v3" "tailscale.com/health" ) func (m *directManager) runFileWatcher() { - ctx, cancel := context.WithCancel(m.ctx) - defer cancel() - in, err := gonotify.NewInotify(ctx) - if err != nil { - // Oh well, we tried. This is all best effort for now, to - // surface warnings to users. - m.logf("dns: inotify new: %v", err) - return + if err := watchFile(m.ctx, "/etc/", resolvConf, m.checkForFileTrample); err != nil { + // This is all best effort for now, so surface warnings to users. + m.logf("dns: inotify: %s", err) } +} + +// watchFile sets up an inotify watch for a given directory and +// calls the callback function every time a particular file is changed. +// The filename should be located in the provided directory. +func watchFile(ctx context.Context, dir, filename string, cb func()) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() const events = gonotify.IN_ATTRIB | gonotify.IN_CLOSE_WRITE | @@ -29,30 +35,20 @@ func (m *directManager) runFileWatcher() { gonotify.IN_MODIFY | gonotify.IN_MOVE - if err := in.AddWatch("/etc/", events); err != nil { - m.logf("dns: inotify addwatch: %v", err) - return + watcher, err := gonotify.NewDirWatcher(ctx, events, dir) + if err != nil { + return fmt.Errorf("NewDirWatcher: %w", err) } + for { - events, err := in.Read() - if ctx.Err() != nil { - return - } - if err != nil { - m.logf("dns: inotify read: %v", err) - return - } - var match bool - for _, ev := range events { - if ev.Name == resolvConf { - match = true - break + select { + case event := <-watcher.C: + if event.Name == filename { + cb() } + case <-ctx.Done(): + return ctx.Err() } - if !match { - continue - } - m.checkForFileTrample() } } diff --git a/net/dns/direct_linux_test.go b/net/dns/direct_linux_test.go new file mode 100644 index 0000000000000..079d060ed3323 --- /dev/null +++ b/net/dns/direct_linux_test.go @@ -0,0 +1,56 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +import ( + "context" + "errors" + "fmt" + "os" + "sync/atomic" + "testing" + "time" + + "golang.org/x/sync/errgroup" +) + +func TestWatchFile(t *testing.T) { + dir := t.TempDir() + filepath := dir + "/test.txt" + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + var callbackCalled atomic.Bool + callbackDone := make(chan bool) + callback := func() { + callbackDone <- true + callbackCalled.Store(true) + } + + var eg errgroup.Group + eg.Go(func() error { return watchFile(ctx, dir, filepath, callback) }) + + // Keep writing until we get a callback. + func() { + for i := range 10000 { + if err := os.WriteFile(filepath, []byte(fmt.Sprintf("write%d", i)), 0644); err != nil { + t.Fatal(err) + } + select { + case <-callbackDone: + return + case <-time.After(10 * time.Millisecond): + } + } + }() + + cancel() + if err := eg.Wait(); err != nil && !errors.Is(err, context.Canceled) { + t.Error(err) + } + if !callbackCalled.Load() { + t.Error("callback was not called") + } +} diff --git a/net/dns/direct_notlinux.go b/net/dns/direct_notlinux.go index c221ca1beaa59..a73a35e5ead2b 100644 --- a/net/dns/direct_notlinux.go +++ b/net/dns/direct_notlinux.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !linux +//go:build !linux && !android && !ios package dns diff --git a/net/dns/manager.go b/net/dns/manager.go index 51a0fa12cba63..5d6f225ce032f 100644 --- a/net/dns/manager.go +++ b/net/dns/manager.go @@ -8,6 +8,7 @@ import ( "context" "encoding/binary" "errors" + "fmt" "io" "net" "net/netip" @@ -18,22 +19,24 @@ import ( "sync/atomic" "time" - xmaps "golang.org/x/exp/maps" "tailscale.com/control/controlknobs" "tailscale.com/health" "tailscale.com/net/dns/resolver" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" "tailscale.com/syncs" - "tailscale.com/tstime/rate" "tailscale.com/types/dnstype" "tailscale.com/types/logger" "tailscale.com/util/clientmetric" "tailscale.com/util/dnsname" + "tailscale.com/util/slicesx" ) var ( errFullQueue = errors.New("request queue full") + // ErrNoDNSConfig is returned by RecompileDNSConfig when the Manager + // has no existing DNS configuration. + ErrNoDNSConfig = errors.New("no DNS configuration") ) // maxActiveQueries returns the maximal number of DNS requests that can @@ -59,10 +62,8 @@ type Manager struct { knobs *controlknobs.Knobs // or nil goos string // if empty, gets set to runtime.GOOS - mu sync.Mutex // guards following - // config is the last configuration we successfully compiled or nil if there - // was any failure applying the last configuration. - config *Config + mu sync.Mutex // guards following + config *Config // Tracks the last viable DNS configuration set by Set. nil on failures other than compilation failures or if set has never been called. } // NewManagers created a new manager from the given config. @@ -89,25 +90,6 @@ func NewManager(logf logger.Logf, oscfg OSConfigurator, health *health.Tracker, goos: goos, } - // Rate limit our attempts to correct our DNS configuration. - limiter := rate.NewLimiter(1.0/5.0, 1) - - // This will recompile the DNS config, which in turn will requery the system - // DNS settings. The recovery func should triggered only when we are missing - // upstream nameservers and require them to forward a query. - m.resolver.SetMissingUpstreamRecovery(func() { - m.mu.Lock() - defer m.mu.Unlock() - if m.config == nil { - return - } - - if limiter.Allow() { - m.logf("DNS resolution failed due to missing upstream nameservers. Recompiling DNS configuration.") - m.setLocked(*m.config) - } - }) - m.ctx, m.ctxCancel = context.WithCancel(context.Background()) m.logf("using %T", m.os) return m @@ -116,6 +98,26 @@ func NewManager(logf logger.Logf, oscfg OSConfigurator, health *health.Tracker, // Resolver returns the Manager's DNS Resolver. func (m *Manager) Resolver() *resolver.Resolver { return m.resolver } +// RecompileDNSConfig recompiles the last attempted DNS configuration, which has +// the side effect of re-querying the OS's interface nameservers. This should be used +// on platforms where the interface nameservers can change. Darwin, for example, +// where the nameservers aren't always available when we process a major interface +// change event, or platforms where the nameservers may change while tunnel is up. +// +// This should be called if it is determined that [OSConfigurator.GetBaseConfig] may +// give a better or different result than when [Manager.Set] was last called. The +// logic for making that determination is up to the caller. +// +// It returns [ErrNoDNSConfig] if [Manager.Set] has never been called. +func (m *Manager) RecompileDNSConfig() error { + m.mu.Lock() + defer m.mu.Unlock() + if m.config != nil { + return m.setLocked(*m.config) + } + return ErrNoDNSConfig +} + func (m *Manager) Set(cfg Config) error { m.mu.Lock() defer m.mu.Unlock() @@ -133,15 +135,15 @@ func (m *Manager) GetBaseConfig() (OSConfig, error) { func (m *Manager) setLocked(cfg Config) error { syncs.AssertLocked(&m.mu) - // On errors, the 'set' config is cleared. - m.config = nil - m.logf("Set: %v", logger.ArgWriter(func(w *bufio.Writer) { cfg.WriteToBufioWriter(w) })) rcfg, ocfg, err := m.compileConfig(cfg) if err != nil { + // On a compilation failure, set m.config set for later reuse by + // [Manager.RecompileDNSConfig] and return the error. + m.config = &cfg return err } @@ -153,14 +155,16 @@ func (m *Manager) setLocked(cfg Config) error { })) if err := m.resolver.SetConfig(rcfg); err != nil { + m.config = nil return err } if err := m.os.SetDNS(ocfg); err != nil { - m.health.SetDNSOSHealth(err) + m.config = nil + m.health.SetUnhealthy(osConfigurationSetWarnable, health.Args{health.ArgError: err.Error()}) return err } - m.health.SetDNSOSHealth(nil) + m.health.SetHealthy(osConfigurationSetWarnable) m.config = &cfg return nil @@ -203,7 +207,7 @@ func compileHostEntries(cfg Config) (hosts []*HostEntry) { if len(hostsMap) == 0 { return nil } - hosts = xmaps.Values(hostsMap) + hosts = slicesx.MapValues(hostsMap) slices.SortFunc(hosts, func(a, b *HostEntry) int { if len(a.Hosts) == 0 && len(b.Hosts) == 0 { return 0 @@ -217,6 +221,26 @@ func compileHostEntries(cfg Config) (hosts []*HostEntry) { return hosts } +var osConfigurationReadWarnable = health.Register(&health.Warnable{ + Code: "dns-read-os-config-failed", + Title: "Failed to read system DNS configuration", + Text: func(args health.Args) string { + return fmt.Sprintf("Tailscale failed to fetch the DNS configuration of your device: %v", args[health.ArgError]) + }, + Severity: health.SeverityLow, + DependsOn: []*health.Warnable{health.NetworkStatusWarnable}, +}) + +var osConfigurationSetWarnable = health.Register(&health.Warnable{ + Code: "dns-set-os-config-failed", + Title: "Failed to set system DNS configuration", + Text: func(args health.Args) string { + return fmt.Sprintf("Tailscale failed to set the DNS configuration of your device: %v", args[health.ArgError]) + }, + Severity: health.SeverityMedium, + DependsOn: []*health.Warnable{health.NetworkStatusWarnable}, +}) + // compileConfig converts cfg into a quad-100 resolver configuration // and an OS-level configuration. func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig, err error) { @@ -225,8 +249,10 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig // the OS. rcfg.Hosts = cfg.Hosts routes := map[dnsname.FQDN][]*dnstype.Resolver{} // assigned conditionally to rcfg.Routes below. + var propagateHostsToOS bool for suffix, resolvers := range cfg.Routes { if len(resolvers) == 0 { + propagateHostsToOS = true rcfg.LocalDomains = append(rcfg.LocalDomains, suffix) } else { routes[suffix] = resolvers @@ -235,13 +261,13 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig // Similarly, the OS always gets search paths. ocfg.SearchDomains = cfg.SearchDomains - if m.goos == "windows" { + if propagateHostsToOS && m.goos == "windows" { ocfg.Hosts = compileHostEntries(cfg) } // Deal with trivial configs first. switch { - case !cfg.needsOSResolver(): + case !cfg.needsOSResolver() || runtime.GOOS == "plan9": // Set search domains, but nothing else. This also covers the // case where cfg is entirely zero, in which case these // configs clear all Tailscale DNS settings. @@ -264,7 +290,7 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig // through quad-100. rcfg.Routes = routes rcfg.Routes["."] = cfg.DefaultResolvers - ocfg.Nameservers = []netip.Addr{cfg.serviceIP()} + ocfg.Nameservers = cfg.serviceIPs(m.knobs) return rcfg, ocfg, nil } @@ -302,7 +328,7 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig // or routes + MagicDNS, or just MagicDNS, or on an OS that cannot // split-DNS. Install a split config pointing at quad-100. rcfg.Routes = routes - ocfg.Nameservers = []netip.Addr{cfg.serviceIP()} + ocfg.Nameservers = cfg.serviceIPs(m.knobs) var baseCfg *OSConfig // base config; non-nil if/when known @@ -312,7 +338,10 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig // that as the forwarder for all DNS traffic that quad-100 doesn't handle. if isApple || !m.os.SupportsSplitDNS() { // If the OS can't do native split-dns, read out the underlying - // resolver config and blend it into our config. + // resolver config and blend it into our config. On apple platforms, [OSConfigurator.GetBaseConfig] + // has a tendency to temporarily fail if called immediately following + // an interface change. These failures should be retried if/when the OS + // indicates that the DNS configuration has changed via [RecompileDNSConfig]. cfg, err := m.os.GetBaseConfig() if err == nil { baseCfg = &cfg @@ -320,9 +349,10 @@ func (m *Manager) compileConfig(cfg Config) (rcfg resolver.Config, ocfg OSConfig // This is currently (2022-10-13) expected on certain iOS and macOS // builds. } else { - m.health.SetDNSOSHealth(err) + m.health.SetUnhealthy(osConfigurationReadWarnable, health.Args{health.ArgError: err.Error()}) return resolver.Config{}, OSConfig{}, err } + m.health.SetHealthy(osConfigurationReadWarnable) } if baseCfg == nil { diff --git a/net/dns/manager_default.go b/net/dns/manager_default.go index 11dea5ca888b1..dbe985cacdfc9 100644 --- a/net/dns/manager_default.go +++ b/net/dns/manager_default.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !linux && !freebsd && !openbsd && !windows && !darwin +//go:build (!linux || android) && !freebsd && !openbsd && !windows && !darwin && !illumos && !solaris && !plan9 package dns diff --git a/net/dns/manager_linux.go b/net/dns/manager_linux.go index 3ba3022b62757..6bd368f50e330 100644 --- a/net/dns/manager_linux.go +++ b/net/dns/manager_linux.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build linux && !android + package dns import ( diff --git a/net/dns/manager_plan9.go b/net/dns/manager_plan9.go new file mode 100644 index 0000000000000..ca179f27fcc8a --- /dev/null +++ b/net/dns/manager_plan9.go @@ -0,0 +1,181 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// TODO: man 6 ndb | grep -e 'suffix.*same line' +// to detect Russ's https://9fans.topicbox.com/groups/9fans/T9c9d81b5801a0820/ndb-suffix-specific-dns-changes + +package dns + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net/netip" + "os" + "regexp" + "strings" + "unicode" + + "tailscale.com/control/controlknobs" + "tailscale.com/health" + "tailscale.com/types/logger" + "tailscale.com/util/set" +) + +func NewOSConfigurator(logf logger.Logf, ht *health.Tracker, knobs *controlknobs.Knobs, interfaceName string) (OSConfigurator, error) { + return &plan9DNSManager{ + logf: logf, + ht: ht, + knobs: knobs, + }, nil +} + +type plan9DNSManager struct { + logf logger.Logf + ht *health.Tracker + knobs *controlknobs.Knobs +} + +// netNDBBytesWithoutTailscale returns raw (the contents of /net/ndb) with any +// Tailscale bits removed. +func netNDBBytesWithoutTailscale(raw []byte) ([]byte, error) { + var ret bytes.Buffer + bs := bufio.NewScanner(bytes.NewReader(raw)) + removeLine := set.Set[string]{} + for bs.Scan() { + t := bs.Text() + if rest, ok := strings.CutPrefix(t, "#tailscaled-added-line:"); ok { + removeLine.Add(strings.TrimSpace(rest)) + continue + } + trimmed := strings.TrimSpace(t) + if removeLine.Contains(trimmed) { + removeLine.Delete(trimmed) + continue + } + + // Also remove any DNS line referencing *.ts.net. This is + // Tailscale-specific (and won't work with, say, Headscale), but + // the Headscale case will be covered by the #tailscaled-added-line + // logic above, assuming the user didn't delete those comments. + if (strings.HasPrefix(trimmed, "dns=") || strings.Contains(trimmed, "dnsdomain=")) && + strings.HasSuffix(trimmed, ".ts.net") { + continue + } + + ret.WriteString(t) + ret.WriteByte('\n') + } + return ret.Bytes(), bs.Err() +} + +// setNDBSuffix adds lines to tsFree (the contents of /net/ndb already cleaned +// of Tailscale-added lines) to add the optional DNS search domain (e.g. +// "foo.ts.net") and DNS server to it. +func setNDBSuffix(tsFree []byte, suffix string) []byte { + suffix = strings.TrimSuffix(suffix, ".") + if suffix == "" { + return tsFree + } + var buf bytes.Buffer + bs := bufio.NewScanner(bytes.NewReader(tsFree)) + var added []string + addLine := func(s string) { + added = append(added, strings.TrimSpace(s)) + buf.WriteString(s) + } + for bs.Scan() { + buf.Write(bs.Bytes()) + buf.WriteByte('\n') + + t := bs.Text() + if suffix != "" && len(added) == 0 && strings.HasPrefix(t, "\tdns=") { + addLine(fmt.Sprintf("\tdns=100.100.100.100 suffix=%s\n", suffix)) + addLine(fmt.Sprintf("\tdnsdomain=%s\n", suffix)) + } + } + bufTrim := bytes.TrimLeftFunc(buf.Bytes(), unicode.IsSpace) + if len(added) == 0 { + return bufTrim + } + var ret bytes.Buffer + for _, s := range added { + ret.WriteString("#tailscaled-added-line: ") + ret.WriteString(s) + ret.WriteString("\n") + } + ret.WriteString("\n") + ret.Write(bufTrim) + return ret.Bytes() +} + +func (m *plan9DNSManager) SetDNS(c OSConfig) error { + ndbOnDisk, err := os.ReadFile("/net/ndb") + if err != nil { + return err + } + + tsFree, err := netNDBBytesWithoutTailscale(ndbOnDisk) + if err != nil { + return err + } + + var suffix string + if len(c.SearchDomains) > 0 { + suffix = string(c.SearchDomains[0]) + } + + newBuf := setNDBSuffix(tsFree, suffix) + if !bytes.Equal(newBuf, ndbOnDisk) { + if err := os.WriteFile("/net/ndb", newBuf, 0644); err != nil { + return fmt.Errorf("writing /net/ndb: %w", err) + } + if f, err := os.OpenFile("/net/dns", os.O_RDWR, 0); err == nil { + if _, err := io.WriteString(f, "refresh\n"); err != nil { + f.Close() + return fmt.Errorf("/net/dns refresh write: %w", err) + } + if err := f.Close(); err != nil { + return fmt.Errorf("/net/dns refresh close: %w", err) + } + } + } + + return nil +} + +func (m *plan9DNSManager) SupportsSplitDNS() bool { return false } + +func (m *plan9DNSManager) Close() error { + // TODO(bradfitz): remove the Tailscale bits from /net/ndb ideally + return nil +} + +var dnsRegex = regexp.MustCompile(`\bdns=(\d+\.\d+\.\d+\.\d+)\b`) + +func (m *plan9DNSManager) GetBaseConfig() (OSConfig, error) { + var oc OSConfig + f, err := os.Open("/net/ndb") + if err != nil { + return oc, err + } + defer f.Close() + bs := bufio.NewScanner(f) + for bs.Scan() { + m := dnsRegex.FindSubmatch(bs.Bytes()) + if m == nil { + continue + } + addr, err := netip.ParseAddr(string(m[1])) + if err != nil { + continue + } + oc.Nameservers = append(oc.Nameservers, addr) + } + if err := bs.Err(); err != nil { + return oc, err + } + + return oc, nil +} diff --git a/net/dns/manager_plan9_test.go b/net/dns/manager_plan9_test.go new file mode 100644 index 0000000000000..806fdb68ed6ba --- /dev/null +++ b/net/dns/manager_plan9_test.go @@ -0,0 +1,86 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build plan9 + +package dns + +import "testing" + +func TestNetNDBBytesWithoutTailscale(t *testing.T) { + tests := []struct { + name string + raw string + want string + }{ + { + name: "empty", + raw: "", + want: "", + }, + { + name: "no-tailscale", + raw: "# This is a comment\nip=10.0.2.15 ipmask=255.255.255.0 ipgw=10.0.2.2\n\tsys=gnot\n", + want: "# This is a comment\nip=10.0.2.15 ipmask=255.255.255.0 ipgw=10.0.2.2\n\tsys=gnot\n", + }, + { + name: "remove-by-comments", + raw: "# This is a comment\n#tailscaled-added-line: dns=100.100.100.100\nip=10.0.2.15 ipmask=255.255.255.0 ipgw=10.0.2.2\n\tdns=100.100.100.100\n\tsys=gnot\n", + want: "# This is a comment\nip=10.0.2.15 ipmask=255.255.255.0 ipgw=10.0.2.2\n\tsys=gnot\n", + }, + { + name: "remove-by-ts.net", + raw: "Some line\n\tdns=100.100.100.100 suffix=foo.ts.net\n\tfoo=bar\n", + want: "Some line\n\tfoo=bar\n", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := netNDBBytesWithoutTailscale([]byte(tt.raw)) + if err != nil { + t.Fatal(err) + } + if string(got) != tt.want { + t.Errorf("GOT:\n%s\n\nWANT:\n%s\n", string(got), tt.want) + } + }) + } +} + +func TestSetNDBSuffix(t *testing.T) { + tests := []struct { + name string + raw string + want string + }{ + { + name: "empty", + raw: "", + want: "", + }, + { + name: "set", + raw: "ip=10.0.2.15 ipmask=255.255.255.0 ipgw=10.0.2.2\n\tsys=gnot\n\tdns=100.100.100.100\n\n# foo\n", + want: `#tailscaled-added-line: dns=100.100.100.100 suffix=foo.ts.net +#tailscaled-added-line: dnsdomain=foo.ts.net + +ip=10.0.2.15 ipmask=255.255.255.0 ipgw=10.0.2.2 + sys=gnot + dns=100.100.100.100 + dns=100.100.100.100 suffix=foo.ts.net + dnsdomain=foo.ts.net + +# foo +`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := setNDBSuffix([]byte(tt.raw), "foo.ts.net") + if string(got) != tt.want { + t.Errorf("wrong value\n GOT %q:\n%s\n\nWANT %q:\n%s\n", got, got, tt.want, tt.want) + } + }) + } + +} diff --git a/net/dns/manager_solaris.go b/net/dns/manager_solaris.go new file mode 100644 index 0000000000000..1f48efb9e61a1 --- /dev/null +++ b/net/dns/manager_solaris.go @@ -0,0 +1,14 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +import ( + "tailscale.com/control/controlknobs" + "tailscale.com/health" + "tailscale.com/types/logger" +) + +func NewOSConfigurator(logf logger.Logf, health *health.Tracker, _ *controlknobs.Knobs, iface string) (OSConfigurator, error) { + return newDirectManager(logf, health), nil +} diff --git a/net/dns/manager_test.go b/net/dns/manager_test.go index 366e08bbf8644..522f9636abefe 100644 --- a/net/dns/manager_test.go +++ b/net/dns/manager_test.go @@ -4,6 +4,7 @@ package dns import ( + "errors" "net/netip" "runtime" "strings" @@ -24,8 +25,9 @@ type fakeOSConfigurator struct { SplitDNS bool BaseConfig OSConfig - OSConfig OSConfig - ResolverConfig resolver.Config + OSConfig OSConfig + ResolverConfig resolver.Config + GetBaseConfigErr *error } func (c *fakeOSConfigurator) SetDNS(cfg OSConfig) error { @@ -45,6 +47,9 @@ func (c *fakeOSConfigurator) SupportsSplitDNS() bool { } func (c *fakeOSConfigurator) GetBaseConfig() (OSConfig, error) { + if c.GetBaseConfigErr != nil { + return OSConfig{}, *c.GetBaseConfigErr + } return c.BaseConfig, nil } @@ -836,6 +841,76 @@ func TestManager(t *testing.T) { }, goos: "darwin", }, + { + name: "populate-hosts-magicdns", + in: Config{ + Routes: upstreams( + "corp.com", "2.2.2.2", + "ts.com", ""), + Hosts: hosts( + "dave.ts.com.", "1.2.3.4", + "bradfitz.ts.com.", "2.3.4.5"), + SearchDomains: fqdns("ts.com", "universe.tf"), + }, + split: true, + os: OSConfig{ + Hosts: []*HostEntry{ + { + Addr: netip.MustParseAddr("2.3.4.5"), + Hosts: []string{ + "bradfitz.ts.com.", + "bradfitz", + }, + }, + { + Addr: netip.MustParseAddr("1.2.3.4"), + Hosts: []string{ + "dave.ts.com.", + "dave", + }, + }, + }, + Nameservers: mustIPs("100.100.100.100"), + SearchDomains: fqdns("ts.com", "universe.tf"), + MatchDomains: fqdns("corp.com", "ts.com"), + }, + rs: resolver.Config{ + Routes: upstreams("corp.com.", "2.2.2.2"), + Hosts: hosts( + "dave.ts.com.", "1.2.3.4", + "bradfitz.ts.com.", "2.3.4.5"), + LocalDomains: fqdns("ts.com."), + }, + goos: "windows", + }, + { + // Regression test for https://github.com/tailscale/tailscale/issues/14428 + name: "nopopulate-hosts-nomagicdns", + in: Config{ + Routes: upstreams( + "corp.com", "2.2.2.2", + "ts.com", "1.1.1.1"), + Hosts: hosts( + "dave.ts.com.", "1.2.3.4", + "bradfitz.ts.com.", "2.3.4.5"), + SearchDomains: fqdns("ts.com", "universe.tf"), + }, + split: true, + os: OSConfig{ + Nameservers: mustIPs("100.100.100.100"), + SearchDomains: fqdns("ts.com", "universe.tf"), + MatchDomains: fqdns("corp.com", "ts.com"), + }, + rs: resolver.Config{ + Routes: upstreams( + "corp.com.", "2.2.2.2", + "ts.com", "1.1.1.1"), + Hosts: hosts( + "dave.ts.com.", "1.2.3.4", + "bradfitz.ts.com.", "2.3.4.5"), + }, + goos: "windows", + }, } trIP := cmp.Transformer("ipStr", func(ip netip.Addr) string { return ip.String() }) @@ -949,3 +1024,50 @@ func upstreams(strs ...string) (ret map[dnsname.FQDN][]*dnstype.Resolver) { } return ret } + +func TestConfigRecompilation(t *testing.T) { + fakeErr := errors.New("fake os configurator error") + f := &fakeOSConfigurator{} + f.GetBaseConfigErr = &fakeErr + f.BaseConfig = OSConfig{ + Nameservers: mustIPs("1.1.1.1"), + } + + config := Config{ + Routes: upstreams("ts.net", "69.4.2.0", "foo.ts.net", ""), + SearchDomains: fqdns("foo.ts.net"), + } + + m := NewManager(t.Logf, f, new(health.Tracker), tsdial.NewDialer(netmon.NewStatic()), nil, nil, "darwin") + + var managerConfig *resolver.Config + m.resolver.TestOnlySetHook(func(cfg resolver.Config) { + managerConfig = &cfg + }) + + // Initial set should error out and store the config + if err := m.Set(config); err == nil { + t.Fatalf("Want non-nil error. Got nil") + } + if m.config == nil { + t.Fatalf("Want persisted config. Got nil.") + } + if managerConfig != nil { + t.Fatalf("Want nil managerConfig. Got %v", managerConfig) + } + + // Clear the error. We should take the happy path now and + // set m.manager's Config. + f.GetBaseConfigErr = nil + + // Recompilation without an error should succeed and set m.config and m.manager's [resolver.Config] + if err := m.RecompileDNSConfig(); err != nil { + t.Fatalf("Want nil error. Got err %v", err) + } + if m.config == nil { + t.Fatalf("Want non-nil config. Got nil") + } + if managerConfig == nil { + t.Fatalf("Want non nil managerConfig. Got nil") + } +} diff --git a/net/dns/manager_windows.go b/net/dns/manager_windows.go index 250a2557350dd..6ed5d3ba61f7e 100644 --- a/net/dns/manager_windows.go +++ b/net/dns/manager_windows.go @@ -8,10 +8,12 @@ import ( "bytes" "errors" "fmt" + "maps" "net/netip" "os" "os/exec" "path/filepath" + "slices" "sort" "strings" "sync" @@ -27,6 +29,9 @@ import ( "tailscale.com/health" "tailscale.com/types/logger" "tailscale.com/util/dnsname" + "tailscale.com/util/syspolicy" + "tailscale.com/util/syspolicy/rsop" + "tailscale.com/util/syspolicy/setting" "tailscale.com/util/winutil" ) @@ -43,6 +48,8 @@ type windowsManager struct { nrptDB *nrptRuleDatabase wslManager *wslManager + unregisterPolicyChangeCb func() // called when the manager is closing + mu sync.Mutex closing bool } @@ -62,6 +69,11 @@ func NewOSConfigurator(logf logger.Logf, health *health.Tracker, knobs *controlk ret.nrptDB = newNRPTRuleDatabase(logf) } + var err error + if ret.unregisterPolicyChangeCb, err = syspolicy.RegisterChangeCallback(ret.sysPolicyChanged); err != nil { + logf("error registering policy change callback: %v", err) // non-fatal + } + go func() { // Log WSL status once at startup. if distros, err := wslDistros(); err != nil { @@ -140,9 +152,8 @@ func (m *windowsManager) setSplitDNS(resolvers []netip.Addr, domains []dnsname.F return m.nrptDB.WriteSplitDNSConfig(servers, domains) } -func setTailscaleHosts(prevHostsFile []byte, hosts []*HostEntry) ([]byte, error) { - b := bytes.ReplaceAll(prevHostsFile, []byte("\r\n"), []byte("\n")) - sc := bufio.NewScanner(bytes.NewReader(b)) +func setTailscaleHosts(logf logger.Logf, prevHostsFile []byte, hosts []*HostEntry) ([]byte, error) { + sc := bufio.NewScanner(bytes.NewReader(prevHostsFile)) const ( header = "# TailscaleHostsSectionStart" footer = "# TailscaleHostsSectionEnd" @@ -151,6 +162,32 @@ func setTailscaleHosts(prevHostsFile []byte, hosts []*HostEntry) ([]byte, error) "# This section contains MagicDNS entries for Tailscale.", "# Do not edit this section manually.", } + + prevEntries := make(map[netip.Addr][]string) + addPrevEntry := func(line string) { + if line == "" || line[0] == '#' { + return + } + + parts := strings.Split(line, " ") + if len(parts) < 1 { + return + } + + addr, err := netip.ParseAddr(parts[0]) + if err != nil { + logf("Parsing address from hosts: %v", err) + return + } + + prevEntries[addr] = parts[1:] + } + + nextEntries := make(map[netip.Addr][]string, len(hosts)) + for _, he := range hosts { + nextEntries[he.Addr] = he.Hosts + } + var out bytes.Buffer var inSection bool for sc.Scan() { @@ -164,26 +201,34 @@ func setTailscaleHosts(prevHostsFile []byte, hosts []*HostEntry) ([]byte, error) continue } if inSection { + addPrevEntry(line) continue } - fmt.Fprintln(&out, line) + fmt.Fprintf(&out, "%s\r\n", line) } if err := sc.Err(); err != nil { return nil, err } + + unchanged := maps.EqualFunc(prevEntries, nextEntries, func(a, b []string) bool { + return slices.Equal(a, b) + }) + if unchanged { + return nil, nil + } + if len(hosts) > 0 { - fmt.Fprintln(&out, header) + fmt.Fprintf(&out, "%s\r\n", header) for _, c := range comments { - fmt.Fprintln(&out, c) + fmt.Fprintf(&out, "%s\r\n", c) } - fmt.Fprintln(&out) + fmt.Fprintf(&out, "\r\n") for _, he := range hosts { - fmt.Fprintf(&out, "%s %s\n", he.Addr, strings.Join(he.Hosts, " ")) + fmt.Fprintf(&out, "%s %s\r\n", he.Addr, strings.Join(he.Hosts, " ")) } - fmt.Fprintln(&out) - fmt.Fprintln(&out, footer) + fmt.Fprintf(&out, "\r\n%s\r\n", footer) } - return bytes.ReplaceAll(out.Bytes(), []byte("\n"), []byte("\r\n")), nil + return out.Bytes(), nil } // setHosts sets the hosts file to contain the given host entries. @@ -197,10 +242,15 @@ func (m *windowsManager) setHosts(hosts []*HostEntry) error { if err != nil { return err } - outB, err := setTailscaleHosts(b, hosts) + outB, err := setTailscaleHosts(m.logf, b, hosts) if err != nil { return err } + if outB == nil { + // No change to hosts file, therefore no write necessary. + return nil + } + const fileMode = 0 // ignored on windows. // This can fail spuriously with an access denied error, so retry it a @@ -322,11 +372,9 @@ func (m *windowsManager) SetDNS(cfg OSConfig) error { // configuration only, routing one set of things to the "split" // resolver and the rest to the primary. - // Unconditionally disable dynamic DNS updates and NetBIOS on our - // interfaces. - if err := m.disableDynamicUpdates(); err != nil { - m.logf("disableDynamicUpdates error: %v\n", err) - } + // Reconfigure DNS registration according to the [syspolicy.DNSRegistration] + // policy setting, and unconditionally disable NetBIOS on our interfaces. + m.reconfigureDNSRegistration() if err := m.disableNetBIOS(); err != nil { m.logf("disableNetBIOS error: %v\n", err) } @@ -445,6 +493,10 @@ func (m *windowsManager) Close() error { m.closing = true m.mu.Unlock() + if m.unregisterPolicyChangeCb != nil { + m.unregisterPolicyChangeCb() + } + err := m.SetDNS(OSConfig{}) if m.nrptDB != nil { m.nrptDB.Close() @@ -453,15 +505,62 @@ func (m *windowsManager) Close() error { return err } -// disableDynamicUpdates sets the appropriate registry values to prevent the -// Windows DHCP client from sending dynamic DNS updates for our interface to -// AD domain controllers. -func (m *windowsManager) disableDynamicUpdates() error { +// sysPolicyChanged is a callback triggered by [syspolicy] when it detects +// a change in one or more syspolicy settings. +func (m *windowsManager) sysPolicyChanged(policy *rsop.PolicyChange) { + if policy.HasChanged(syspolicy.EnableDNSRegistration) { + m.reconfigureDNSRegistration() + } +} + +// reconfigureDNSRegistration configures the DNS registration settings +// using the [syspolicy.DNSRegistration] policy setting, if it is set. +// If the policy is not configured, it disables DNS registration. +func (m *windowsManager) reconfigureDNSRegistration() { + // Disable DNS registration by default (if the policy setting is not configured). + // This is primarily for historical reasons and to avoid breaking existing + // setups that rely on this behavior. + enableDNSRegistration, err := syspolicy.GetPreferenceOptionOrDefault(syspolicy.EnableDNSRegistration, setting.NeverByPolicy) + if err != nil { + m.logf("error getting DNSRegistration policy setting: %v", err) // non-fatal; we'll use the default + } + + if enableDNSRegistration.Show() { + // "Show" reports whether the policy setting is configured as "user-decides". + // The name is a bit unfortunate in this context, as we don't actually "show" anything. + // Still, if the admin configured the policy as "user-decides", we shouldn't modify + // the adapter's settings and should leave them up to the user (admin rights required) + // or the system defaults. + return + } + + // Otherwise, if the policy setting is configured as "always" or "never", + // we should configure the adapter accordingly. + if err := m.configureDNSRegistration(enableDNSRegistration.IsAlways()); err != nil { + m.logf("error configuring DNS registration: %v", err) + } +} + +// configureDNSRegistration sets the appropriate registry values to allow or prevent +// the Windows DHCP client from registering Tailscale IP addresses with DNS +// and sending dynamic updates for our interface to AD domain controllers. +func (m *windowsManager) configureDNSRegistration(enabled bool) error { prefixen := []winutil.RegistryPathPrefix{ winutil.IPv4TCPIPInterfacePrefix, winutil.IPv6TCPIPInterfacePrefix, } + var ( + registrationEnabled = uint32(0) + disableDynamicUpdate = uint32(1) + maxNumberOfAddressesToRegister = uint32(0) + ) + if enabled { + registrationEnabled = 1 + disableDynamicUpdate = 0 + maxNumberOfAddressesToRegister = 1 + } + for _, prefix := range prefixen { k, err := m.openInterfaceKey(prefix) if err != nil { @@ -469,13 +568,13 @@ func (m *windowsManager) disableDynamicUpdates() error { } defer k.Close() - if err := k.SetDWordValue("RegistrationEnabled", 0); err != nil { + if err := k.SetDWordValue("RegistrationEnabled", registrationEnabled); err != nil { return err } - if err := k.SetDWordValue("DisableDynamicUpdate", 1); err != nil { + if err := k.SetDWordValue("DisableDynamicUpdate", disableDynamicUpdate); err != nil { return err } - if err := k.SetDWordValue("MaxNumberOfAddressesToRegister", 0); err != nil { + if err := k.SetDWordValue("MaxNumberOfAddressesToRegister", maxNumberOfAddressesToRegister); err != nil { return err } } diff --git a/net/dns/manager_windows_test.go b/net/dns/manager_windows_test.go index 62c4dd9fbb740..edcf24ec04240 100644 --- a/net/dns/manager_windows_test.go +++ b/net/dns/manager_windows_test.go @@ -15,6 +15,7 @@ import ( "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" + "tailscale.com/types/logger" "tailscale.com/util/dnsname" "tailscale.com/util/winutil" "tailscale.com/util/winutil/gp" @@ -24,9 +25,56 @@ const testGPRuleID = "{7B1B6151-84E6-41A3-8967-62F7F7B45687}" func TestHostFileNewLines(t *testing.T) { in := []byte("#foo\r\n#bar\n#baz\n") - want := []byte("#foo\r\n#bar\r\n#baz\r\n") + want := []byte("#foo\r\n#bar\r\n#baz\r\n# TailscaleHostsSectionStart\r\n# This section contains MagicDNS entries for Tailscale.\r\n# Do not edit this section manually.\r\n\r\n192.168.1.1 aaron\r\n\r\n# TailscaleHostsSectionEnd\r\n") - got, err := setTailscaleHosts(in, nil) + he := []*HostEntry{ + &HostEntry{ + Addr: netip.MustParseAddr("192.168.1.1"), + Hosts: []string{"aaron"}, + }, + } + got, err := setTailscaleHosts(logger.Discard, in, he) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(got, want) { + t.Errorf("got %q, want %q\n", got, want) + } +} + +func TestHostFileUnchanged(t *testing.T) { + in := []byte("#foo\r\n#bar\r\n#baz\r\n# TailscaleHostsSectionStart\r\n# This section contains MagicDNS entries for Tailscale.\r\n# Do not edit this section manually.\r\n\r\n192.168.1.1 aaron\r\n\r\n# TailscaleHostsSectionEnd\r\n") + + he := []*HostEntry{ + &HostEntry{ + Addr: netip.MustParseAddr("192.168.1.1"), + Hosts: []string{"aaron"}, + }, + } + got, err := setTailscaleHosts(logger.Discard, in, he) + if err != nil { + t.Fatal(err) + } + if got != nil { + t.Errorf("got %q, want nil\n", got) + } +} + +func TestHostFileChanged(t *testing.T) { + in := []byte("#foo\r\n#bar\r\n#baz\r\n# TailscaleHostsSectionStart\r\n# This section contains MagicDNS entries for Tailscale.\r\n# Do not edit this section manually.\r\n\r\n192.168.1.1 aaron1\r\n\r\n# TailscaleHostsSectionEnd\r\n") + want := []byte("#foo\r\n#bar\r\n#baz\r\n# TailscaleHostsSectionStart\r\n# This section contains MagicDNS entries for Tailscale.\r\n# Do not edit this section manually.\r\n\r\n192.168.1.1 aaron1\r\n192.168.1.2 aaron2\r\n\r\n# TailscaleHostsSectionEnd\r\n") + + he := []*HostEntry{ + &HostEntry{ + Addr: netip.MustParseAddr("192.168.1.1"), + Hosts: []string{"aaron1"}, + }, + &HostEntry{ + Addr: netip.MustParseAddr("192.168.1.2"), + Hosts: []string{"aaron2"}, + }, + } + got, err := setTailscaleHosts(logger.Discard, in, he) if err != nil { t.Fatal(err) } diff --git a/net/dns/nm.go b/net/dns/nm.go index adb33cdb7967a..97557e33aa9bf 100644 --- a/net/dns/nm.go +++ b/net/dns/nm.go @@ -1,12 +1,13 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux +//go:build linux && !android package dns import ( "context" + "encoding/binary" "fmt" "net" "net/netip" @@ -14,7 +15,6 @@ import ( "time" "github.com/godbus/dbus/v5" - "github.com/josharian/native" "tailscale.com/net/tsaddr" "tailscale.com/util/dnsname" ) @@ -137,7 +137,7 @@ func (m *nmManager) trySet(ctx context.Context, config OSConfig) error { for _, ip := range config.Nameservers { b := ip.As16() if ip.Is4() { - dnsv4 = append(dnsv4, native.Endian.Uint32(b[12:])) + dnsv4 = append(dnsv4, binary.NativeEndian.Uint32(b[12:])) } else { dnsv6 = append(dnsv6, b[:]) } diff --git a/net/dns/openresolv.go b/net/dns/openresolv.go index 0b5c87a3b3534..c9562b6a91d13 100644 --- a/net/dns/openresolv.go +++ b/net/dns/openresolv.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || freebsd || openbsd +//go:build (linux && !android) || freebsd || openbsd package dns diff --git a/net/dns/resolvd.go b/net/dns/resolvd.go index 9b067eb07b178..ad1a99c111997 100644 --- a/net/dns/resolvd.go +++ b/net/dns/resolvd.go @@ -57,6 +57,7 @@ func (m *resolvdManager) SetDNS(config OSConfig) error { if len(newSearch) > 1 { newResolvConf = append(newResolvConf, []byte(strings.Join(newSearch, " "))...) + newResolvConf = append(newResolvConf, '\n') } err = m.fs.WriteFile(resolvConf, newResolvConf, 0644) @@ -123,6 +124,6 @@ func (m resolvdManager) readResolvConf() (config OSConfig, err error) { } func removeSearchLines(orig []byte) []byte { - re := regexp.MustCompile(`(?m)^search\s+.+$`) + re := regexp.MustCompile(`(?ms)^search\s+.+$`) return re.ReplaceAll(orig, []byte("")) } diff --git a/net/dns/resolved.go b/net/dns/resolved.go index d82d3fc31d80a..4f58f3f9cc080 100644 --- a/net/dns/resolved.go +++ b/net/dns/resolved.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux +//go:build linux && !android package dns @@ -163,9 +163,9 @@ func (m *resolvedManager) run(ctx context.Context) { } conn.Signal(signals) - // Reset backoff and SetNSOSHealth after successful on reconnect. + // Reset backoff and set osConfigurationSetWarnable to healthy after a successful reconnect. bo.BackOff(ctx, nil) - m.health.SetDNSOSHealth(nil) + m.health.SetHealthy(osConfigurationSetWarnable) return nil } @@ -243,9 +243,12 @@ func (m *resolvedManager) run(ctx context.Context) { // Set health while holding the lock, because this will // graciously serialize the resync's health outcome with a // concurrent SetDNS call. - m.health.SetDNSOSHealth(err) + if err != nil { m.logf("failed to configure systemd-resolved: %v", err) + m.health.SetUnhealthy(osConfigurationSetWarnable, health.Args{health.ArgError: err.Error()}) + } else { + m.health.SetHealthy(osConfigurationSetWarnable) } } } diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index c00dea1aea8c4..c87fbd5041a93 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -17,6 +17,7 @@ import ( "net/http" "net/netip" "net/url" + "runtime" "sort" "strings" "sync" @@ -31,6 +32,7 @@ import ( "tailscale.com/net/dnscache" "tailscale.com/net/neterror" "tailscale.com/net/netmon" + "tailscale.com/net/netx" "tailscale.com/net/sockstats" "tailscale.com/net/tsdial" "tailscale.com/types/dnstype" @@ -243,12 +245,6 @@ type forwarder struct { // /etc/resolv.conf is missing/corrupt, and the peerapi ExitDNS stub // resolver lookup. cloudHostFallback []resolverAndDelay - - // missingUpstreamRecovery, if non-nil, is set called when a SERVFAIL is - // returned due to missing upstream resolvers. - // - // This should attempt to properly (re)set the upstream resolvers. - missingUpstreamRecovery func() } func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkSelector, dialer *tsdial.Dialer, health *health.Tracker, knobs *controlknobs.Knobs) *forwarder { @@ -256,13 +252,12 @@ func newForwarder(logf logger.Logf, netMon *netmon.Monitor, linkSel ForwardLinkS panic("nil netMon") } f := &forwarder{ - logf: logger.WithPrefix(logf, "forward: "), - netMon: netMon, - linkSel: linkSel, - dialer: dialer, - health: health, - controlKnobs: knobs, - missingUpstreamRecovery: func() {}, + logf: logger.WithPrefix(logf, "forward: "), + netMon: netMon, + linkSel: linkSel, + dialer: dialer, + health: health, + controlKnobs: knobs, } f.ctx, f.ctxCancel = context.WithCancel(context.Background()) return f @@ -739,18 +734,35 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn return out, nil } -func (f *forwarder) getDialerType() dnscache.DialContextFunc { - if f.controlKnobs != nil && f.controlKnobs.UserDialUseRoutes.Load() { - // It is safe to use UserDial as it dials external servers without going through Tailscale - // and closes connections on interface change in the same way as SystemDial does, - // thus preventing DNS resolution issues when switching between WiFi and cellular, - // but can also dial an internal DNS server on the Tailnet or via a subnet router. - // - // TODO(nickkhyl): Update tsdial.Dialer to reuse the bart.Table we create in net/tstun.Wrapper - // to avoid having two bart tables in memory, especially on iOS. Once that's done, - // we can get rid of the nodeAttr/control knob and always use UserDial for DNS. - // - // See https://github.com/tailscale/tailscale/issues/12027. +var optDNSForwardUseRoutes = envknob.RegisterOptBool("TS_DEBUG_DNS_FORWARD_USE_ROUTES") + +// ShouldUseRoutes reports whether the DNS resolver should consider routes when dialing +// upstream nameservers via TCP. +// +// If true, routes should be considered ([tsdial.Dialer.UserDial]), otherwise defer +// to the system routes ([tsdial.Dialer.SystemDial]). +// +// TODO(nickkhyl): Update [tsdial.Dialer] to reuse the bart.Table we create in net/tstun.Wrapper +// to avoid having two bart tables in memory, especially on iOS. Once that's done, +// we can get rid of the nodeAttr/control knob and always use UserDial for DNS. +// +// See tailscale/tailscale#12027. +func ShouldUseRoutes(knobs *controlknobs.Knobs) bool { + switch runtime.GOOS { + case "android", "ios": + // On mobile platforms with lower memory limits (e.g., 50MB on iOS), + // this behavior is still gated by the "user-dial-routes" nodeAttr. + return knobs != nil && knobs.UserDialUseRoutes.Load() + default: + // On all other platforms, it is the default behavior, + // but it can be overridden with the "TS_DEBUG_DNS_FORWARD_USE_ROUTES" env var. + doNotUseRoutes := optDNSForwardUseRoutes().EqualBool(false) + return !doNotUseRoutes + } +} + +func (f *forwarder) getDialerType() netx.DialFunc { + if ShouldUseRoutes(f.controlKnobs) { return f.dialer.UserDial } return f.dialer.SystemDial @@ -943,13 +955,6 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: ""}) f.logf("no upstream resolvers set, returning SERVFAIL") - // Attempt to recompile the DNS configuration - // If we are being asked to forward queries and we have no - // nameservers, the network is in a bad state. - if f.missingUpstreamRecovery != nil { - f.missingUpstreamRecovery() - } - res, err := servfailResponse(query) if err != nil { return err diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index e341186ecf45e..f7cda15f6a000 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -27,7 +27,9 @@ import ( "tailscale.com/health" "tailscale.com/net/netmon" "tailscale.com/net/tsdial" + "tailscale.com/tstest" "tailscale.com/types/dnstype" + "tailscale.com/util/eventbus" ) func (rr resolverAndDelay) String() string { @@ -276,6 +278,8 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on tb.Fatal("cannot skip both UDP and TCP servers") } + logf := tstest.WhileTestRunningLogger(tb) + tcpResponse := make([]byte, len(response)+2) binary.BigEndian.PutUint16(tcpResponse, uint16(len(response))) copy(tcpResponse[2:], response) @@ -329,13 +333,13 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on // Read the length header, then the buffer var length uint16 if err := binary.Read(conn, binary.BigEndian, &length); err != nil { - tb.Logf("error reading length header: %v", err) + logf("error reading length header: %v", err) return } req := make([]byte, length) n, err := io.ReadFull(conn, req) if err != nil { - tb.Logf("error reading query: %v", err) + logf("error reading query: %v", err) return } req = req[:n] @@ -343,7 +347,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on // Write response if _, err := conn.Write(tcpResponse); err != nil { - tb.Logf("error writing response: %v", err) + logf("error writing response: %v", err) return } } @@ -367,7 +371,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on handleUDP := func(addr netip.AddrPort, req []byte) { onRequest(false, req) if _, err := udpLn.WriteToUDPAddrPort(response, addr); err != nil { - tb.Logf("error writing response: %v", err) + logf("error writing response: %v", err) } } @@ -390,7 +394,7 @@ func runDNSServer(tb testing.TB, opts *testDNSServerOptions, response []byte, on tb.Cleanup(func() { tcpLn.Close() udpLn.Close() - tb.Logf("waiting for listeners to finish...") + logf("waiting for listeners to finish...") wg.Wait() }) return @@ -450,7 +454,10 @@ func makeLargeResponse(tb testing.TB, domain string) (request, response []byte) } func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports ...uint16) ([]byte, error) { - netMon, err := netmon.New(tb.Logf) + logf := tstest.WhileTestRunningLogger(tb) + bus := eventbus.New() + defer bus.Close() + netMon, err := netmon.New(bus, logf) if err != nil { tb.Fatal(err) } @@ -458,7 +465,7 @@ func runTestQuery(tb testing.TB, request []byte, modify func(*forwarder), ports var dialer tsdial.Dialer dialer.SetNetMon(netMon) - fwd := newForwarder(tb.Logf, netMon, nil, &dialer, new(health.Tracker), nil) + fwd := newForwarder(logf, netMon, nil, &dialer, new(health.Tracker), nil) if modify != nil { modify(fwd) } diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 43ba0acf194f2..33fa9c3c07d4c 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -251,15 +251,6 @@ func New(logf logger.Logf, linkSel ForwardLinkSelector, dialer *tsdial.Dialer, h return r } -// SetMissingUpstreamRecovery sets a callback to be called upon encountering -// a SERVFAIL due to missing upstream resolvers. -// -// This call should only happen before the resolver is used. It is not safe -// for concurrent use. -func (r *Resolver) SetMissingUpstreamRecovery(f func()) { - r.forwarder.missingUpstreamRecovery = f -} - func (r *Resolver) TestOnlySetHook(hook func(Config)) { r.saveConfigForTests = hook } func (r *Resolver) SetConfig(cfg Config) error { @@ -384,7 +375,7 @@ func (r *Resolver) HandlePeerDNSQuery(ctx context.Context, q []byte, from netip. // but for now that's probably good enough. Later we'll // want to blend in everything from scutil --dns. fallthrough - case "linux", "freebsd", "openbsd", "illumos", "ios": + case "linux", "freebsd", "openbsd", "illumos", "solaris", "ios": nameserver, err := stubResolverForOS() if err != nil { r.logf("stubResolverForOS: %v", err) diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index d7b9fb360eaf0..de08450d2e3eb 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -31,6 +31,7 @@ import ( "tailscale.com/types/dnstype" "tailscale.com/types/logger" "tailscale.com/util/dnsname" + "tailscale.com/util/eventbus" ) var ( @@ -1059,7 +1060,10 @@ func TestForwardLinkSelection(t *testing.T) { // routes differently. specialIP := netaddr.IPv4(1, 2, 3, 4) - netMon, err := netmon.New(logger.WithPrefix(t.Logf, ".... netmon: ")) + bus := eventbus.New() + defer bus.Close() + + netMon, err := netmon.New(bus, logger.WithPrefix(t.Logf, ".... netmon: ")) if err != nil { t.Fatal(err) } diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index 2cbea6c0fd896..96550cbb17fca 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -19,6 +19,7 @@ import ( "time" "tailscale.com/envknob" + "tailscale.com/net/netx" "tailscale.com/types/logger" "tailscale.com/util/cloudenv" "tailscale.com/util/singleflight" @@ -355,10 +356,8 @@ func (r *Resolver) addIPCache(host string, ip, ip6 netip.Addr, allIPs []netip.Ad } } -type DialContextFunc func(ctx context.Context, network, address string) (net.Conn, error) - // Dialer returns a wrapped DialContext func that uses the provided dnsCache. -func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { +func Dialer(fwd netx.DialFunc, dnsCache *Resolver) netx.DialFunc { d := &dialer{ fwd: fwd, dnsCache: dnsCache, @@ -369,7 +368,7 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { // dialer is the config and accumulated state for a dial func returned by Dialer. type dialer struct { - fwd DialContextFunc + fwd netx.DialFunc dnsCache *Resolver mu sync.Mutex @@ -653,7 +652,7 @@ func v6addrs(aa []netip.Addr) (ret []netip.Addr) { // TLSDialer is like Dialer but returns a func suitable for using with net/http.Transport.DialTLSContext. // It returns a *tls.Conn type on success. // On TLS cert validation failure, it can invoke a backup DNS resolution strategy. -func TLSDialer(fwd DialContextFunc, dnsCache *Resolver, tlsConfigBase *tls.Config) DialContextFunc { +func TLSDialer(fwd netx.DialFunc, dnsCache *Resolver, tlsConfigBase *tls.Config) netx.DialFunc { tcpDialer := Dialer(fwd, dnsCache) return func(ctx context.Context, network, address string) (net.Conn, error) { host, _, err := net.SplitHostPort(address) diff --git a/net/dnsfallback/dnsfallback_test.go b/net/dnsfallback/dnsfallback_test.go index 16f5027d4850f..7f881057450e7 100644 --- a/net/dnsfallback/dnsfallback_test.go +++ b/net/dnsfallback/dnsfallback_test.go @@ -15,6 +15,7 @@ import ( "tailscale.com/net/netmon" "tailscale.com/tailcfg" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" ) func TestGetDERPMap(t *testing.T) { @@ -185,7 +186,10 @@ func TestLookup(t *testing.T) { logf, closeLogf := logger.LogfCloser(t.Logf) defer closeLogf() - netMon, err := netmon.New(logf) + bus := eventbus.New() + defer bus.Close() + + netMon, err := netmon.New(bus, logf) if err != nil { t.Fatal(err) } diff --git a/net/ipset/ipset.go b/net/ipset/ipset.go index 622fd61d05c16..27c1e27ed4180 100644 --- a/net/ipset/ipset.go +++ b/net/ipset/ipset.go @@ -82,8 +82,8 @@ func NewContainsIPFunc(addrs views.Slice[netip.Prefix]) func(ip netip.Addr) bool pathForTest("bart") // Built a bart table. t := &bart.Table[struct{}]{} - for i := range addrs.Len() { - t.Insert(addrs.At(i), struct{}{}) + for _, p := range addrs.All() { + t.Insert(p, struct{}{}) } return bartLookup(t) } @@ -99,8 +99,8 @@ func NewContainsIPFunc(addrs views.Slice[netip.Prefix]) func(ip netip.Addr) bool // General case: pathForTest("ip-map") m := set.Set[netip.Addr]{} - for i := range addrs.Len() { - m.Add(addrs.At(i).Addr()) + for _, p := range addrs.All() { + m.Add(p.Addr()) } return ipInMap(m) } diff --git a/net/ktimeout/ktimeout_linux_test.go b/net/ktimeout/ktimeout_linux_test.go index a367bfd4a5a95..df41567454f4b 100644 --- a/net/ktimeout/ktimeout_linux_test.go +++ b/net/ktimeout/ktimeout_linux_test.go @@ -4,17 +4,22 @@ package ktimeout import ( + "context" "net" "testing" "time" - "golang.org/x/net/nettest" "golang.org/x/sys/unix" "tailscale.com/util/must" ) func TestSetUserTimeout(t *testing.T) { - l := must.Get(nettest.NewLocalListener("tcp")) + lc := net.ListenConfig{} + // As of 2025-02-19, MPTCP does not support TCP_USER_TIMEOUT socket option + // set in ktimeout.UserTimeout above. + lc.SetMultipathTCP(false) + + l := must.Get(lc.Listen(context.Background(), "tcp", "localhost:0")) defer l.Close() var err error diff --git a/net/memnet/memnet.go b/net/memnet/memnet.go index c8799bc17035e..7c2435684059e 100644 --- a/net/memnet/memnet.go +++ b/net/memnet/memnet.go @@ -6,3 +6,82 @@ // in tests and other situations where you don't want to use the // network. package memnet + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + + "tailscale.com/net/netx" +) + +var _ netx.Network = (*Network)(nil) + +// Network implements [Network] using an in-memory network, usually +// used for testing. +// +// As of 2025-04-08, it only supports TCP. +// +// Its zero value is a valid [netx.Network] implementation. +type Network struct { + mu sync.Mutex + lns map[string]*Listener // address -> listener +} + +func (m *Network) Listen(network, address string) (net.Listener, error) { + if network != "tcp" && network != "tcp4" && network != "tcp6" { + return nil, fmt.Errorf("memNetwork: Listen called with unsupported network %q", network) + } + ap, err := netip.ParseAddrPort(address) + if err != nil { + return nil, fmt.Errorf("memNetwork: Listen called with invalid address %q: %w", address, err) + } + + m.mu.Lock() + defer m.mu.Unlock() + + if m.lns == nil { + m.lns = make(map[string]*Listener) + } + port := ap.Port() + for { + if port == 0 { + port = 33000 + } + key := net.JoinHostPort(ap.Addr().String(), fmt.Sprint(port)) + _, ok := m.lns[key] + if ok { + if ap.Port() != 0 { + return nil, fmt.Errorf("memNetwork: Listen called with duplicate address %q", address) + } + port++ + continue + } + ln := Listen(key) + m.lns[key] = ln + return ln, nil + } +} + +func (m *Network) NewLocalTCPListener() net.Listener { + ln, err := m.Listen("tcp", "127.0.0.1:0") + if err != nil { + panic(fmt.Sprintf("memNetwork: failed to create local TCP listener: %v", err)) + } + return ln +} + +func (m *Network) Dial(ctx context.Context, network, address string) (net.Conn, error) { + if network != "tcp" && network != "tcp4" && network != "tcp6" { + return nil, fmt.Errorf("memNetwork: Dial called with unsupported network %q", network) + } + m.mu.Lock() + ln, ok := m.lns[address] + m.mu.Unlock() + if !ok { + return nil, fmt.Errorf("memNetwork: Dial called on unknown address %q", address) + } + return ln.Dial(ctx, network, address) +} diff --git a/net/netcheck/netcheck.go b/net/netcheck/netcheck.go index bebf4c9b05461..c9f03966beedd 100644 --- a/net/netcheck/netcheck.go +++ b/net/netcheck/netcheck.go @@ -23,9 +23,9 @@ import ( "syscall" "time" - "github.com/tcnksm/go-httpstat" "tailscale.com/derp/derphttp" "tailscale.com/envknob" + "tailscale.com/hostinfo" "tailscale.com/net/captivedetection" "tailscale.com/net/dnscache" "tailscale.com/net/neterror" @@ -85,13 +85,14 @@ const ( // Report contains the result of a single netcheck. type Report struct { - UDP bool // a UDP STUN round trip completed - IPv6 bool // an IPv6 STUN round trip completed - IPv4 bool // an IPv4 STUN round trip completed - IPv6CanSend bool // an IPv6 packet was able to be sent - IPv4CanSend bool // an IPv4 packet was able to be sent - OSHasIPv6 bool // could bind a socket to ::1 - ICMPv4 bool // an ICMPv4 round trip completed + Now time.Time // the time the report was run + UDP bool // a UDP STUN round trip completed + IPv6 bool // an IPv6 STUN round trip completed + IPv4 bool // an IPv4 STUN round trip completed + IPv6CanSend bool // an IPv6 packet was able to be sent + IPv4CanSend bool // an IPv4 packet was able to be sent + OSHasIPv6 bool // could bind a socket to ::1 + ICMPv4 bool // an ICMPv4 round trip completed // MappingVariesByDestIP is whether STUN results depend which // STUN server you're talking to (on IPv4). @@ -172,25 +173,14 @@ func (r *Report) Clone() *Report { return nil } r2 := *r - r2.RegionLatency = cloneDurationMap(r2.RegionLatency) - r2.RegionV4Latency = cloneDurationMap(r2.RegionV4Latency) - r2.RegionV6Latency = cloneDurationMap(r2.RegionV6Latency) + r2.RegionLatency = maps.Clone(r2.RegionLatency) + r2.RegionV4Latency = maps.Clone(r2.RegionV4Latency) + r2.RegionV6Latency = maps.Clone(r2.RegionV6Latency) r2.GlobalV4Counters = maps.Clone(r2.GlobalV4Counters) r2.GlobalV6Counters = maps.Clone(r2.GlobalV6Counters) return &r2 } -func cloneDurationMap(m map[int]time.Duration) map[int]time.Duration { - if m == nil { - return nil - } - m2 := make(map[int]time.Duration, len(m)) - for k, v := range m { - m2[k] = v - } - return m2 -} - // Client generates Reports describing the result of both passive and active // network configuration probing. It provides two different modes of report, a // full report (see MakeNextReportFull) and a more lightweight incremental @@ -235,6 +225,10 @@ type Client struct { // If false, the default net.Resolver will be used, with no caching. UseDNSCache bool + // if non-zero, force this DERP region to be preferred in all reports where + // the DERP is found to be reachable. + ForcePreferredDERP int + // For tests testEnoughRegions int testCaptivePortalDelay time.Duration @@ -391,10 +385,14 @@ type probePlan map[string][]probe // sortRegions returns the regions of dm first sorted // from fastest to slowest (based on the 'last' report), // end in regions that have no data. -func sortRegions(dm *tailcfg.DERPMap, last *Report) (prev []*tailcfg.DERPRegion) { +func sortRegions(dm *tailcfg.DERPMap, last *Report, preferredDERP int) (prev []*tailcfg.DERPRegion) { prev = make([]*tailcfg.DERPRegion, 0, len(dm.Regions)) for _, reg := range dm.Regions { - if reg.Avoid { + if reg.NoMeasureNoHome { + continue + } + // include an otherwise avoid region if it is the current preferred region + if reg.Avoid && reg.RegionID != preferredDERP { continue } prev = append(prev, reg) @@ -419,9 +417,19 @@ func sortRegions(dm *tailcfg.DERPMap, last *Report) (prev []*tailcfg.DERPRegion) // a full report, all regions are scanned.) const numIncrementalRegions = 3 -// makeProbePlan generates the probe plan for a DERPMap, given the most -// recent report and whether IPv6 is configured on an interface. -func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report) (plan probePlan) { +// makeProbePlan generates the probe plan for a DERPMap, given the most recent +// report and the current home DERP. preferredDERP is passed independently of +// last (report) because last is currently nil'd to indicate a desire for a full +// netcheck. +// +// TODO(raggi,jwhited): refactor the callers and this function to be more clear +// about full vs. incremental netchecks, and remove the need for the history +// hiding. This was avoided in an incremental change due to exactly this kind of +// distant coupling. +// TODO(raggi): change from "preferred DERP" from a historical report to "home +// DERP" as in what DERP is the current home connection, this would further +// reduce flap events. +func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report, preferredDERP int) (plan probePlan) { if last == nil || len(last.RegionLatency) == 0 { return makeProbePlanInitial(dm, ifState) } @@ -432,9 +440,34 @@ func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report) (pl had4 := len(last.RegionV4Latency) > 0 had6 := len(last.RegionV6Latency) > 0 hadBoth := have6if && had4 && had6 - for ri, reg := range sortRegions(dm, last) { - if ri == numIncrementalRegions { - break + // #13969 ensure that the home region is always probed. + // If a netcheck has unstable latency, such as a user with large amounts of + // bufferbloat or a highly congested connection, there are cases where a full + // netcheck may observe a one-off high latency to the current home DERP. Prior + // to the forced inclusion of the home DERP, this would result in an + // incremental netcheck following such an event to cause a home DERP move, with + // restoration back to the home DERP on the next full netcheck ~5 minutes later + // - which is highly disruptive when it causes shifts in geo routed subnet + // routers. By always including the home DERP in the incremental netcheck, we + // ensure that the home DERP is always probed, even if it observed a recenet + // poor latency sample. This inclusion enables the latency history checks in + // home DERP selection to still take effect. + // planContainsHome indicates whether the home DERP has been added to the probePlan, + // if there is no prior home, then there's no home to additionally include. + planContainsHome := preferredDERP == 0 + for ri, reg := range sortRegions(dm, last, preferredDERP) { + regIsHome := reg.RegionID == preferredDERP + if ri >= numIncrementalRegions { + // planned at least numIncrementalRegions regions and that includes the + // last home region (or there was none), plan complete. + if planContainsHome { + break + } + // planned at least numIncrementalRegions regions, but not the home region, + // check if this is the home region, if not, skip it. + if !regIsHome { + continue + } } var p4, p6 []probe do4 := have4if @@ -445,7 +478,7 @@ func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report) (pl tries := 1 isFastestTwo := ri < 2 - if isFastestTwo { + if isFastestTwo || regIsHome { tries = 2 } else if hadBoth { // For dual stack machines, make the 3rd & slower nodes alternate @@ -456,14 +489,15 @@ func makeProbePlan(dm *tailcfg.DERPMap, ifState *netmon.State, last *Report) (pl do4, do6 = false, true } } - if !isFastestTwo && !had6 { + if !regIsHome && !isFastestTwo && !had6 { do6 = false } - if reg.RegionID == last.PreferredDERP { + if regIsHome { // But if we already had a DERP home, try extra hard to // make sure it's there so we don't flip flop around. tries = 4 + planContainsHome = true } for try := 0; try < tries; try++ { @@ -503,7 +537,7 @@ func makeProbePlanInitial(dm *tailcfg.DERPMap, ifState *netmon.State) (plan prob plan = make(probePlan) for _, reg := range dm.Regions { - if len(reg.Nodes) == 0 { + if reg.NoMeasureNoHome || len(reg.Nodes) == 0 { continue } @@ -742,6 +776,12 @@ func (o *GetReportOpts) getLastDERPActivity(region int) time.Time { return o.GetLastDERPActivity(region) } +func (c *Client) SetForcePreferredDERP(region int) { + c.mu.Lock() + defer c.mu.Unlock() + c.ForcePreferredDERP = region +} + // GetReport gets a report. The 'opts' argument is optional and can be nil. // Callers are discouraged from passing a ctx with an arbitrary deadline as this // may cause GetReport to return prematurely before all reporting methods have @@ -788,9 +828,10 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe c.curState = rs last := c.last - // Even if we're doing a non-incremental update, we may want to try our - // preferred DERP region for captive portal detection. Save that, if we - // have it. + // Extract preferredDERP from the last report, if available. This will be used + // in captive portal detection and DERP flapping suppression. Ideally this would + // be the current active home DERP rather than the last report preferred DERP, + // but only the latter is presently available. var preferredDERP int if last != nil { preferredDERP = last.PreferredDERP @@ -823,7 +864,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe c.curState = nil }() - if runtime.GOOS == "js" || runtime.GOOS == "tamago" { + if runtime.GOOS == "js" || runtime.GOOS == "tamago" || (runtime.GOOS == "plan9" && hostinfo.IsInVM86()) { if err := c.runHTTPOnlyChecks(ctx, last, rs, dm); err != nil { return nil, err } @@ -847,7 +888,7 @@ func (c *Client) GetReport(ctx context.Context, dm *tailcfg.DERPMap, opts *GetRe var plan probePlan if opts == nil || !opts.OnlyTCP443 { - plan = makeProbePlan(dm, ifState, last) + plan = makeProbePlan(dm, ifState, last, preferredDERP) } // If we're doing a full probe, also check for a captive portal. We @@ -1005,7 +1046,7 @@ func (c *Client) finishAndStoreReport(rs *reportState, dm *tailcfg.DERPMap) *Rep } // runHTTPOnlyChecks is the netcheck done by environments that can -// only do HTTP requests, such as ws/wasm. +// only do HTTP requests, such as js/wasm. func (c *Client) runHTTPOnlyChecks(ctx context.Context, last *Report, rs *reportState, dm *tailcfg.DERPMap) error { var regions []*tailcfg.DERPRegion if rs.incremental && last != nil { @@ -1017,9 +1058,25 @@ func (c *Client) runHTTPOnlyChecks(ctx context.Context, last *Report, rs *report } if len(regions) == 0 { for _, dr := range dm.Regions { + if dr.NoMeasureNoHome { + continue + } regions = append(regions, dr) } } + + if len(regions) == 1 && hostinfo.IsInVM86() { + // If we only have 1 region that's probably and we're in a + // network-limited v86 environment, don't actually probe it. Just fake + // some results. + rg := regions[0] + if len(rg.Nodes) > 0 { + node := rg.Nodes[0] + rs.addNodeLatency(node, netip.AddrPort{}, 999*time.Millisecond) + return nil + } + } + c.logf("running HTTP-only netcheck against %v regions", len(regions)) var wg sync.WaitGroup @@ -1061,10 +1118,11 @@ func (c *Client) runHTTPOnlyChecks(ctx context.Context, last *Report, rs *report return nil } +// measureHTTPSLatency measures HTTP request latency to the DERP region, but +// only returns success if an HTTPS request to the region succeeds. func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegion) (time.Duration, netip.Addr, error) { metricHTTPSend.Add(1) - var result httpstat.Result - ctx, cancel := context.WithTimeout(httpstat.WithHTTPStat(ctx, &result), httpsProbeTimeout) + ctx, cancel := context.WithTimeout(ctx, httpsProbeTimeout) defer cancel() var ip netip.Addr @@ -1072,6 +1130,8 @@ func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegio dc := derphttp.NewNetcheckClient(c.logf, c.NetMon) defer dc.Close() + // DialRegionTLS may dial multiple times if a node is not available, as such + // it does not have stable timing to measure. tlsConn, tcpConn, node, err := dc.DialRegionTLS(ctx, reg) if err != nil { return 0, ip, err @@ -1089,6 +1149,8 @@ func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegio connc := make(chan *tls.Conn, 1) connc <- tlsConn + // make an HTTP request to measure, as this enables us to account for MITM + // overhead in e.g. corp environments that have HTTP MITM in front of DERP. tr := &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return nil, errors.New("unexpected DialContext dial") @@ -1104,12 +1166,17 @@ func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegio } hc := &http.Client{Transport: tr} + // This is the request that will be measured, the request and response + // should be small enough to fit into a single packet each way unless the + // connection has already become unstable. req, err := http.NewRequestWithContext(ctx, "GET", "https://"+node.HostName+"/derp/latency-check", nil) if err != nil { return 0, ip, err } + startTime := c.timeNow() resp, err := hc.Do(req) + reqDur := c.timeNow().Sub(startTime) if err != nil { return 0, ip, err } @@ -1126,17 +1193,22 @@ func (c *Client) measureHTTPSLatency(ctx context.Context, reg *tailcfg.DERPRegio if err != nil { return 0, ip, err } - result.End(c.timeNow()) - // TODO: decide best timing heuristic here. - // Maybe the server should return the tcpinfo_rtt? - return result.ServerProcessing, ip, nil + // return the connection duration, not the request duration, as this is the + // best approximation of the RTT latency to the node. Note that the + // connection setup performs happy-eyeballs and TLS so there are additional + // overheads. + return reqDur, ip, nil } func (c *Client) measureAllICMPLatency(ctx context.Context, rs *reportState, need []*tailcfg.DERPRegion) error { if len(need) == 0 { return nil } + if runtime.GOOS == "plan9" { + // ICMP isn't implemented. + return nil + } ctx, done := context.WithTimeout(ctx, icmpProbeTimeout) defer done() @@ -1182,17 +1254,19 @@ func (c *Client) measureICMPLatency(ctx context.Context, reg *tailcfg.DERPRegion // Try pinging the first node in the region node := reg.Nodes[0] - // Get the IPAddr by asking for the UDP address that we would use for - // STUN and then using that IP. - // - // TODO(andrew-d): this is a bit ugly - nodeAddr := c.nodeAddr(ctx, node, probeIPv4) - if !nodeAddr.IsValid() { + if node.STUNPort < 0 { + // If STUN is disabled on a node, interpret that as meaning don't measure latency. + return 0, false, nil + } + const unusedPort = 0 + stunAddrPort, ok := c.nodeAddrPort(ctx, node, unusedPort, probeIPv4) + if !ok { return 0, false, fmt.Errorf("no address for node %v (v4-for-icmp)", node.Name) } + ip := stunAddrPort.Addr() addr := &net.IPAddr{ - IP: net.IP(nodeAddr.Addr().AsSlice()), - Zone: nodeAddr.Addr().Zone(), + IP: net.IP(ip.AsSlice()), + Zone: ip.Zone(), } // Use the unique node.Name field as the packet data to reduce the @@ -1236,6 +1310,9 @@ func (c *Client) logConciseReport(r *Report, dm *tailcfg.DERPMap) { if r.CaptivePortal != "" { fmt.Fprintf(w, " captiveportal=%v", r.CaptivePortal) } + if c.ForcePreferredDERP != 0 { + fmt.Fprintf(w, " force=%v", c.ForcePreferredDERP) + } fmt.Fprintf(w, " derp=%v", r.PreferredDERP) if r.PreferredDERP != 0 { fmt.Fprintf(w, " derpdist=") @@ -1297,6 +1374,7 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(rs *reportState, r *Report, c.prev = map[time.Time]*Report{} } now := c.timeNow() + r.Now = now.UTC() c.prev[now] = r c.last = r @@ -1393,6 +1471,21 @@ func (c *Client) addReportHistoryAndSetPreferredDERP(rs *reportState, r *Report, // which undoes any region change we made above. r.PreferredDERP = prevDERP } + if c.ForcePreferredDERP != 0 { + // If the forced DERP region probed successfully, or has recent traffic, + // use it. + _, haveLatencySample := r.RegionLatency[c.ForcePreferredDERP] + var recentActivity bool + if lastHeard := rs.opts.getLastDERPActivity(c.ForcePreferredDERP); !lastHeard.IsZero() { + now := c.timeNow() + recentActivity = lastHeard.After(rs.start) + recentActivity = recentActivity || lastHeard.After(now.Add(-PreferredDERPFrameTime)) + } + + if haveLatencySample || recentActivity { + r.PreferredDERP = c.ForcePreferredDERP + } + } } func updateLatency(m map[int]time.Duration, regionID int, d time.Duration) { @@ -1438,8 +1531,8 @@ func (rs *reportState) runProbe(ctx context.Context, dm *tailcfg.DERPMap, probe return } - addr := c.nodeAddr(ctx, node, probe.proto) - if !addr.IsValid() { + addr, ok := c.nodeAddrPort(ctx, node, node.STUNPort, probe.proto) + if !ok { c.logf("netcheck.runProbe: named node %q has no %v address", probe.node, probe.proto) return } @@ -1488,12 +1581,20 @@ func (rs *reportState) runProbe(ctx context.Context, dm *tailcfg.DERPMap, probe c.vlogf("sent to %v", addr) } -// proto is 4 or 6 -// If it returns nil, the node is skipped. -func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeProto) (ap netip.AddrPort) { - port := cmp.Or(n.STUNPort, 3478) +// nodeAddrPort returns the IP:port to send a STUN queries to for a given node. +// +// The provided port should be n.STUNPort, which may be negative to disable STUN. +// If STUN is disabled for this node, it returns ok=false. +// The port parameter is separate for the ICMP caller to provide a fake value. +// +// proto is [probeIPv4] or [probeIPv6]. +func (c *Client) nodeAddrPort(ctx context.Context, n *tailcfg.DERPNode, port int, proto probeProto) (_ netip.AddrPort, ok bool) { + var zero netip.AddrPort if port < 0 || port > 1<<16-1 { - return + return zero, false + } + if port == 0 { + port = 3478 } if n.STUNTestIP != "" { ip, err := netip.ParseAddr(n.STUNTestIP) @@ -1506,7 +1607,7 @@ func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeP if proto == probeIPv6 && ip.Is4() { return } - return netip.AddrPortFrom(ip, uint16(port)) + return netip.AddrPortFrom(ip, uint16(port)), true } switch proto { @@ -1514,20 +1615,20 @@ func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeP if n.IPv4 != "" { ip, _ := netip.ParseAddr(n.IPv4) if !ip.Is4() { - return + return zero, false } - return netip.AddrPortFrom(ip, uint16(port)) + return netip.AddrPortFrom(ip, uint16(port)), true } case probeIPv6: if n.IPv6 != "" { ip, _ := netip.ParseAddr(n.IPv6) if !ip.Is6() { - return + return zero, false } - return netip.AddrPortFrom(ip, uint16(port)) + return netip.AddrPortFrom(ip, uint16(port)), true } default: - return + return zero, false } // The default lookup function if we don't set UseDNSCache is to use net.DefaultResolver. @@ -1569,13 +1670,13 @@ func (c *Client) nodeAddr(ctx context.Context, n *tailcfg.DERPNode, proto probeP addrs, err := lookupIPAddr(ctx, n.HostName) for _, a := range addrs { if (a.Is4() && probeIsV4) || (a.Is6() && !probeIsV4) { - return netip.AddrPortFrom(a, uint16(port)) + return netip.AddrPortFrom(a, uint16(port)), true } } if err != nil { c.logf("netcheck: DNS lookup error for %q (node %q region %v): %v", n.HostName, n.Name, n.RegionID, err) } - return + return zero, false } func regionHasDERPNode(r *tailcfg.DERPRegion) bool { diff --git a/net/netcheck/netcheck_test.go b/net/netcheck/netcheck_test.go index 02076f8d468e1..3affa614dae88 100644 --- a/net/netcheck/netcheck_test.go +++ b/net/netcheck/netcheck_test.go @@ -28,6 +28,9 @@ func newTestClient(t testing.TB) *Client { c := &Client{ NetMon: netmon.NewStatic(), Logf: t.Logf, + TimeNow: func() time.Time { + return time.Unix(1729624521, 0) + }, } return c } @@ -38,7 +41,7 @@ func TestBasic(t *testing.T) { c := newTestClient(t) - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() if err := c.Standalone(ctx, "127.0.0.1:0"); err != nil { @@ -52,6 +55,9 @@ func TestBasic(t *testing.T) { if !r.UDP { t.Error("want UDP") } + if r.Now.IsZero() { + t.Error("Now is zero") + } if len(r.RegionLatency) != 1 { t.Errorf("expected 1 key in DERPLatency; got %+v", r.RegionLatency) } @@ -117,7 +123,7 @@ func TestWorksWhenUDPBlocked(t *testing.T) { c := newTestClient(t) - ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + ctx, cancel := context.WithCancel(context.Background()) defer cancel() r, err := c.GetReport(ctx, dm, nil) @@ -130,6 +136,14 @@ func TestWorksWhenUDPBlocked(t *testing.T) { want := newReport() + // The Now field can't be compared with reflect.DeepEqual; check using + // the Equal method and then overwrite it so that the comparison below + // succeeds. + if !r.Now.Equal(c.TimeNow()) { + t.Errorf("Now = %v; want %v", r.Now, c.TimeNow()) + } + want.Now = r.Now + // The IPv4CanSend flag gets set differently across platforms. // On Windows this test detects false, while on Linux detects true. // That's not relevant to this test, so just accept what we're @@ -187,6 +201,7 @@ func TestAddReportHistoryAndSetPreferredDERP(t *testing.T) { steps []step homeParams *tailcfg.DERPHomeParams opts *GetReportOpts + forcedDERP int // if non-zero, force this DERP to be the preferred one wantDERP int // want PreferredDERP on final step wantPrevLen int // wanted len(c.prev) }{ @@ -343,12 +358,74 @@ func TestAddReportHistoryAndSetPreferredDERP(t *testing.T) { wantPrevLen: 3, wantDERP: 2, // moved to d2 since d1 is gone }, + { + name: "preferred_derp_hysteresis_no_switch_pct", + steps: []step{ + {0 * time.Second, report("d1", 34*time.Millisecond, "d2", 35*time.Millisecond)}, + {1 * time.Second, report("d1", 34*time.Millisecond, "d2", 23*time.Millisecond)}, + }, + wantPrevLen: 2, + wantDERP: 1, // diff is 11ms, but d2 is greater than 2/3s of d1 + }, + { + name: "forced_two", + steps: []step{ + {time.Second, report("d1", 2, "d2", 3)}, + {2 * time.Second, report("d1", 4, "d2", 3)}, + }, + forcedDERP: 2, + wantPrevLen: 2, + wantDERP: 2, + }, + { + name: "forced_two_unavailable", + steps: []step{ + {time.Second, report("d1", 2, "d2", 1)}, + {2 * time.Second, report("d1", 4)}, + }, + forcedDERP: 2, + wantPrevLen: 2, + wantDERP: 1, + }, + { + name: "forced_two_no_probe_recent_activity", + steps: []step{ + {time.Second, report("d1", 2)}, + {2 * time.Second, report("d1", 4)}, + }, + opts: &GetReportOpts{ + GetLastDERPActivity: mkLDAFunc(map[int]time.Time{ + 1: startTime, + 2: startTime.Add(time.Second), + }), + }, + forcedDERP: 2, + wantPrevLen: 2, + wantDERP: 2, + }, + { + name: "forced_two_no_probe_no_recent_activity", + steps: []step{ + {time.Second, report("d1", 2)}, + {PreferredDERPFrameTime + time.Second, report("d1", 4)}, + }, + opts: &GetReportOpts{ + GetLastDERPActivity: mkLDAFunc(map[int]time.Time{ + 1: startTime, + 2: startTime, + }), + }, + forcedDERP: 2, + wantPrevLen: 2, + wantDERP: 1, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fakeTime := startTime c := &Client{ - TimeNow: func() time.Time { return fakeTime }, + TimeNow: func() time.Time { return fakeTime }, + ForcePreferredDERP: tt.forcedDERP, } dm := &tailcfg.DERPMap{HomeParams: tt.homeParams} rs := &reportState{ @@ -378,7 +455,7 @@ func TestMakeProbePlan(t *testing.T) { basicMap := &tailcfg.DERPMap{ Regions: map[int]*tailcfg.DERPRegion{}, } - for rid := 1; rid <= 5; rid++ { + for rid := 1; rid <= 6; rid++ { var nodes []*tailcfg.DERPNode for nid := 0; nid < rid; nid++ { nodes = append(nodes, &tailcfg.DERPNode{ @@ -390,8 +467,9 @@ func TestMakeProbePlan(t *testing.T) { }) } basicMap.Regions[rid] = &tailcfg.DERPRegion{ - RegionID: rid, - Nodes: nodes, + RegionID: rid, + Nodes: nodes, + NoMeasureNoHome: rid == 6, } } @@ -576,6 +654,40 @@ func TestMakeProbePlan(t *testing.T) { "region-3-v4": []probe{p("3a", 4)}, }, }, + { + // #13969: ensure that the prior/current home region is always included in + // probe plans, so that we don't flap between regions due to a single major + // netcheck having excluded the home region due to a spuriously high sample. + name: "ensure_home_region_inclusion", + dm: basicMap, + have6if: true, + last: &Report{ + RegionLatency: map[int]time.Duration{ + 1: 50 * time.Millisecond, + 2: 20 * time.Millisecond, + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, + }, + RegionV4Latency: map[int]time.Duration{ + 1: 50 * time.Millisecond, + 2: 20 * time.Millisecond, + }, + RegionV6Latency: map[int]time.Duration{ + 3: 30 * time.Millisecond, + 4: 40 * time.Millisecond, + }, + PreferredDERP: 1, + }, + want: probePlan{ + "region-1-v4": []probe{p("1a", 4), p("1a", 4, 60*ms), p("1a", 4, 220*ms), p("1a", 4, 330*ms)}, + "region-1-v6": []probe{p("1a", 6), p("1a", 6, 60*ms), p("1a", 6, 220*ms), p("1a", 6, 330*ms)}, + "region-2-v4": []probe{p("2a", 4), p("2b", 4, 24*ms)}, + "region-2-v6": []probe{p("2a", 6), p("2b", 6, 24*ms)}, + "region-3-v4": []probe{p("3a", 4), p("3b", 4, 36*ms)}, + "region-3-v6": []probe{p("3a", 6), p("3b", 6, 36*ms)}, + "region-4-v4": []probe{p("4a", 4)}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -583,7 +695,11 @@ func TestMakeProbePlan(t *testing.T) { HaveV6: tt.have6if, HaveV4: !tt.no4, } - got := makeProbePlan(tt.dm, ifState, tt.last) + preferredDERP := 0 + if tt.last != nil { + preferredDERP = tt.last.PreferredDERP + } + got := makeProbePlan(tt.dm, ifState, tt.last, preferredDERP) if !reflect.DeepEqual(got, tt.want) { t.Errorf("unexpected plan; got:\n%v\nwant:\n%v\n", got, tt.want) } @@ -756,7 +872,7 @@ func TestSortRegions(t *testing.T) { report.RegionLatency[3] = time.Second * time.Duration(6) report.RegionLatency[4] = time.Second * time.Duration(0) report.RegionLatency[5] = time.Second * time.Duration(2) - sortedMap := sortRegions(unsortedMap, report) + sortedMap := sortRegions(unsortedMap, report, 0) // Sorting by latency this should result in rid: 5, 2, 1, 3 // rid 4 with latency 0 should be at the end @@ -826,8 +942,8 @@ func TestNodeAddrResolve(t *testing.T) { c.UseDNSCache = tt t.Run("IPv4", func(t *testing.T) { - ap := c.nodeAddr(ctx, dn, probeIPv4) - if !ap.IsValid() { + ap, ok := c.nodeAddrPort(ctx, dn, dn.STUNPort, probeIPv4) + if !ok { t.Fatal("expected valid AddrPort") } if !ap.Addr().Is4() { @@ -841,8 +957,8 @@ func TestNodeAddrResolve(t *testing.T) { t.Skipf("IPv6 may not work on this machine") } - ap := c.nodeAddr(ctx, dn, probeIPv6) - if !ap.IsValid() { + ap, ok := c.nodeAddrPort(ctx, dn, dn.STUNPort, probeIPv6) + if !ok { t.Fatal("expected valid AddrPort") } if !ap.Addr().Is6() { @@ -851,8 +967,8 @@ func TestNodeAddrResolve(t *testing.T) { t.Logf("got IPv6 addr: %v", ap) }) t.Run("IPv6 Failure", func(t *testing.T) { - ap := c.nodeAddr(ctx, dnV4Only, probeIPv6) - if ap.IsValid() { + ap, ok := c.nodeAddrPort(ctx, dnV4Only, dn.STUNPort, probeIPv6) + if ok { t.Fatalf("expected no addr but got: %v", ap) } t.Logf("correctly got invalid addr") @@ -872,3 +988,30 @@ func TestReportTimeouts(t *testing.T) { t.Errorf("ReportTimeout (%v) cannot be less than httpsProbeTimeout (%v)", ReportTimeout, httpsProbeTimeout) } } + +func TestNoUDPNilGetReportOpts(t *testing.T) { + blackhole, err := net.ListenPacket("udp4", "127.0.0.1:0") + if err != nil { + t.Fatalf("failed to open blackhole STUN listener: %v", err) + } + defer blackhole.Close() + + dm := stuntest.DERPMapOf(blackhole.LocalAddr().String()) + for _, region := range dm.Regions { + for _, n := range region.Nodes { + n.STUNOnly = false // exercise ICMP & HTTPS probing + } + } + + c := newTestClient(t) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + r, err := c.GetReport(ctx, dm, nil) + if err != nil { + t.Fatal(err) + } + if r.UDP { + t.Fatal("unexpected working UDP") + } +} diff --git a/net/netkernelconf/netkernelconf_default.go b/net/netkernelconf/netkernelconf_default.go index ec1b2e619807c..3e160e5edf5b0 100644 --- a/net/netkernelconf/netkernelconf_default.go +++ b/net/netkernelconf/netkernelconf_default.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !linux +//go:build !linux || android package netkernelconf diff --git a/net/netkernelconf/netkernelconf_linux.go b/net/netkernelconf/netkernelconf_linux.go index 51ed8ea993a46..2a4f0a049f56d 100644 --- a/net/netkernelconf/netkernelconf_linux.go +++ b/net/netkernelconf/netkernelconf_linux.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build linux && !android + package netkernelconf import ( diff --git a/net/netmon/interfaces_android.go b/net/netmon/interfaces_android.go index a96423eb6bfeb..26104e879a393 100644 --- a/net/netmon/interfaces_android.go +++ b/net/netmon/interfaces_android.go @@ -5,7 +5,6 @@ package netmon import ( "bytes" - "errors" "log" "net/netip" "os/exec" @@ -15,7 +14,7 @@ import ( "golang.org/x/sys/unix" "tailscale.com/net/netaddr" "tailscale.com/syncs" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" ) var ( @@ -34,11 +33,6 @@ func init() { var procNetRouteErr atomic.Bool -// errStopReading is a sentinel error value used internally by -// lineread.File callers to stop reading. It doesn't escape to -// callers/users. -var errStopReading = errors.New("stop reading") - /* Parse 10.0.0.1 out of: @@ -54,44 +48,42 @@ func likelyHomeRouterIPAndroid() (ret netip.Addr, myIP netip.Addr, ok bool) { } lineNum := 0 var f []mem.RO - err := lineread.File(procNetRoutePath, func(line []byte) error { + for lr := range lineiter.File(procNetRoutePath) { + line, err := lr.Value() + if err != nil { + procNetRouteErr.Store(true) + return likelyHomeRouterIP() + } + lineNum++ if lineNum == 1 { // Skip header line. - return nil + continue } if lineNum > maxProcNetRouteRead { - return errStopReading + break } f = mem.AppendFields(f[:0], mem.B(line)) if len(f) < 4 { - return nil + continue } gwHex, flagsHex := f[2], f[3] flags, err := mem.ParseUint(flagsHex, 16, 16) if err != nil { - return nil // ignore error, skip line and keep going + continue // ignore error, skip line and keep going } if flags&(unix.RTF_UP|unix.RTF_GATEWAY) != unix.RTF_UP|unix.RTF_GATEWAY { - return nil + continue } ipu32, err := mem.ParseUint(gwHex, 16, 32) if err != nil { - return nil // ignore error, skip line and keep going + continue // ignore error, skip line and keep going } ip := netaddr.IPv4(byte(ipu32), byte(ipu32>>8), byte(ipu32>>16), byte(ipu32>>24)) if ip.IsPrivate() { ret = ip - return errStopReading + break } - return nil - }) - if errors.Is(err, errStopReading) { - err = nil - } - if err != nil { - procNetRouteErr.Store(true) - return likelyHomeRouterIP() } if ret.IsValid() { // Try to get the local IP of the interface associated with @@ -144,23 +136,26 @@ func likelyHomeRouterIPHelper() (ret netip.Addr, _ netip.Addr, ok bool) { return } // Search for line like "default via 10.0.2.2 dev radio0 table 1016 proto static mtu 1500 " - lineread.Reader(out, func(line []byte) error { + for lr := range lineiter.Reader(out) { + line, err := lr.Value() + if err != nil { + break + } const pfx = "default via " if !mem.HasPrefix(mem.B(line), mem.S(pfx)) { - return nil + continue } line = line[len(pfx):] sp := bytes.IndexByte(line, ' ') if sp == -1 { - return nil + continue } ipb := line[:sp] if ip, err := netip.ParseAddr(string(ipb)); err == nil && ip.Is4() { ret = ip log.Printf("interfaces: found Android default route %v", ip) } - return nil - }) + } cmd.Process.Kill() cmd.Wait() return ret, netip.Addr{}, ret.IsValid() diff --git a/net/netmon/interfaces_darwin_test.go b/net/netmon/interfaces_darwin_test.go index d34040d60d31d..d756d13348bc3 100644 --- a/net/netmon/interfaces_darwin_test.go +++ b/net/netmon/interfaces_darwin_test.go @@ -4,14 +4,13 @@ package netmon import ( - "errors" "io" "net/netip" "os/exec" "testing" "go4.org/mem" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/version" ) @@ -73,31 +72,34 @@ func likelyHomeRouterIPDarwinExec() (ret netip.Addr, netif string, ok bool) { defer io.Copy(io.Discard, stdout) // clear the pipe to prevent hangs var f []mem.RO - lineread.Reader(stdout, func(lineb []byte) error { + for lr := range lineiter.Reader(stdout) { + lineb, err := lr.Value() + if err != nil { + break + } line := mem.B(lineb) if !mem.Contains(line, mem.S("default")) { - return nil + continue } f = mem.AppendFields(f[:0], line) if len(f) < 4 || !f[0].EqualString("default") { - return nil + continue } ipm, flagsm, netifm := f[1], f[2], f[3] if !mem.Contains(flagsm, mem.S("G")) { - return nil + continue } if mem.Contains(flagsm, mem.S("I")) { - return nil + continue } ip, err := netip.ParseAddr(string(mem.Append(nil, ipm))) if err == nil && ip.IsPrivate() { ret = ip netif = netifm.StringCopy() // We've found what we're looking for. - return errStopReadingNetstatTable + break } - return nil - }) + } return ret, netif, ret.IsValid() } @@ -110,5 +112,3 @@ func TestFetchRoutingTable(t *testing.T) { } } } - -var errStopReadingNetstatTable = errors.New("found private gateway") diff --git a/net/netmon/interfaces_linux.go b/net/netmon/interfaces_linux.go index 299f3101ea73b..d0fb15ababe9e 100644 --- a/net/netmon/interfaces_linux.go +++ b/net/netmon/interfaces_linux.go @@ -23,7 +23,7 @@ import ( "go4.org/mem" "golang.org/x/sys/unix" "tailscale.com/net/netaddr" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" ) func init() { @@ -32,11 +32,6 @@ func init() { var procNetRouteErr atomic.Bool -// errStopReading is a sentinel error value used internally by -// lineread.File callers to stop reading. It doesn't escape to -// callers/users. -var errStopReading = errors.New("stop reading") - /* Parse 10.0.0.1 out of: @@ -52,44 +47,42 @@ func likelyHomeRouterIPLinux() (ret netip.Addr, myIP netip.Addr, ok bool) { } lineNum := 0 var f []mem.RO - err := lineread.File(procNetRoutePath, func(line []byte) error { + for lr := range lineiter.File(procNetRoutePath) { + line, err := lr.Value() + if err != nil { + procNetRouteErr.Store(true) + log.Printf("interfaces: failed to read /proc/net/route: %v", err) + return ret, myIP, false + } lineNum++ if lineNum == 1 { // Skip header line. - return nil + continue } if lineNum > maxProcNetRouteRead { - return errStopReading + break } f = mem.AppendFields(f[:0], mem.B(line)) if len(f) < 4 { - return nil + continue } gwHex, flagsHex := f[2], f[3] flags, err := mem.ParseUint(flagsHex, 16, 16) if err != nil { - return nil // ignore error, skip line and keep going + continue // ignore error, skip line and keep going } if flags&(unix.RTF_UP|unix.RTF_GATEWAY) != unix.RTF_UP|unix.RTF_GATEWAY { - return nil + continue } ipu32, err := mem.ParseUint(gwHex, 16, 32) if err != nil { - return nil // ignore error, skip line and keep going + continue // ignore error, skip line and keep going } ip := netaddr.IPv4(byte(ipu32), byte(ipu32>>8), byte(ipu32>>16), byte(ipu32>>24)) if ip.IsPrivate() { ret = ip - return errStopReading + break } - return nil - }) - if errors.Is(err, errStopReading) { - err = nil - } - if err != nil { - procNetRouteErr.Store(true) - log.Printf("interfaces: failed to read /proc/net/route: %v", err) } if ret.IsValid() { // Try to get the local IP of the interface associated with diff --git a/net/netmon/interfaces_test.go b/net/netmon/interfaces_test.go index edd4f6d6e202b..e4274819f90df 100644 --- a/net/netmon/interfaces_test.go +++ b/net/netmon/interfaces_test.go @@ -13,7 +13,7 @@ import ( ) func TestGetState(t *testing.T) { - st, err := GetState() + st, err := getState("") if err != nil { t.Fatal(err) } diff --git a/net/netmon/loghelper.go b/net/netmon/loghelper.go new file mode 100644 index 0000000000000..824faeef09b1c --- /dev/null +++ b/net/netmon/loghelper.go @@ -0,0 +1,42 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netmon + +import ( + "sync" + + "tailscale.com/types/logger" +) + +// LinkChangeLogLimiter returns a new [logger.Logf] that logs each unique +// format string to the underlying logger only once per major LinkChange event. +// +// The returned function should be called when the logger is no longer needed, +// to release resources from the Monitor. +func LinkChangeLogLimiter(logf logger.Logf, nm *Monitor) (_ logger.Logf, unregister func()) { + var formatSeen sync.Map // map[string]bool + unregister = nm.RegisterChangeCallback(func(cd *ChangeDelta) { + // If we're in a major change or a time jump, clear the seen map. + if cd.Major || cd.TimeJumped { + formatSeen.Clear() + } + }) + + return func(format string, args ...any) { + // We only store 'true' in the map, so if it's present then it + // means we've already logged this format string. + _, loaded := formatSeen.LoadOrStore(format, true) + if loaded { + // TODO(andrew-d): we may still want to log this + // message every N minutes (1x/hour?) even if it's been + // seen, so that debugging doesn't require searching + // back in the logs for an unbounded amount of time. + // + // See: https://github.com/tailscale/tailscale/issues/13145 + return + } + + logf(format, args...) + }, unregister +} diff --git a/net/netmon/loghelper_test.go b/net/netmon/loghelper_test.go new file mode 100644 index 0000000000000..44aa46783de07 --- /dev/null +++ b/net/netmon/loghelper_test.go @@ -0,0 +1,82 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package netmon + +import ( + "bytes" + "fmt" + "testing" + + "tailscale.com/util/eventbus" +) + +func TestLinkChangeLogLimiter(t *testing.T) { + bus := eventbus.New() + defer bus.Close() + mon, err := New(bus, t.Logf) + if err != nil { + t.Fatal(err) + } + defer mon.Close() + + var logBuffer bytes.Buffer + logf := func(format string, args ...any) { + t.Logf("captured log: "+format, args...) + + if format[len(format)-1] != '\n' { + format += "\n" + } + fmt.Fprintf(&logBuffer, format, args...) + } + + logf, unregister := LinkChangeLogLimiter(logf, mon) + defer unregister() + + // Log once, which should write to our log buffer. + logf("hello %s", "world") + if got := logBuffer.String(); got != "hello world\n" { + t.Errorf("unexpected log buffer contents: %q", got) + } + + // Log again, which should not write to our log buffer. + logf("hello %s", "andrew") + if got := logBuffer.String(); got != "hello world\n" { + t.Errorf("unexpected log buffer contents: %q", got) + } + + // Log a different message, which should write to our log buffer. + logf("other message") + if got := logBuffer.String(); got != "hello world\nother message\n" { + t.Errorf("unexpected log buffer contents: %q", got) + } + + // Synthesize a fake major change event, which should clear the format + // string cache and allow the next log to write to our log buffer. + // + // InjectEvent doesn't work because it's not a major event, so we + // instead reach into the netmon and grab the callback, and then call + // it ourselves. + mon.mu.Lock() + var cb func(*ChangeDelta) + for _, c := range mon.cbs { + cb = c + break + } + mon.mu.Unlock() + + cb(&ChangeDelta{Major: true}) + + logf("hello %s", "world") + if got := logBuffer.String(); got != "hello world\nother message\nhello world\n" { + t.Errorf("unexpected log buffer contents: %q", got) + } + + // Unregistering the callback should clear our 'cbs' set. + unregister() + mon.mu.Lock() + if len(mon.cbs) != 0 { + t.Errorf("expected no callbacks, got %v", mon.cbs) + } + mon.mu.Unlock() +} diff --git a/net/netmon/netmon.go b/net/netmon/netmon.go index 47b540d6a7f83..3f825bc9797fe 100644 --- a/net/netmon/netmon.go +++ b/net/netmon/netmon.go @@ -16,6 +16,7 @@ import ( "tailscale.com/types/logger" "tailscale.com/util/clientmetric" + "tailscale.com/util/eventbus" "tailscale.com/util/set" ) @@ -50,7 +51,10 @@ type osMon interface { // Monitor represents a monitoring instance. type Monitor struct { - logf logger.Logf + logf logger.Logf + b *eventbus.Client + changed *eventbus.Publisher[*ChangeDelta] + om osMon // nil means not supported on this platform change chan bool // send false to wake poller, true to also force ChangeDeltas be sent stop chan struct{} // closed on Stop @@ -114,21 +118,23 @@ type ChangeDelta struct { // New instantiates and starts a monitoring instance. // The returned monitor is inactive until it's started by the Start method. // Use RegisterChangeCallback to get notified of network changes. -func New(logf logger.Logf) (*Monitor, error) { +func New(bus *eventbus.Bus, logf logger.Logf) (*Monitor, error) { logf = logger.WithPrefix(logf, "monitor: ") m := &Monitor{ logf: logf, + b: bus.Client("netmon"), change: make(chan bool, 1), stop: make(chan struct{}), lastWall: wallTime(), } + m.changed = eventbus.Publish[*ChangeDelta](m.b) st, err := m.interfaceStateUncached() if err != nil { return nil, err } m.ifState = st - m.om, err = newOSMon(logf, m) + m.om, err = newOSMon(bus, logf, m) if err != nil { return nil, err } @@ -161,7 +167,7 @@ func (m *Monitor) InterfaceState() *State { } func (m *Monitor) interfaceStateUncached() (*State, error) { - return GetState() + return getState(m.tsIfName) } // SetTailscaleInterfaceName sets the name of the Tailscale interface. For @@ -441,7 +447,6 @@ func (m *Monitor) handlePotentialChange(newState *State, forceCallbacks bool) { delta.Major = m.IsMajorChangeFrom(oldState, newState) if delta.Major { m.gwValid = false - m.ifState = newState if s1, s2 := oldState.String(), delta.New.String(); s1 == s2 { m.logf("[unexpected] network state changed, but stringification didn't: %v", s1) @@ -449,6 +454,7 @@ func (m *Monitor) handlePotentialChange(newState *State, forceCallbacks bool) { m.logf("[unexpected] new: %s", jsonSummary(newState)) } } + m.ifState = newState // See if we have a queued or new time jump signal. if timeJumped { m.resetTimeJumpedLocked() @@ -465,6 +471,7 @@ func (m *Monitor) handlePotentialChange(newState *State, forceCallbacks bool) { if delta.TimeJumped { metricChangeTimeJump.Add(1) } + m.changed.Publish(delta) for _, cb := range m.cbs { go cb(delta) } @@ -596,7 +603,7 @@ func (m *Monitor) pollWallTime() { // // We don't do this on mobile platforms for battery reasons, and because these // platforms don't really sleep in the same way. -const shouldMonitorTimeJump = runtime.GOOS != "android" && runtime.GOOS != "ios" +const shouldMonitorTimeJump = runtime.GOOS != "android" && runtime.GOOS != "ios" && runtime.GOOS != "plan9" // checkWallTimeAdvanceLocked reports whether wall time jumped more than 150% of // pollWallTimeInterval, indicating we probably just came out of sleep. Once a diff --git a/net/netmon/netmon_darwin.go b/net/netmon/netmon_darwin.go index cc630112523fa..9c5e76475f3fd 100644 --- a/net/netmon/netmon_darwin.go +++ b/net/netmon/netmon_darwin.go @@ -13,6 +13,7 @@ import ( "golang.org/x/sys/unix" "tailscale.com/net/netaddr" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" ) const debugRouteMessages = false @@ -24,7 +25,7 @@ type unspecifiedMessage struct{} func (unspecifiedMessage) ignore() bool { return false } -func newOSMon(logf logger.Logf, _ *Monitor) (osMon, error) { +func newOSMon(_ *eventbus.Bus, logf logger.Logf, _ *Monitor) (osMon, error) { fd, err := unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, 0) if err != nil { return nil, err @@ -56,7 +57,19 @@ func (m *darwinRouteMon) Receive() (message, error) { if err != nil { return nil, err } - msgs, err := route.ParseRIB(route.RIBTypeRoute, m.buf[:n]) + msgs, err := func() (msgs []route.Message, err error) { + defer func() { + // #14201: permanent panic protection, as we have been burned by + // ParseRIB panics too many times. + msg := recover() + if msg != nil { + msgs = nil + m.logf("[unexpected] netmon: panic in route.ParseRIB from % 02x", m.buf[:n]) + err = fmt.Errorf("panic in route.ParseRIB: %s", msg) + } + }() + return route.ParseRIB(route.RIBTypeRoute, m.buf[:n]) + }() if err != nil { if debugRouteMessages { m.logf("read %d bytes (% 02x), failed to parse RIB: %v", n, m.buf[:n], err) diff --git a/net/netmon/netmon_freebsd.go b/net/netmon/netmon_freebsd.go index 30480a1d3387e..842cbdb0d6476 100644 --- a/net/netmon/netmon_freebsd.go +++ b/net/netmon/netmon_freebsd.go @@ -10,6 +10,7 @@ import ( "strings" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" ) // unspecifiedMessage is a minimal message implementation that should not @@ -24,7 +25,7 @@ type devdConn struct { conn net.Conn } -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { +func newOSMon(_ *eventbus.Bus, logf logger.Logf, m *Monitor) (osMon, error) { conn, err := net.Dial("unixpacket", "/var/run/devd.seqpacket.pipe") if err != nil { logf("devd dial error: %v, falling back to polling method", err) diff --git a/net/netmon/netmon_linux.go b/net/netmon/netmon_linux.go index dd23dd34263c5..659fcc74bb0e6 100644 --- a/net/netmon/netmon_linux.go +++ b/net/netmon/netmon_linux.go @@ -16,6 +16,7 @@ import ( "tailscale.com/envknob" "tailscale.com/net/tsaddr" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" ) var debugNetlinkMessages = envknob.RegisterBool("TS_DEBUG_NETLINK") @@ -27,15 +28,26 @@ type unspecifiedMessage struct{} func (unspecifiedMessage) ignore() bool { return false } +// RuleDeleted reports that one of Tailscale's policy routing rules +// was deleted. +type RuleDeleted struct { + // Table is the table number that the deleted rule referenced. + Table uint8 + // Priority is the lookup priority of the deleted rule. + Priority uint32 +} + // nlConn wraps a *netlink.Conn and returns a monitor.Message // instead of a netlink.Message. Currently, messages are discarded, // but down the line, when messages trigger different logic depending // on the type of event, this provides the capability of handling // each architecture-specific message in a generic fashion. type nlConn struct { - logf logger.Logf - conn *netlink.Conn - buffered []netlink.Message + busClient *eventbus.Client + rulesDeleted *eventbus.Publisher[RuleDeleted] + logf logger.Logf + conn *netlink.Conn + buffered []netlink.Message // addrCache maps interface indices to a set of addresses, and is // used to suppress duplicate RTM_NEWADDR messages. It is populated @@ -44,7 +56,7 @@ type nlConn struct { addrCache map[uint32]map[netip.Addr]bool } -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { +func newOSMon(bus *eventbus.Bus, logf logger.Logf, m *Monitor) (osMon, error) { conn, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ // Routes get us most of the events of interest, but we need // address as well to cover things like DHCP deciding to give @@ -59,12 +71,22 @@ func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { logf("monitor_linux: AF_NETLINK RTMGRP failed, falling back to polling") return newPollingMon(logf, m) } - return &nlConn{logf: logf, conn: conn, addrCache: make(map[uint32]map[netip.Addr]bool)}, nil + client := bus.Client("netmon-iprules") + return &nlConn{ + busClient: client, + rulesDeleted: eventbus.Publish[RuleDeleted](client), + logf: logf, + conn: conn, + addrCache: make(map[uint32]map[netip.Addr]bool), + }, nil } func (c *nlConn) IsInterestingInterface(iface string) bool { return true } -func (c *nlConn) Close() error { return c.conn.Close() } +func (c *nlConn) Close() error { + c.busClient.Close() + return c.conn.Close() +} func (c *nlConn) Receive() (message, error) { if len(c.buffered) == 0 { @@ -219,6 +241,10 @@ func (c *nlConn) Receive() (message, error) { // On `ip -4 rule del pref 5210 table main`, logs: // monitor: ip rule deleted: {Family:2 DstLength:0 SrcLength:0 Tos:0 Table:254 Protocol:0 Scope:0 Type:1 Flags:0 Attributes:{Dst: Src: Gateway: OutIface:0 Priority:5210 Table:254 Mark:4294967295 Expires: Metrics: Multipath:[]}} } + c.rulesDeleted.Publish(RuleDeleted{ + Table: rmsg.Table, + Priority: rmsg.Attributes.Priority, + }) rdm := ipRuleDeletedMessage{ table: rmsg.Table, priority: rmsg.Attributes.Priority, diff --git a/net/netmon/netmon_linux_test.go b/net/netmon/netmon_linux_test.go index d09fac26aecee..75d7c646559f1 100644 --- a/net/netmon/netmon_linux_test.go +++ b/net/netmon/netmon_linux_test.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build linux && !android + package netmon import ( diff --git a/net/netmon/netmon_polling.go b/net/netmon/netmon_polling.go index 3d6f94731077a..3b5ef6fe9206f 100644 --- a/net/netmon/netmon_polling.go +++ b/net/netmon/netmon_polling.go @@ -7,9 +7,10 @@ package netmon import ( "tailscale.com/types/logger" + "tailscale.com/util/eventbus" ) -func newOSMon(logf logger.Logf, m *Monitor) (osMon, error) { +func newOSMon(_ *eventbus.Bus, logf logger.Logf, m *Monitor) (osMon, error) { return newPollingMon(logf, m) } diff --git a/net/netmon/netmon_test.go b/net/netmon/netmon_test.go index ce55d19464100..a9af8fb004af3 100644 --- a/net/netmon/netmon_test.go +++ b/net/netmon/netmon_test.go @@ -11,11 +11,15 @@ import ( "testing" "time" + "tailscale.com/util/eventbus" "tailscale.com/util/mak" ) func TestMonitorStartClose(t *testing.T) { - mon, err := New(t.Logf) + bus := eventbus.New() + defer bus.Close() + + mon, err := New(bus, t.Logf) if err != nil { t.Fatal(err) } @@ -26,7 +30,10 @@ func TestMonitorStartClose(t *testing.T) { } func TestMonitorJustClose(t *testing.T) { - mon, err := New(t.Logf) + bus := eventbus.New() + defer bus.Close() + + mon, err := New(bus, t.Logf) if err != nil { t.Fatal(err) } @@ -36,7 +43,10 @@ func TestMonitorJustClose(t *testing.T) { } func TestMonitorInjectEvent(t *testing.T) { - mon, err := New(t.Logf) + bus := eventbus.New() + defer bus.Close() + + mon, err := New(bus, t.Logf) if err != nil { t.Fatal(err) } @@ -71,7 +81,11 @@ func TestMonitorMode(t *testing.T) { default: t.Skipf(`invalid --monitor value: must be "raw" or "callback"`) } - mon, err := New(t.Logf) + + bus := eventbus.New() + defer bus.Close() + + mon, err := New(bus, t.Logf) if err != nil { t.Fatal(err) } diff --git a/net/netmon/netmon_windows.go b/net/netmon/netmon_windows.go index ddf13a2e453b2..718724b6d3f8d 100644 --- a/net/netmon/netmon_windows.go +++ b/net/netmon/netmon_windows.go @@ -13,6 +13,7 @@ import ( "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" "tailscale.com/net/tsaddr" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" ) var ( @@ -45,7 +46,7 @@ type winMon struct { noDeadlockTicker *time.Ticker } -func newOSMon(logf logger.Logf, pm *Monitor) (osMon, error) { +func newOSMon(_ *eventbus.Bus, logf logger.Logf, pm *Monitor) (osMon, error) { m := &winMon{ logf: logf, isActive: pm.isActive, diff --git a/net/netmon/state.go b/net/netmon/state.go index d9b360f5eee45..bd09607682bb4 100644 --- a/net/netmon/state.go +++ b/net/netmon/state.go @@ -19,8 +19,14 @@ import ( "tailscale.com/net/netaddr" "tailscale.com/net/tsaddr" "tailscale.com/net/tshttpproxy" + "tailscale.com/util/mak" ) +// forceAllIPv6Endpoints is a debug knob that when set forces the client to +// report all IPv6 endpoints rather than trim endpoints that are siblings on the +// same interface and subnet. +var forceAllIPv6Endpoints = envknob.RegisterBool("TS_DEBUG_FORCE_ALL_IPV6_ENDPOINTS") + // LoginEndpointForProxyDetermination is the URL used for testing // which HTTP proxy the system should use. var LoginEndpointForProxyDetermination = "https://controlplane.tailscale.com/" @@ -65,6 +71,7 @@ func LocalAddresses() (regular, loopback []netip.Addr, err error) { if err != nil { return nil, nil, err } + var subnets map[netip.Addr]int for _, a := range addrs { switch v := a.(type) { case *net.IPNet: @@ -102,7 +109,15 @@ func LocalAddresses() (regular, loopback []netip.Addr, err error) { if ip.Is4() { regular4 = append(regular4, ip) } else { - regular6 = append(regular6, ip) + curMask, _ := netip.AddrFromSlice(v.IP.Mask(v.Mask)) + // Limit the number of addresses reported per subnet for + // IPv6, as we have seen some nodes with extremely large + // numbers of assigned addresses being carved out of + // same-subnet allocations. + if forceAllIPv6Endpoints() || subnets[curMask] < 2 { + regular6 = append(regular6, ip) + } + mak.Set(&subnets, curMask, subnets[curMask]+1) } } } @@ -446,21 +461,22 @@ func isTailscaleInterface(name string, ips []netip.Prefix) bool { // getPAC, if non-nil, returns the current PAC file URL. var getPAC func() string -// GetState returns the state of all the current machine's network interfaces. +// getState returns the state of all the current machine's network interfaces. // // It does not set the returned State.IsExpensive. The caller can populate that. // -// Deprecated: use netmon.Monitor.InterfaceState instead. -func GetState() (*State, error) { +// optTSInterfaceName is the name of the Tailscale interface, if known. +func getState(optTSInterfaceName string) (*State, error) { s := &State{ InterfaceIPs: make(map[string][]netip.Prefix), Interface: make(map[string]Interface), } if err := ForeachInterface(func(ni Interface, pfxs []netip.Prefix) { + isTSInterfaceName := optTSInterfaceName != "" && ni.Name == optTSInterfaceName ifUp := ni.IsUp() s.Interface[ni.Name] = ni s.InterfaceIPs[ni.Name] = append(s.InterfaceIPs[ni.Name], pfxs...) - if !ifUp || isTailscaleInterface(ni.Name, pfxs) { + if !ifUp || isTSInterfaceName || isTailscaleInterface(ni.Name, pfxs) { return } for _, pfx := range pfxs { @@ -740,11 +756,12 @@ func DefaultRoute() (DefaultRouteDetails, error) { // HasCGNATInterface reports whether there are any non-Tailscale interfaces that // use a CGNAT IP range. -func HasCGNATInterface() (bool, error) { +func (m *Monitor) HasCGNATInterface() (bool, error) { hasCGNATInterface := false cgnatRange := tsaddr.CGNATRange() err := ForeachInterface(func(i Interface, pfxs []netip.Prefix) { - if hasCGNATInterface || !i.IsUp() || isTailscaleInterface(i.Name, pfxs) { + isTSInterfaceName := m.tsIfName != "" && i.Name == m.tsIfName + if hasCGNATInterface || !i.IsUp() || isTSInterfaceName || isTailscaleInterface(i.Name, pfxs) { return } for _, pfx := range pfxs { diff --git a/net/netns/socks.go b/net/netns/socks.go index eea69d8651eda..ee8dfa20eec7f 100644 --- a/net/netns/socks.go +++ b/net/netns/socks.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !ios && !js +//go:build !ios && !js && !android package netns diff --git a/net/netutil/ip_forward.go b/net/netutil/ip_forward.go index 48cee68eaff88..c64a9e4269ae0 100644 --- a/net/netutil/ip_forward.go +++ b/net/netutil/ip_forward.go @@ -63,6 +63,11 @@ func CheckIPForwarding(routes []netip.Prefix, state *netmon.State) (warn, err er switch runtime.GOOS { case "dragonfly", "freebsd", "netbsd", "openbsd": return fmt.Errorf("Subnet routing and exit nodes only work with additional manual configuration on %v, and is not currently officially supported.", runtime.GOOS), nil + case "illumos", "solaris": + _, err := ipForwardingEnabledSunOS(ipv4, "") + if err != nil { + return nil, fmt.Errorf("Couldn't check system's IP forwarding configuration, subnet routing/exit nodes may not work: %w%s", err, "") + } } return nil, nil } @@ -325,3 +330,24 @@ func reversePathFilterValueLinux(iface string) (int, error) { } return v, nil } + +func ipForwardingEnabledSunOS(p protocol, iface string) (bool, error) { + var proto string + if p == ipv4 { + proto = "ipv4" + } else if p == ipv6 { + proto = "ipv6" + } else { + return false, fmt.Errorf("unknown protocol") + } + + ipadmCmd := "\"ipadm show-prop " + proto + " -p forwarding -o CURRENT -c\"" + bs, err := exec.Command("ipadm", "show-prop", proto, "-p", "forwarding", "-o", "CURRENT", "-c").Output() + if err != nil { + return false, fmt.Errorf("couldn't check %s (%v).\nSubnet routes won't work without IP forwarding.", ipadmCmd, err) + } + if string(bs) != "on\n" { + return false, fmt.Errorf("IP forwarding is set to off. Subnet routes won't work. Try 'routeadm -u -e %s-forwarding'", proto) + } + return true, nil +} diff --git a/net/netutil/netutil_test.go b/net/netutil/netutil_test.go index fdc26b02f09aa..0523946e63c9b 100644 --- a/net/netutil/netutil_test.go +++ b/net/netutil/netutil_test.go @@ -10,6 +10,7 @@ import ( "testing" "tailscale.com/net/netmon" + "tailscale.com/util/eventbus" ) type conn struct { @@ -72,7 +73,10 @@ func TestCheckReversePathFiltering(t *testing.T) { if runtime.GOOS != "linux" { t.Skipf("skipping on %s", runtime.GOOS) } - netMon, err := netmon.New(t.Logf) + bus := eventbus.New() + defer bus.Close() + + netMon, err := netmon.New(bus, t.Logf) if err != nil { t.Fatal(err) } diff --git a/net/netx/netx.go b/net/netx/netx.go new file mode 100644 index 0000000000000..014daa9a795cb --- /dev/null +++ b/net/netx/netx.go @@ -0,0 +1,53 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netx contains types to describe and abstract over how dialing and +// listening are performed. +package netx + +import ( + "context" + "fmt" + "net" +) + +// DialFunc is a function that dials a network address. +// +// It's the type implemented by net.Dialer.DialContext or required +// by net/http.Transport.DialContext, etc. +type DialFunc func(ctx context.Context, network, address string) (net.Conn, error) + +// Network describes a network that can listen and dial. The two common +// implementations are [RealNetwork], using the net package to use the real +// network, or [memnet.Network], using an in-memory network (typically for testing) +type Network interface { + NewLocalTCPListener() net.Listener + Listen(network, address string) (net.Listener, error) + Dial(ctx context.Context, network, address string) (net.Conn, error) +} + +// RealNetwork returns a Network implementation that uses the real +// net package. +func RealNetwork() Network { return realNetwork{} } + +// realNetwork implements [Network] using the real net package. +type realNetwork struct{} + +func (realNetwork) Listen(network, address string) (net.Listener, error) { + return net.Listen(network, address) +} + +func (realNetwork) Dial(ctx context.Context, network, address string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, network, address) +} + +func (realNetwork) NewLocalTCPListener() net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + if ln, err = net.Listen("tcp6", "[::1]:0"); err != nil { + panic(fmt.Sprintf("failed to listen on either IPv4 or IPv6 localhost port: %v", err)) + } + } + return ln +} diff --git a/net/packet/capture.go b/net/packet/capture.go new file mode 100644 index 0000000000000..dd0ca411f2051 --- /dev/null +++ b/net/packet/capture.go @@ -0,0 +1,75 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "io" + "net/netip" + "time" +) + +// Callback describes a function which is called to +// record packets when debugging packet-capture. +// Such callbacks must not take ownership of the +// provided data slice: it may only copy out of it +// within the lifetime of the function. +type CaptureCallback func(CapturePath, time.Time, []byte, CaptureMeta) + +// CaptureSink is the minimal interface from [tailscale.com/feature/capture]'s +// Sink type that is needed by the core (magicsock/LocalBackend/wgengine/etc). +// This lets the relativel heavy feature/capture package be optionally linked. +type CaptureSink interface { + // Close closes + Close() error + + // NumOutputs returns the number of outputs registered with the sink. + NumOutputs() int + + // CaptureCallback returns a callback which can be used to + // write packets to the sink. + CaptureCallback() CaptureCallback + + // WaitCh returns a channel which blocks until + // the sink is closed. + WaitCh() <-chan struct{} + + // RegisterOutput connects an output to this sink, which + // will be written to with a pcap stream as packets are logged. + // A function is returned which unregisters the output when + // called. + // + // If w implements io.Closer, it will be closed upon error + // or when the sink is closed. If w implements http.Flusher, + // it will be flushed periodically. + RegisterOutput(w io.Writer) (unregister func()) +} + +// CaptureMeta contains metadata that is used when debugging. +type CaptureMeta struct { + DidSNAT bool // SNAT was performed & the address was updated. + OriginalSrc netip.AddrPort // The source address before SNAT was performed. + DidDNAT bool // DNAT was performed & the address was updated. + OriginalDst netip.AddrPort // The destination address before DNAT was performed. +} + +// CapturePath describes where in the data path the packet was captured. +type CapturePath uint8 + +// CapturePath values +const ( + // FromLocal indicates the packet was logged as it traversed the FromLocal path: + // i.e.: A packet from the local system into the TUN. + FromLocal CapturePath = 0 + // FromPeer indicates the packet was logged upon reception from a remote peer. + FromPeer CapturePath = 1 + // SynthesizedToLocal indicates the packet was generated from within tailscaled, + // and is being routed to the local machine's network stack. + SynthesizedToLocal CapturePath = 2 + // SynthesizedToPeer indicates the packet was generated from within tailscaled, + // and is being routed to a remote Wireguard peer. + SynthesizedToPeer CapturePath = 3 + + // PathDisco indicates the packet is information about a disco frame. + PathDisco CapturePath = 254 +) diff --git a/net/packet/geneve.go b/net/packet/geneve.go new file mode 100644 index 0000000000000..29970a8fd6bfb --- /dev/null +++ b/net/packet/geneve.go @@ -0,0 +1,104 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "encoding/binary" + "errors" + "io" +) + +const ( + // GeneveFixedHeaderLength is the length of the fixed size portion of the + // Geneve header, in bytes. + GeneveFixedHeaderLength = 8 +) + +const ( + // GeneveProtocolDisco is the IEEE 802 Ethertype number used to represent + // the Tailscale Disco protocol in a Geneve header. + GeneveProtocolDisco uint16 = 0x7A11 + // GeneveProtocolWireGuard is the IEEE 802 Ethertype number used to represent the + // WireGuard protocol in a Geneve header. + GeneveProtocolWireGuard uint16 = 0x7A12 +) + +// GeneveHeader represents the fixed size Geneve header from RFC8926. +// TLVs/options are not implemented/supported. +// +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |Ver| Opt Len |O|C| Rsvd. | Protocol Type | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// | Virtual Network Identifier (VNI) | Reserved | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +type GeneveHeader struct { + // Ver (2 bits): The current version number is 0. Packets received by a + // tunnel endpoint with an unknown version MUST be dropped. Transit devices + // interpreting Geneve packets with an unknown version number MUST treat + // them as UDP packets with an unknown payload. + Version uint8 + + // Protocol Type (16 bits): The type of protocol data unit appearing after + // the Geneve header. This follows the Ethertype [ETYPES] convention, with + // Ethernet itself being represented by the value 0x6558. + Protocol uint16 + + // Virtual Network Identifier (VNI) (24 bits): An identifier for a unique + // element of a virtual network. In many situations, this may represent an + // L2 segment; however, the control plane defines the forwarding semantics + // of decapsulated packets. The VNI MAY be used as part of ECMP forwarding + // decisions or MAY be used as a mechanism to distinguish between + // overlapping address spaces contained in the encapsulated packet when load + // balancing across CPUs. + VNI uint32 + + // O (1 bit): Control packet. This packet contains a control message. + // Control messages are sent between tunnel endpoints. Tunnel endpoints MUST + // NOT forward the payload, and transit devices MUST NOT attempt to + // interpret it. Since control messages are less frequent, it is RECOMMENDED + // that tunnel endpoints direct these packets to a high-priority control + // queue (for example, to direct the packet to a general purpose CPU from a + // forwarding Application-Specific Integrated Circuit (ASIC) or to separate + // out control traffic on a NIC). Transit devices MUST NOT alter forwarding + // behavior on the basis of this bit, such as ECMP link selection. + Control bool +} + +// Encode encodes GeneveHeader into b. If len(b) < GeneveFixedHeaderLength an +// io.ErrShortBuffer error is returned. +func (h *GeneveHeader) Encode(b []byte) error { + if len(b) < GeneveFixedHeaderLength { + return io.ErrShortBuffer + } + if h.Version > 3 { + return errors.New("version must be <= 3") + } + b[0] = 0 + b[1] = 0 + b[0] |= h.Version << 6 + if h.Control { + b[1] |= 0x80 + } + binary.BigEndian.PutUint16(b[2:], h.Protocol) + if h.VNI > 1<<24-1 { + return errors.New("VNI must be <= 2^24-1") + } + binary.BigEndian.PutUint32(b[4:], h.VNI<<8) + return nil +} + +// Decode decodes GeneveHeader from b. If len(b) < GeneveFixedHeaderLength an +// io.ErrShortBuffer error is returned. +func (h *GeneveHeader) Decode(b []byte) error { + if len(b) < GeneveFixedHeaderLength { + return io.ErrShortBuffer + } + h.Version = b[0] >> 6 + if b[1]&0x80 != 0 { + h.Control = true + } + h.Protocol = binary.BigEndian.Uint16(b[2:]) + h.VNI = binary.BigEndian.Uint32(b[4:]) >> 8 + return nil +} diff --git a/net/packet/geneve_test.go b/net/packet/geneve_test.go new file mode 100644 index 0000000000000..029638638aa96 --- /dev/null +++ b/net/packet/geneve_test.go @@ -0,0 +1,32 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package packet + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestGeneveHeader(t *testing.T) { + in := GeneveHeader{ + Version: 3, + Protocol: GeneveProtocolDisco, + VNI: 1<<24 - 1, + Control: true, + } + b := make([]byte, GeneveFixedHeaderLength) + err := in.Encode(b) + if err != nil { + t.Fatal(err) + } + out := GeneveHeader{} + err = out.Decode(b) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(out, in); diff != "" { + t.Fatalf("wrong results (-got +want)\n%s", diff) + } +} diff --git a/net/packet/packet.go b/net/packet/packet.go index c9521ad4667c2..b683b22126948 100644 --- a/net/packet/packet.go +++ b/net/packet/packet.go @@ -34,14 +34,6 @@ const ( TCPECNBits TCPFlag = TCPECNEcho | TCPCWR ) -// CaptureMeta contains metadata that is used when debugging. -type CaptureMeta struct { - DidSNAT bool // SNAT was performed & the address was updated. - OriginalSrc netip.AddrPort // The source address before SNAT was performed. - DidDNAT bool // DNAT was performed & the address was updated. - OriginalDst netip.AddrPort // The destination address before DNAT was performed. -} - // Parsed is a minimal decoding of a packet suitable for use in filters. type Parsed struct { // b is the byte buffer that this decodes. diff --git a/net/portmapper/igd_test.go b/net/portmapper/igd_test.go index 5c24d03aadde1..3ef7989a3a241 100644 --- a/net/portmapper/igd_test.go +++ b/net/portmapper/igd_test.go @@ -18,7 +18,10 @@ import ( "tailscale.com/net/netaddr" "tailscale.com/net/netmon" "tailscale.com/syncs" + "tailscale.com/tstest" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" + "tailscale.com/util/testenv" ) // TestIGD is an IGD (Internet Gateway Device) for testing. It supports fake @@ -63,7 +66,8 @@ type igdCounters struct { invalidPCPMapPkt int32 } -func NewTestIGD(logf logger.Logf, t TestIGDOptions) (*TestIGD, error) { +func NewTestIGD(tb testenv.TB, t TestIGDOptions) (*TestIGD, error) { + logf := tstest.WhileTestRunningLogger(tb) d := &TestIGD{ doPMP: t.PMP, doPCP: t.PCP, @@ -258,15 +262,25 @@ func (d *TestIGD) handlePCPQuery(pkt []byte, src netip.AddrPort) { } } -func newTestClient(t *testing.T, igd *TestIGD) *Client { +// newTestClient configures a new test client connected to igd for mapping updates. +// If bus != nil, update events are published to it. +// A cleanup for the resulting client is added to t. +func newTestClient(t *testing.T, igd *TestIGD, bus *eventbus.Bus) *Client { var c *Client - c = NewClient(t.Logf, netmon.NewStatic(), nil, new(controlknobs.Knobs), func() { - t.Logf("port map changed") - t.Logf("have mapping: %v", c.HaveMapping()) + c = NewClient(Config{ + Logf: tstest.WhileTestRunningLogger(t), + NetMon: netmon.NewStatic(), + ControlKnobs: new(controlknobs.Knobs), + EventBus: bus, + OnChange: func() { + t.Logf("port map changed") + t.Logf("have mapping: %v", c.HaveMapping()) + }, }) c.testPxPPort = igd.TestPxPPort() c.testUPnPPort = igd.TestUPnPPort() c.netMon = netmon.NewStatic() c.SetGatewayLookupFunc(testIPAndGateway) + t.Cleanup(func() { c.Close() }) return c } diff --git a/net/portmapper/portmapper.go b/net/portmapper/portmapper.go index 71b55b8a7f240..59f88e96604a5 100644 --- a/net/portmapper/portmapper.go +++ b/net/portmapper/portmapper.go @@ -31,6 +31,7 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/nettype" "tailscale.com/util/clientmetric" + "tailscale.com/util/eventbus" ) var disablePortMapperEnv = envknob.RegisterBool("TS_DISABLE_PORTMAPPER") @@ -84,6 +85,11 @@ const trustServiceStillAvailableDuration = 10 * time.Minute // Client is a port mapping client. type Client struct { + // The following two fields must either both be nil, or both non-nil. + // Both are immutable after construction. + pubClient *eventbus.Client + updates *eventbus.Publisher[Mapping] + logf logger.Logf netMon *netmon.Monitor // optional; nil means interfaces will be looked up on-demand controlKnobs *controlknobs.Knobs @@ -201,32 +207,56 @@ func (m *pmpMapping) Release(ctx context.Context) { uc.WriteToUDPAddrPort(pkt, m.gw) } -// NewClient returns a new portmapping client. -// -// The netMon parameter is required. -// -// The debug argument allows configuring the behaviour of the portmapper for -// debugging; if nil, a sensible set of defaults will be used. -// -// The controlKnobs, if non-nil, specifies the control knobs from the control -// plane that might disable portmapping. -// -// The optional onChange argument specifies a func to run in a new goroutine -// whenever the port mapping status has changed. If nil, it doesn't make a -// callback. -func NewClient(logf logger.Logf, netMon *netmon.Monitor, debug *DebugKnobs, controlKnobs *controlknobs.Knobs, onChange func()) *Client { - if netMon == nil { +// Config carries the settings for a [Client]. +type Config struct { + // EventBus, if non-nil, is used for event publication and subscription by + // portmapper clients created from this config. + // + // TODO(creachadair): As of 2025-03-19 this is optional, but is intended to + // become required non-nil. + EventBus *eventbus.Bus + + // Logf is called to generate text logs for the client. If nil, logger.Discard is used. + Logf logger.Logf + + // NetMon is the network monitor used by the client. It must be non-nil. + NetMon *netmon.Monitor + + // DebugKnobs, if non-nil, configure the behaviour of the portmapper for + // debugging. If nil, a sensible set of defaults will be used. + DebugKnobs *DebugKnobs + + // ControlKnobs, if non-nil, specifies knobs from the control plane that + // might disable port mapping. + ControlKnobs *controlknobs.Knobs + + // OnChange is called to run in a new goroutine whenever the port mapping + // status has changed. If nil, no callback is issued. + OnChange func() +} + +// NewClient constructs a new portmapping [Client] from c. It will panic if any +// required parameters are omitted. +func NewClient(c Config) *Client { + if c.NetMon == nil { panic("nil netMon") } ret := &Client{ - logf: logf, - netMon: netMon, + logf: c.Logf, + netMon: c.NetMon, ipAndGateway: netmon.LikelyHomeRouterIP, // TODO(bradfitz): move this to method on netMon - onChange: onChange, - controlKnobs: controlKnobs, + onChange: c.OnChange, + controlKnobs: c.ControlKnobs, } - if debug != nil { - ret.debug = *debug + if c.EventBus != nil { + ret.pubClient = c.EventBus.Client("portmapper") + ret.updates = eventbus.Publish[Mapping](ret.pubClient) + } + if ret.logf == nil { + ret.logf = logger.Discard + } + if c.DebugKnobs != nil { + ret.debug = *c.DebugKnobs } return ret } @@ -256,6 +286,10 @@ func (c *Client) Close() error { } c.closed = true c.invalidateMappingsLocked(true) + if c.updates != nil { + c.updates.Close() + c.pubClient.Close() + } // TODO: close some future ever-listening UDP socket(s), // waiting for multicast announcements from router. return nil @@ -467,13 +501,34 @@ func (c *Client) createMapping() { c.runningCreate = false }() - if _, err := c.createOrGetMapping(ctx); err == nil && c.onChange != nil { + mapping, _, err := c.createOrGetMapping(ctx) + if err != nil { + if !IsNoMappingError(err) { + c.logf("createOrGetMapping: %v", err) + } + return + } + if c.updates != nil { + c.updates.Publish(Mapping{ + External: mapping.External(), + Type: mapping.MappingType(), + GoodUntil: mapping.GoodUntil(), + }) + } + if c.onChange != nil { go c.onChange() - } else if err != nil && !IsNoMappingError(err) { - c.logf("createOrGetMapping: %v", err) } } +// Mapping is an event recording the allocation of a port mapping. +type Mapping struct { + External netip.AddrPort + Type string + GoodUntil time.Time + + // TODO(creachadair): Record whether we reused an existing mapping? +} + // wildcardIP is used when the previous external IP is not known for PCP port mapping. var wildcardIP = netip.MustParseAddr("0.0.0.0") @@ -482,19 +537,19 @@ var wildcardIP = netip.MustParseAddr("0.0.0.0") // // If no mapping is available, the error will be of type // NoMappingError; see IsNoMappingError. -func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPort, err error) { +func (c *Client) createOrGetMapping(ctx context.Context) (mapping mapping, external netip.AddrPort, err error) { if c.debug.disableAll() { - return netip.AddrPort{}, NoMappingError{ErrPortMappingDisabled} + return nil, netip.AddrPort{}, NoMappingError{ErrPortMappingDisabled} } if c.debug.DisableUPnP && c.debug.DisablePCP && c.debug.DisablePMP { - return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} + return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} } gw, myIP, ok := c.gatewayAndSelfIP() if !ok { - return netip.AddrPort{}, NoMappingError{ErrGatewayRange} + return nil, netip.AddrPort{}, NoMappingError{ErrGatewayRange} } if gw.Is6() { - return netip.AddrPort{}, NoMappingError{ErrGatewayIPv6} + return nil, netip.AddrPort{}, NoMappingError{ErrGatewayIPv6} } now := time.Now() @@ -523,6 +578,17 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor return } + // TODO(creachadair): This is more subtle than it should be. Ideally we + // would just return the mapping directly, but there are many different + // paths through the function with carefully-balanced locks, and not all + // the paths have a mapping to return. As a workaround, while we're here + // doing cleanup under the lock, grab the final mapping value and return + // it, so the caller does not need to grab the lock again and potentially + // race with a later update. The mapping itself is concurrency-safe. + // + // We should restructure this code so the locks are properly scoped. + mapping = c.mapping + // Print the internal details of each mapping if we're being verbose. if c.debug.VerboseLogs { c.logf("successfully obtained mapping: now=%d external=%v type=%s mapping=%s", @@ -548,7 +614,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor if now.Before(m.RenewAfter()) { defer c.mu.Unlock() reusedExisting = true - return m.External(), nil + return nil, m.External(), nil } // The mapping might still be valid, so just try to renew it. prevPort = m.External().Port() @@ -557,10 +623,10 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor if c.debug.DisablePCP && c.debug.DisablePMP { c.mu.Unlock() if external, ok := c.getUPnPPortMapping(ctx, gw, internalAddr, prevPort); ok { - return external, nil + return nil, external, nil } c.vlogf("fallback to UPnP due to PCP and PMP being disabled failed") - return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} + return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} } // If we just did a Probe (e.g. via netchecker) but didn't @@ -587,16 +653,16 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor c.mu.Unlock() // fallback to UPnP portmapping if external, ok := c.getUPnPPortMapping(ctx, gw, internalAddr, prevPort); ok { - return external, nil + return nil, external, nil } c.vlogf("fallback to UPnP due to no PCP and PMP failed") - return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} + return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} } c.mu.Unlock() uc, err := c.listenPacket(ctx, "udp4", ":0") if err != nil { - return netip.AddrPort{}, err + return nil, netip.AddrPort{}, err } defer uc.Close() @@ -616,7 +682,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor if neterror.TreatAsLostUDP(err) { err = NoMappingError{ErrNoPortMappingServices} } - return netip.AddrPort{}, err + return nil, netip.AddrPort{}, err } } else { // Ask for our external address if needed. @@ -625,7 +691,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor if neterror.TreatAsLostUDP(err) { err = NoMappingError{ErrNoPortMappingServices} } - return netip.AddrPort{}, err + return nil, netip.AddrPort{}, err } } @@ -634,7 +700,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor if neterror.TreatAsLostUDP(err) { err = NoMappingError{ErrNoPortMappingServices} } - return netip.AddrPort{}, err + return nil, netip.AddrPort{}, err } } @@ -643,13 +709,13 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor n, src, err := uc.ReadFromUDPAddrPort(res) if err != nil { if ctx.Err() == context.Canceled { - return netip.AddrPort{}, err + return nil, netip.AddrPort{}, err } // fallback to UPnP portmapping if mapping, ok := c.getUPnPPortMapping(ctx, gw, internalAddr, prevPort); ok { - return mapping, nil + return nil, mapping, nil } - return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} + return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} } src = netaddr.Unmap(src) if !src.IsValid() { @@ -665,7 +731,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor continue } if pres.ResultCode != 0 { - return netip.AddrPort{}, NoMappingError{fmt.Errorf("PMP response Op=0x%x,Res=0x%x", pres.OpCode, pres.ResultCode)} + return nil, netip.AddrPort{}, NoMappingError{fmt.Errorf("PMP response Op=0x%x,Res=0x%x", pres.OpCode, pres.ResultCode)} } if pres.OpCode == pmpOpReply|pmpOpMapPublicAddr { m.external = netip.AddrPortFrom(pres.PublicAddr, m.external.Port()) @@ -683,7 +749,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor if err != nil { c.logf("failed to get PCP mapping: %v", err) // PCP should only have a single packet response - return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} + return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} } pcpMapping.c = c pcpMapping.internal = m.internal @@ -691,10 +757,10 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor c.mu.Lock() defer c.mu.Unlock() c.mapping = pcpMapping - return pcpMapping.external, nil + return pcpMapping, pcpMapping.external, nil default: c.logf("unknown PMP/PCP version number: %d %v", version, res[:n]) - return netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} + return nil, netip.AddrPort{}, NoMappingError{ErrNoPortMappingServices} } } @@ -702,7 +768,7 @@ func (c *Client) createOrGetMapping(ctx context.Context) (external netip.AddrPor c.mu.Lock() defer c.mu.Unlock() c.mapping = m - return m.external, nil + return nil, m.external, nil } } } diff --git a/net/portmapper/portmapper_test.go b/net/portmapper/portmapper_test.go index d321b720a02b3..515a0c28c993f 100644 --- a/net/portmapper/portmapper_test.go +++ b/net/portmapper/portmapper_test.go @@ -12,20 +12,21 @@ import ( "time" "tailscale.com/control/controlknobs" + "tailscale.com/util/eventbus" ) func TestCreateOrGetMapping(t *testing.T) { if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v { t.Skip("skipping test without HIT_NETWORK=1") } - c := NewClient(t.Logf, nil, nil, new(controlknobs.Knobs), nil) + c := NewClient(Config{Logf: t.Logf, ControlKnobs: new(controlknobs.Knobs)}) defer c.Close() c.SetLocalPort(1234) for i := range 2 { if i > 0 { time.Sleep(100 * time.Millisecond) } - ext, err := c.createOrGetMapping(context.Background()) + _, ext, err := c.createOrGetMapping(context.Background()) t.Logf("Got: %v, %v", ext, err) } } @@ -34,7 +35,7 @@ func TestClientProbe(t *testing.T) { if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v { t.Skip("skipping test without HIT_NETWORK=1") } - c := NewClient(t.Logf, nil, nil, new(controlknobs.Knobs), nil) + c := NewClient(Config{Logf: t.Logf, ControlKnobs: new(controlknobs.Knobs)}) defer c.Close() for i := range 3 { if i > 0 { @@ -49,26 +50,25 @@ func TestClientProbeThenMap(t *testing.T) { if v, _ := strconv.ParseBool(os.Getenv("HIT_NETWORK")); !v { t.Skip("skipping test without HIT_NETWORK=1") } - c := NewClient(t.Logf, nil, nil, new(controlknobs.Knobs), nil) + c := NewClient(Config{Logf: t.Logf, ControlKnobs: new(controlknobs.Knobs)}) defer c.Close() c.debug.VerboseLogs = true c.SetLocalPort(1234) res, err := c.Probe(context.Background()) t.Logf("Probe: %+v, %v", res, err) - ext, err := c.createOrGetMapping(context.Background()) + _, ext, err := c.createOrGetMapping(context.Background()) t.Logf("createOrGetMapping: %v, %v", ext, err) } func TestProbeIntegration(t *testing.T) { - igd, err := NewTestIGD(t.Logf, TestIGDOptions{PMP: true, PCP: true, UPnP: true}) + igd, err := NewTestIGD(t, TestIGDOptions{PMP: true, PCP: true, UPnP: true}) if err != nil { t.Fatal(err) } defer igd.Close() - c := newTestClient(t, igd) + c := newTestClient(t, igd, nil) t.Logf("Listening on pxp=%v, upnp=%v", c.testPxPPort, c.testUPnPPort) - defer c.Close() res, err := c.Probe(context.Background()) if err != nil { @@ -95,14 +95,13 @@ func TestProbeIntegration(t *testing.T) { } func TestPCPIntegration(t *testing.T) { - igd, err := NewTestIGD(t.Logf, TestIGDOptions{PMP: false, PCP: true, UPnP: false}) + igd, err := NewTestIGD(t, TestIGDOptions{PMP: false, PCP: true, UPnP: false}) if err != nil { t.Fatal(err) } defer igd.Close() - c := newTestClient(t, igd) - defer c.Close() + c := newTestClient(t, igd, nil) res, err := c.Probe(context.Background()) if err != nil { t.Fatalf("probe failed: %v", err) @@ -114,7 +113,7 @@ func TestPCPIntegration(t *testing.T) { t.Fatalf("probe did not see pcp: %+v", res) } - external, err := c.createOrGetMapping(context.Background()) + _, external, err := c.createOrGetMapping(context.Background()) if err != nil { t.Fatalf("failed to get mapping: %v", err) } @@ -136,3 +135,29 @@ func TestGetUPnPErrorsMetric(t *testing.T) { getUPnPErrorsMetric(0) getUPnPErrorsMetric(-100) } + +func TestUpdateEvent(t *testing.T) { + igd, err := NewTestIGD(t, TestIGDOptions{PCP: true}) + if err != nil { + t.Fatalf("Create test gateway: %v", err) + } + + bus := eventbus.New() + defer bus.Close() + + sub := eventbus.Subscribe[Mapping](bus.Client("TestUpdateEvent")) + c := newTestClient(t, igd, bus) + if _, err := c.Probe(t.Context()); err != nil { + t.Fatalf("Probe failed: %v", err) + } + c.GetCachedMappingOrStartCreatingOne() + + select { + case evt := <-sub.Events(): + t.Logf("Received portmap update: %+v", evt) + case <-sub.Done(): + t.Error("Subscriber closed prematurely") + case <-time.After(5 * time.Second): + t.Error("Timed out waiting for an update event") + } +} diff --git a/net/portmapper/select_test.go b/net/portmapper/select_test.go index 9e99c9a9d3a90..af2e35cbfb764 100644 --- a/net/portmapper/select_test.go +++ b/net/portmapper/select_test.go @@ -28,7 +28,7 @@ func TestSelectBestService(t *testing.T) { } // Run a fake IGD server to respond to UPnP requests. - igd, err := NewTestIGD(t.Logf, TestIGDOptions{UPnP: true}) + igd, err := NewTestIGD(t, TestIGDOptions{UPnP: true}) if err != nil { t.Fatal(err) } @@ -163,9 +163,8 @@ func TestSelectBestService(t *testing.T) { Desc: rootDesc, Control: tt.control, }) - c := newTestClient(t, igd) + c := newTestClient(t, igd, nil) t.Logf("Listening on upnp=%v", c.testUPnPPort) - defer c.Close() // Ensure that we're using the HTTP client that talks to our test IGD server ctx := context.Background() diff --git a/net/portmapper/upnp.go b/net/portmapper/upnp.go index f1199f0a6c584..13418313597f0 100644 --- a/net/portmapper/upnp.go +++ b/net/portmapper/upnp.go @@ -610,8 +610,9 @@ func (c *Client) tryUPnPPortmapWithDevice( } // From the UPnP spec: http://upnp.org/specs/gw/UPnP-gw-WANIPConnection-v2-Service.pdf + // 402: Invalid Args (see: https://github.com/tailscale/tailscale/issues/15223) // 725: OnlyPermanentLeasesSupported - if ok && code == 725 { + if ok && (code == 402 || code == 725) { newPort, err = addAnyPortMapping( ctx, client, @@ -620,7 +621,7 @@ func (c *Client) tryUPnPPortmapWithDevice( internal.Addr().String(), 0, // permanent ) - c.vlogf("addAnyPortMapping: 725 retry %v, err=%q", newPort, err) + c.vlogf("addAnyPortMapping: errcode=%d retried: port=%v err=%v", code, newPort, err) } } if err != nil { diff --git a/net/portmapper/upnp_test.go b/net/portmapper/upnp_test.go index c41b535a54df2..c07ec020813ed 100644 --- a/net/portmapper/upnp_test.go +++ b/net/portmapper/upnp_test.go @@ -533,7 +533,7 @@ func TestGetUPnPClient(t *testing.T) { } func TestGetUPnPPortMapping(t *testing.T) { - igd, err := NewTestIGD(t.Logf, TestIGDOptions{UPnP: true}) + igd, err := NewTestIGD(t, TestIGDOptions{UPnP: true}) if err != nil { t.Fatal(err) } @@ -586,9 +586,8 @@ func TestGetUPnPPortMapping(t *testing.T) { }, }) - c := newTestClient(t, igd) + c := newTestClient(t, igd, nil) t.Logf("Listening on upnp=%v", c.testUPnPPort) - defer c.Close() c.debug.VerboseLogs = true @@ -628,13 +627,102 @@ func TestGetUPnPPortMapping(t *testing.T) { } } +func TestGetUPnPPortMapping_LeaseDuration(t *testing.T) { + testCases := []struct { + name string + resp string + }{ + {"only_permanent_leases", testAddPortMappingPermanentLease}, + {"invalid_args", testAddPortMappingPermanentLease_InvalidArgs}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + // This is a very basic fake UPnP server handler. + var sawRequestWithLease atomic.Bool + handlers := map[string]any{ + "AddPortMapping": func(body []byte) (int, string) { + // Decode a minimal body to determine whether we skip the request or not. + var req struct { + Protocol string `xml:"NewProtocol"` + InternalPort string `xml:"NewInternalPort"` + ExternalPort string `xml:"NewExternalPort"` + InternalClient string `xml:"NewInternalClient"` + LeaseDuration string `xml:"NewLeaseDuration"` + } + if err := xml.Unmarshal(body, &req); err != nil { + t.Errorf("bad request: %v", err) + return http.StatusBadRequest, "bad request" + } + + if req.Protocol != "UDP" { + t.Errorf(`got Protocol=%q, want "UDP"`, req.Protocol) + } + if req.LeaseDuration != "0" { + // Return a fake error to ensure that we fall back to a permanent lease. + sawRequestWithLease.Store(true) + return http.StatusOK, tc.resp + } + + return http.StatusOK, testAddPortMappingResponse + }, + "GetExternalIPAddress": testGetExternalIPAddressResponse, + "GetStatusInfo": testGetStatusInfoResponse, + "DeletePortMapping": "", // Do nothing for test + } + + igd, err := NewTestIGD(t, TestIGDOptions{UPnP: true}) + if err != nil { + t.Fatal(err) + } + defer igd.Close() + + igd.SetUPnPHandler(&upnpServer{ + t: t, + Desc: testRootDesc, + Control: map[string]map[string]any{ + "/ctl/IPConn": handlers, + "/upnp/control/yomkmsnooi/wanipconn-1": handlers, + }, + }) + + ctx := context.Background() + c := newTestClient(t, igd, nil) + c.debug.VerboseLogs = true + t.Logf("Listening on upnp=%v", c.testUPnPPort) + + // Actually test the UPnP port mapping. + mustProbeUPnP(t, ctx, c) + + gw, myIP, ok := c.gatewayAndSelfIP() + if !ok { + t.Fatalf("could not get gateway and self IP") + } + t.Logf("gw=%v myIP=%v", gw, myIP) + + ext, ok := c.getUPnPPortMapping(ctx, gw, netip.AddrPortFrom(myIP, 12345), 0) + if !ok { + t.Fatal("could not get UPnP port mapping") + } + if got, want := ext.Addr(), netip.MustParseAddr("123.123.123.123"); got != want { + t.Errorf("bad external address; got %v want %v", got, want) + } + if !sawRequestWithLease.Load() { + t.Errorf("wanted request with lease, but didn't see one") + } + t.Logf("external IP: %v", ext) + }) + } +} + // TestGetUPnPPortMapping_NoValidServices tests that getUPnPPortMapping doesn't // crash when a valid UPnP response with no supported services is discovered // and parsed. // // See https://github.com/tailscale/tailscale/issues/10911 func TestGetUPnPPortMapping_NoValidServices(t *testing.T) { - igd, err := NewTestIGD(t.Logf, TestIGDOptions{UPnP: true}) + igd, err := NewTestIGD(t, TestIGDOptions{UPnP: true}) if err != nil { t.Fatal(err) } @@ -645,8 +733,7 @@ func TestGetUPnPPortMapping_NoValidServices(t *testing.T) { Desc: noSupportedServicesRootDesc, }) - c := newTestClient(t, igd) - defer c.Close() + c := newTestClient(t, igd, nil) c.debug.VerboseLogs = true ctx := context.Background() @@ -666,7 +753,7 @@ func TestGetUPnPPortMapping_NoValidServices(t *testing.T) { // Tests the legacy behaviour with the pre-UPnP standard portmapping service. func TestGetUPnPPortMapping_Legacy(t *testing.T) { - igd, err := NewTestIGD(t.Logf, TestIGDOptions{UPnP: true}) + igd, err := NewTestIGD(t, TestIGDOptions{UPnP: true}) if err != nil { t.Fatal(err) } @@ -688,8 +775,7 @@ func TestGetUPnPPortMapping_Legacy(t *testing.T) { }, }) - c := newTestClient(t, igd) - defer c.Close() + c := newTestClient(t, igd, nil) c.debug.VerboseLogs = true ctx := context.Background() @@ -710,15 +796,14 @@ func TestGetUPnPPortMapping_Legacy(t *testing.T) { } func TestGetUPnPPortMappingNoResponses(t *testing.T) { - igd, err := NewTestIGD(t.Logf, TestIGDOptions{UPnP: true}) + igd, err := NewTestIGD(t, TestIGDOptions{UPnP: true}) if err != nil { t.Fatal(err) } defer igd.Close() - c := newTestClient(t, igd) + c := newTestClient(t, igd, nil) t.Logf("Listening on upnp=%v", c.testUPnPPort) - defer c.Close() c.debug.VerboseLogs = true @@ -827,7 +912,7 @@ func TestGetUPnPPortMapping_Invalid(t *testing.T) { "127.0.0.1", } { t.Run(responseAddr, func(t *testing.T) { - igd, err := NewTestIGD(t.Logf, TestIGDOptions{UPnP: true}) + igd, err := NewTestIGD(t, TestIGDOptions{UPnP: true}) if err != nil { t.Fatal(err) } @@ -849,8 +934,7 @@ func TestGetUPnPPortMapping_Invalid(t *testing.T) { }, }) - c := newTestClient(t, igd) - defer c.Close() + c := newTestClient(t, igd, nil) c.debug.VerboseLogs = true ctx := context.Background() @@ -1045,6 +1129,23 @@ const testAddPortMappingPermanentLease = ` ` +const testAddPortMappingPermanentLease_InvalidArgs = ` + + + + SOAP:Client + UPnPError + + + 402 + Invalid Args + + + + + +` + const testAddPortMappingResponse = ` diff --git a/net/routetable/routetable_linux.go b/net/routetable/routetable_linux.go index 88dc8535a99e4..0b2cb305d7154 100644 --- a/net/routetable/routetable_linux.go +++ b/net/routetable/routetable_linux.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux +//go:build linux && !android package routetable diff --git a/net/routetable/routetable_other.go b/net/routetable/routetable_other.go index 35c83e374564f..e547ab0ac769a 100644 --- a/net/routetable/routetable_other.go +++ b/net/routetable/routetable_other.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !linux && !darwin && !freebsd +//go:build android || (!linux && !darwin && !freebsd) package routetable diff --git a/net/socks5/socks5.go b/net/socks5/socks5.go index 0d651537fac9a..4a5befa1d2fef 100644 --- a/net/socks5/socks5.go +++ b/net/socks5/socks5.go @@ -81,6 +81,12 @@ const ( addrTypeNotSupported replyCode = 8 ) +// UDP conn default buffer size and read timeout. +const ( + bufferSize = 8 * 1024 + readTimeout = 5 * time.Second +) + // Server is a SOCKS5 proxy server. type Server struct { // Logf optionally specifies the logger to use. @@ -143,7 +149,8 @@ type Conn struct { clientConn net.Conn request *request - udpClientAddr net.Addr + udpClientAddr net.Addr + udpTargetConns map[socksAddr]net.Conn } // Run starts the new connection. @@ -276,15 +283,6 @@ func (c *Conn) handleUDP() error { } defer clientUDPConn.Close() - serverUDPConn, err := net.ListenPacket("udp", "[::]:0") - if err != nil { - res := errorResponse(generalFailure) - buf, _ := res.marshal() - c.clientConn.Write(buf) - return err - } - defer serverUDPConn.Close() - bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String()) if err != nil { return err @@ -305,25 +303,32 @@ func (c *Conn) handleUDP() error { } c.clientConn.Write(buf) - return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn) + return c.transferUDP(c.clientConn, clientUDPConn) } -func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error { +func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) error { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - const bufferSize = 8 * 1024 - const readTimeout = 5 * time.Second // client -> target go func() { defer cancel() + + c.udpTargetConns = make(map[socksAddr]net.Conn) + // close all target udp connections when the client connection is closed + defer func() { + for _, conn := range c.udpTargetConns { + _ = conn.Close() + } + }() + buf := make([]byte, bufferSize) for { select { case <-ctx.Done(): return default: - err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout) + err := c.handleUDPRequest(ctx, clientConn, buf) if err != nil { if isTimeout(err) { continue @@ -337,21 +342,44 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta } }() + // A UDP association terminates when the TCP connection that the UDP + // ASSOCIATE request arrived on terminates. RFC1928 + _, err := io.Copy(io.Discard, associatedTCP) + if err != nil { + err = fmt.Errorf("udp associated tcp conn: %w", err) + } + return err +} + +func (c *Conn) getOrDialTargetConn( + ctx context.Context, + clientConn net.PacketConn, + targetAddr socksAddr, +) (net.Conn, error) { + conn, exist := c.udpTargetConns[targetAddr] + if exist { + return conn, nil + } + conn, err := c.srv.dial(ctx, "udp", targetAddr.hostPort()) + if err != nil { + return nil, err + } + c.udpTargetConns[targetAddr] = conn + // target -> client go func() { - defer cancel() buf := make([]byte, bufferSize) for { select { case <-ctx.Done(): return default: - err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout) + err := c.handleUDPResponse(clientConn, targetAddr, conn, buf) if err != nil { if isTimeout(err) { continue } - if errors.Is(err, net.ErrClosed) { + if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) { return } c.logf("udp transfer: handle udp response fail: %v", err) @@ -360,20 +388,13 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta } }() - // A UDP association terminates when the TCP connection that the UDP - // ASSOCIATE request arrived on terminates. RFC1928 - _, err := io.Copy(io.Discard, associatedTCP) - if err != nil { - err = fmt.Errorf("udp associated tcp conn: %w", err) - } - return err + return conn, nil } func (c *Conn) handleUDPRequest( + ctx context.Context, clientConn net.PacketConn, - targetConn net.PacketConn, buf []byte, - readTimeout time.Duration, ) error { // add a deadline for the read to avoid blocking forever _ = clientConn.SetReadDeadline(time.Now().Add(readTimeout)) @@ -386,38 +407,35 @@ func (c *Conn) handleUDPRequest( if err != nil { return fmt.Errorf("parse udp request: %w", err) } - targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort()) + + targetConn, err := c.getOrDialTargetConn(ctx, clientConn, req.addr) if err != nil { - c.logf("resolve target addr fail: %v", err) + return fmt.Errorf("dial target %s fail: %w", req.addr, err) } - nn, err := targetConn.WriteTo(data, targetAddr) + nn, err := targetConn.Write(data) if err != nil { - return fmt.Errorf("write to target %s fail: %w", targetAddr, err) + return fmt.Errorf("write to target %s fail: %w", req.addr, err) } if nn != len(data) { - return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite) + return fmt.Errorf("write to target %s fail: %w", req.addr, io.ErrShortWrite) } return nil } func (c *Conn) handleUDPResponse( - targetConn net.PacketConn, clientConn net.PacketConn, + targetAddr socksAddr, + targetConn net.Conn, buf []byte, - readTimeout time.Duration, ) error { // add a deadline for the read to avoid blocking forever _ = targetConn.SetReadDeadline(time.Now().Add(readTimeout)) - n, addr, err := targetConn.ReadFrom(buf) + n, err := targetConn.Read(buf) if err != nil { return fmt.Errorf("read from target: %w", err) } - host, port, err := splitHostPort(addr.String()) - if err != nil { - return fmt.Errorf("split host port: %w", err) - } - hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}} + hdr := udpRequest{addr: targetAddr} pkt, err := hdr.marshal() if err != nil { return fmt.Errorf("marshal udp request: %w", err) @@ -627,10 +645,15 @@ func (s socksAddr) marshal() ([]byte, error) { pkt = binary.BigEndian.AppendUint16(pkt, s.port) return pkt, nil } + func (s socksAddr) hostPort() string { return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port))) } +func (s socksAddr) String() string { + return s.hostPort() +} + // response contains the contents of // a response packet sent from the proxy // to the client. diff --git a/net/socks5/socks5_test.go b/net/socks5/socks5_test.go index 11ea59d4b57d1..bc6fac79fdcf9 100644 --- a/net/socks5/socks5_test.go +++ b/net/socks5/socks5_test.go @@ -169,12 +169,25 @@ func TestReadPassword(t *testing.T) { func TestUDP(t *testing.T) { // backend UDP server which we'll use SOCKS5 to connect to - listener, err := net.ListenPacket("udp", ":0") - if err != nil { - t.Fatal(err) + newUDPEchoServer := func() net.PacketConn { + listener, err := net.ListenPacket("udp", ":0") + if err != nil { + t.Fatal(err) + } + go udpEchoServer(listener) + return listener } - backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port - go udpEchoServer(listener) + + const echoServerNumber = 3 + echoServerListener := make([]net.PacketConn, echoServerNumber) + for i := 0; i < echoServerNumber; i++ { + echoServerListener[i] = newUDPEchoServer() + } + defer func() { + for i := 0; i < echoServerNumber; i++ { + _ = echoServerListener[i].Close() + } + }() // SOCKS5 server socks5, err := net.Listen("tcp", ":0") @@ -184,84 +197,93 @@ func TestUDP(t *testing.T) { socks5Port := socks5.Addr().(*net.TCPAddr).Port go socks5Server(socks5) - // net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request - conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port)) - if err != nil { - t.Fatal(err) - } - _, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth - if err != nil { - t.Fatal(err) - } - buf := make([]byte, 1024) - n, err := conn.Read(buf) // server hello - if err != nil { - t.Fatal(err) - } - if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 { - t.Fatalf("got: %q want: 0x05 0x00", buf[:n]) - } + // make a socks5 udpAssociate conn + newUdpAssociateConn := func() (socks5Conn net.Conn, socks5UDPAddr socksAddr) { + // net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request + conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port)) + if err != nil { + t.Fatal(err) + } + _, err = conn.Write([]byte{socks5Version, 0x01, noAuthRequired}) // client hello with no auth + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := conn.Read(buf) // server hello + if err != nil { + t.Fatal(err) + } + if n != 2 || buf[0] != socks5Version || buf[1] != noAuthRequired { + t.Fatalf("got: %q want: 0x05 0x00", buf[:n]) + } - targetAddr := socksAddr{ - addrType: domainName, - addr: "localhost", - port: uint16(backendServerPort), - } - targetAddrPkt, err := targetAddr.marshal() - if err != nil { - t.Fatal(err) - } - _, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust - if err != nil { - t.Fatal(err) - } + targetAddr := socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0} + targetAddrPkt, err := targetAddr.marshal() + if err != nil { + t.Fatal(err) + } + _, err = conn.Write(append([]byte{socks5Version, byte(udpAssociate), 0x00}, targetAddrPkt...)) // client reqeust + if err != nil { + t.Fatal(err) + } - n, err = conn.Read(buf) // server response - if err != nil { - t.Fatal(err) - } - if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) { - t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n]) + n, err = conn.Read(buf) // server response + if err != nil { + t.Fatal(err) + } + if n < 3 || !bytes.Equal(buf[:3], []byte{socks5Version, 0x00, 0x00}) { + t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n]) + } + udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n])) + if err != nil { + t.Fatal(err) + } + + return conn, udpProxySocksAddr } - udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n])) - if err != nil { - t.Fatal(err) + + conn, udpProxySocksAddr := newUdpAssociateConn() + defer conn.Close() + + sendUDPAndWaitResponse := func(socks5UDPConn net.Conn, addr socksAddr, body []byte) (responseBody []byte) { + udpPayload, err := (&udpRequest{addr: addr}).marshal() + if err != nil { + t.Fatal(err) + } + udpPayload = append(udpPayload, body...) + _, err = socks5UDPConn.Write(udpPayload) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, 1024) + n, err := socks5UDPConn.Read(buf) + if err != nil { + t.Fatal(err) + } + _, responseBody, err = parseUDPRequest(buf[:n]) + if err != nil { + t.Fatal(err) + } + return responseBody } udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort()) if err != nil { t.Fatal(err) } - udpConn, err := net.DialUDP("udp", nil, udpProxyAddr) - if err != nil { - t.Fatal(err) - } - udpPayload, err := (&udpRequest{addr: targetAddr}).marshal() - if err != nil { - t.Fatal(err) - } - udpPayload = append(udpPayload, []byte("Test")...) - _, err = udpConn.Write(udpPayload) // send udp package - if err != nil { - t.Fatal(err) - } - n, _, err = udpConn.ReadFrom(buf) - if err != nil { - t.Fatal(err) - } - _, responseBody, err := parseUDPRequest(buf[:n]) // read udp response - if err != nil { - t.Fatal(err) - } - if string(responseBody) != "Test" { - t.Fatalf("got: %q want: Test", responseBody) - } - err = udpConn.Close() + socks5UDPConn, err := net.DialUDP("udp", nil, udpProxyAddr) if err != nil { t.Fatal(err) } - err = conn.Close() - if err != nil { - t.Fatal(err) + defer socks5UDPConn.Close() + + for i := 0; i < echoServerNumber; i++ { + port := echoServerListener[i].LocalAddr().(*net.UDPAddr).Port + addr := socksAddr{addrType: ipv4, addr: "127.0.0.1", port: uint16(port)} + requestBody := []byte(fmt.Sprintf("Test %d", i)) + responseBody := sendUDPAndWaitResponse(socks5UDPConn, addr, requestBody) + if !bytes.Equal(requestBody, responseBody) { + t.Fatalf("got: %q want: %q", responseBody, requestBody) + } } } diff --git a/net/sockstats/sockstats_tsgo.go b/net/sockstats/sockstats_tsgo.go index af691302f8be8..fec9ec3b0dad2 100644 --- a/net/sockstats/sockstats_tsgo.go +++ b/net/sockstats/sockstats_tsgo.go @@ -279,7 +279,13 @@ func setNetMon(netMon *netmon.Monitor) { if ifName == "" { return } - ifIndex := state.Interface[ifName].Index + // DefaultRouteInterface and Interface are gathered at different points in time. + // Check for existence first, to avoid a nil pointer dereference. + iface, ok := state.Interface[ifName] + if !ok { + return + } + ifIndex := iface.Index sockStats.mu.Lock() defer sockStats.mu.Unlock() // Ignore changes to unknown interfaces -- it would require diff --git a/net/tlsdial/blockblame/blockblame.go b/net/tlsdial/blockblame/blockblame.go new file mode 100644 index 0000000000000..57dc7a6e6d885 --- /dev/null +++ b/net/tlsdial/blockblame/blockblame.go @@ -0,0 +1,104 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package blockblame blames specific firewall manufacturers for blocking Tailscale, +// by analyzing the SSL certificate presented when attempting to connect to a remote +// server. +package blockblame + +import ( + "crypto/x509" + "strings" +) + +// VerifyCertificate checks if the given certificate c is issued by a firewall manufacturer +// that is known to block Tailscale connections. It returns true and the Manufacturer of +// the equipment if it is, or false and nil if it is not. +func VerifyCertificate(c *x509.Certificate) (m *Manufacturer, ok bool) { + for _, m := range Manufacturers { + if m.match != nil && m.match(c) { + return m, true + } + } + return nil, false +} + +// Manufacturer represents a firewall manufacturer that may be blocking Tailscale. +type Manufacturer struct { + // Name is the name of the firewall manufacturer to be + // mentioned in health warning messages, e.g. "Fortinet". + Name string + // match is a function that returns true if the given certificate looks like it might + // be issued by this manufacturer. + match matchFunc +} + +var Manufacturers = []*Manufacturer{ + { + Name: "Aruba Networks", + match: issuerContains("Aruba"), + }, + { + Name: "Cisco", + match: issuerContains("Cisco"), + }, + { + Name: "Fortinet", + match: matchAny( + issuerContains("Fortinet"), + certEmail("support@fortinet.com"), + ), + }, + { + Name: "Huawei", + match: certEmail("mobile@huawei.com"), + }, + { + Name: "Palo Alto Networks", + match: matchAny( + issuerContains("Palo Alto Networks"), + issuerContains("PAN-FW"), + ), + }, + { + Name: "Sophos", + match: issuerContains("Sophos"), + }, + { + Name: "Ubiquiti", + match: matchAny( + issuerContains("UniFi"), + issuerContains("Ubiquiti"), + ), + }, +} + +type matchFunc func(*x509.Certificate) bool + +func issuerContains(s string) matchFunc { + return func(c *x509.Certificate) bool { + return strings.Contains(strings.ToLower(c.Issuer.String()), strings.ToLower(s)) + } +} + +func certEmail(v string) matchFunc { + return func(c *x509.Certificate) bool { + for _, email := range c.EmailAddresses { + if strings.Contains(strings.ToLower(email), strings.ToLower(v)) { + return true + } + } + return false + } +} + +func matchAny(fs ...matchFunc) matchFunc { + return func(c *x509.Certificate) bool { + for _, f := range fs { + if f(c) { + return true + } + } + return false + } +} diff --git a/net/tlsdial/blockblame/blockblame_test.go b/net/tlsdial/blockblame/blockblame_test.go new file mode 100644 index 0000000000000..6d3592c60a3de --- /dev/null +++ b/net/tlsdial/blockblame/blockblame_test.go @@ -0,0 +1,54 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package blockblame + +import ( + "crypto/x509" + "encoding/pem" + "testing" +) + +const controlplaneDotTailscaleDotComPEM = ` +-----BEGIN CERTIFICATE----- +MIIDkzCCAxqgAwIBAgISA2GOahsftpp59yuHClbDuoduMAoGCCqGSM49BAMDMDIx +CzAJBgNVBAYTAlVTMRYwFAYDVQQKEw1MZXQncyBFbmNyeXB0MQswCQYDVQQDEwJF +NjAeFw0yNDEwMTIxNjE2NDVaFw0yNTAxMTAxNjE2NDRaMCUxIzAhBgNVBAMTGmNv +bnRyb2xwbGFuZS50YWlsc2NhbGUuY29tMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcD +QgAExfraDUc1t185zuGtZlnPDtEJJSDBqvHN4vQcXSzSTPSAdDYHcA8fL5woU2Kg +jK/2C0wm/rYy2Rre/ulhkS4wB6OCAhswggIXMA4GA1UdDwEB/wQEAwIHgDAdBgNV +HSUEFjAUBggrBgEFBQcDAQYIKwYBBQUHAwIwDAYDVR0TAQH/BAIwADAdBgNVHQ4E +FgQUpArnpDj8Yh6NTgMOZjDPx0TuLmcwHwYDVR0jBBgwFoAUkydGmAOpUWiOmNbE +QkjbI79YlNIwVQYIKwYBBQUHAQEESTBHMCEGCCsGAQUFBzABhhVodHRwOi8vZTYu +by5sZW5jci5vcmcwIgYIKwYBBQUHMAKGFmh0dHA6Ly9lNi5pLmxlbmNyLm9yZy8w +JQYDVR0RBB4wHIIaY29udHJvbHBsYW5lLnRhaWxzY2FsZS5jb20wEwYDVR0gBAww +CjAIBgZngQwBAgEwggEDBgorBgEEAdZ5AgQCBIH0BIHxAO8AdgDgkrP8DB3I52g2 +H95huZZNClJ4GYpy1nLEsE2lbW9UBAAAAZKBujCyAAAEAwBHMEUCIQDHMgUaL4H9 +ZJa090ZOpBeEVu3+t+EF4HlHI1NqAai6uQIgeY/lLfjAXfcVgxBHHR4zjd0SzhaP +TREHXzwxzN/8blkAdQDPEVbu1S58r/OHW9lpLpvpGnFnSrAX7KwB0lt3zsw7CAAA +AZKBujh8AAAEAwBGMEQCICQwhMk45t9aiFjfwOC/y6+hDbszqSCpIv63kFElweUy +AiAqTdkqmbqUVpnav5JdWkNERVAIlY4jqrThLsCLZYbNszAKBggqhkjOPQQDAwNn +ADBkAjALyfgAt1XQp1uSfxy4GapR5OsmjEMBRVq6IgsPBlCRBfmf0Q3/a6mF0pjb +Sj4oa+cCMEhZk4DmBTIdZY9zjuh8s7bXNfKxUQS0pEhALtXqyFr+D5dF7JcQo9+s +Z98JY7/PCA== +-----END CERTIFICATE-----` + +func TestVerifyCertificateOurControlPlane(t *testing.T) { + p, _ := pem.Decode([]byte(controlplaneDotTailscaleDotComPEM)) + if p == nil { + t.Fatalf("failed to extract certificate bytes for controlplane.tailscale.com") + return + } + cert, err := x509.ParseCertificate(p.Bytes) + if err != nil { + t.Fatalf("failed to parse certificate: %v", err) + return + } + m, found := VerifyCertificate(cert) + if found { + t.Fatalf("expected to not get a result for the controlplane.tailscale.com certificate") + } + if m != nil { + t.Fatalf("expected nil manufacturer for controlplane.tailscale.com certificate") + } +} diff --git a/net/tlsdial/tlsdial.go b/net/tlsdial/tlsdial.go index a49e7f0f730ee..1bd2450aa3c5d 100644 --- a/net/tlsdial/tlsdial.go +++ b/net/tlsdial/tlsdial.go @@ -12,6 +12,7 @@ package tlsdial import ( "bytes" "context" + "crypto/sha256" "crypto/tls" "crypto/x509" "errors" @@ -20,13 +21,17 @@ import ( "net" "net/http" "os" + "strings" "sync" "sync/atomic" "time" + "tailscale.com/derp/derpconst" "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/hostinfo" + "tailscale.com/net/bakedroots" + "tailscale.com/net/tlsdial/blockblame" ) var counterFallbackOK int32 // atomic @@ -44,6 +49,16 @@ var debug = envknob.RegisterBool("TS_DEBUG_TLS_DIAL") // Headscale, etc. var tlsdialWarningPrinted sync.Map // map[string]bool +var mitmBlockWarnable = health.Register(&health.Warnable{ + Code: "blockblame-mitm-detected", + Title: "Network may be blocking Tailscale", + Text: func(args health.Args) string { + return fmt.Sprintf("Network equipment from %q may be blocking Tailscale traffic on this network. Connect to another network, or contact your network administrator for assistance.", args["manufacturer"]) + }, + Severity: health.SeverityMedium, + ImpactsConnectivity: true, +}) + // Config returns a tls.Config for connecting to a server. // If base is non-nil, it's cloned as the base config before // being configured and returned. @@ -78,20 +93,37 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config { // (with the baked-in fallback root) in the VerifyConnection hook. conf.InsecureSkipVerify = true conf.VerifyConnection = func(cs tls.ConnectionState) (retErr error) { - if host == "log.tailscale.io" && hostinfo.IsNATLabGuestVM() { - // Allow log.tailscale.io TLS MITM for integration tests when + if host == "log.tailscale.com" && hostinfo.IsNATLabGuestVM() { + // Allow log.tailscale.com TLS MITM for integration tests when // the client's running within a NATLab VM. return nil } // Perform some health checks on this certificate before we do // any verification. + var cert *x509.Certificate var selfSignedIssuer string - if certs := cs.PeerCertificates; len(certs) > 0 && certIsSelfSigned(certs[0]) { - selfSignedIssuer = certs[0].Issuer.String() + if certs := cs.PeerCertificates; len(certs) > 0 { + cert = certs[0] + if certIsSelfSigned(cert) { + selfSignedIssuer = cert.Issuer.String() + } } if ht != nil { defer func() { + if retErr != nil && cert != nil { + // Is it a MITM SSL certificate from a well-known network appliance manufacturer? + // Show a dedicated warning. + m, ok := blockblame.VerifyCertificate(cert) + if ok { + log.Printf("tlsdial: server cert for %q looks like %q equipment (could be blocking Tailscale)", host, m.Name) + ht.SetUnhealthy(mitmBlockWarnable, health.Args{"manufacturer": m.Name}) + } else { + ht.SetHealthy(mitmBlockWarnable) + } + } else { + ht.SetHealthy(mitmBlockWarnable) + } if retErr != nil && selfSignedIssuer != "" { // Self-signed certs are never valid. // @@ -126,7 +158,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config { // Always verify with our baked-in Let's Encrypt certificate, // so we can log an informational message. This is useful for // detecting SSL MiTM. - opts.Roots = bakedInRoots() + opts.Roots = bakedroots.Get() _, bakedErr := cs.PeerCertificates[0].Verify(opts) if debug() { log.Printf("tlsdial(bake %q): %v", host, bakedErr) @@ -205,7 +237,7 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) { if errSys == nil { return nil } - opts.Roots = bakedInRoots() + opts.Roots = bakedroots.Get() _, err := certs[0].Verify(opts) if debug() { log.Printf("tlsdial(bake %q/%q): %v", c.ServerName, certDNSName, err) @@ -217,6 +249,54 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) { } } +// SetConfigExpectedCertHash configures c's VerifyPeerCertificate function to +// require that exactly 1 cert is presented (not counting any present MetaCert), +// and that the hex of its SHA256 hash is equal to wantFullCertSHA256Hex and +// that it's a valid cert for c.ServerName. +func SetConfigExpectedCertHash(c *tls.Config, wantFullCertSHA256Hex string) { + if c.VerifyPeerCertificate != nil { + panic("refusing to override tls.Config.VerifyPeerCertificate") + } + // Set InsecureSkipVerify to prevent crypto/tls from doing its + // own cert verification, but do the same work that it'd do + // (but using certDNSName) in the VerifyPeerCertificate hook. + c.InsecureSkipVerify = true + c.VerifyConnection = nil + c.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error { + var sawGoodCert bool + for _, rawCert := range rawCerts { + cert, err := x509.ParseCertificate(rawCert) + if err != nil { + return fmt.Errorf("ParseCertificate: %w", err) + } + if strings.HasPrefix(cert.Subject.CommonName, derpconst.MetaCertCommonNamePrefix) { + continue + } + if sawGoodCert { + return errors.New("unexpected multiple certs presented") + } + if fmt.Sprintf("%02x", sha256.Sum256(rawCert)) != wantFullCertSHA256Hex { + return fmt.Errorf("cert hash does not match expected cert hash") + } + if err := cert.VerifyHostname(c.ServerName); err != nil { + return fmt.Errorf("cert does not match server name %q: %w", c.ServerName, err) + } + now := time.Now() + if now.After(cert.NotAfter) { + return fmt.Errorf("cert expired %v", cert.NotAfter) + } + if now.Before(cert.NotBefore) { + return fmt.Errorf("cert not yet valid until %v; is your clock correct?", cert.NotBefore) + } + sawGoodCert = true + } + if !sawGoodCert { + return errors.New("expected cert not presented") + } + return nil + } +} + // NewTransport returns a new HTTP transport that verifies TLS certs using this // package, including its baked-in LetsEncrypt fallback roots. func NewTransport() *http.Transport { @@ -232,84 +312,3 @@ func NewTransport() *http.Transport { }, } } - -/* -letsEncryptX1 is the LetsEncrypt X1 root: - -Certificate: - - Data: - Version: 3 (0x2) - Serial Number: - 82:10:cf:b0:d2:40:e3:59:44:63:e0:bb:63:82:8b:00 - Signature Algorithm: sha256WithRSAEncryption - Issuer: C = US, O = Internet Security Research Group, CN = ISRG Root X1 - Validity - Not Before: Jun 4 11:04:38 2015 GMT - Not After : Jun 4 11:04:38 2035 GMT - Subject: C = US, O = Internet Security Research Group, CN = ISRG Root X1 - Subject Public Key Info: - Public Key Algorithm: rsaEncryption - RSA Public-Key: (4096 bit) - -We bake it into the binary as a fallback verification root, -in case the system we're running on doesn't have it. -(Tailscale runs on some ancient devices.) - -To test that this code is working on Debian/Ubuntu: - -$ sudo mv /usr/share/ca-certificates/mozilla/ISRG_Root_X1.crt{,.old} -$ sudo update-ca-certificates - -Then restart tailscaled. To also test dnsfallback's use of it, nuke -your /etc/resolv.conf and it should still start & run fine. -*/ -const letsEncryptX1 = ` ------BEGIN CERTIFICATE----- -MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw -TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh -cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 -WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu -ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY -MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc -h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ -0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U -A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW -T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH -B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC -B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv -KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn -OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn -jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw -qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI -rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV -HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq -hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL -ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ -3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK -NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 -ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur -TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC -jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc -oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq -4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA -mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d -emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= ------END CERTIFICATE----- -` - -var bakedInRootsOnce struct { - sync.Once - p *x509.CertPool -} - -func bakedInRoots() *x509.CertPool { - bakedInRootsOnce.Do(func() { - p := x509.NewCertPool() - if !p.AppendCertsFromPEM([]byte(letsEncryptX1)) { - panic("bogus PEM") - } - bakedInRootsOnce.p = p - }) - return bakedInRootsOnce.p -} diff --git a/net/tlsdial/tlsdial_test.go b/net/tlsdial/tlsdial_test.go index 26814ebbd8dc0..6723b82e0d1c9 100644 --- a/net/tlsdial/tlsdial_test.go +++ b/net/tlsdial/tlsdial_test.go @@ -4,37 +4,22 @@ package tlsdial import ( - "crypto/x509" "io" "net" "net/http" "os" "os/exec" "path/filepath" - "reflect" "runtime" "sync/atomic" "testing" "tailscale.com/health" + "tailscale.com/net/bakedroots" ) -func resetOnce() { - rv := reflect.ValueOf(&bakedInRootsOnce).Elem() - rv.Set(reflect.Zero(rv.Type())) -} - -func TestBakedInRoots(t *testing.T) { - resetOnce() - p := bakedInRoots() - got := p.Subjects() - if len(got) != 1 { - t.Errorf("subjects = %v; want 1", len(got)) - } -} - func TestFallbackRootWorks(t *testing.T) { - defer resetOnce() + defer bakedroots.ResetForTest(t, nil) const debug = false if runtime.GOOS != "linux" { @@ -69,14 +54,7 @@ func TestFallbackRootWorks(t *testing.T) { if err != nil { t.Fatal(err) } - resetOnce() - bakedInRootsOnce.Do(func() { - p := x509.NewCertPool() - if !p.AppendCertsFromPEM(caPEM) { - t.Fatal("failed to add") - } - bakedInRootsOnce.p = p - }) + bakedroots.ResetForTest(t, caPEM) ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { diff --git a/net/tsaddr/tsaddr.go b/net/tsaddr/tsaddr.go index 88069538724b6..06e6a26ddb721 100644 --- a/net/tsaddr/tsaddr.go +++ b/net/tsaddr/tsaddr.go @@ -66,15 +66,21 @@ const ( TailscaleServiceIPv6String = "fd7a:115c:a1e0::53" ) -// IsTailscaleIP reports whether ip is an IP address in a range that +// IsTailscaleIP reports whether IP is an IP address in a range that // Tailscale assigns from. func IsTailscaleIP(ip netip.Addr) bool { if ip.Is4() { - return CGNATRange().Contains(ip) && !ChromeOSVMRange().Contains(ip) + return IsTailscaleIPv4(ip) } return TailscaleULARange().Contains(ip) } +// IsTailscaleIPv4 reports whether an IPv4 IP is an IP address that +// Tailscale assigns from. +func IsTailscaleIPv4(ip netip.Addr) bool { + return CGNATRange().Contains(ip) && !ChromeOSVMRange().Contains(ip) +} + // TailscaleULARange returns the IPv6 Unique Local Address range that // is the superset range that Tailscale assigns out of. func TailscaleULARange() netip.Prefix { @@ -180,8 +186,7 @@ func PrefixIs6(p netip.Prefix) bool { return p.Addr().Is6() } // IPv6 /0 route. func ContainsExitRoutes(rr views.Slice[netip.Prefix]) bool { var v4, v6 bool - for i := range rr.Len() { - r := rr.At(i) + for _, r := range rr.All() { if r == allIPv4 { v4 = true } else if r == allIPv6 { @@ -194,8 +199,8 @@ func ContainsExitRoutes(rr views.Slice[netip.Prefix]) bool { // ContainsExitRoute reports whether rr contains at least one of IPv4 or // IPv6 /0 (exit) routes. func ContainsExitRoute(rr views.Slice[netip.Prefix]) bool { - for i := range rr.Len() { - if rr.At(i).Bits() == 0 { + for _, r := range rr.All() { + if r.Bits() == 0 { return true } } @@ -205,8 +210,8 @@ func ContainsExitRoute(rr views.Slice[netip.Prefix]) bool { // ContainsNonExitSubnetRoutes reports whether v contains Subnet // Routes other than ExitNode Routes. func ContainsNonExitSubnetRoutes(rr views.Slice[netip.Prefix]) bool { - for i := range rr.Len() { - if rr.At(i).Bits() != 0 { + for _, r := range rr.All() { + if r.Bits() != 0 { return true } } diff --git a/net/tsaddr/tsaddr_test.go b/net/tsaddr/tsaddr_test.go index 4aa2f8c60f5b3..9ac1ce3036299 100644 --- a/net/tsaddr/tsaddr_test.go +++ b/net/tsaddr/tsaddr_test.go @@ -222,3 +222,71 @@ func TestContainsExitRoute(t *testing.T) { } } } + +func TestIsTailscaleIPv4(t *testing.T) { + tests := []struct { + in netip.Addr + want bool + }{ + { + in: netip.MustParseAddr("100.67.19.57"), + want: true, + }, + { + in: netip.MustParseAddr("10.10.10.10"), + want: false, + }, + { + + in: netip.MustParseAddr("fd7a:115c:a1e0:3f2b:7a1d:4e88:9c2b:7f01"), + want: false, + }, + { + in: netip.MustParseAddr("bc9d:0aa0:1f0a:69ab:eb5c:28e0:5456:a518"), + want: false, + }, + { + in: netip.MustParseAddr("100.115.92.157"), + want: false, + }, + } + for _, tt := range tests { + if got := IsTailscaleIPv4(tt.in); got != tt.want { + t.Errorf("IsTailscaleIPv4(%v) = %v, want %v", tt.in, got, tt.want) + } + } +} + +func TestIsTailscaleIP(t *testing.T) { + tests := []struct { + in netip.Addr + want bool + }{ + { + in: netip.MustParseAddr("100.67.19.57"), + want: true, + }, + { + in: netip.MustParseAddr("10.10.10.10"), + want: false, + }, + { + + in: netip.MustParseAddr("fd7a:115c:a1e0:3f2b:7a1d:4e88:9c2b:7f01"), + want: true, + }, + { + in: netip.MustParseAddr("bc9d:0aa0:1f0a:69ab:eb5c:28e0:5456:a518"), + want: false, + }, + { + in: netip.MustParseAddr("100.115.92.157"), + want: false, + }, + } + for _, tt := range tests { + if got := IsTailscaleIP(tt.in); got != tt.want { + t.Errorf("IsTailscaleIP(%v) = %v, want %v", tt.in, got, tt.want) + } + } +} diff --git a/net/tsdial/dnsmap.go b/net/tsdial/dnsmap.go index f5d13861bb65f..2ef1cb1f171c0 100644 --- a/net/tsdial/dnsmap.go +++ b/net/tsdial/dnsmap.go @@ -42,8 +42,8 @@ func dnsMapFromNetworkMap(nm *netmap.NetworkMap) dnsMap { if dnsname.HasSuffix(nm.Name, suffix) { ret[canonMapKey(dnsname.TrimSuffix(nm.Name, suffix))] = ip } - for i := range addrs.Len() { - if addrs.At(i).Addr().Is4() { + for _, p := range addrs.All() { + if p.Addr().Is4() { have4 = true } } @@ -52,9 +52,8 @@ func dnsMapFromNetworkMap(nm *netmap.NetworkMap) dnsMap { if p.Name() == "" { continue } - for i := range p.Addresses().Len() { - a := p.Addresses().At(i) - ip := a.Addr() + for _, pfx := range p.Addresses().All() { + ip := pfx.Addr() if ip.Is4() && !have4 { continue } diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go index 3606dd67f7ea2..2492f666cf063 100644 --- a/net/tsdial/tsdial.go +++ b/net/tsdial/tsdial.go @@ -23,6 +23,7 @@ import ( "tailscale.com/net/netknob" "tailscale.com/net/netmon" "tailscale.com/net/netns" + "tailscale.com/net/netx" "tailscale.com/net/tsaddr" "tailscale.com/types/logger" "tailscale.com/types/netmap" @@ -71,6 +72,7 @@ type Dialer struct { netnsDialerOnce sync.Once netnsDialer netns.Dialer + sysDialForTest netx.DialFunc // or nil routes atomic.Pointer[bart.Table[bool]] // or nil if UserDial should not use routes. `true` indicates routes that point into the Tailscale interface @@ -149,6 +151,7 @@ func (d *Dialer) SetRoutes(routes, localRoutes []netip.Prefix) { for _, r := range localRoutes { rt.Insert(r, false) } + d.logf("tsdial: bart table size: %d", rt.Size()) } d.routes.Store(rt) @@ -242,7 +245,7 @@ func changeAffectsConn(delta *netmon.ChangeDelta, conn net.Conn) bool { // In a few cases, we don't have a new DefaultRouteInterface (e.g. on // Android; see tailscale/corp#19124); if so, pessimistically assume // that all connections are affected. - if delta.New.DefaultRouteInterface == "" { + if delta.New.DefaultRouteInterface == "" && runtime.GOOS != "plan9" { return true } @@ -361,6 +364,13 @@ func (d *Dialer) logf(format string, args ...any) { } } +// SetSystemDialerForTest sets an alternate function to use for SystemDial +// instead of netns.Dialer. This is intended for use with nettest.MemoryNetwork. +func (d *Dialer) SetSystemDialerForTest(fn netx.DialFunc) { + testenv.AssertInTest() + d.sysDialForTest = fn +} + // SystemDial connects to the provided network address without going over // Tailscale. It prefers going over the default interface and closes existing // connections if the default interface changes. It is used to connect to @@ -380,10 +390,16 @@ func (d *Dialer) SystemDial(ctx context.Context, network, addr string) (net.Conn return nil, net.ErrClosed } - d.netnsDialerOnce.Do(func() { - d.netnsDialer = netns.NewDialer(d.logf, d.netMon) - }) - c, err := d.netnsDialer.DialContext(ctx, network, addr) + var c net.Conn + var err error + if d.sysDialForTest != nil { + c, err = d.sysDialForTest(ctx, network, addr) + } else { + d.netnsDialerOnce.Do(func() { + d.netnsDialer = netns.NewDialer(d.logf, d.netMon) + }) + c, err = d.netnsDialer.DialContext(ctx, network, addr) + } if err != nil { return nil, err } diff --git a/net/tshttpproxy/tshttpproxy_synology.go b/net/tshttpproxy/tshttpproxy_synology.go index cda95764865d4..e28844f7dbf67 100644 --- a/net/tshttpproxy/tshttpproxy_synology.go +++ b/net/tshttpproxy/tshttpproxy_synology.go @@ -17,7 +17,7 @@ import ( "sync" "time" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" ) // These vars are overridden for tests. @@ -47,7 +47,7 @@ func synologyProxyFromConfigCached(req *http.Request) (*url.URL, error) { var err error modtime := mtime(synologyProxyConfigPath) - if modtime != cache.updated { + if !modtime.Equal(cache.updated) { cache.httpProxy, cache.httpsProxy, err = synologyProxiesFromConfig() cache.updated = modtime } @@ -76,21 +76,22 @@ func synologyProxiesFromConfig() (*url.URL, *url.URL, error) { func parseSynologyConfig(r io.Reader) (*url.URL, *url.URL, error) { cfg := map[string]string{} - if err := lineread.Reader(r, func(line []byte) error { + for lr := range lineiter.Reader(r) { + line, err := lr.Value() + if err != nil { + return nil, nil, err + } // accept and skip over empty lines line = bytes.TrimSpace(line) if len(line) == 0 { - return nil + continue } key, value, ok := strings.Cut(string(line), "=") if !ok { - return fmt.Errorf("missing \"=\" in proxy.conf line: %q", line) + return nil, nil, fmt.Errorf("missing \"=\" in proxy.conf line: %q", line) } cfg[string(key)] = string(value) - return nil - }); err != nil { - return nil, nil, err } if cfg["proxy_enabled"] != "yes" { diff --git a/net/tshttpproxy/tshttpproxy_synology_test.go b/net/tshttpproxy/tshttpproxy_synology_test.go index 3061740f3beff..b6e8b948c3ae9 100644 --- a/net/tshttpproxy/tshttpproxy_synology_test.go +++ b/net/tshttpproxy/tshttpproxy_synology_test.go @@ -41,7 +41,7 @@ func TestSynologyProxyFromConfigCached(t *testing.T) { t.Fatalf("got %s, %v; want nil, nil", val, err) } - if got, want := cache.updated, time.Unix(0, 0); got != want { + if got, want := cache.updated.UTC(), time.Unix(0, 0).UTC(); !got.Equal(want) { t.Fatalf("got %s, want %s", got, want) } if cache.httpProxy != nil { diff --git a/net/tstun/linkattrs_linux.go b/net/tstun/linkattrs_linux.go index 681e79269f75f..320385ba694dc 100644 --- a/net/tstun/linkattrs_linux.go +++ b/net/tstun/linkattrs_linux.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !android + package tstun import ( diff --git a/net/tstun/linkattrs_notlinux.go b/net/tstun/linkattrs_notlinux.go index 7a7b40fc2652b..77d227934083e 100644 --- a/net/tstun/linkattrs_notlinux.go +++ b/net/tstun/linkattrs_notlinux.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !linux +//go:build !linux || android package tstun diff --git a/net/tstun/tap_unsupported.go b/net/tstun/tap_unsupported.go deleted file mode 100644 index 6792b229f6b79..0000000000000 --- a/net/tstun/tap_unsupported.go +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !linux || ts_omit_tap - -package tstun - -func (*Wrapper) handleTAPFrame([]byte) bool { panic("unreachable") } diff --git a/net/tstun/tstun_stub.go b/net/tstun/tstun_stub.go index 7a4f71a099fd5..d21eda6b07a57 100644 --- a/net/tstun/tstun_stub.go +++ b/net/tstun/tstun_stub.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build plan9 || aix +//go:build aix || solaris || illumos package tstun diff --git a/net/tstun/tun.go b/net/tstun/tun.go index 66e209d1acb5a..88679daa24b6c 100644 --- a/net/tstun/tun.go +++ b/net/tstun/tun.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !wasm && !plan9 && !tamago && !aix +//go:build !wasm && !tamago && !aix && !solaris && !illumos // Package tun creates a tuntap device, working around OS-specific // quirks if necessary. @@ -9,16 +9,20 @@ package tstun import ( "errors" + "fmt" + "log" + "os" "runtime" "strings" "time" "github.com/tailscale/wireguard-go/tun" + "tailscale.com/feature" "tailscale.com/types/logger" ) -// createTAP is non-nil on Linux. -var createTAP func(tapName, bridgeName string) (tun.Device, error) +// CrateTAP is the hook set by feature/tap. +var CreateTAP feature.Hook[func(logf logger.Logf, tapName, bridgeName string) (tun.Device, error)] // New returns a tun.Device for the requested device name, along with // the OS-dependent name that was allocated to the device. @@ -29,7 +33,7 @@ func New(logf logger.Logf, tunName string) (tun.Device, string, error) { if runtime.GOOS != "linux" { return nil, "", errors.New("tap only works on Linux") } - if createTAP == nil { // if the ts_omit_tap tag is used + if !CreateTAP.IsSet() { // if the ts_omit_tap tag is used return nil, "", errors.New("tap is not supported in this build") } f := strings.Split(tunName, ":") @@ -42,8 +46,11 @@ func New(logf logger.Logf, tunName string) (tun.Device, string, error) { default: return nil, "", errors.New("bogus tap argument") } - dev, err = createTAP(tapName, bridgeName) + dev, err = CreateTAP.Get()(logf, tapName, bridgeName) } else { + if runtime.GOOS == "plan9" { + cleanUpPlan9Interfaces() + } dev, err = tun.CreateTUN(tunName, int(DefaultTUNMTU())) } if err != nil { @@ -64,6 +71,36 @@ func New(logf logger.Logf, tunName string) (tun.Device, string, error) { return dev, name, nil } +func cleanUpPlan9Interfaces() { + maybeUnbind := func(n int) { + b, err := os.ReadFile(fmt.Sprintf("/net/ipifc/%d/status", n)) + if err != nil { + return + } + status := string(b) + if !(strings.HasPrefix(status, "device maxtu ") || + strings.Contains(status, "fd7a:115c:a1e0:")) { + return + } + f, err := os.OpenFile(fmt.Sprintf("/net/ipifc/%d/ctl", n), os.O_RDWR, 0) + if err != nil { + return + } + defer f.Close() + if _, err := fmt.Fprintf(f, "unbind\n"); err != nil { + log.Printf("unbind interface %v: %v", n, err) + return + } + log.Printf("tun: unbound stale interface %v", n) + } + + // A common case: after unclean shutdown we might leave interfaces + // behind. Look for our straggler(s) and clean them up. + for n := 2; n < 5; n++ { + maybeUnbind(n) + } +} + // tunDiagnoseFailure, if non-nil, does OS-specific diagnostics of why // TUN failed to work. var tunDiagnoseFailure func(tunName string, logf logger.Logf, err error) diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index dcd43d5718ca8..442184065aa92 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -36,7 +36,6 @@ import ( "tailscale.com/types/logger" "tailscale.com/util/clientmetric" "tailscale.com/util/usermetric" - "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/netstack/gro" "tailscale.com/wgengine/wgcfg" @@ -53,7 +52,8 @@ const PacketStartOffset = device.MessageTransportHeaderSize // of a packet that can be injected into a tstun.Wrapper. const MaxPacketSize = device.MaxContentSize -const tapDebug = false // for super verbose TAP debugging +// TAPDebug is whether super verbose TAP debugging is enabled. +const TAPDebug = false var ( // ErrClosed is returned when attempting an operation on a closed Wrapper. @@ -109,9 +109,7 @@ type Wrapper struct { lastActivityAtomic mono.Time // time of last send or receive destIPActivity syncs.AtomicValue[map[netip.Addr]func()] - //lint:ignore U1000 used in tap_linux.go - destMACAtomic syncs.AtomicValue[[6]byte] - discoKey syncs.AtomicValue[key.DiscoPublic] + discoKey syncs.AtomicValue[key.DiscoPublic] // timeNow, if non-nil, will be used to obtain the current time. timeNow func() time.Time @@ -209,30 +207,20 @@ type Wrapper struct { // stats maintains per-connection counters. stats atomic.Pointer[connstats.Statistics] - captureHook syncs.AtomicValue[capture.Callback] + captureHook syncs.AtomicValue[packet.CaptureCallback] metrics *metrics } type metrics struct { - inboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[dropPacketLabel] - outboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[dropPacketLabel] + inboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[usermetric.DropLabels] + outboundDroppedPacketsTotal *tsmetrics.MultiLabelMap[usermetric.DropLabels] } func registerMetrics(reg *usermetric.Registry) *metrics { return &metrics{ - inboundDroppedPacketsTotal: usermetric.NewMultiLabelMapWithRegistry[dropPacketLabel]( - reg, - "tailscaled_inbound_dropped_packets_total", - "counter", - "Counts the number of dropped packets received by the node from other peers", - ), - outboundDroppedPacketsTotal: usermetric.NewMultiLabelMapWithRegistry[dropPacketLabel]( - reg, - "tailscaled_outbound_dropped_packets_total", - "counter", - "Counts the number of packets dropped while being sent to other peers", - ), + inboundDroppedPacketsTotal: reg.DroppedPacketsInbound(), + outboundDroppedPacketsTotal: reg.DroppedPacketsOutbound(), } } @@ -257,12 +245,6 @@ type tunVectorReadResult struct { dataOffset int } -type setWrapperer interface { - // setWrapper enables the underlying TUN/TAP to have access to the Wrapper. - // It MUST be called only once during initialization, other usage is unsafe. - setWrapper(*Wrapper) -} - // Start unblocks any Wrapper.Read calls that have already started // and makes the Wrapper functional. // @@ -313,10 +295,6 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry) w.bufferConsumed <- struct{}{} w.noteActivity() - if sw, ok := w.tdev.(setWrapperer); ok { - sw.setWrapper(w) - } - return w } @@ -459,12 +437,18 @@ const ethernetFrameSize = 14 // 2 six byte MACs, 2 bytes ethertype func (t *Wrapper) pollVector() { sizes := make([]int, len(t.vectorBuffer)) readOffset := PacketStartOffset + reader := t.tdev.Read if t.isTAP { - readOffset = PacketStartOffset - ethernetFrameSize + type tapReader interface { + ReadEthernet(buffs [][]byte, sizes []int, offset int) (int, error) + } + if r, ok := t.tdev.(tapReader); ok { + readOffset = PacketStartOffset - ethernetFrameSize + reader = r.ReadEthernet + } } for range t.bufferConsumed { - DoRead: for i := range t.vectorBuffer { t.vectorBuffer[i] = t.vectorBuffer[i][:cap(t.vectorBuffer[i])] } @@ -474,8 +458,8 @@ func (t *Wrapper) pollVector() { if t.isClosed() { return } - n, err = t.tdev.Read(t.vectorBuffer[:], sizes, readOffset) - if t.isTAP && tapDebug { + n, err = reader(t.vectorBuffer[:], sizes, readOffset) + if t.isTAP && TAPDebug { s := fmt.Sprintf("% x", t.vectorBuffer[0][:]) for strings.HasSuffix(s, " 00") { s = strings.TrimSuffix(s, " 00") @@ -486,21 +470,6 @@ func (t *Wrapper) pollVector() { for i := range sizes[:n] { t.vectorBuffer[i] = t.vectorBuffer[i][:readOffset+sizes[i]] } - if t.isTAP { - if err == nil { - ethernetFrame := t.vectorBuffer[0][readOffset:] - if t.handleTAPFrame(ethernetFrame) { - goto DoRead - } - } - // Fall through. We got an IP packet. - if sizes[0] >= ethernetFrameSize { - t.vectorBuffer[0] = t.vectorBuffer[0][:readOffset+sizes[0]-ethernetFrameSize] - } - if tapDebug { - t.logf("tap regular frame: %x", t.vectorBuffer[0][PacketStartOffset:PacketStartOffset+sizes[0]]) - } - } t.sendVectorOutbound(tunVectorReadResult{ data: t.vectorBuffer[:n], dataOffset: PacketStartOffset, @@ -823,10 +792,21 @@ func (pc *peerConfigTable) outboundPacketIsJailed(p *packet.Parsed) bool { return c.jailed } +// SetIPer is the interface expected to be implemented by the TAP implementation +// of tun.Device. +type SetIPer interface { + // SetIP sets the IP addresses of the TAP device. + SetIP(ipV4, ipV6 netip.Addr) error +} + // SetWGConfig is called when a new NetworkMap is received. func (t *Wrapper) SetWGConfig(wcfg *wgcfg.Config) { + if t.isTAP { + if sip, ok := t.tdev.(SetIPer); ok { + sip.SetIP(findV4(wcfg.Addresses), findV6(wcfg.Addresses)) + } + } cfg := peerConfigTableFromWGConfig(wcfg) - old := t.peerConfig.Swap(cfg) if !reflect.DeepEqual(old, cfg) { t.logf("peer config: %v", cfg) @@ -896,11 +876,13 @@ func (t *Wrapper) filterPacketOutboundToWireGuard(p *packet.Parsed, pc *peerConf return filter.Drop, gro } - if filt.RunOut(p, t.filterFlags) != filter.Accept { + if resp, reason := filt.RunOut(p, t.filterFlags); resp != filter.Accept { metricPacketOutDropFilter.Add(1) - t.metrics.outboundDroppedPacketsTotal.Add(dropPacketLabel{ - Reason: DropReasonACL, - }, 1) + if reason != "" { + t.metrics.outboundDroppedPacketsTotal.Add(usermetric.DropLabels{ + Reason: reason, + }, 1) + } return filter.Drop, gro } @@ -925,9 +907,23 @@ func (t *Wrapper) IdleDuration() time.Duration { return mono.Since(t.lastActivityAtomic.LoadAtomic()) } +func (t *Wrapper) awaitStart() { + for { + select { + case <-t.startCh: + return + case <-time.After(1 * time.Second): + // Multiple times while remixing tailscaled I (Brad) have forgotten + // to call Start and then wasted far too much time debugging. + // I do not wish that debugging on anyone else. Hopefully this'll help: + t.logf("tstun: awaiting Wrapper.Start call") + } + } +} + func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { if !t.started.Load() { - <-t.startCh + t.awaitStart() } // packet from OS read and sent to WG res, ok := <-t.vectorOutbound @@ -958,7 +954,7 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { } } if captHook != nil { - captHook(capture.FromLocal, t.now(), p.Buffer(), p.CaptureMeta) + captHook(packet.FromLocal, t.now(), p.Buffer(), p.CaptureMeta) } if !t.disableFilter { var response filter.Response @@ -1104,9 +1100,9 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i return n, err } -func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook capture.Callback, pc *peerConfigTable, gro *gro.GRO) (filter.Response, *gro.GRO) { +func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook packet.CaptureCallback, pc *peerConfigTable, gro *gro.GRO) (filter.Response, *gro.GRO) { if captHook != nil { - captHook(capture.FromPeer, t.now(), p.Buffer(), p.CaptureMeta) + captHook(packet.FromPeer, t.now(), p.Buffer(), p.CaptureMeta) } if p.IPProto == ipproto.TSMP { @@ -1170,8 +1166,8 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook ca if outcome != filter.Accept { metricPacketInDropFilter.Add(1) - t.metrics.inboundDroppedPacketsTotal.Add(dropPacketLabel{ - Reason: DropReasonACL, + t.metrics.inboundDroppedPacketsTotal.Add(usermetric.DropLabels{ + Reason: usermetric.ReasonACL, }, 1) // Tell them, via TSMP, we're dropping them due to the ACL. @@ -1251,8 +1247,8 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) { t.noteActivity() _, err := t.tdevWrite(buffs, offset) if err != nil { - t.metrics.inboundDroppedPacketsTotal.Add(dropPacketLabel{ - Reason: DropReasonError, + t.metrics.inboundDroppedPacketsTotal.Add(usermetric.DropLabels{ + Reason: usermetric.ReasonError, }, int64(len(buffs))) } return len(buffs), err @@ -1320,7 +1316,7 @@ func (t *Wrapper) InjectInboundPacketBuffer(pkt *stack.PacketBuffer, buffs [][]b p.Decode(buf) captHook := t.captureHook.Load() if captHook != nil { - captHook(capture.SynthesizedToLocal, t.now(), p.Buffer(), p.CaptureMeta) + captHook(packet.SynthesizedToLocal, t.now(), p.Buffer(), p.CaptureMeta) } invertGSOChecksum(buf, pkt.GSOOptions) @@ -1452,7 +1448,7 @@ func (t *Wrapper) InjectOutboundPacketBuffer(pkt *stack.PacketBuffer) error { } if capt := t.captureHook.Load(); capt != nil { b := pkt.ToBuffer() - capt(capture.SynthesizedToPeer, t.now(), b.Flatten(), packet.CaptureMeta{}) + capt(packet.SynthesizedToPeer, t.now(), b.Flatten(), packet.CaptureMeta{}) } t.injectOutbound(tunInjectedRead{packet: pkt}) @@ -1494,20 +1490,6 @@ var ( metricPacketOutDropSelfDisco = clientmetric.NewCounter("tstun_out_to_wg_drop_self_disco") ) -type DropReason string - -const ( - DropReasonACL DropReason = "acl" - DropReasonError DropReason = "error" -) - -type dropPacketLabel struct { - // Reason indicates what we have done with the packet, and has the following values: - // - acl (rejected packets because of ACL) - // - error (rejected packets because of an error) - Reason DropReason -} - -func (t *Wrapper) InstallCaptureHook(cb capture.Callback) { +func (t *Wrapper) InstallCaptureHook(cb packet.CaptureCallback) { t.captureHook.Store(cb) } diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index 0ed0075b616ee..223ee34f4336a 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -40,7 +40,6 @@ import ( "tailscale.com/types/views" "tailscale.com/util/must" "tailscale.com/util/usermetric" - "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/wgcfg" ) @@ -441,19 +440,19 @@ func TestFilter(t *testing.T) { } var metricInboundDroppedPacketsACL, metricInboundDroppedPacketsErr, metricOutboundDroppedPacketsACL int64 - if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(dropPacketLabel{Reason: DropReasonACL}).(*expvar.Int); ok { + if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(usermetric.DropLabels{Reason: usermetric.ReasonACL}).(*expvar.Int); ok { metricInboundDroppedPacketsACL = m.Value() } - if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(dropPacketLabel{Reason: DropReasonError}).(*expvar.Int); ok { + if m, ok := tun.metrics.inboundDroppedPacketsTotal.Get(usermetric.DropLabels{Reason: usermetric.ReasonError}).(*expvar.Int); ok { metricInboundDroppedPacketsErr = m.Value() } - if m, ok := tun.metrics.outboundDroppedPacketsTotal.Get(dropPacketLabel{Reason: DropReasonACL}).(*expvar.Int); ok { + if m, ok := tun.metrics.outboundDroppedPacketsTotal.Get(usermetric.DropLabels{Reason: usermetric.ReasonACL}).(*expvar.Int); ok { metricOutboundDroppedPacketsACL = m.Value() } assertMetricPackets(t, "inACL", 3, metricInboundDroppedPacketsACL) assertMetricPackets(t, "inError", 0, metricInboundDroppedPacketsErr) - assertMetricPackets(t, "outACL", 1, metricOutboundDroppedPacketsACL) + assertMetricPackets(t, "outACL", 0, metricOutboundDroppedPacketsACL) } func assertMetricPackets(t *testing.T, metricName string, want, got int64) { @@ -871,14 +870,14 @@ func TestPeerCfg_NAT(t *testing.T) { // with the correct parameters when various packet operations are performed. func TestCaptureHook(t *testing.T) { type captureRecord struct { - path capture.Path + path packet.CapturePath now time.Time pkt []byte meta packet.CaptureMeta } var captured []captureRecord - hook := func(path capture.Path, now time.Time, pkt []byte, meta packet.CaptureMeta) { + hook := func(path packet.CapturePath, now time.Time, pkt []byte, meta packet.CaptureMeta) { captured = append(captured, captureRecord{ path: path, now: now, @@ -935,19 +934,19 @@ func TestCaptureHook(t *testing.T) { // Assert that the right packets are captured. want := []captureRecord{ { - path: capture.FromPeer, + path: packet.FromPeer, pkt: []byte("Write1"), }, { - path: capture.FromPeer, + path: packet.FromPeer, pkt: []byte("Write2"), }, { - path: capture.SynthesizedToLocal, + path: packet.SynthesizedToLocal, pkt: []byte("InjectInboundPacketBuffer"), }, { - path: capture.SynthesizedToPeer, + path: packet.SynthesizedToPeer, pkt: []byte("InjectOutboundPacketBuffer"), }, } diff --git a/net/udprelay/endpoint/endpoint.go b/net/udprelay/endpoint/endpoint.go new file mode 100644 index 0000000000000..2672a856b797b --- /dev/null +++ b/net/udprelay/endpoint/endpoint.go @@ -0,0 +1,55 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package endpoint contains types relating to UDP relay server endpoints. It +// does not import tailscale.com/net/udprelay. +package endpoint + +import ( + "net/netip" + + "tailscale.com/tstime" + "tailscale.com/types/key" +) + +// ServerEndpoint contains details for an endpoint served by a +// [tailscale.com/net/udprelay.Server]. +type ServerEndpoint struct { + // ServerDisco is the Server's Disco public key used as part of the 3-way + // bind handshake. Server will use the same ServerDisco for its lifetime. + // ServerDisco value in combination with LamportID value represents a + // unique ServerEndpoint allocation. + ServerDisco key.DiscoPublic + + // LamportID is unique and monotonically non-decreasing across + // ServerEndpoint allocations for the lifetime of Server. It enables clients + // to dedup and resolve allocation event order. Clients may race to allocate + // on the same Server, and signal ServerEndpoint details via alternative + // channels, e.g. DERP. Additionally, Server.AllocateEndpoint() requests may + // not result in a new allocation depending on existing server-side endpoint + // state. Therefore, where clients have local, existing state that contains + // ServerDisco and LamportID values matching a newly learned endpoint, these + // can be considered one and the same. If ServerDisco is equal, but + // LamportID is unequal, LamportID comparison determines which + // ServerEndpoint was allocated most recently. + LamportID uint64 + + // AddrPorts are the IP:Port candidate pairs the Server may be reachable + // over. + AddrPorts []netip.AddrPort + + // VNI (Virtual Network Identifier) is the Geneve header VNI the Server + // will use for transmitted packets, and expects for received packets + // associated with this endpoint. + VNI uint32 + + // BindLifetime is amount of time post-allocation the Server will consider + // the endpoint active while it has yet to be bound via 3-way bind handshake + // from both client parties. + BindLifetime tstime.GoDuration + + // SteadyStateLifetime is the amount of time post 3-way bind handshake from + // both client parties the Server will consider the endpoint active lacking + // bidirectional data flow. + SteadyStateLifetime tstime.GoDuration +} diff --git a/net/udprelay/endpoint/endpoint_test.go b/net/udprelay/endpoint/endpoint_test.go new file mode 100644 index 0000000000000..f12a6e2f62240 --- /dev/null +++ b/net/udprelay/endpoint/endpoint_test.go @@ -0,0 +1,110 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package endpoint + +import ( + "encoding/json" + "math" + "net/netip" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/tstime" + "tailscale.com/types/key" +) + +func TestServerEndpointJSONUnmarshal(t *testing.T) { + tests := []struct { + name string + json []byte + wantErr bool + }{ + { + name: "valid", + json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), + wantErr: false, + }, + { + name: "invalid ServerDisco", + json: []byte(`{"ServerDisco":"1","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), + wantErr: true, + }, + { + name: "invalid LamportID", + json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":1.1,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), + wantErr: true, + }, + { + name: "invalid AddrPorts", + json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), + wantErr: true, + }, + { + name: "invalid VNI", + json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":18446744073709551615,"BindLifetime":"30s","SteadyStateLifetime":"5m0s"}`), + wantErr: true, + }, + { + name: "invalid BindLifetime", + json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"5","SteadyStateLifetime":"5m0s"}`), + wantErr: true, + }, + { + name: "invalid SteadyStateLifetime", + json: []byte(`{"ServerDisco":"discokey:003cd7453e04a653eb0e7a18f206fc353180efadb2facfd05ebd6982a1392c7f","LamportID":18446744073709551615,"AddrPorts":["127.0.0.1:1","127.0.0.2:2"],"VNI":16777215,"BindLifetime":"30s","SteadyStateLifetime":"5"}`), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var out ServerEndpoint + err := json.Unmarshal(tt.json, &out) + if tt.wantErr != (err != nil) { + t.Fatalf("wantErr: %v (err == nil): %v", tt.wantErr, err == nil) + } + if tt.wantErr { + return + } + }) + } +} + +func TestServerEndpointJSONMarshal(t *testing.T) { + tests := []struct { + name string + serverEndpoint ServerEndpoint + }{ + { + name: "valid roundtrip", + serverEndpoint: ServerEndpoint{ + ServerDisco: key.NewDisco().Public(), + LamportID: uint64(math.MaxUint64), + AddrPorts: []netip.AddrPort{netip.MustParseAddrPort("127.0.0.1:1"), netip.MustParseAddrPort("127.0.0.2:2")}, + VNI: 1<<24 - 1, + BindLifetime: tstime.GoDuration{Duration: time.Second * 30}, + SteadyStateLifetime: tstime.GoDuration{Duration: time.Minute * 5}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b, err := json.Marshal(&tt.serverEndpoint) + if err != nil { + t.Fatal(err) + } + var got ServerEndpoint + err = json.Unmarshal(b, &got) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(got, tt.serverEndpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" { + t.Fatalf("ServerEndpoint unequal (-got +want)\n%s", diff) + } + }) + } +} diff --git a/net/udprelay/server.go b/net/udprelay/server.go new file mode 100644 index 0000000000000..7b63ec95e77c4 --- /dev/null +++ b/net/udprelay/server.go @@ -0,0 +1,478 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package udprelay contains constructs for relaying Disco and WireGuard packets +// between Tailscale clients over UDP. This package is currently considered +// experimental. +package udprelay + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "net" + "net/netip" + "slices" + "strconv" + "sync" + "time" + + "go4.org/mem" + "tailscale.com/disco" + "tailscale.com/net/packet" + "tailscale.com/net/udprelay/endpoint" + "tailscale.com/tstime" + "tailscale.com/types/key" +) + +const ( + // defaultBindLifetime is somewhat arbitrary. We attempt to account for + // high latency between client and [Server], and high latency between + // clients over side channels, e.g. DERP, used to exchange + // [endpoint.ServerEndpoint] details. So, a total of 3 paths with + // potentially high latency. Using a conservative 10s "high latency" bounds + // for each path we end up at a 30s total. It is worse to set an aggressive + // bind lifetime as this may lead to path discovery failure, vs dealing with + // a slight increase of [Server] resource utilization (VNIs, RAM, etc) while + // tracking endpoints that won't bind. + defaultBindLifetime = time.Second * 30 + defaultSteadyStateLifetime = time.Minute * 5 +) + +// Server implements an experimental UDP relay server. +type Server struct { + // disco keypair used as part of 3-way bind handshake + disco key.DiscoPrivate + discoPublic key.DiscoPublic + + bindLifetime time.Duration + steadyStateLifetime time.Duration + + // addrPorts contains the ip:port pairs returned as candidate server + // endpoints in response to an allocation request. + addrPorts []netip.AddrPort + + uc *net.UDPConn + + closeOnce sync.Once + wg sync.WaitGroup + closeCh chan struct{} + closed bool + + mu sync.Mutex // guards the following fields + lamportID uint64 + vniPool []uint32 // the pool of available VNIs + byVNI map[uint32]*serverEndpoint + byDisco map[pairOfDiscoPubKeys]*serverEndpoint +} + +// pairOfDiscoPubKeys is a pair of key.DiscoPublic. It must be constructed via +// newPairOfDiscoPubKeys to ensure lexicographical ordering. +type pairOfDiscoPubKeys [2]key.DiscoPublic + +func (p pairOfDiscoPubKeys) String() string { + return fmt.Sprintf("%s <=> %s", p[0].ShortString(), p[1].ShortString()) +} + +func newPairOfDiscoPubKeys(discoA, discoB key.DiscoPublic) pairOfDiscoPubKeys { + pair := pairOfDiscoPubKeys([2]key.DiscoPublic{discoA, discoB}) + slices.SortFunc(pair[:], func(a, b key.DiscoPublic) int { + return a.Compare(b) + }) + return pair +} + +// serverEndpoint contains Server-internal [endpoint.ServerEndpoint] state. +// serverEndpoint methods are not thread-safe. +type serverEndpoint struct { + // discoPubKeys contains the key.DiscoPublic of the served clients. The + // indexing of this array aligns with the following fields, e.g. + // discoSharedSecrets[0] is the shared secret to use when sealing + // Disco protocol messages for transmission towards discoPubKeys[0]. + discoPubKeys pairOfDiscoPubKeys + discoSharedSecrets [2]key.DiscoShared + handshakeState [2]disco.BindUDPRelayHandshakeState + addrPorts [2]netip.AddrPort + lastSeen [2]time.Time // TODO(jwhited): consider using mono.Time + challenge [2][disco.BindUDPRelayEndpointChallengeLen]byte + + lamportID uint64 + vni uint32 + allocatedAt time.Time +} + +func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, uw udpWriter, serverDisco key.DiscoPublic) { + if senderIndex != 0 && senderIndex != 1 { + return + } + handshakeState := e.handshakeState[senderIndex] + if handshakeState == disco.BindUDPRelayHandshakeStateAnswerReceived { + // this sender is already bound + return + } + switch discoMsg := discoMsg.(type) { + case *disco.BindUDPRelayEndpoint: + switch handshakeState { + case disco.BindUDPRelayHandshakeStateInit: + // set sender addr + e.addrPorts[senderIndex] = from + fallthrough + case disco.BindUDPRelayHandshakeStateChallengeSent: + if from != e.addrPorts[senderIndex] { + // this is a later arriving bind from a different source, or + // a retransmit and the sender's source has changed, discard + return + } + m := new(disco.BindUDPRelayEndpointChallenge) + copy(m.Challenge[:], e.challenge[senderIndex][:]) + reply := make([]byte, packet.GeneveFixedHeaderLength, 512) + gh := packet.GeneveHeader{Control: true, VNI: e.vni, Protocol: packet.GeneveProtocolDisco} + err := gh.Encode(reply) + if err != nil { + return + } + reply = append(reply, disco.Magic...) + reply = serverDisco.AppendTo(reply) + box := e.discoSharedSecrets[senderIndex].Seal(m.AppendMarshal(nil)) + reply = append(reply, box...) + uw.WriteMsgUDPAddrPort(reply, nil, from) + // set new state + e.handshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateChallengeSent + return + default: + // disco.BindUDPRelayEndpoint is unexpected in all other handshake states + return + } + case *disco.BindUDPRelayEndpointAnswer: + switch handshakeState { + case disco.BindUDPRelayHandshakeStateChallengeSent: + if from != e.addrPorts[senderIndex] { + // sender source has changed + return + } + if !bytes.Equal(discoMsg.Answer[:], e.challenge[senderIndex][:]) { + // bad answer + return + } + // sender is now bound + // TODO: Consider installing a fast path via netfilter or similar to + // relay (NAT) data packets for this serverEndpoint. + e.handshakeState[senderIndex] = disco.BindUDPRelayHandshakeStateAnswerReceived + // record last seen as bound time + e.lastSeen[senderIndex] = time.Now() + return + default: + // disco.BindUDPRelayEndpointAnswer is unexpected in all other handshake + // states, or we've already handled it + return + } + default: + // unexpected Disco message type + return + } +} + +func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, uw udpWriter, serverDisco key.DiscoPublic) { + senderRaw, isDiscoMsg := disco.Source(b) + if !isDiscoMsg { + // Not a Disco message + return + } + sender := key.DiscoPublicFromRaw32(mem.B(senderRaw)) + senderIndex := -1 + switch { + case sender.Compare(e.discoPubKeys[0]) == 0: + senderIndex = 0 + case sender.Compare(e.discoPubKeys[1]) == 0: + senderIndex = 1 + default: + // unknown Disco public key + return + } + + const headerLen = len(disco.Magic) + key.DiscoPublicRawLen + discoPayload, ok := e.discoSharedSecrets[senderIndex].Open(b[headerLen:]) + if !ok { + // unable to decrypt the Disco payload + return + } + + discoMsg, err := disco.Parse(discoPayload) + if err != nil { + // unable to parse the Disco payload + return + } + + e.handleDiscoControlMsg(from, senderIndex, discoMsg, uw, serverDisco) +} + +type udpWriter interface { + WriteMsgUDPAddrPort(b []byte, oob []byte, addr netip.AddrPort) (n, oobn int, err error) +} + +func (e *serverEndpoint) handlePacket(from netip.AddrPort, gh packet.GeneveHeader, b []byte, uw udpWriter, serverDisco key.DiscoPublic) { + if !gh.Control { + if !e.isBound() { + // not a control packet, but serverEndpoint isn't bound + return + } + var to netip.AddrPort + switch { + case from == e.addrPorts[0]: + e.lastSeen[0] = time.Now() + to = e.addrPorts[1] + case from == e.addrPorts[1]: + e.lastSeen[1] = time.Now() + to = e.addrPorts[0] + default: + // unrecognized source + return + } + // relay packet + uw.WriteMsgUDPAddrPort(b, nil, to) + return + } + + if e.isBound() { + // control packet, but serverEndpoint is already bound + return + } + + if gh.Protocol != packet.GeneveProtocolDisco { + // control packet, but not Disco + return + } + + msg := b[packet.GeneveFixedHeaderLength:] + e.handleSealedDiscoControlMsg(from, msg, uw, serverDisco) +} + +func (e *serverEndpoint) isExpired(now time.Time, bindLifetime, steadyStateLifetime time.Duration) bool { + if !e.isBound() { + if now.Sub(e.allocatedAt) > bindLifetime { + return true + } + return false + } + if now.Sub(e.lastSeen[0]) > steadyStateLifetime || now.Sub(e.lastSeen[1]) > steadyStateLifetime { + return true + } + return false +} + +// isBound returns true if both clients have completed their 3-way handshake, +// otherwise false. +func (e *serverEndpoint) isBound() bool { + return e.handshakeState[0] == disco.BindUDPRelayHandshakeStateAnswerReceived && + e.handshakeState[1] == disco.BindUDPRelayHandshakeStateAnswerReceived +} + +// NewServer constructs a [Server] listening on 0.0.0.0:'port'. IPv6 is not yet +// supported. Port may be 0, and what ultimately gets bound is returned as +// 'boundPort'. Supplied 'addrs' are joined with 'boundPort' and returned as +// [endpoint.ServerEndpoint.AddrPorts] in response to Server.AllocateEndpoint() +// requests. +// +// TODO: IPv6 support +// TODO: dynamic addrs:port discovery +func NewServer(port int, addrs []netip.Addr) (s *Server, boundPort int, err error) { + s = &Server{ + disco: key.NewDisco(), + bindLifetime: defaultBindLifetime, + steadyStateLifetime: defaultSteadyStateLifetime, + closeCh: make(chan struct{}), + byDisco: make(map[pairOfDiscoPubKeys]*serverEndpoint), + byVNI: make(map[uint32]*serverEndpoint), + } + s.discoPublic = s.disco.Public() + // TODO: instead of allocating 10s of MBs for the full pool, allocate + // smaller chunks and increase as needed + s.vniPool = make([]uint32, 0, 1<<24-1) + for i := 1; i < 1<<24; i++ { + s.vniPool = append(s.vniPool, uint32(i)) + } + boundPort, err = s.listenOn(port) + if err != nil { + return nil, 0, err + } + addrPorts := make([]netip.AddrPort, 0, len(addrs)) + for _, addr := range addrs { + addrPort, err := netip.ParseAddrPort(net.JoinHostPort(addr.String(), strconv.Itoa(boundPort))) + if err != nil { + return nil, 0, err + } + addrPorts = append(addrPorts, addrPort) + } + s.addrPorts = addrPorts + s.wg.Add(2) + go s.packetReadLoop() + go s.endpointGCLoop() + return s, boundPort, nil +} + +func (s *Server) listenOn(port int) (int, error) { + uc, err := net.ListenUDP("udp4", &net.UDPAddr{Port: port}) + if err != nil { + return 0, err + } + // TODO: set IP_PKTINFO sockopt + _, boundPortStr, err := net.SplitHostPort(uc.LocalAddr().String()) + if err != nil { + s.uc.Close() + return 0, err + } + boundPort, err := strconv.Atoi(boundPortStr) + if err != nil { + s.uc.Close() + return 0, err + } + s.uc = uc + return boundPort, nil +} + +// Close closes the server. +func (s *Server) Close() error { + s.closeOnce.Do(func() { + s.mu.Lock() + defer s.mu.Unlock() + s.uc.Close() + close(s.closeCh) + s.wg.Wait() + clear(s.byVNI) + clear(s.byDisco) + s.vniPool = nil + s.closed = true + }) + return nil +} + +func (s *Server) endpointGCLoop() { + defer s.wg.Done() + ticker := time.NewTicker(s.bindLifetime) + defer ticker.Stop() + + gc := func() { + now := time.Now() + // TODO: consider performance implications of scanning all endpoints and + // holding s.mu for the duration. Keep it simple (and slow) for now. + s.mu.Lock() + defer s.mu.Unlock() + for k, v := range s.byDisco { + if v.isExpired(now, s.bindLifetime, s.steadyStateLifetime) { + delete(s.byDisco, k) + delete(s.byVNI, v.vni) + s.vniPool = append(s.vniPool, v.vni) + } + } + } + + for { + select { + case <-ticker.C: + gc() + case <-s.closeCh: + return + } + } +} + +func (s *Server) handlePacket(from netip.AddrPort, b []byte, uw udpWriter) { + gh := packet.GeneveHeader{} + err := gh.Decode(b) + if err != nil { + return + } + // TODO: consider performance implications of holding s.mu for the remainder + // of this method, which does a bunch of disco/crypto work depending. Keep + // it simple (and slow) for now. + s.mu.Lock() + defer s.mu.Unlock() + e, ok := s.byVNI[gh.VNI] + if !ok { + // unknown VNI + return + } + + e.handlePacket(from, gh, b, uw, s.discoPublic) +} + +func (s *Server) packetReadLoop() { + defer func() { + s.wg.Done() + s.Close() + }() + b := make([]byte, 1<<16-1) + for { + // TODO: extract laddr from IP_PKTINFO for use in reply + n, from, err := s.uc.ReadFromUDPAddrPort(b) + if err != nil { + return + } + s.handlePacket(from, b[:n], s.uc) + } +} + +var ErrServerClosed = errors.New("server closed") + +// AllocateEndpoint allocates an [endpoint.ServerEndpoint] for the provided pair +// of [key.DiscoPublic]'s. If an allocation already exists for discoA and discoB +// it is returned without modification/reallocation. AllocateEndpoint returns +// [ErrServerClosed] if the server has been closed. +func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.ServerEndpoint, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return endpoint.ServerEndpoint{}, ErrServerClosed + } + + if discoA.Compare(s.discoPublic) == 0 || discoB.Compare(s.discoPublic) == 0 { + return endpoint.ServerEndpoint{}, fmt.Errorf("client disco equals server disco: %s", s.discoPublic.ShortString()) + } + + pair := newPairOfDiscoPubKeys(discoA, discoB) + e, ok := s.byDisco[pair] + if ok { + // Return the existing allocation. Clients can resolve duplicate + // [endpoint.ServerEndpoint]'s via [endpoint.ServerEndpoint.LamportID]. + // + // TODO: consider ServerEndpoint.BindLifetime -= time.Now()-e.allocatedAt + // to give the client a more accurate picture of the bind window. + return endpoint.ServerEndpoint{ + ServerDisco: s.discoPublic, + AddrPorts: s.addrPorts, + VNI: e.vni, + LamportID: e.lamportID, + BindLifetime: tstime.GoDuration{Duration: s.bindLifetime}, + SteadyStateLifetime: tstime.GoDuration{Duration: s.steadyStateLifetime}, + }, nil + } + + if len(s.vniPool) == 0 { + return endpoint.ServerEndpoint{}, errors.New("VNI pool exhausted") + } + + s.lamportID++ + e = &serverEndpoint{ + discoPubKeys: pair, + lamportID: s.lamportID, + allocatedAt: time.Now(), + } + e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys[0]) + e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys[1]) + e.vni, s.vniPool = s.vniPool[0], s.vniPool[1:] + rand.Read(e.challenge[0][:]) + rand.Read(e.challenge[1][:]) + + s.byDisco[pair] = e + s.byVNI[e.vni] = e + + return endpoint.ServerEndpoint{ + ServerDisco: s.discoPublic, + AddrPorts: s.addrPorts, + VNI: e.vni, + LamportID: e.lamportID, + BindLifetime: tstime.GoDuration{Duration: s.bindLifetime}, + SteadyStateLifetime: tstime.GoDuration{Duration: s.steadyStateLifetime}, + }, nil +} diff --git a/net/udprelay/server_test.go b/net/udprelay/server_test.go new file mode 100644 index 0000000000000..38c7ae5d9a749 --- /dev/null +++ b/net/udprelay/server_test.go @@ -0,0 +1,212 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package udprelay + +import ( + "bytes" + "net" + "net/netip" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "go4.org/mem" + "tailscale.com/disco" + "tailscale.com/net/packet" + "tailscale.com/types/key" +) + +type testClient struct { + vni uint32 + local key.DiscoPrivate + server key.DiscoPublic + uc *net.UDPConn +} + +func newTestClient(t *testing.T, vni uint32, serverEndpoint netip.AddrPort, local key.DiscoPrivate, server key.DiscoPublic) *testClient { + rAddr := &net.UDPAddr{IP: serverEndpoint.Addr().AsSlice(), Port: int(serverEndpoint.Port())} + uc, err := net.DialUDP("udp4", nil, rAddr) + if err != nil { + t.Fatal(err) + } + return &testClient{ + vni: vni, + local: local, + server: server, + uc: uc, + } +} + +func (c *testClient) write(t *testing.T, b []byte) { + _, err := c.uc.Write(b) + if err != nil { + t.Fatal(err) + } +} + +func (c *testClient) read(t *testing.T) []byte { + c.uc.SetReadDeadline(time.Now().Add(time.Second)) + b := make([]byte, 1<<16-1) + n, err := c.uc.Read(b) + if err != nil { + t.Fatal(err) + } + return b[:n] +} + +func (c *testClient) writeDataPkt(t *testing.T, b []byte) { + pkt := make([]byte, packet.GeneveFixedHeaderLength, packet.GeneveFixedHeaderLength+len(b)) + gh := packet.GeneveHeader{Control: false, VNI: c.vni, Protocol: packet.GeneveProtocolWireGuard} + err := gh.Encode(pkt) + if err != nil { + t.Fatal(err) + } + pkt = append(pkt, b...) + c.write(t, pkt) +} + +func (c *testClient) readDataPkt(t *testing.T) []byte { + b := c.read(t) + gh := packet.GeneveHeader{} + err := gh.Decode(b) + if err != nil { + t.Fatal(err) + } + if gh.Protocol != packet.GeneveProtocolWireGuard { + t.Fatal("unexpected geneve protocol") + } + if gh.Control { + t.Fatal("unexpected control") + } + if gh.VNI != c.vni { + t.Fatal("unexpected vni") + } + return b[packet.GeneveFixedHeaderLength:] +} + +func (c *testClient) writeControlDiscoMsg(t *testing.T, msg disco.Message) { + pkt := make([]byte, packet.GeneveFixedHeaderLength, 512) + gh := packet.GeneveHeader{Control: true, VNI: c.vni, Protocol: packet.GeneveProtocolDisco} + err := gh.Encode(pkt) + if err != nil { + t.Fatal(err) + } + pkt = append(pkt, disco.Magic...) + pkt = c.local.Public().AppendTo(pkt) + box := c.local.Shared(c.server).Seal(msg.AppendMarshal(nil)) + pkt = append(pkt, box...) + c.write(t, pkt) +} + +func (c *testClient) readControlDiscoMsg(t *testing.T) disco.Message { + b := c.read(t) + gh := packet.GeneveHeader{} + err := gh.Decode(b) + if err != nil { + t.Fatal(err) + } + if gh.Protocol != packet.GeneveProtocolDisco { + t.Fatal("unexpected geneve protocol") + } + if !gh.Control { + t.Fatal("unexpected non-control") + } + if gh.VNI != c.vni { + t.Fatal("unexpected vni") + } + b = b[packet.GeneveFixedHeaderLength:] + headerLen := len(disco.Magic) + key.DiscoPublicRawLen + if len(b) < headerLen { + t.Fatal("disco message too short") + } + sender := key.DiscoPublicFromRaw32(mem.B(b[len(disco.Magic):headerLen])) + if sender.Compare(c.server) != 0 { + t.Fatal("unknown disco public key") + } + payload, ok := c.local.Shared(c.server).Open(b[headerLen:]) + if !ok { + t.Fatal("failed to open sealed disco msg") + } + msg, err := disco.Parse(payload) + if err != nil { + t.Fatal("failed to parse disco payload") + } + return msg +} + +func (c *testClient) handshake(t *testing.T) { + c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpoint{}) + msg := c.readControlDiscoMsg(t) + challenge, ok := msg.(*disco.BindUDPRelayEndpointChallenge) + if !ok { + t.Fatal("unexepcted disco message type") + } + c.writeControlDiscoMsg(t, &disco.BindUDPRelayEndpointAnswer{Answer: challenge.Challenge}) +} + +func (c *testClient) close() { + c.uc.Close() +} + +func TestServer(t *testing.T) { + discoA := key.NewDisco() + discoB := key.NewDisco() + + ipv4LoopbackAddr := netip.MustParseAddr("127.0.0.1") + + server, _, err := NewServer(0, []netip.Addr{ipv4LoopbackAddr}) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + endpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public()) + if err != nil { + t.Fatal(err) + } + dupEndpoint, err := server.AllocateEndpoint(discoA.Public(), discoB.Public()) + if err != nil { + t.Fatal(err) + } + + // We expect the same endpoint details pre-handshake. + if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" { + t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff) + } + + if len(endpoint.AddrPorts) != 1 { + t.Fatalf("unexpected endpoint.AddrPorts: %v", endpoint.AddrPorts) + } + tcA := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoA, endpoint.ServerDisco) + defer tcA.close() + tcB := newTestClient(t, endpoint.VNI, endpoint.AddrPorts[0], discoB, endpoint.ServerDisco) + defer tcB.close() + + tcA.handshake(t) + tcB.handshake(t) + + dupEndpoint, err = server.AllocateEndpoint(discoA.Public(), discoB.Public()) + if err != nil { + t.Fatal(err) + } + // We expect the same endpoint details post-handshake. + if diff := cmp.Diff(dupEndpoint, endpoint, cmpopts.EquateComparable(netip.AddrPort{}, key.DiscoPublic{})); diff != "" { + t.Fatalf("wrong dupEndpoint (-got +want)\n%s", diff) + } + + txToB := []byte{1, 2, 3} + tcA.writeDataPkt(t, txToB) + rxFromA := tcB.readDataPkt(t) + if !bytes.Equal(txToB, rxFromA) { + t.Fatal("unexpected msg A->B") + } + + txToA := []byte{4, 5, 6} + tcB.writeDataPkt(t, txToA) + rxFromB := tcA.readDataPkt(t) + if !bytes.Equal(txToA, rxFromB) { + t.Fatal("unexpected msg B->A") + } +} diff --git a/paths/paths_unix.go b/paths/paths_unix.go index 6a2b28733a93b..50a8b7ca502f7 100644 --- a/paths/paths_unix.go +++ b/paths/paths_unix.go @@ -22,7 +22,7 @@ func init() { func statePath() string { switch runtime.GOOS { - case "linux": + case "linux", "illumos", "solaris": return "/var/lib/tailscale/tailscaled.state" case "freebsd", "openbsd": return "/var/db/tailscale/tailscaled.state" diff --git a/pkgdoc_test.go b/pkgdoc_test.go index be08a358b7c7a..0f4a455288950 100644 --- a/pkgdoc_test.go +++ b/pkgdoc_test.go @@ -26,6 +26,9 @@ func TestPackageDocs(t *testing.T) { if err != nil { return err } + if fi.Mode().IsDir() && path == ".git" { + return filepath.SkipDir // No documentation lives in .git + } if fi.Mode().IsRegular() && strings.HasSuffix(path, ".go") { if strings.HasSuffix(path, "_test.go") { return nil diff --git a/portlist/portlist_plan9.go b/portlist/portlist_plan9.go new file mode 100644 index 0000000000000..77f8619f97ffa --- /dev/null +++ b/portlist/portlist_plan9.go @@ -0,0 +1,122 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package portlist + +import ( + "bufio" + "bytes" + "os" + "strconv" + "strings" + "time" +) + +func init() { + newOSImpl = newPlan9Impl + + pollInterval = 5 * time.Second +} + +type plan9Impl struct { + known map[protoPort]*portMeta // inode string => metadata + + br *bufio.Reader // reused + portsBuf []Port + includeLocalhost bool +} + +type protoPort struct { + proto string + port uint16 +} + +type portMeta struct { + port Port + keep bool +} + +func newPlan9Impl(includeLocalhost bool) osImpl { + return &plan9Impl{ + known: map[protoPort]*portMeta{}, + br: bufio.NewReader(bytes.NewReader(nil)), + includeLocalhost: includeLocalhost, + } +} + +func (*plan9Impl) Close() error { return nil } + +func (im *plan9Impl) AppendListeningPorts(base []Port) ([]Port, error) { + ret := base + + des, err := os.ReadDir("/proc") + if err != nil { + return nil, err + } + for _, de := range des { + if !de.IsDir() { + continue + } + pidStr := de.Name() + pid, err := strconv.Atoi(pidStr) + if err != nil { + continue + } + st, _ := os.ReadFile("/proc/" + pidStr + "/fd") + if !bytes.Contains(st, []byte("/net/tcp/clone")) { + continue + } + args, _ := os.ReadFile("/proc/" + pidStr + "/args") + procName := string(bytes.TrimSpace(args)) + // term% cat /proc/417/fd + // /usr/glenda + // 0 r M 35 (0000000000000001 0 00) 16384 260 /dev/cons + // 1 w c 0 (000000000000000a 0 00) 0 471 /dev/null + // 2 w M 35 (0000000000000001 0 00) 16384 108 /dev/cons + // 3 rw I 0 (000000000000002c 0 00) 0 14 /net/tcp/clone + for line := range bytes.Lines(st) { + if !bytes.Contains(line, []byte("/net/tcp/clone")) { + continue + } + f := strings.Fields(string(line)) + if len(f) < 10 { + continue + } + if f[9] != "/net/tcp/clone" { + continue + } + qid, err := strconv.ParseUint(strings.TrimPrefix(f[4], "("), 16, 64) + if err != nil { + continue + } + tcpN := (qid >> 5) & (1<<12 - 1) + tcpNStr := strconv.FormatUint(tcpN, 10) + st, _ := os.ReadFile("/net/tcp/" + tcpNStr + "/status") + if !bytes.Contains(st, []byte("Listen ")) { + // Unexpected. Or a race. + continue + } + bl, _ := os.ReadFile("/net/tcp/" + tcpNStr + "/local") + i := bytes.LastIndexByte(bl, '!') + if i == -1 { + continue + } + if bytes.HasPrefix(bl, []byte("127.0.0.1!")) && !im.includeLocalhost { + continue + } + portStr := strings.TrimSpace(string(bl[i+1:])) + port, _ := strconv.Atoi(portStr) + if port == 0 { + continue + } + ret = append(ret, Port{ + Proto: "tcp", + Port: uint16(port), + Process: procName, + Pid: pid, + }) + } + } + + return sortAndDedup(ret), nil +} diff --git a/prober/derp.go b/prober/derp.go index 0dadbe8c2fe06..98e61ff54b89a 100644 --- a/prober/derp.go +++ b/prober/derp.go @@ -8,24 +8,34 @@ import ( "cmp" "context" crand "crypto/rand" + "encoding/binary" "encoding/json" "errors" "expvar" "fmt" + "io" "log" + "maps" "net" "net/http" + "net/netip" + "slices" "strconv" "strings" "sync" "time" "github.com/prometheus/client_golang/prometheus" - "tailscale.com/client/tailscale" + wgconn "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "go4.org/netipx" + "tailscale.com/client/local" "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/net/netmon" "tailscale.com/net/stun" + "tailscale.com/net/tstun" "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -42,8 +52,16 @@ type derpProber struct { tlsInterval time.Duration // Optional bandwidth probing. - bwInterval time.Duration - bwProbeSize int64 + bwInterval time.Duration + bwProbeSize int64 + bwTUNIPv4Prefix *netip.Prefix // or nil to not use TUN + + // Optional queuing delay probing. + qdPacketsPerSecond int // in packets per second + qdPacketTimeout time.Duration + + // Optionally restrict probes to a single regionCodeOrID. + regionCodeOrID string // Probe class for fetching & updating the DERP map. ProbeMap ProbeClass @@ -53,6 +71,7 @@ type derpProber struct { udpProbeFn func(string, int) ProbeClass meshProbeFn func(string, string) ProbeClass bwProbeFn func(string, string, int64) ProbeClass + qdProbeFn func(string, string, int, time.Duration) ProbeClass sync.Mutex lastDERPMap *tailcfg.DERPMap @@ -65,11 +84,30 @@ type DERPOpt func(*derpProber) // WithBandwidthProbing enables bandwidth probing. When enabled, a payload of // `size` bytes will be regularly transferred through each DERP server, and each -// pair of DERP servers in every region. -func WithBandwidthProbing(interval time.Duration, size int64) DERPOpt { +// pair of DERP servers in every region. If tunAddress is specified, probes will +// use a TCP connection over a TUN device at this address in order to exercise +// TCP-in-TCP in similar fashion to TCP over Tailscale via DERP. +func WithBandwidthProbing(interval time.Duration, size int64, tunAddress string) DERPOpt { return func(d *derpProber) { d.bwInterval = interval d.bwProbeSize = size + if tunAddress != "" { + prefix, err := netip.ParsePrefix(fmt.Sprintf("%s/30", tunAddress)) + if err != nil { + log.Fatalf("failed to parse IP prefix from bw-tun-ipv4-addr: %v", err) + } + d.bwTUNIPv4Prefix = &prefix + } + } +} + +// WithQueuingDelayProbing enables/disables queuing delay probing. qdSendRate +// is the number of packets sent per second. qdTimeout is the amount of time +// after which a sent packet is considered to have timed out. +func WithQueuingDelayProbing(qdPacketsPerSecond int, qdPacketTimeout time.Duration) DERPOpt { + return func(d *derpProber) { + d.qdPacketsPerSecond = qdPacketsPerSecond + d.qdPacketTimeout = qdPacketTimeout } } @@ -97,6 +135,14 @@ func WithTLSProbing(interval time.Duration) DERPOpt { } } +// WithRegionCodeOrID restricts probing to the specified region identified by its code +// (e.g. "lax") or its id (e.g. "17"). This is case sensitive. +func WithRegionCodeOrID(regionCode string) DERPOpt { + return func(d *derpProber) { + d.regionCodeOrID = regionCode + } +} + // DERP creates a new derpProber. // // If derpMapURL is "local", the DERPMap is fetched via @@ -119,6 +165,7 @@ func DERP(p *Prober, derpMapURL string, opts ...DERPOpt) (*derpProber, error) { d.udpProbeFn = d.ProbeUDP d.meshProbeFn = d.probeMesh d.bwProbeFn = d.probeBandwidth + d.qdProbeFn = d.probeQueuingDelay return d, nil } @@ -135,6 +182,10 @@ func (d *derpProber) probeMapFn(ctx context.Context) error { defer d.Unlock() for _, region := range d.lastDERPMap.Regions { + if d.skipRegion(region) { + continue + } + for _, server := range region.Nodes { labels := Labels{ "region": region.RegionCode, @@ -181,14 +232,27 @@ func (d *derpProber) probeMapFn(ctx context.Context) error { } } - if d.bwInterval > 0 && d.bwProbeSize > 0 { + if d.bwInterval != 0 && d.bwProbeSize > 0 { n := fmt.Sprintf("derp/%s/%s/%s/bw", region.RegionCode, server.Name, to.Name) wantProbes[n] = true if d.probes[n] == nil { - log.Printf("adding DERP bandwidth probe for %s->%s (%s) %v bytes every %v", server.Name, to.Name, region.RegionName, d.bwProbeSize, d.bwInterval) + tunString := "" + if d.bwTUNIPv4Prefix != nil { + tunString = " (TUN)" + } + log.Printf("adding%s DERP bandwidth probe for %s->%s (%s) %v bytes every %v", tunString, server.Name, to.Name, region.RegionName, d.bwProbeSize, d.bwInterval) d.probes[n] = d.p.Run(n, d.bwInterval, labels, d.bwProbeFn(server.Name, to.Name, d.bwProbeSize)) } } + + if d.qdPacketsPerSecond > 0 { + n := fmt.Sprintf("derp/%s/%s/%s/qd", region.RegionCode, server.Name, to.Name) + wantProbes[n] = true + if d.probes[n] == nil { + log.Printf("adding DERP queuing delay probe for %s->%s (%s)", server.Name, to.Name, region.RegionName) + d.probes[n] = d.p.Run(n, -10*time.Second, labels, d.qdProbeFn(server.Name, to.Name, d.qdPacketsPerSecond, d.qdPacketTimeout)) + } + } } } } @@ -204,7 +268,7 @@ func (d *derpProber) probeMapFn(ctx context.Context) error { return nil } -// probeMesh returs a probe class that sends a test packet through a pair of DERP +// probeMesh returns a probe class that sends a test packet through a pair of DERP // servers (or just one server, if 'from' and 'to' are the same). 'from' and 'to' // are expected to be names (DERPNode.Name) of two DERP servers in the same region. func (d *derpProber) probeMesh(from, to string) ProbeClass { @@ -227,7 +291,7 @@ func (d *derpProber) probeMesh(from, to string) ProbeClass { } } -// probeBandwidth returs a probe class that sends a payload of a given size +// probeBandwidth returns a probe class that sends a payload of a given size // through a pair of DERP servers (or just one server, if 'from' and 'to' are // the same). 'from' and 'to' are expected to be names (DERPNode.Name) of two // DERP servers in the same region. @@ -236,26 +300,224 @@ func (d *derpProber) probeBandwidth(from, to string, size int64) ProbeClass { if from == to { derpPath = "single" } - var transferTime expvar.Float + var transferTimeSeconds expvar.Float + var totalBytesTransferred expvar.Float return ProbeClass{ Probe: func(ctx context.Context) error { fromN, toN, err := d.getNodePair(from, to) if err != nil { return err } - return derpProbeBandwidth(ctx, d.lastDERPMap, fromN, toN, size, &transferTime) + return derpProbeBandwidth(ctx, d.lastDERPMap, fromN, toN, size, &transferTimeSeconds, &totalBytesTransferred, d.bwTUNIPv4Prefix) + }, + Class: "derp_bw", + Labels: Labels{ + "derp_path": derpPath, + "tcp_in_tcp": strconv.FormatBool(d.bwTUNIPv4Prefix != nil), }, - Class: "derp_bw", - Labels: Labels{"derp_path": derpPath}, Metrics: func(l prometheus.Labels) []prometheus.Metric { - return []prometheus.Metric{ + metrics := []prometheus.Metric{ prometheus.MustNewConstMetric(prometheus.NewDesc("derp_bw_probe_size_bytes", "Payload size of the bandwidth prober", nil, l), prometheus.GaugeValue, float64(size)), - prometheus.MustNewConstMetric(prometheus.NewDesc("derp_bw_transfer_time_seconds_total", "Time it took to transfer data", nil, l), prometheus.CounterValue, transferTime.Value()), + prometheus.MustNewConstMetric(prometheus.NewDesc("derp_bw_transfer_time_seconds_total", "Time it took to transfer data", nil, l), prometheus.CounterValue, transferTimeSeconds.Value()), } + if d.bwTUNIPv4Prefix != nil { + // For TCP-in-TCP probes, also record cumulative bytes transferred. + metrics = append(metrics, prometheus.MustNewConstMetric(prometheus.NewDesc("derp_bw_bytes_total", "Amount of data transferred", nil, l), prometheus.CounterValue, totalBytesTransferred.Value())) + } + return metrics }, } } +// probeQueuingDelay returns a probe class that continuously sends packets +// through a pair of DERP servers (or just one server, if 'from' and 'to' are +// the same) at a rate of `packetsPerSecond` packets per second in order to +// measure queuing delays. Packets arriving after `packetTimeout` don't contribute +// to the queuing delay measurement and are recorded as dropped. 'from' and 'to' are +// expected to be names (DERPNode.Name) of two DERP servers in the same region, +// and may refer to the same server. +func (d *derpProber) probeQueuingDelay(from, to string, packetsPerSecond int, packetTimeout time.Duration) ProbeClass { + derpPath := "mesh" + if from == to { + derpPath = "single" + } + var packetsDropped expvar.Float + qdh := newHistogram([]float64{.005, .01, .025, .05, .1, .25, .5, 1}) + return ProbeClass{ + Probe: func(ctx context.Context) error { + fromN, toN, err := d.getNodePair(from, to) + if err != nil { + return err + } + return derpProbeQueuingDelay(ctx, d.lastDERPMap, fromN, toN, packetsPerSecond, packetTimeout, &packetsDropped, qdh) + }, + Class: "derp_qd", + Labels: Labels{"derp_path": derpPath}, + Metrics: func(l prometheus.Labels) []prometheus.Metric { + qdh.mx.Lock() + result := []prometheus.Metric{ + prometheus.MustNewConstMetric(prometheus.NewDesc("derp_qd_probe_dropped_packets", "Total packets dropped", nil, l), prometheus.CounterValue, float64(packetsDropped.Value())), + prometheus.MustNewConstHistogram(prometheus.NewDesc("derp_qd_probe_delays_seconds", "Distribution of queuing delays", nil, l), qdh.count, qdh.sum, maps.Clone(qdh.bucketedCounts)), + } + qdh.mx.Unlock() + return result + }, + } +} + +// derpProbeQueuingDelay continuously sends data between two local DERP clients +// connected to two DERP servers in order to measure queuing delays. From and to +// can be the same server. +func derpProbeQueuingDelay(ctx context.Context, dm *tailcfg.DERPMap, from, to *tailcfg.DERPNode, packetsPerSecond int, packetTimeout time.Duration, packetsDropped *expvar.Float, qdh *histogram) (err error) { + // This probe uses clients with isProber=false to avoid spamming the derper + // logs with every packet sent by the queuing delay probe. + fromc, err := newConn(ctx, dm, from, false) + if err != nil { + return err + } + defer fromc.Close() + toc, err := newConn(ctx, dm, to, false) + if err != nil { + return err + } + defer toc.Close() + + // Wait a bit for from's node to hear about to existing on the + // other node in the region, in the case where the two nodes + // are different. + if from.Name != to.Name { + time.Sleep(100 * time.Millisecond) // pretty arbitrary + } + + if err := runDerpProbeQueuingDelayContinously(ctx, from, to, fromc, toc, packetsPerSecond, packetTimeout, packetsDropped, qdh); err != nil { + // Record pubkeys on failed probes to aid investigation. + return fmt.Errorf("%s -> %s: %w", + fromc.SelfPublicKey().ShortString(), + toc.SelfPublicKey().ShortString(), err) + } + return nil +} + +func runDerpProbeQueuingDelayContinously(ctx context.Context, from, to *tailcfg.DERPNode, fromc, toc *derphttp.Client, packetsPerSecond int, packetTimeout time.Duration, packetsDropped *expvar.Float, qdh *histogram) error { + // Make sure all goroutines have finished. + var wg sync.WaitGroup + defer wg.Wait() + + // Close the clients to make sure goroutines that are reading/writing from them terminate. + defer fromc.Close() + defer toc.Close() + + type txRecord struct { + at time.Time + seq uint64 + } + // txRecords is sized to hold enough transmission records to keep timings + // for packets up to their timeout. As records age out of the front of this + // list, if the associated packet arrives, we won't have a txRecord for it + // and will consider it to have timed out. + txRecords := make([]txRecord, 0, packetsPerSecond*int(packetTimeout.Seconds())) + var txRecordsMu sync.Mutex + + // Send the packets. + sendErrC := make(chan error, 1) + // TODO: construct a disco CallMeMaybe in the same fashion as magicsock, e.g. magic bytes, src pub, seal payload. + // DERP server handling of disco may vary from non-disco, and we may want to measure queue delay of both. + pkt := make([]byte, 260) // the same size as a CallMeMaybe packet observed on a Tailscale client. + crand.Read(pkt) + + wg.Add(1) + go func() { + defer wg.Done() + t := time.NewTicker(time.Second / time.Duration(packetsPerSecond)) + defer t.Stop() + + toDERPPubKey := toc.SelfPublicKey() + seq := uint64(0) + for { + select { + case <-ctx.Done(): + return + case <-t.C: + txRecordsMu.Lock() + if len(txRecords) == cap(txRecords) { + txRecords = slices.Delete(txRecords, 0, 1) + packetsDropped.Add(1) + } + txRecords = append(txRecords, txRecord{time.Now(), seq}) + txRecordsMu.Unlock() + binary.BigEndian.PutUint64(pkt, seq) + seq++ + if err := fromc.Send(toDERPPubKey, pkt); err != nil { + sendErrC <- fmt.Errorf("sending packet %w", err) + return + } + } + } + }() + + // Receive the packets. + recvFinishedC := make(chan error, 1) + wg.Add(1) + go func() { + defer wg.Done() + defer close(recvFinishedC) // to break out of 'select' below. + fromDERPPubKey := fromc.SelfPublicKey() + for { + m, err := toc.Recv() + if err != nil { + recvFinishedC <- err + return + } + switch v := m.(type) { + case derp.ReceivedPacket: + now := time.Now() + if v.Source != fromDERPPubKey { + recvFinishedC <- fmt.Errorf("got data packet from unexpected source, %v", v.Source) + return + } + seq := binary.BigEndian.Uint64(v.Data) + txRecordsMu.Lock() + findTxRecord: + for i, record := range txRecords { + switch { + case record.seq == seq: + rtt := now.Sub(record.at) + qdh.add(rtt.Seconds()) + txRecords = slices.Delete(txRecords, i, i+1) + break findTxRecord + case record.seq > seq: + // No sent time found, probably a late arrival already + // recorded as drop by sender when deleted. + break findTxRecord + case record.seq < seq: + continue + } + } + txRecordsMu.Unlock() + + case derp.KeepAliveMessage: + // Silently ignore. + + default: + log.Printf("%v: ignoring Recv frame type %T", to.Name, v) + // Loop. + } + } + }() + + select { + case <-ctx.Done(): + return fmt.Errorf("timeout: %w", ctx.Err()) + case err := <-sendErrC: + return fmt.Errorf("error sending via %q: %w", from.Name, err) + case err := <-recvFinishedC: + if err != nil { + return fmt.Errorf("error receiving from %q: %w", to.Name, err) + } + } + return nil +} + // getNodePair returns DERPNode objects for two DERP servers based on their // short names. func (d *derpProber) getNodePair(n1, n2 string) (ret1, ret2 *tailcfg.DERPNode, _ error) { @@ -272,7 +534,7 @@ func (d *derpProber) getNodePair(n1, n2 string) (ret1, ret2 *tailcfg.DERPNode, _ return ret1, ret2, nil } -var tsLocalClient tailscale.LocalClient +var tsLocalClient local.Client // updateMap refreshes the locally-cached DERP map. func (d *derpProber) updateMap(ctx context.Context) error { @@ -316,6 +578,10 @@ func (d *derpProber) updateMap(ctx context.Context) error { d.lastDERPMapAt = time.Now() d.nodes = make(map[string]*tailcfg.DERPNode) for _, reg := range d.lastDERPMap.Regions { + if d.skipRegion(reg) { + continue + } + for _, n := range reg.Nodes { if existing, ok := d.nodes[n.Name]; ok { return fmt.Errorf("derpmap has duplicate nodes: %+v and %+v", existing, n) @@ -330,14 +596,30 @@ func (d *derpProber) updateMap(ctx context.Context) error { } func (d *derpProber) ProbeUDP(ipaddr string, port int) ProbeClass { + initLabels := make(Labels) + ip := net.ParseIP(ipaddr) + + if ip.To4() != nil { + initLabels["address_family"] = "ipv4" + } else if ip.To16() != nil { // Will return an IPv4 as 16 byte, so ensure the check for IPv4 precedes this + initLabels["address_family"] = "ipv6" + } else { + initLabels["address_family"] = "unknown" + } + return ProbeClass{ Probe: func(ctx context.Context) error { return derpProbeUDP(ctx, ipaddr, port) }, - Class: "derp_udp", + Class: "derp_udp", + Labels: initLabels, } } +func (d *derpProber) skipRegion(region *tailcfg.DERPRegion) bool { + return d.regionCodeOrID != "" && region.RegionCode != d.regionCodeOrID && strconv.Itoa(region.RegionID) != d.regionCodeOrID +} + func derpProbeUDP(ctx context.Context, ipStr string, port int) error { pc, err := net.ListenPacket("udp", ":0") if err != nil { @@ -389,8 +671,10 @@ func derpProbeUDP(ctx context.Context, ipStr string, port int) error { } // derpProbeBandwidth sends a payload of a given size between two local -// DERP clients connected to two DERP servers. -func derpProbeBandwidth(ctx context.Context, dm *tailcfg.DERPMap, from, to *tailcfg.DERPNode, size int64, transferTime *expvar.Float) (err error) { +// DERP clients connected to two DERP servers.If tunIPv4Address is specified, +// probes will use a TCP connection over a TUN device at this address in order +// to exercise TCP-in-TCP in similar fashion to TCP over Tailscale via DERP. +func derpProbeBandwidth(ctx context.Context, dm *tailcfg.DERPMap, from, to *tailcfg.DERPNode, size int64, transferTimeSeconds, totalBytesTransferred *expvar.Float, tunIPv4Prefix *netip.Prefix) (err error) { // This probe uses clients with isProber=false to avoid spamming the derper logs with every packet // sent by the bandwidth probe. fromc, err := newConn(ctx, dm, from, false) @@ -411,10 +695,13 @@ func derpProbeBandwidth(ctx context.Context, dm *tailcfg.DERPMap, from, to *tail time.Sleep(100 * time.Millisecond) // pretty arbitrary } - start := time.Now() - defer func() { transferTime.Add(time.Since(start).Seconds()) }() + if tunIPv4Prefix != nil { + err = derpProbeBandwidthTUN(ctx, transferTimeSeconds, totalBytesTransferred, from, to, fromc, toc, size, tunIPv4Prefix) + } else { + err = derpProbeBandwidthDirect(ctx, transferTimeSeconds, from, to, fromc, toc, size) + } - if err := runDerpProbeNodePair(ctx, from, to, fromc, toc, size); err != nil { + if err != nil { // Record pubkeys on failed probes to aid investigation. return fmt.Errorf("%s -> %s: %w", fromc.SelfPublicKey().ShortString(), @@ -494,9 +781,10 @@ func runDerpProbeNodePair(ctx context.Context, from, to *tailcfg.DERPNode, fromc // Send the packets. sendc := make(chan error, 1) go func() { + toDERPPubKey := toc.SelfPublicKey() for idx, pkt := range pkts { inFlight.AcquireContext(ctx) - if err := fromc.Send(toc.SelfPublicKey(), pkt); err != nil { + if err := fromc.Send(toDERPPubKey, pkt); err != nil { sendc <- fmt.Errorf("sending packet %d: %w", idx, err) return } @@ -508,6 +796,7 @@ func runDerpProbeNodePair(ctx context.Context, from, to *tailcfg.DERPNode, fromc go func() { defer close(recvc) // to break out of 'select' below. idx := 0 + fromDERPPubKey := fromc.SelfPublicKey() for { m, err := toc.Recv() if err != nil { @@ -517,10 +806,12 @@ func runDerpProbeNodePair(ctx context.Context, from, to *tailcfg.DERPNode, fromc switch v := m.(type) { case derp.ReceivedPacket: inFlight.Release() - if v.Source != fromc.SelfPublicKey() { + if v.Source != fromDERPPubKey { recvc <- fmt.Errorf("got data packet %d from unexpected source, %v", idx, v.Source) return } + // This assumes that the packets are received reliably and in order. + // The DERP protocol does not guarantee this, but this probe assumes it. if got, want := v.Data, pkts[idx]; !bytes.Equal(got, want) { recvc <- fmt.Errorf("unexpected data packet %d (out of %d)", idx, len(pkts)) return @@ -554,6 +845,277 @@ func runDerpProbeNodePair(ctx context.Context, from, to *tailcfg.DERPNode, fromc return nil } +// derpProbeBandwidthDirect takes two DERP clients (fromc and toc) connected to two +// DERP servers (from and to) and sends a test payload of a given size from one +// to another using runDerpProbeNodePair. The time taken to finish the transfer is +// recorded in `transferTimeSeconds`. +func derpProbeBandwidthDirect(ctx context.Context, transferTimeSeconds *expvar.Float, from, to *tailcfg.DERPNode, fromc, toc *derphttp.Client, size int64) error { + start := time.Now() + defer func() { transferTimeSeconds.Add(time.Since(start).Seconds()) }() + + return runDerpProbeNodePair(ctx, from, to, fromc, toc, size) +} + +// derpProbeBandwidthTUNMu ensures that TUN bandwidth probes don't run concurrently. +// This is necessary to avoid conflicts trying to create the TUN device, and +// it also has the nice benefit of preventing concurrent bandwidth probes from +// influencing each other's results. +// +// This guards derpProbeBandwidthTUN. +var derpProbeBandwidthTUNMu sync.Mutex + +// derpProbeBandwidthTUN takes two DERP clients (fromc and toc) connected to two +// DERP servers (from and to) and sends a test payload of a given size from one +// to another over a TUN device at an address at the start of the usable host IP +// range that the given tunAddress lives in. The time taken to finish the transfer +// is recorded in `transferTimeSeconds`. +func derpProbeBandwidthTUN(ctx context.Context, transferTimeSeconds, totalBytesTransferred *expvar.Float, from, to *tailcfg.DERPNode, fromc, toc *derphttp.Client, size int64, prefix *netip.Prefix) error { + // Make sure all goroutines have finished. + var wg sync.WaitGroup + defer wg.Wait() + + // Close the clients to make sure goroutines that are reading/writing from them terminate. + defer fromc.Close() + defer toc.Close() + + ipRange := netipx.RangeOfPrefix(*prefix) + // Start of the usable host IP range from the address we have been passed in. + ifAddr := ipRange.From().Next() + // Destination address to dial. This is the next address in the range from + // our ifAddr to ensure that the underlying networking stack is actually being + // utilized instead of being optimized away and treated as a loopback. Packets + // sent to this address will be routed over the TUN. + destinationAddr := ifAddr.Next() + + derpProbeBandwidthTUNMu.Lock() + defer derpProbeBandwidthTUNMu.Unlock() + + // Temporarily set up a TUN device with which to simulate a real client TCP connection + // tunneling over DERP. Use `tstun.DefaultTUNMTU()` (e.g., 1280) as our MTU as this is + // the minimum safe MTU used by Tailscale. + dev, err := tun.CreateTUN(tunName, int(tstun.DefaultTUNMTU())) + if err != nil { + return fmt.Errorf("failed to create TUN device: %w", err) + } + defer func() { + if err := dev.Close(); err != nil { + log.Printf("failed to close TUN device: %s", err) + } + }() + mtu, err := dev.MTU() + if err != nil { + return fmt.Errorf("failed to get TUN MTU: %w", err) + } + + name, err := dev.Name() + if err != nil { + return fmt.Errorf("failed to get device name: %w", err) + } + + // Perform platform specific configuration of the TUN device. + err = configureTUN(*prefix, name) + if err != nil { + return fmt.Errorf("failed to configure tun: %w", err) + } + + // Depending on platform, we need some space for headers at the front + // of TUN I/O op buffers. The below constant is more than enough space + // for any platform that this might run on. + tunStartOffset := device.MessageTransportHeaderSize + + // This goroutine reads packets from the TUN device and evaluates if they + // are IPv4 packets destined for loopback via DERP. If so, it performs L3 NAT + // (swap src/dst) and writes them towards DERP in order to loopback via the + // `toc` DERP client. It only reports errors to `tunReadErrC`. + wg.Add(1) + tunReadErrC := make(chan error, 1) + go func() { + defer wg.Done() + + numBufs := wgconn.IdealBatchSize + bufs := make([][]byte, 0, numBufs) + sizes := make([]int, numBufs) + for range numBufs { + bufs = append(bufs, make([]byte, mtu+tunStartOffset)) + } + + destinationAddrBytes := destinationAddr.AsSlice() + scratch := make([]byte, 4) + toDERPPubKey := toc.SelfPublicKey() + for { + n, err := dev.Read(bufs, sizes, tunStartOffset) + if err != nil { + tunReadErrC <- err + return + } + + for i := range n { + pkt := bufs[i][tunStartOffset : sizes[i]+tunStartOffset] + // Skip everything except valid IPv4 packets + if len(pkt) < 20 { + // Doesn't even have a full IPv4 header + continue + } + if pkt[0]>>4 != 4 { + // Not IPv4 + continue + } + + if !bytes.Equal(pkt[16:20], destinationAddrBytes) { + // Unexpected dst address + continue + } + + copy(scratch, pkt[12:16]) + copy(pkt[12:16], pkt[16:20]) + copy(pkt[16:20], scratch) + + if err := fromc.Send(toDERPPubKey, pkt); err != nil { + tunReadErrC <- err + return + } + } + } + }() + + // This goroutine reads packets from the `toc` DERP client and writes them towards the TUN. + // It only reports errors to `recvErrC` channel. + wg.Add(1) + recvErrC := make(chan error, 1) + go func() { + defer wg.Done() + + buf := make([]byte, mtu+tunStartOffset) + bufs := make([][]byte, 1) + + fromDERPPubKey := fromc.SelfPublicKey() + for { + m, err := toc.Recv() + if err != nil { + recvErrC <- fmt.Errorf("failed to receive: %w", err) + return + } + switch v := m.(type) { + case derp.ReceivedPacket: + if v.Source != fromDERPPubKey { + recvErrC <- fmt.Errorf("got data packet from unexpected source, %v", v.Source) + return + } + pkt := v.Data + copy(buf[tunStartOffset:], pkt) + bufs[0] = buf[:len(pkt)+tunStartOffset] + if _, err := dev.Write(bufs, tunStartOffset); err != nil { + recvErrC <- fmt.Errorf("failed to write to TUN device: %w", err) + return + } + case derp.KeepAliveMessage: + // Silently ignore. + default: + log.Printf("%v: ignoring Recv frame type %T", to.Name, v) + // Loop. + } + } + }() + + // Start a listener to receive the data + l, err := net.Listen("tcp", net.JoinHostPort(ifAddr.String(), "0")) + if err != nil { + return fmt.Errorf("failed to listen: %s", err) + } + defer l.Close() + + // 128KB by default + const writeChunkSize = 128 << 10 + + randData := make([]byte, writeChunkSize) + _, err = crand.Read(randData) + if err != nil { + return fmt.Errorf("failed to initialize random data: %w", err) + } + + // Dial ourselves + _, port, err := net.SplitHostPort(l.Addr().String()) + if err != nil { + return fmt.Errorf("failed to split address %q: %w", l.Addr().String(), err) + } + + connAddr := net.JoinHostPort(destinationAddr.String(), port) + conn, err := net.Dial("tcp", connAddr) + if err != nil { + return fmt.Errorf("failed to dial address %q: %w", connAddr, err) + } + defer conn.Close() + + // Timing only includes the actual sending and receiving of data. + start := time.Now() + + // This goroutine reads data from the TCP stream being looped back via DERP. + // It reports to `readFinishedC` when `size` bytes have been read, or if an + // error occurs. + wg.Add(1) + readFinishedC := make(chan error, 1) + go func() { + defer wg.Done() + + readConn, err := l.Accept() + if err != nil { + readFinishedC <- err + return + } + defer readConn.Close() + deadline, ok := ctx.Deadline() + if ok { + // Don't try reading past our context's deadline. + if err := readConn.SetReadDeadline(deadline); err != nil { + readFinishedC <- fmt.Errorf("unable to set read deadline: %w", err) + return + } + } + n, err := io.CopyN(io.Discard, readConn, size) + // Measure transfer time and bytes transferred irrespective of whether it succeeded or failed. + transferTimeSeconds.Add(time.Since(start).Seconds()) + totalBytesTransferred.Add(float64(n)) + readFinishedC <- err + }() + + // This goroutine sends data to the TCP stream being looped back via DERP. + // It only reports errors to `sendErrC`. + wg.Add(1) + sendErrC := make(chan error, 1) + go func() { + defer wg.Done() + + for wrote := 0; wrote < int(size); wrote += len(randData) { + b := randData + if wrote+len(randData) > int(size) { + // This is the last chunk and we don't need the whole thing + b = b[0 : int(size)-wrote] + } + if _, err := conn.Write(b); err != nil { + sendErrC <- fmt.Errorf("failed to write to conn: %w", err) + return + } + } + }() + + select { + case <-ctx.Done(): + return fmt.Errorf("timeout: %w", ctx.Err()) + case err := <-tunReadErrC: + return fmt.Errorf("error reading from TUN via %q: %w", from.Name, err) + case err := <-sendErrC: + return fmt.Errorf("error sending via %q: %w", from.Name, err) + case err := <-recvErrC: + return fmt.Errorf("error receiving from %q: %w", to.Name, err) + case err := <-readFinishedC: + if err != nil { + return fmt.Errorf("error reading from %q to TUN: %w", to.Name, err) + } + } + + return nil +} + func newConn(ctx context.Context, dm *tailcfg.DERPMap, n *tailcfg.DERPNode, isProber bool) (*derphttp.Client, error) { // To avoid spamming the log with regular connection messages. l := logger.Filtered(log.Printf, func(s string) bool { @@ -574,18 +1136,22 @@ func newConn(ctx context.Context, dm *tailcfg.DERPMap, n *tailcfg.DERPNode, isPr if err != nil { return nil, err } - cs, ok := dc.TLSConnectionState() - if !ok { - dc.Close() - return nil, errors.New("no TLS state") - } - if len(cs.PeerCertificates) == 0 { - dc.Close() - return nil, errors.New("no peer certificates") - } - if cs.ServerName != n.HostName { - dc.Close() - return nil, fmt.Errorf("TLS server name %q != derp hostname %q", cs.ServerName, n.HostName) + + // Only verify TLS state if this is a prober. + if isProber { + cs, ok := dc.TLSConnectionState() + if !ok { + dc.Close() + return nil, errors.New("no TLS state") + } + if len(cs.PeerCertificates) == 0 { + dc.Close() + return nil, errors.New("no peer certificates") + } + if cs.ServerName != n.HostName { + dc.Close() + return nil, fmt.Errorf("TLS server name %q != derp hostname %q", cs.ServerName, n.HostName) + } } errc := make(chan error, 1) diff --git a/prober/derp_test.go b/prober/derp_test.go index a34292a23b6f4..93b8d760b3f18 100644 --- a/prober/derp_test.go +++ b/prober/derp_test.go @@ -44,6 +44,19 @@ func TestDerpProber(t *testing.T) { }, }, }, + 1: { + RegionID: 1, + RegionCode: "one", + Nodes: []*tailcfg.DERPNode{ + { + Name: "n3", + RegionID: 0, + HostName: "derpn3.tailscale.test", + IPv4: "1.1.1.1", + IPv6: "::1", + }, + }, + }, }, } srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -58,16 +71,17 @@ func TestDerpProber(t *testing.T) { clk := newFakeTime() p := newForTest(clk.Now, clk.NewTicker) dp := &derpProber{ - p: p, - derpMapURL: srv.URL, - tlsInterval: time.Second, - tlsProbeFn: func(_ string) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) }, - udpInterval: time.Second, - udpProbeFn: func(_ string, _ int) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) }, - meshInterval: time.Second, - meshProbeFn: func(_, _ string) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) }, - nodes: make(map[string]*tailcfg.DERPNode), - probes: make(map[string]*Probe), + p: p, + derpMapURL: srv.URL, + tlsInterval: time.Second, + tlsProbeFn: func(_ string) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) }, + udpInterval: time.Second, + udpProbeFn: func(_ string, _ int) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) }, + meshInterval: time.Second, + meshProbeFn: func(_, _ string) ProbeClass { return FuncProbe(func(context.Context) error { return nil }) }, + nodes: make(map[string]*tailcfg.DERPNode), + probes: make(map[string]*Probe), + regionCodeOrID: "zero", } if err := dp.probeMapFn(context.Background()); err != nil { t.Errorf("unexpected probeMapFn() error: %s", err) @@ -84,9 +98,9 @@ func TestDerpProber(t *testing.T) { // Add one more node and check that probes got created. dm.Regions[0].Nodes = append(dm.Regions[0].Nodes, &tailcfg.DERPNode{ - Name: "n3", + Name: "n4", RegionID: 0, - HostName: "derpn3.tailscale.test", + HostName: "derpn4.tailscale.test", IPv4: "1.1.1.1", IPv6: "::1", }) @@ -113,6 +127,19 @@ func TestDerpProber(t *testing.T) { if len(dp.probes) != 4 { t.Errorf("unexpected probes: %+v", dp.probes) } + + // Stop filtering regions. + dp.regionCodeOrID = "" + if err := dp.probeMapFn(context.Background()); err != nil { + t.Errorf("unexpected probeMapFn() error: %s", err) + } + if len(dp.nodes) != 2 { + t.Errorf("unexpected nodes: %+v", dp.nodes) + } + // 6 regular probes + 2 mesh probe + if len(dp.probes) != 8 { + t.Errorf("unexpected probes: %+v", dp.probes) + } } func TestRunDerpProbeNodePair(t *testing.T) { diff --git a/prober/histogram.go b/prober/histogram.go new file mode 100644 index 0000000000000..c544a5f79bb17 --- /dev/null +++ b/prober/histogram.go @@ -0,0 +1,49 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package prober + +import ( + "slices" + "sync" +) + +// histogram serves as an adapter to the Prometheus histogram datatype. +// The prober framework passes labels at custom metric collection time that +// it expects to be coupled with the returned metrics. See ProbeClass.Metrics +// and its call sites. Native prometheus histograms cannot be collected while +// injecting more labels. Instead we use this type and pass observations + +// collection labels to prometheus.MustNewConstHistogram() at prometheus +// metric collection time. +type histogram struct { + count uint64 + sum float64 + buckets []float64 + bucketedCounts map[float64]uint64 + mx sync.Mutex +} + +// newHistogram constructs a histogram that buckets data based on the given +// slice of upper bounds. +func newHistogram(buckets []float64) *histogram { + slices.Sort(buckets) + return &histogram{ + buckets: buckets, + bucketedCounts: make(map[float64]uint64, len(buckets)), + } +} + +func (h *histogram) add(v float64) { + h.mx.Lock() + defer h.mx.Unlock() + + h.count++ + h.sum += v + + for _, b := range h.buckets { + if v > b { + continue + } + h.bucketedCounts[b] += 1 + } +} diff --git a/prober/histogram_test.go b/prober/histogram_test.go new file mode 100644 index 0000000000000..dbb5eda6741a5 --- /dev/null +++ b/prober/histogram_test.go @@ -0,0 +1,29 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package prober + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestHistogram(t *testing.T) { + h := newHistogram([]float64{1, 2}) + h.add(0.5) + h.add(1) + h.add(1.5) + h.add(2) + h.add(2.5) + + if diff := cmp.Diff(h.count, uint64(5)); diff != "" { + t.Errorf("wrong count; (-got+want):%v", diff) + } + if diff := cmp.Diff(h.sum, 7.5); diff != "" { + t.Errorf("wrong sum; (-got+want):%v", diff) + } + if diff := cmp.Diff(h.bucketedCounts, map[float64]uint64{1: 2, 2: 4}); diff != "" { + t.Errorf("wrong bucketedCounts; (-got+want):%v", diff) + } +} diff --git a/prober/prober.go b/prober/prober.go index 2a43628bda908..1237611f4e0c9 100644 --- a/prober/prober.go +++ b/prober/prober.go @@ -7,6 +7,7 @@ package prober import ( + "cmp" "container/ring" "context" "encoding/json" @@ -20,6 +21,7 @@ import ( "time" "github.com/prometheus/client_golang/prometheus" + "tailscale.com/syncs" "tailscale.com/tsweb" ) @@ -44,6 +46,14 @@ type ProbeClass struct { // exposed by this probe class. Labels Labels + // Timeout is the maximum time the probe function is allowed to run before + // its context is cancelled. Defaults to 80% of the scheduling interval. + Timeout time.Duration + + // Concurrency is the maximum number of concurrent probe executions + // allowed for this probe class. Defaults to 1. + Concurrency int + // Metrics allows a probe class to export custom Metrics. Can be nil. Metrics func(prometheus.Labels) []prometheus.Metric } @@ -94,6 +104,9 @@ func newForTest(now func() time.Time, newTicker func(time.Duration) ticker) *Pro // Run executes probe class function every interval, and exports probe results under probeName. // +// If interval is negative, the probe will run continuously. If it encounters a failure while +// running continuously, it will pause for -1*interval and then retry. +// // Registering a probe under an already-registered name panics. func (p *Prober) Run(name string, interval time.Duration, labels Labels, pc ProbeClass) *Probe { p.mu.Lock() @@ -128,9 +141,12 @@ func newProbe(p *Prober, name string, interval time.Duration, l prometheus.Label cancel: cancel, stopped: make(chan struct{}), + runSema: syncs.NewSemaphore(cmp.Or(pc.Concurrency, 1)), + name: name, probeClass: pc, interval: interval, + timeout: cmp.Or(pc.Timeout, time.Duration(float64(interval)*0.8)), initialDelay: initialDelay(name, interval), successHist: ring.New(recentHistSize), latencyHist: ring.New(recentHistSize), @@ -223,11 +239,12 @@ type Probe struct { ctx context.Context cancel context.CancelFunc // run to initiate shutdown stopped chan struct{} // closed when shutdown is complete - runMu sync.Mutex // ensures only one probe runs at a time + runSema syncs.Semaphore // restricts concurrency per probe name string probeClass ProbeClass interval time.Duration + timeout time.Duration initialDelay time.Duration tick ticker @@ -256,6 +273,11 @@ type Probe struct { latencyHist *ring.Ring } +// IsContinuous indicates that this is a continuous probe. +func (p *Probe) IsContinuous() bool { + return p.interval < 0 +} + // Close shuts down the Probe and unregisters it from its Prober. // It is safe to Run a new probe of the same name after Close returns. func (p *Probe) Close() error { @@ -274,26 +296,43 @@ func (p *Probe) loop() { t := p.prober.newTicker(p.initialDelay) select { case <-t.Chan(): - p.run() case <-p.ctx.Done(): t.Stop() return } t.Stop() - } else { - p.run() } if p.prober.once { + p.run() return } + if p.IsContinuous() { + // Probe function is going to run continuously. + for { + p.run() + // Wait and then retry if probe fails. We use the inverse of the + // configured negative interval as our sleep period. + // TODO(percy):implement exponential backoff, possibly using logtail/backoff. + select { + case <-time.After(-1 * p.interval): + p.run() + case <-p.ctx.Done(): + return + } + } + } + p.tick = p.prober.newTicker(p.interval) defer p.tick.Stop() for { + // Run the probe in a new goroutine every tick. Default concurrency & timeout + // settings will ensure that only one probe is running at a time. + go p.run() + select { case <-p.tick.Chan(): - p.run() case <-p.ctx.Done(): return } @@ -307,8 +346,13 @@ func (p *Probe) loop() { // that the probe either succeeds or fails before the next cycle is scheduled to // start. func (p *Probe) run() (pi ProbeInfo, err error) { - p.runMu.Lock() - defer p.runMu.Unlock() + // Probes are scheduled each p.interval, so we don't wait longer than that. + semaCtx, cancel := context.WithTimeout(p.ctx, p.interval) + defer cancel() + if !p.runSema.AcquireContext(semaCtx) { + return pi, fmt.Errorf("probe %s: context cancelled", p.name) + } + defer p.runSema.Release() p.recordStart() defer func() { @@ -320,15 +364,21 @@ func (p *Probe) run() (pi ProbeInfo, err error) { if r := recover(); r != nil { log.Printf("probe %s panicked: %v", p.name, r) err = fmt.Errorf("panic: %v", r) - p.recordEnd(err) + p.recordEndLocked(err) } }() - timeout := time.Duration(float64(p.interval) * 0.8) - ctx, cancel := context.WithTimeout(p.ctx, timeout) - defer cancel() + ctx := p.ctx + if !p.IsContinuous() { + var cancel func() + ctx, cancel = context.WithTimeout(ctx, p.timeout) + defer cancel() + } err = p.probeClass.Probe(ctx) - p.recordEnd(err) + + p.mu.Lock() + defer p.mu.Unlock() + p.recordEndLocked(err) if err != nil { log.Printf("probe %s: %v", p.name, err) } @@ -342,10 +392,8 @@ func (p *Probe) recordStart() { p.mu.Unlock() } -func (p *Probe) recordEnd(err error) { +func (p *Probe) recordEndLocked(err error) { end := p.prober.now() - p.mu.Lock() - defer p.mu.Unlock() p.end = end p.succeeded = err == nil p.lastErr = err @@ -356,15 +404,29 @@ func (p *Probe) recordEnd(err error) { p.mSeconds.WithLabelValues("ok").Add(latency.Seconds()) p.latencyHist.Value = latency p.latencyHist = p.latencyHist.Next() + p.mAttempts.WithLabelValues("fail").Add(0) + p.mSeconds.WithLabelValues("fail").Add(0) } else { p.latency = 0 p.mAttempts.WithLabelValues("fail").Inc() p.mSeconds.WithLabelValues("fail").Add(latency.Seconds()) + p.mAttempts.WithLabelValues("ok").Add(0) + p.mSeconds.WithLabelValues("ok").Add(0) } p.successHist.Value = p.succeeded p.successHist = p.successHist.Next() } +// ProbeStatus indicates the status of a probe. +type ProbeStatus string + +const ( + ProbeStatusUnknown = "unknown" + ProbeStatusRunning = "running" + ProbeStatusFailed = "failed" + ProbeStatusSucceeded = "succeeded" +) + // ProbeInfo is a snapshot of the configuration and state of a Probe. type ProbeInfo struct { Name string @@ -374,7 +436,7 @@ type ProbeInfo struct { Start time.Time End time.Time Latency time.Duration - Result bool + Status ProbeStatus Error string RecentResults []bool RecentLatencies []time.Duration @@ -402,6 +464,10 @@ func (pb ProbeInfo) RecentMedianLatency() time.Duration { return pb.RecentLatencies[len(pb.RecentLatencies)/2] } +func (pb ProbeInfo) Continuous() bool { + return pb.Interval < 0 +} + // ProbeInfo returns the state of all probes. func (p *Prober) ProbeInfo() map[string]ProbeInfo { out := map[string]ProbeInfo{} @@ -429,9 +495,14 @@ func (probe *Probe) probeInfoLocked() ProbeInfo { Labels: probe.metricLabels, Start: probe.start, End: probe.end, - Result: probe.succeeded, } - if probe.lastErr != nil { + inf.Status = ProbeStatusUnknown + if probe.end.Before(probe.start) { + inf.Status = ProbeStatusRunning + } else if probe.succeeded { + inf.Status = ProbeStatusSucceeded + } else if probe.lastErr != nil { + inf.Status = ProbeStatusFailed inf.Error = probe.lastErr.Error() } if probe.latency > 0 { @@ -467,7 +538,7 @@ func (p *Prober) RunHandler(w http.ResponseWriter, r *http.Request) error { p.mu.Lock() probe, ok := p.probes[name] p.mu.Unlock() - if !ok { + if !ok || probe.IsContinuous() { return tsweb.Error(http.StatusNotFound, fmt.Sprintf("unknown probe %q", name), nil) } @@ -488,8 +559,8 @@ func (p *Prober) RunHandler(w http.ResponseWriter, r *http.Request) error { PreviousSuccessRatio: prevInfo.RecentSuccessRatio(), PreviousMedianLatency: prevInfo.RecentMedianLatency(), } - w.WriteHeader(respStatus) w.Header().Set("Content-Type", "application/json") + w.WriteHeader(respStatus) if err := json.NewEncoder(w).Encode(resp); err != nil { return tsweb.Error(http.StatusInternalServerError, "error encoding JSON response", err) } @@ -531,7 +602,8 @@ func (p *Probe) Collect(ch chan<- prometheus.Metric) { if !p.start.IsZero() { ch <- prometheus.MustNewConstMetric(p.mStartTime, prometheus.GaugeValue, float64(p.start.Unix())) } - if p.end.IsZero() { + // For periodic probes that haven't ended, don't collect probe metrics yet. + if p.end.IsZero() && !p.IsContinuous() { return } ch <- prometheus.MustNewConstMetric(p.mEndTime, prometheus.GaugeValue, float64(p.end.Unix())) diff --git a/prober/prober_test.go b/prober/prober_test.go index 742a914b24661..21c975a73a655 100644 --- a/prober/prober_test.go +++ b/prober/prober_test.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "net/http" "net/http/httptest" "strings" "sync" @@ -149,6 +150,74 @@ func TestProberTimingSpread(t *testing.T) { notCalled() } +func TestProberTimeout(t *testing.T) { + clk := newFakeTime() + p := newForTest(clk.Now, clk.NewTicker) + + var done sync.WaitGroup + done.Add(1) + pfunc := FuncProbe(func(ctx context.Context) error { + defer done.Done() + select { + case <-ctx.Done(): + return ctx.Err() + } + }) + pfunc.Timeout = time.Microsecond + probe := p.Run("foo", 30*time.Second, nil, pfunc) + waitActiveProbes(t, p, clk, 1) + done.Wait() + probe.mu.Lock() + info := probe.probeInfoLocked() + probe.mu.Unlock() + wantInfo := ProbeInfo{ + Name: "foo", + Interval: 30 * time.Second, + Labels: map[string]string{"class": "", "name": "foo"}, + Status: ProbeStatusFailed, + Error: "context deadline exceeded", + RecentResults: []bool{false}, + RecentLatencies: nil, + } + if diff := cmp.Diff(wantInfo, info, cmpopts.IgnoreFields(ProbeInfo{}, "Start", "End", "Latency")); diff != "" { + t.Fatalf("unexpected ProbeInfo (-want +got):\n%s", diff) + } + if got := info.Latency; got > time.Second { + t.Errorf("info.Latency = %v, want at most 1s", got) + } +} + +func TestProberConcurrency(t *testing.T) { + clk := newFakeTime() + p := newForTest(clk.Now, clk.NewTicker) + + var ran atomic.Int64 + stopProbe := make(chan struct{}) + pfunc := FuncProbe(func(ctx context.Context) error { + ran.Add(1) + <-stopProbe + return nil + }) + pfunc.Timeout = time.Hour + pfunc.Concurrency = 3 + p.Run("foo", time.Second, nil, pfunc) + waitActiveProbes(t, p, clk, 1) + + for range 50 { + clk.Advance(time.Second) + } + + if err := tstest.WaitFor(convergenceTimeout, func() error { + if got, want := ran.Load(), int64(3); got != want { + return fmt.Errorf("expected %d probes to run concurrently, got %d", want, got) + } + return nil + }); err != nil { + t.Fatal(err) + } + close(stopProbe) +} + func TestProberRun(t *testing.T) { clk := newFakeTime() p := newForTest(clk.Now, clk.NewTicker) @@ -316,7 +385,7 @@ func TestProberProbeInfo(t *testing.T) { Interval: probeInterval, Labels: map[string]string{"class": "", "name": "probe1"}, Latency: 500 * time.Millisecond, - Result: true, + Status: ProbeStatusSucceeded, RecentResults: []bool{true}, RecentLatencies: []time.Duration{500 * time.Millisecond}, }, @@ -324,6 +393,7 @@ func TestProberProbeInfo(t *testing.T) { Name: "probe2", Interval: probeInterval, Labels: map[string]string{"class": "", "name": "probe2"}, + Status: ProbeStatusFailed, Error: "error2", RecentResults: []bool{false}, RecentLatencies: nil, // no latency for failed probes @@ -349,7 +419,7 @@ func TestProbeInfoRecent(t *testing.T) { }{ { name: "no_runs", - wantProbeInfo: ProbeInfo{}, + wantProbeInfo: ProbeInfo{Status: ProbeStatusUnknown}, wantRecentSuccessRatio: 0, wantRecentMedianLatency: 0, }, @@ -358,7 +428,7 @@ func TestProbeInfoRecent(t *testing.T) { results: []probeResult{{latency: 100 * time.Millisecond, err: nil}}, wantProbeInfo: ProbeInfo{ Latency: 100 * time.Millisecond, - Result: true, + Status: ProbeStatusSucceeded, RecentResults: []bool{true}, RecentLatencies: []time.Duration{100 * time.Millisecond}, }, @@ -369,7 +439,7 @@ func TestProbeInfoRecent(t *testing.T) { name: "single_failure", results: []probeResult{{latency: 100 * time.Millisecond, err: errors.New("error123")}}, wantProbeInfo: ProbeInfo{ - Result: false, + Status: ProbeStatusFailed, RecentResults: []bool{false}, RecentLatencies: nil, Error: "error123", @@ -390,7 +460,7 @@ func TestProbeInfoRecent(t *testing.T) { {latency: 80 * time.Millisecond, err: nil}, }, wantProbeInfo: ProbeInfo{ - Result: true, + Status: ProbeStatusSucceeded, Latency: 80 * time.Millisecond, RecentResults: []bool{false, true, true, false, true, true, false, true}, RecentLatencies: []time.Duration{ @@ -420,7 +490,7 @@ func TestProbeInfoRecent(t *testing.T) { {latency: 110 * time.Millisecond, err: nil}, }, wantProbeInfo: ProbeInfo{ - Result: true, + Status: ProbeStatusSucceeded, Latency: 110 * time.Millisecond, RecentResults: []bool{true, true, true, true, true, true, true, true, true, true}, RecentLatencies: []time.Duration{ @@ -449,9 +519,11 @@ func TestProbeInfoRecent(t *testing.T) { for _, r := range tt.results { probe.recordStart() clk.Advance(r.latency) - probe.recordEnd(r.err) + probe.recordEndLocked(r.err) } + probe.mu.Lock() info := probe.probeInfoLocked() + probe.mu.Unlock() if diff := cmp.Diff(tt.wantProbeInfo, info, cmpopts.IgnoreFields(ProbeInfo{}, "Start", "End", "Interval")); diff != "" { t.Fatalf("unexpected ProbeInfo (-want +got):\n%s", diff) } @@ -483,7 +555,7 @@ func TestProberRunHandler(t *testing.T) { ProbeInfo: ProbeInfo{ Name: "success", Interval: probeInterval, - Result: true, + Status: ProbeStatusSucceeded, RecentResults: []bool{true, true}, }, PreviousSuccessRatio: 1, @@ -498,7 +570,7 @@ func TestProberRunHandler(t *testing.T) { ProbeInfo: ProbeInfo{ Name: "failure", Interval: probeInterval, - Result: false, + Status: ProbeStatusFailed, Error: "error123", RecentResults: []bool{false, false}, }, @@ -515,27 +587,48 @@ func TestProberRunHandler(t *testing.T) { defer probe.Close() <-probe.stopped // wait for the first run. - w := httptest.NewRecorder() + mux := http.NewServeMux() + server := httptest.NewServer(mux) + defer server.Close() + + mux.Handle("/prober/run/", tsweb.StdHandler(tsweb.ReturnHandlerFunc(p.RunHandler), tsweb.HandlerOptions{})) + + req, err := http.NewRequest("GET", server.URL+"/prober/run/?name="+tt.name, nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } - req := httptest.NewRequest("GET", "/prober/run/?name="+tt.name, nil) if reqJSON { req.Header.Set("Accept", "application/json") } - tsweb.StdHandler(tsweb.ReturnHandlerFunc(p.RunHandler), tsweb.HandlerOptions{}).ServeHTTP(w, req) - if w.Result().StatusCode != tt.wantResponseCode { - t.Errorf("unexpected response code: got %d, want %d", w.Code, tt.wantResponseCode) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("failed to make request: %v", err) + } + + if resp.StatusCode != tt.wantResponseCode { + t.Errorf("unexpected response code: got %d, want %d", resp.StatusCode, tt.wantResponseCode) } if reqJSON { + if resp.Header.Get("Content-Type") != "application/json" { + t.Errorf("unexpected content type: got %q, want application/json", resp.Header.Get("Content-Type")) + } var gotJSON RunHandlerResponse - if err := json.Unmarshal(w.Body.Bytes(), &gotJSON); err != nil { - t.Fatalf("failed to unmarshal JSON response: %v; body: %s", err, w.Body.String()) + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + + if err := json.Unmarshal(body, &gotJSON); err != nil { + t.Fatalf("failed to unmarshal JSON response: %v; body: %s", err, body) } if diff := cmp.Diff(tt.wantJSONResponse, gotJSON, cmpopts.IgnoreFields(ProbeInfo{}, "Start", "End", "Labels", "RecentLatencies")); diff != "" { t.Errorf("unexpected JSON response (-want +got):\n%s", diff) } } else { - body, _ := io.ReadAll(w.Result().Body) + body, _ := io.ReadAll(resp.Body) if !strings.Contains(string(body), tt.wantPlaintextResponse) { t.Errorf("unexpected response body: got %q, want to contain %q", body, tt.wantPlaintextResponse) } diff --git a/prober/status.go b/prober/status.go index aa9ef99d05d2c..20fbeec58a77e 100644 --- a/prober/status.go +++ b/prober/status.go @@ -62,8 +62,9 @@ func (p *Prober) StatusHandler(opts ...statusHandlerOpt) tsweb.ReturnHandlerFunc return func(w http.ResponseWriter, r *http.Request) error { type probeStatus struct { ProbeInfo - TimeSinceLast time.Duration - Links map[string]template.URL + TimeSinceLastStart time.Duration + TimeSinceLastEnd time.Duration + Links map[string]template.URL } vars := struct { Title string @@ -81,12 +82,15 @@ func (p *Prober) StatusHandler(opts ...statusHandlerOpt) tsweb.ReturnHandlerFunc for name, info := range p.ProbeInfo() { vars.TotalProbes++ - if !info.Result { + if info.Error != "" { vars.UnhealthyProbes++ } s := probeStatus{ProbeInfo: info} + if !info.Start.IsZero() { + s.TimeSinceLastStart = time.Since(info.Start).Truncate(time.Second) + } if !info.End.IsZero() { - s.TimeSinceLast = time.Since(info.End).Truncate(time.Second) + s.TimeSinceLastEnd = time.Since(info.End).Truncate(time.Second) } for textTpl, urlTpl := range params.probeLinks { text, err := renderTemplate(textTpl, info) diff --git a/prober/status.html b/prober/status.html index ff0f06c13fe62..d26588da19431 100644 --- a/prober/status.html +++ b/prober/status.html @@ -73,8 +73,9 @@

Probes:

Name Probe Class & Labels Interval - Last Attempt - Success + Last Finished + Last Started + Status Latency Last Error @@ -85,9 +86,11 @@

Probes:

{{$name}} {{range $text, $url := $probeInfo.Links}}
- + {{if not $probeInfo.Continuous}} + + {{end}} {{end}} {{$probeInfo.Class}}
@@ -97,28 +100,48 @@

Probes:

{{end}} - {{$probeInfo.Interval}} - - {{if $probeInfo.TimeSinceLast}} - {{$probeInfo.TimeSinceLast.String}} ago
+ + {{if $probeInfo.Continuous}} + Continuous + {{else}} + {{$probeInfo.Interval}} + {{end}} + + + {{if $probeInfo.TimeSinceLastEnd}} + {{$probeInfo.TimeSinceLastEnd.String}} ago
{{$probeInfo.End.Format "2006-01-02T15:04:05Z07:00"}} {{else}} Never {{end}} + + {{if $probeInfo.TimeSinceLastStart}} + {{$probeInfo.TimeSinceLastStart.String}} ago
+ {{$probeInfo.Start.Format "2006-01-02T15:04:05Z07:00"}} + {{else}} + Never + {{end}} + - {{if $probeInfo.Result}} - {{$probeInfo.Result}} + {{if $probeInfo.Error}} + {{$probeInfo.Status}} {{else}} - {{$probeInfo.Result}} + {{$probeInfo.Status}} {{end}}
-
Recent: {{$probeInfo.RecentResults}}
-
Mean: {{$probeInfo.RecentSuccessRatio}}
+ {{if not $probeInfo.Continuous}} +
Recent: {{$probeInfo.RecentResults}}
+
Mean: {{$probeInfo.RecentSuccessRatio}}
+ {{end}} - {{$probeInfo.Latency.String}} -
Recent: {{$probeInfo.RecentLatencies}}
-
Median: {{$probeInfo.RecentMedianLatency}}
+ {{if $probeInfo.Continuous}} + n/a + {{else}} + {{$probeInfo.Latency.String}} +
Recent: {{$probeInfo.RecentLatencies}}
+
Median: {{$probeInfo.RecentMedianLatency}}
+ {{end}} {{$probeInfo.Error}} diff --git a/prober/tls.go b/prober/tls.go index 787df05c2c3a9..4fb4aa9c6becf 100644 --- a/prober/tls.go +++ b/prober/tls.go @@ -4,7 +4,6 @@ package prober import ( - "bytes" "context" "crypto/tls" "crypto/x509" @@ -15,12 +14,14 @@ import ( "net/netip" "time" - "github.com/pkg/errors" - "golang.org/x/crypto/ocsp" "tailscale.com/util/multierr" ) const expiresSoon = 7 * 24 * time.Hour // 7 days from now +// Let’s Encrypt promises to issue certificates with CRL servers after 2025-05-07: +// https://letsencrypt.org/2024/12/05/ending-ocsp/ +// https://github.com/tailscale/tailscale/issues/15912 +const letsEncryptStartedStaplingCRL int64 = 1746576000 // 2025-05-07 00:00:00 UTC // TLS returns a Probe that healthchecks a TLS endpoint. // @@ -106,50 +107,55 @@ func validateConnState(ctx context.Context, cs *tls.ConnectionState) (returnerr } } - if len(leafCert.OCSPServer) == 0 { - errs = append(errs, fmt.Errorf("no OCSP server presented in leaf cert for %v", leafCert.Subject)) + if len(leafCert.CRLDistributionPoints) == 0 { + if leafCert.NotBefore.Before(time.Unix(letsEncryptStartedStaplingCRL, 0)) { + // Certificate might not have a CRL. + return + } + errs = append(errs, fmt.Errorf("no CRL server presented in leaf cert for %v", leafCert.Subject)) return } - ocspResp, err := getOCSPResponse(ctx, leafCert.OCSPServer[0], leafCert, issuerCert) + err := checkCertCRL(ctx, leafCert.CRLDistributionPoints[0], leafCert, issuerCert) if err != nil { - errs = append(errs, errors.Wrapf(err, "OCSP verification failed for %v", leafCert.Subject)) - return - } - - if ocspResp.Status == ocsp.Unknown { - errs = append(errs, fmt.Errorf("unknown OCSP verification status for %v", leafCert.Subject)) - } - - if ocspResp.Status == ocsp.Revoked { - errs = append(errs, fmt.Errorf("cert for %v has been revoked on %v, reason: %v", leafCert.Subject, ocspResp.RevokedAt, ocspResp.RevocationReason)) + errs = append(errs, fmt.Errorf("CRL verification failed for %v: %w", leafCert.Subject, err)) } return } -func getOCSPResponse(ctx context.Context, ocspServer string, leafCert, issuerCert *x509.Certificate) (*ocsp.Response, error) { - reqb, err := ocsp.CreateRequest(leafCert, issuerCert, nil) - if err != nil { - return nil, errors.Wrap(err, "could not create OCSP request") - } - hreq, err := http.NewRequestWithContext(ctx, "POST", ocspServer, bytes.NewReader(reqb)) +func checkCertCRL(ctx context.Context, crlURL string, leafCert, issuerCert *x509.Certificate) error { + hreq, err := http.NewRequestWithContext(ctx, "GET", crlURL, nil) if err != nil { - return nil, errors.Wrap(err, "could not create OCSP POST request") + return fmt.Errorf("could not create CRL GET request: %w", err) } - hreq.Header.Add("Content-Type", "application/ocsp-request") - hreq.Header.Add("Accept", "application/ocsp-response") hresp, err := http.DefaultClient.Do(hreq) if err != nil { - return nil, errors.Wrap(err, "OCSP request failed") + return fmt.Errorf("CRL request failed: %w", err) } defer hresp.Body.Close() if hresp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("ocsp: non-200 status code from OCSP server: %s", hresp.Status) + return fmt.Errorf("crl: non-200 status code from CRL server: %s", hresp.Status) } lr := io.LimitReader(hresp.Body, 10<<20) // 10MB - ocspB, err := io.ReadAll(lr) + crlB, err := io.ReadAll(lr) if err != nil { - return nil, err + return err } - return ocsp.ParseResponse(ocspB, issuerCert) + + crl, err := x509.ParseRevocationList(crlB) + if err != nil { + return fmt.Errorf("could not parse CRL: %w", err) + } + + if err := crl.CheckSignatureFrom(issuerCert); err != nil { + return fmt.Errorf("could not verify CRL signature: %w", err) + } + + for _, revoked := range crl.RevokedCertificateEntries { + if revoked.SerialNumber.Cmp(leafCert.SerialNumber) == 0 { + return fmt.Errorf("cert for %v has been revoked on %v, reason: %v", leafCert.Subject, revoked.RevocationTime, revoked.ReasonCode) + } + } + + return nil } diff --git a/prober/tls_test.go b/prober/tls_test.go index 5bfb739db928d..9ba17f79da911 100644 --- a/prober/tls_test.go +++ b/prober/tls_test.go @@ -6,7 +6,6 @@ package prober import ( "bytes" "context" - "crypto" "crypto/rand" "crypto/rsa" "crypto/tls" @@ -20,8 +19,6 @@ import ( "strings" "testing" "time" - - "golang.org/x/crypto/ocsp" ) var leafCert = x509.Certificate{ @@ -118,11 +115,6 @@ func TestCertExpiration(t *testing.T) { }, "one of the certs expires in", }, - { - "valid duration but no OCSP", - func() *x509.Certificate { return &leafCert }, - "no OCSP server presented in leaf cert for CN=tlsprobe.test", - }, } { t.Run(tt.name, func(t *testing.T) { cs := &tls.ConnectionState{PeerCertificates: []*x509.Certificate{tt.cert()}} @@ -134,100 +126,157 @@ func TestCertExpiration(t *testing.T) { } } -type ocspServer struct { - issuer *x509.Certificate - responderCert *x509.Certificate - template *ocsp.Response - priv crypto.Signer +type CRLServer struct { + crlBytes []byte } -func (s *ocspServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if s.template == nil { +func (s *CRLServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if s.crlBytes == nil { w.WriteHeader(http.StatusInternalServerError) return } - resp, err := ocsp.CreateResponse(s.issuer, s.responderCert, *s.template, s.priv) - if err != nil { - panic(err) - } - w.Write(resp) + w.Header().Set("Content-Type", "application/pkix-crl") + w.WriteHeader(http.StatusOK) + w.Write(s.crlBytes) } -func TestOCSP(t *testing.T) { - issuerKey, err := rsa.GenerateKey(rand.Reader, 4096) +func TestCRL(t *testing.T) { + // Generate CA key and self-signed CA cert + caKey, err := rsa.GenerateKey(rand.Reader, 4096) if err != nil { t.Fatal(err) } - issuerBytes, err := x509.CreateCertificate(rand.Reader, &issuerCertTpl, &issuerCertTpl, &issuerKey.PublicKey, issuerKey) + caTpl := issuerCertTpl + caTpl.BasicConstraintsValid = true + caTpl.IsCA = true + caTpl.KeyUsage = x509.KeyUsageCertSign | x509.KeyUsageCRLSign | x509.KeyUsageDigitalSignature + caBytes, err := x509.CreateCertificate(rand.Reader, &caTpl, &caTpl, &caKey.PublicKey, caKey) if err != nil { t.Fatal(err) } - issuerCert, err := x509.ParseCertificate(issuerBytes) + caCert, err := x509.ParseCertificate(caBytes) if err != nil { t.Fatal(err) } - responderKey, err := rsa.GenerateKey(rand.Reader, 4096) + // Issue a leaf cert signed by the CA + leaf := leafCert + leaf.SerialNumber = big.NewInt(20001) + leaf.Issuer = caCert.Subject + leafKey, err := rsa.GenerateKey(rand.Reader, 4096) if err != nil { t.Fatal(err) } - // issuer cert template re-used here, but with a different key - responderBytes, err := x509.CreateCertificate(rand.Reader, &issuerCertTpl, &issuerCertTpl, &responderKey.PublicKey, responderKey) + leafBytes, err := x509.CreateCertificate(rand.Reader, &leaf, caCert, &leafKey.PublicKey, caKey) if err != nil { t.Fatal(err) } - responderCert, err := x509.ParseCertificate(responderBytes) + leafCertParsed, err := x509.ParseCertificate(leafBytes) if err != nil { t.Fatal(err) } - handler := &ocspServer{ - issuer: issuerCert, - responderCert: responderCert, - priv: issuerKey, + // Catch no CRL set by Let's Encrypt date. + noCRLCert := leafCert + noCRLCert.SerialNumber = big.NewInt(20002) + noCRLCert.CRLDistributionPoints = []string{} + noCRLCert.NotBefore = time.Unix(letsEncryptStartedStaplingCRL, 0).Add(-48 * time.Hour) + noCRLCert.Issuer = caCert.Subject + noCRLCertKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + t.Fatal(err) } - srv := httptest.NewUnstartedServer(handler) - srv.Start() - defer srv.Close() - - cert := leafCert - cert.OCSPServer = append(cert.OCSPServer, srv.URL) - key, err := rsa.GenerateKey(rand.Reader, 4096) + noCRLStapledBytes, err := x509.CreateCertificate(rand.Reader, &noCRLCert, caCert, &noCRLCertKey.PublicKey, caKey) + if err != nil { + t.Fatal(err) + } + noCRLStapledParsed, err := x509.ParseCertificate(noCRLStapledBytes) if err != nil { t.Fatal(err) } - certBytes, err := x509.CreateCertificate(rand.Reader, &cert, issuerCert, &key.PublicKey, issuerKey) + + crlServer := &CRLServer{crlBytes: nil} + srv := httptest.NewServer(crlServer) + defer srv.Close() + + // Create a CRL that revokes the leaf cert using x509.CreateRevocationList + now := time.Now() + revoked := []x509.RevocationListEntry{{ + SerialNumber: leaf.SerialNumber, + RevocationTime: now, + ReasonCode: 1, // Key compromise + }} + rl := x509.RevocationList{ + SignatureAlgorithm: caCert.SignatureAlgorithm, + Issuer: caCert.Subject, + ThisUpdate: now, + NextUpdate: now.Add(24 * time.Hour), + RevokedCertificateEntries: revoked, + Number: big.NewInt(1), + } + rlBytes, err := x509.CreateRevocationList(rand.Reader, &rl, caCert, caKey) if err != nil { t.Fatal(err) } - parsed, err := x509.ParseCertificate(certBytes) + + emptyRlBytes, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{Number: big.NewInt(2)}, caCert, caKey) if err != nil { t.Fatal(err) } for _, tt := range []struct { - name string - resp *ocsp.Response - wantErr string + name string + cert *x509.Certificate + crlBytes []byte + wantErr string }{ - {"good response", &ocsp.Response{Status: ocsp.Good}, ""}, - {"unknown response", &ocsp.Response{Status: ocsp.Unknown}, "unknown OCSP verification status for CN=tlsprobe.test"}, - {"revoked response", &ocsp.Response{Status: ocsp.Revoked}, "cert for CN=tlsprobe.test has been revoked"}, - {"error 500 from ocsp", nil, "non-200 status code from OCSP"}, + { + "ValidCert", + leafCertParsed, + emptyRlBytes, + "", + }, + { + "RevokedCert", + leafCertParsed, + rlBytes, + "has been revoked on", + }, + { + "EmptyCRL", + leafCertParsed, + emptyRlBytes, + "", + }, + { + "NoCRL", + leafCertParsed, + nil, + "no CRL server presented in leaf cert for", + }, + { + "NotBeforeCRLStaplingDate", + noCRLStapledParsed, + nil, + "", + }, } { t.Run(tt.name, func(t *testing.T) { - handler.template = tt.resp - if handler.template != nil { - handler.template.SerialNumber = big.NewInt(1337) + cs := &tls.ConnectionState{PeerCertificates: []*x509.Certificate{tt.cert, caCert}} + if tt.crlBytes != nil { + crlServer.crlBytes = tt.crlBytes + tt.cert.CRLDistributionPoints = []string{srv.URL} + } else { + crlServer.crlBytes = nil + tt.cert.CRLDistributionPoints = []string{} } - cs := &tls.ConnectionState{PeerCertificates: []*x509.Certificate{parsed, issuerCert}} err := validateConnState(context.Background(), cs) if err == nil && tt.wantErr == "" { return } - if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + if err == nil || tt.wantErr == "" || !strings.Contains(err.Error(), tt.wantErr) { t.Errorf("unexpected error %q; want %q", err, tt.wantErr) } }) diff --git a/prober/tun_darwin.go b/prober/tun_darwin.go new file mode 100644 index 0000000000000..0ef22e41e4076 --- /dev/null +++ b/prober/tun_darwin.go @@ -0,0 +1,35 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin + +package prober + +import ( + "fmt" + "net/netip" + "os/exec" + + "go4.org/netipx" +) + +const tunName = "utun" + +func configureTUN(addr netip.Prefix, tunname string) error { + cmd := exec.Command("ifconfig", tunname, "inet", addr.String(), addr.Addr().String()) + res, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to add address: %w (%s)", err, string(res)) + } + + net := netipx.PrefixIPNet(addr) + nip := net.IP.Mask(net.Mask) + nstr := fmt.Sprintf("%v/%d", nip, addr.Bits()) + cmd = exec.Command("route", "-q", "-n", "add", "-inet", nstr, "-iface", addr.Addr().String()) + res, err = cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to add route: %w (%s)", err, string(res)) + } + + return nil +} diff --git a/prober/tun_default.go b/prober/tun_default.go new file mode 100644 index 0000000000000..93a5b07fd442a --- /dev/null +++ b/prober/tun_default.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux && !darwin + +package prober + +import ( + "fmt" + "net/netip" + "runtime" +) + +const tunName = "unused" + +func configureTUN(addr netip.Prefix, tunname string) error { + return fmt.Errorf("not implemented on " + runtime.GOOS) +} diff --git a/prober/tun_linux.go b/prober/tun_linux.go new file mode 100644 index 0000000000000..52a31efbbf66a --- /dev/null +++ b/prober/tun_linux.go @@ -0,0 +1,36 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package prober + +import ( + "fmt" + "net/netip" + + "github.com/tailscale/netlink" + "go4.org/netipx" +) + +const tunName = "derpprobe" + +func configureTUN(addr netip.Prefix, tunname string) error { + link, err := netlink.LinkByName(tunname) + if err != nil { + return fmt.Errorf("failed to look up link %q: %w", tunname, err) + } + + // We need to bring the TUN device up before assigning an address. This + // allows the OS to automatically create a route for it. Otherwise, we'd + // have to manually create the route. + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to bring tun %q up: %w", tunname, err) + } + + if err := netlink.AddrReplace(link, &netlink.Addr{IPNet: netipx.PrefixIPNet(addr)}); err != nil { + return fmt.Errorf("failed to add address: %w", err) + } + + return nil +} diff --git a/release/dist/cli/cli.go b/release/dist/cli/cli.go index 9b861ddd72dc6..f4480cbdbdfa4 100644 --- a/release/dist/cli/cli.go +++ b/release/dist/cli/cli.go @@ -65,6 +65,7 @@ func CLI(getTargets func() ([]dist.Target, error)) *ffcli.Command { fs.StringVar(&buildArgs.manifest, "manifest", "", "manifest file to write") fs.BoolVar(&buildArgs.verbose, "verbose", false, "verbose logging") fs.StringVar(&buildArgs.webClientRoot, "web-client-root", "", "path to root of web client source to build") + fs.StringVar(&buildArgs.outPath, "out", "", "path to write output artifacts (defaults to '$PWD/dist' if not set)") return fs })(), LongHelp: strings.TrimSpace(` @@ -156,6 +157,7 @@ var buildArgs struct { manifest string verbose bool webClientRoot string + outPath string } func runBuild(ctx context.Context, filters []string, targets []dist.Target) error { @@ -172,7 +174,11 @@ func runBuild(ctx context.Context, filters []string, targets []dist.Target) erro if err != nil { return fmt.Errorf("getting working directory: %w", err) } - b, err := dist.NewBuild(wd, filepath.Join(wd, "dist")) + outPath := filepath.Join(wd, "dist") + if buildArgs.outPath != "" { + outPath = buildArgs.outPath + } + b, err := dist.NewBuild(wd, outPath) if err != nil { return fmt.Errorf("creating build context: %w", err) } diff --git a/release/dist/qnap/files/scripts/Dockerfile.qpkg b/release/dist/qnap/files/scripts/Dockerfile.qpkg index 135d5d20fc94c..1f4c2406d7642 100644 --- a/release/dist/qnap/files/scripts/Dockerfile.qpkg +++ b/release/dist/qnap/files/scripts/Dockerfile.qpkg @@ -1,9 +1,21 @@ -FROM ubuntu:20.04 +FROM ubuntu:24.04 RUN apt-get update -y && \ apt-get install -y --no-install-recommends \ git-core \ - ca-certificates -RUN git clone https://github.com/qnap-dev/QDK.git + ca-certificates \ + apt-transport-https \ + gnupg \ + curl \ + patch + +# Install QNAP QDK (force a specific version to pick up updates) +RUN git clone https://github.com/tailscale/QDK.git && cd /QDK && git reset --hard 9a31a67387c583d19a81a378dcf7c25e2abe231d RUN cd /QDK && ./InstallToUbuntu.sh install -ENV PATH="/usr/share/QDK/bin:${PATH}" \ No newline at end of file +ENV PATH="/usr/share/QDK/bin:${PATH}" + +# Install Google Cloud PKCS11 module +RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg +RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list +RUN apt-get update -y && apt-get install -y --no-install-recommends google-cloud-cli libengine-pkcs11-openssl +RUN curl -L https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/pkcs11-v1.6/libkmsp11-1.6-linux-amd64.tar.gz | tar xz diff --git a/release/dist/qnap/files/scripts/sign-qpkg.sh b/release/dist/qnap/files/scripts/sign-qpkg.sh new file mode 100755 index 0000000000000..5629672f85e95 --- /dev/null +++ b/release/dist/qnap/files/scripts/sign-qpkg.sh @@ -0,0 +1,40 @@ +#! /usr/bin/env bash +set -xeu + +mkdir -p "$HOME/.config/gcloud" +echo "$GCLOUD_CREDENTIALS_BASE64" | base64 --decode > /root/.config/gcloud/application_default_credentials.json +gcloud config set project "$GCLOUD_PROJECT" + +echo "--- +tokens: + - key_ring: \"$GCLOUD_KEYRING\" +log_directory: "/tmp/kmsp11" +" > pkcs11-config.yaml +chmod 0600 pkcs11-config.yaml + +export KMS_PKCS11_CONFIG=`readlink -f pkcs11-config.yaml` +export PKCS11_MODULE_PATH=/libkmsp11-1.6-linux-amd64/libkmsp11.so + +# Verify signature of pkcs11 module +# See https://github.com/GoogleCloudPlatform/kms-integrations/blob/master/kmsp11/docs/user_guide.md#downloading-and-verifying-the-library +echo "-----BEGIN PUBLIC KEY----- +MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAEtfLbXkHUVc9oUPTNyaEK3hIwmuGRoTtd +6zDhwqjJuYaMwNd1aaFQLMawTwZgR0Xn27ymVWtqJHBe0FU9BPIQ+SFmKw+9jSwu +/FuqbJnLmTnWMJ1jRCtyHNZawvv2wbiB +-----END PUBLIC KEY-----" > pkcs11-release-signing-key.pem +openssl dgst -sha384 -verify pkcs11-release-signing-key.pem -signature "$PKCS11_MODULE_PATH.sig" "$PKCS11_MODULE_PATH" + +echo "$QNAP_SIGNING_CERT_BASE64" | base64 --decode > cert.crt + +openssl cms \ + -sign \ + -binary \ + -nodetach \ + -engine pkcs11 \ + -keyform engine \ + -inkey "pkcs11:object=$QNAP_SIGNING_KEY_NAME" \ + -keyopt rsa_padding_mode:pss \ + -keyopt rsa_pss_saltlen:digest \ + -signer cert.crt \ + -in "$1" \ + -out - diff --git a/release/dist/qnap/pkgs.go b/release/dist/qnap/pkgs.go index 9df649ddbfed4..7dc3b94958639 100644 --- a/release/dist/qnap/pkgs.go +++ b/release/dist/qnap/pkgs.go @@ -27,8 +27,11 @@ type target struct { } type signer struct { - privateKeyPath string - certificatePath string + gcloudCredentialsBase64 string + gcloudProject string + gcloudKeyring string + keyName string + certificateBase64 string } func (t *target) String() string { @@ -66,7 +69,8 @@ func (t *target) buildQPKG(b *dist.Build, qnapBuilds *qnapBuilds, inner *innerPk filename := fmt.Sprintf("Tailscale_%s-%s_%s.qpkg", b.Version.Short, qnapTag, t.arch) filePath := filepath.Join(b.Out, filename) - cmd := b.Command(b.Repo, "docker", "run", "--rm", + args := []string{"run", "--rm", + "--network=host", "-e", fmt.Sprintf("ARCH=%s", t.arch), "-e", fmt.Sprintf("TSTAG=%s", b.Version.Short), "-e", fmt.Sprintf("QNAPTAG=%s", qnapTag), @@ -76,10 +80,28 @@ func (t *target) buildQPKG(b *dist.Build, qnapBuilds *qnapBuilds, inner *innerPk "-v", fmt.Sprintf("%s:/Tailscale", filepath.Join(qnapBuilds.tmpDir, "files/Tailscale")), "-v", fmt.Sprintf("%s:/build-qpkg.sh", filepath.Join(qnapBuilds.tmpDir, "files/scripts/build-qpkg.sh")), "-v", fmt.Sprintf("%s:/out", b.Out), + } + + if t.signer != nil { + log.Println("Will sign with Google Cloud HSM") + args = append(args, + "-e", fmt.Sprintf("GCLOUD_CREDENTIALS_BASE64=%s", t.signer.gcloudCredentialsBase64), + "-e", fmt.Sprintf("GCLOUD_PROJECT=%s", t.signer.gcloudProject), + "-e", fmt.Sprintf("GCLOUD_KEYRING=%s", t.signer.gcloudKeyring), + "-e", fmt.Sprintf("QNAP_SIGNING_KEY_NAME=%s", t.signer.keyName), + "-e", fmt.Sprintf("QNAP_SIGNING_CERT_BASE64=%s", t.signer.certificateBase64), + "-e", fmt.Sprintf("QNAP_SIGNING_SCRIPT=%s", "/sign-qpkg.sh"), + "-v", fmt.Sprintf("%s:/sign-qpkg.sh", filepath.Join(qnapBuilds.tmpDir, "files/scripts/sign-qpkg.sh")), + ) + } + + args = append(args, "build.tailscale.io/qdk:latest", "/build-qpkg.sh", ) + cmd := b.Command(b.Repo, "docker", args...) + // dist.Build runs target builds in parallel goroutines by default. // For QNAP, this is an issue because the underlaying qbuild builder will // create tmp directories in the shared docker image that end up conflicting @@ -176,32 +198,6 @@ func newQNAPBuilds(b *dist.Build, signer *signer) (*qnapBuilds, error) { return nil, err } - if signer != nil { - log.Print("Setting up qnap signing files") - - key, err := os.ReadFile(signer.privateKeyPath) - if err != nil { - return nil, err - } - cert, err := os.ReadFile(signer.certificatePath) - if err != nil { - return nil, err - } - - // QNAP's qbuild command expects key and cert files to be in the root - // of the project directory (in our case release/dist/qnap/Tailscale). - // So here, we copy the key and cert over to the project folder for the - // duration of qnap package building and then delete them on close. - - keyPath := filepath.Join(m.tmpDir, "files/Tailscale/private_key") - if err := os.WriteFile(keyPath, key, 0400); err != nil { - return nil, err - } - certPath := filepath.Join(m.tmpDir, "files/Tailscale/certificate") - if err := os.WriteFile(certPath, cert, 0400); err != nil { - return nil, err - } - } return m, nil } diff --git a/release/dist/qnap/targets.go b/release/dist/qnap/targets.go index a069dd6238513..1c1818a700cd1 100644 --- a/release/dist/qnap/targets.go +++ b/release/dist/qnap/targets.go @@ -3,16 +3,31 @@ package qnap -import "tailscale.com/release/dist" +import ( + "slices" + + "tailscale.com/release/dist" +) // Targets defines the dist.Targets for QNAP devices. // -// If privateKeyPath and certificatePath are both provided non-empty, -// these targets will be signed for QNAP app store release with built. -func Targets(privateKeyPath, certificatePath string) []dist.Target { +// If all parameters are provided non-empty, then the build will be signed using +// a Google Cloud hosted key. +// +// gcloudCredentialsBase64 is the JSON credential for connecting to Google Cloud, base64 encoded. +// gcloudKeyring is the full path to the Google Cloud keyring containing the signing key. +// keyName is the name of the key. +// certificateBase64 is the PEM certificate to use in the signature, base64 encoded. +func Targets(gcloudCredentialsBase64, gcloudProject, gcloudKeyring, keyName, certificateBase64 string) []dist.Target { var signerInfo *signer - if privateKeyPath != "" && certificatePath != "" { - signerInfo = &signer{privateKeyPath, certificatePath} + if !slices.Contains([]string{gcloudCredentialsBase64, gcloudProject, gcloudKeyring, keyName, certificateBase64}, "") { + signerInfo = &signer{ + gcloudCredentialsBase64: gcloudCredentialsBase64, + gcloudProject: gcloudProject, + gcloudKeyring: gcloudKeyring, + keyName: keyName, + certificateBase64: certificateBase64, + } } return []dist.Target{ &target{ diff --git a/release/dist/synology/pkgs.go b/release/dist/synology/pkgs.go index 7802470e167fe..ab89dbee3e19f 100644 --- a/release/dist/synology/pkgs.go +++ b/release/dist/synology/pkgs.go @@ -155,8 +155,22 @@ func (t *target) mkInfo(b *dist.Build, uncompressedSz int64) []byte { f("os_min_ver", "6.0.1-7445") f("os_max_ver", "7.0-40000") case 7: - f("os_min_ver", "7.0-40000") - f("os_max_ver", "") + if t.packageCenter { + switch t.dsmMinorVersion { + case 0: + f("os_min_ver", "7.0-40000") + f("os_max_ver", "7.2-60000") + case 2: + f("os_min_ver", "7.2-60000") + default: + panic(fmt.Sprintf("unsupported DSM major.minor version %s", t.dsmVersionString())) + } + } else { + // We do not clamp the os_max_ver currently for non-package center builds as + // the binaries for 7.0 and 7.2 are identical. + f("os_min_ver", "7.0-40000") + f("os_max_ver", "") + } default: panic(fmt.Sprintf("unsupported DSM major version %d", t.dsmMajorVersion)) } diff --git a/safesocket/safesocket.go b/safesocket/safesocket.go index 991fddf5fc347..721b694dcf86c 100644 --- a/safesocket/safesocket.go +++ b/safesocket/safesocket.go @@ -61,7 +61,11 @@ func ConnectContext(ctx context.Context, path string) (net.Conn, error) { if ctx.Err() != nil { return nil, ctx.Err() } - time.Sleep(250 * time.Millisecond) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(250 * time.Millisecond): + } continue } return c, err diff --git a/safesocket/safesocket_darwin.go b/safesocket/safesocket_darwin.go index 62e6f7e6de340..bb014d1ec321d 100644 --- a/safesocket/safesocket_darwin.go +++ b/safesocket/safesocket_darwin.go @@ -6,8 +6,11 @@ package safesocket import ( "bufio" "bytes" + crand "crypto/rand" "errors" "fmt" + "io/fs" + "log" "net" "os" "os/exec" @@ -17,6 +20,7 @@ import ( "sync" "time" + "golang.org/x/sys/unix" "tailscale.com/version" ) @@ -24,96 +28,287 @@ func init() { localTCPPortAndToken = localTCPPortAndTokenDarwin } -// localTCPPortAndTokenMacsys returns the localhost TCP port number and auth token -// from /Library/Tailscale. -// -// In that case the files are: +const sameUserProofTokenLength = 10 + +type safesocketDarwin struct { + mu sync.Mutex + token string // safesocket auth token + port int // safesocket port + sameuserproofFD *os.File // File descriptor for macos app store sameuserproof file + sharedDir string // Shared directory for location of sameuserproof file + + checkConn bool // If true, check macsys safesocket port before returning it + isMacSysExt func() bool // Reports true if this binary is the macOS System Extension + isMacGUIApp func() bool // Reports true if running as a macOS GUI app (Tailscale.app) +} + +var ssd = safesocketDarwin{ + isMacSysExt: version.IsMacSysExt, + isMacGUIApp: func() bool { return version.IsMacAppStoreGUI() || version.IsMacSysGUI() }, + checkConn: true, + sharedDir: "/Library/Tailscale", +} + +// There are three ways a Darwin binary can be run: as the Mac App Store (macOS) +// standalone notarized (macsys), or a separate CLI (tailscale) that was +// built or downloaded. // -// /Library/Tailscale/ipnport => $port (symlink with localhost port number target) -// /Library/Tailscale/sameuserproof-$port is a file with auth -func localTCPPortAndTokenMacsys() (port int, token string, err error) { +// The macOS and macsys binaries can communicate directly via XPC with +// the NEPacketTunnelProvider managed tailscaled process and are responsible for +// calling SetCredentials when they need to operate as a CLI. + +// A built/downloaded CLI binary will not be managing the NEPacketTunnelProvider +// hosting tailscaled directly and must source the credentials from a 'sameuserproof' file. +// This file is written to sharedDir when tailscaled/NEPacketTunnelProvider +// calls InitListenerDarwin. + +// localTCPPortAndTokenDarwin returns the localhost TCP port number and auth token +// either from the sameuserproof mechanism, or source and set directly from the +// NEPacketTunnelProvider managed tailscaled process when the CLI is invoked +// from the Tailscale.app GUI. +func localTCPPortAndTokenDarwin() (port int, token string, err error) { + ssd.mu.Lock() + defer ssd.mu.Unlock() + + switch { + case ssd.port != 0 && ssd.token != "": + // If something has explicitly set our credentials (typically non-standalone macos binary), use them. + return ssd.port, ssd.token, nil + case !ssd.isMacGUIApp(): + // We're not a GUI app (probably cmd/tailscale), so try falling back to sameuserproof. + // If portAndTokenFromSameUserProof returns an error here, cmd/tailscale will + // attempt to use the default unix socket mechanism supported by tailscaled. + return portAndTokenFromSameUserProof() + default: + return 0, "", ErrTokenNotFound + } +} + +// SetCredentials sets an token and port used to authenticate safesocket generated +// by the NEPacketTunnelProvider tailscaled process. This is only used when running +// the CLI via Tailscale.app. +func SetCredentials(token string, port int) { + ssd.mu.Lock() + defer ssd.mu.Unlock() + + if ssd.token != "" || ssd.port != 0 { + // Not fatal, but likely programmer error. Credentials do not change. + log.Printf("warning: SetCredentials credentials already set") + } + + ssd.token = token + ssd.port = port +} + +// InitListenerDarwin initializes the listener for the CLI commands +// and localapi HTTP server and sets the port/token. This will override +// any credentials set explicitly via SetCredentials(). Calling this multiple times +// has no effect. The listener and it's corresponding token/port is initialized only once. +func InitListenerDarwin(sharedDir string) (*net.Listener, error) { + ssd.mu.Lock() + defer ssd.mu.Unlock() - const dir = "/Library/Tailscale" - portStr, err := os.Readlink(filepath.Join(dir, "ipnport")) + ln := onceListener.ln + if ln != nil { + return ln, nil + } + + var err error + ln, err = localhostListener() if err != nil { - return 0, "", err + log.Printf("InitListenerDarwin: listener initialization failed") + return nil, err } - port, err = strconv.Atoi(portStr) + + port, err := localhostTCPPort() if err != nil { - return 0, "", err + log.Printf("localhostTCPPort: listener initialization failed") + return nil, err } - authb, err := os.ReadFile(filepath.Join(dir, "sameuserproof-"+portStr)) + + token, err := getToken() if err != nil { - return 0, "", err + log.Printf("localhostTCPPort: getToken failed") + return nil, err } - auth := strings.TrimSpace(string(authb)) - if auth == "" { - return 0, "", errors.New("empty auth token in sameuserproof file") + + if port == 0 || token == "" { + log.Printf("localhostTCPPort: Invalid token or port") + return nil, fmt.Errorf("invalid localhostTCPPort: returned 0") } - // The above files exist forever after the first run of - // /Applications/Tailscale.app, so check we can connect to avoid returning a - // port nothing is listening on. Connect to "127.0.0.1" rather than - // "localhost" due to #7851. - conn, err := net.DialTimeout("tcp", "127.0.0.1:"+portStr, time.Second) + ssd.sharedDir = sharedDir + ssd.token = token + ssd.port = port + + // Write the port and token to a sameuserproof file + err = initSameUserProofToken(sharedDir, port, token) if err != nil { - return 0, "", err + // Not fatal + log.Printf("initSameUserProofToken: failed: %v", err) } - conn.Close() - return port, auth, nil + return ln, nil } -var warnAboutRootOnce sync.Once +var onceListener struct { + once sync.Once + ln *net.Listener +} -func localTCPPortAndTokenDarwin() (port int, token string, err error) { - // There are two ways this binary can be run: as the Mac App Store sandboxed binary, - // or a normal binary that somebody built or download and are being run from outside - // the sandbox. Detect which way we're running and then figure out how to connect - // to the local daemon. - - if dir := os.Getenv("TS_MACOS_CLI_SHARED_DIR"); dir != "" { - // First see if we're running as the non-AppStore "macsys" variant. - if version.IsMacSys() { - if port, token, err := localTCPPortAndTokenMacsys(); err == nil { - return port, token, nil +func localhostTCPPort() (int, error) { + if onceListener.ln == nil { + return 0, fmt.Errorf("listener not initialized") + } + + ln, err := localhostListener() + if err != nil { + return 0, err + } + + return (*ln).Addr().(*net.TCPAddr).Port, nil +} + +func localhostListener() (*net.Listener, error) { + onceListener.once.Do(func() { + ln, err := net.Listen("tcp4", "127.0.0.1:0") + if err != nil { + return + } + onceListener.ln = &ln + }) + if onceListener.ln == nil { + return nil, fmt.Errorf("failed to get TCP listener") + } + return onceListener.ln, nil +} + +var onceToken struct { + once sync.Once + token string +} + +func getToken() (string, error) { + onceToken.once.Do(func() { + buf := make([]byte, sameUserProofTokenLength) + if _, err := crand.Read(buf); err != nil { + return + } + t := fmt.Sprintf("%x", buf) + onceToken.token = t + }) + if onceToken.token == "" { + return "", fmt.Errorf("failed to generate token") + } + + return onceToken.token, nil +} + +// initSameUserProofToken writes the port and token to a sameuserproof +// file owned by the current user. We leave the file open to allow us +// to discover it via lsof. +// +// "sameuserproof" is intended to convey that the user attempting to read +// the credentials from the file is the same user that wrote them. For +// standalone macsys where tailscaled is running as root, we set group +// permissions to allow users in the admin group to read the file. +func initSameUserProofToken(sharedDir string, port int, token string) error { + var err error + + // Guard against bad sharedDir + old, err := os.ReadDir(sharedDir) + if err == os.ErrNotExist { + log.Printf("failed to read shared dir %s: %v", sharedDir, err) + return err + } + + // Remove all old sameuserproof files + for _, fi := range old { + if name := fi.Name(); strings.HasPrefix(name, "sameuserproof-") { + err := os.Remove(filepath.Join(sharedDir, name)) + if err != nil { + log.Printf("failed to remove %s: %v", name, err) } } + } - // The current binary (this process) is sandboxed. The user is - // running the CLI via /Applications/Tailscale.app/Contents/MacOS/Tailscale - // which sets the TS_MACOS_CLI_SHARED_DIR environment variable. - fis, err := os.ReadDir(dir) + var baseFile string + var perm fs.FileMode + if ssd.isMacSysExt() { + perm = 0640 // allow wheel to read + baseFile = fmt.Sprintf("sameuserproof-%d", port) + portFile := filepath.Join(sharedDir, "ipnport") + err := os.Remove(portFile) if err != nil { - return 0, "", err + log.Printf("failed to remove portfile %s: %v", portFile, err) } - for _, fi := range fis { - name := filepath.Base(fi.Name()) - // Look for name like "sameuserproof-61577-2ae2ec9e0aa2005784f1" - // to extract out the port number and token. - if strings.HasPrefix(name, "sameuserproof-") { - f := strings.SplitN(name, "-", 3) - if len(f) == 3 { - if port, err := strconv.Atoi(f[1]); err == nil { - return port, f[2], nil - } - } - } + symlinkErr := os.Symlink(fmt.Sprint(port), portFile) + if symlinkErr != nil { + log.Printf("failed to symlink portfile: %v", symlinkErr) } - if os.Geteuid() == 0 { - // Log a warning as the clue to the user, in case the error - // message is swallowed. Only do this once since we may retry - // multiple times to connect, and don't want to spam. - warnAboutRootOnce.Do(func() { - fmt.Fprintf(os.Stderr, "Warning: The CLI is running as root from within a sandboxed binary. It cannot reach the local tailscaled, please try again as a regular user.\n") - }) + } else { + perm = 0666 + baseFile = fmt.Sprintf("sameuserproof-%d-%s", port, token) + } + + path := filepath.Join(sharedDir, baseFile) + ssd.sameuserproofFD, err = os.OpenFile(path, os.O_RDWR|os.O_CREATE|os.O_TRUNC, perm) + log.Printf("initSameUserProofToken : done=%v", err == nil) + + if ssd.isMacSysExt() && err == nil { + fmt.Fprintf(ssd.sameuserproofFD, "%s\n", token) + + // Macsys runs as root so ownership of this file will be + // root/wheel. Change ownership to root/admin which will let all members + // of the admin group to read it. + unix.Fchown(int(ssd.sameuserproofFD.Fd()), 0, 80 /* admin */) + } + + return err +} + +// readMacsysSameuserproof returns the localhost TCP port number and auth token +// from a sameuserproof file written to /Library/Tailscale. +// +// In that case the files are: +// +// /Library/Tailscale/ipnport => $port (symlink with localhost port number target) +// /Library/Tailscale/sameuserproof-$port is a file containing only the auth token as a hex string. +func readMacsysSameUserProof() (port int, token string, err error) { + portStr, err := os.Readlink(filepath.Join(ssd.sharedDir, "ipnport")) + if err != nil { + return 0, "", err + } + port, err = strconv.Atoi(portStr) + if err != nil { + return 0, "", err + } + authb, err := os.ReadFile(filepath.Join(ssd.sharedDir, "sameuserproof-"+portStr)) + if err != nil { + return 0, "", err + } + auth := strings.TrimSpace(string(authb)) + if auth == "" { + return 0, "", errors.New("empty auth token in sameuserproof file") + } + + if ssd.checkConn { + // Files may be stale and there is no guarantee that the sameuserproof + // derived port is open and valid. Check it before returning it. + conn, err := net.DialTimeout("tcp", "127.0.0.1:"+portStr, time.Second) + if err != nil { + return 0, "", err } - return 0, "", fmt.Errorf("failed to find sandboxed sameuserproof-* file in TS_MACOS_CLI_SHARED_DIR %q", dir) + conn.Close() } - // The current process is running outside the sandbox, so use - // lsof to find the IPNExtension (the Mac App Store variant). + return port, auth, nil +} +// readMacosSameUserProof searches for open sameuserproof files belonging +// to the current user and the IPNExtension (macOS App Store) process and returns a +// port and token. +func readMacosSameUserProof() (port int, token string, err error) { cmd := exec.Command("lsof", "-n", // numeric sockets; don't do DNS lookups, etc "-a", // logical AND remaining options @@ -122,39 +317,45 @@ func localTCPPortAndTokenDarwin() (port int, token string, err error) { "-F", // machine-readable output ) out, err := cmd.Output() - if err != nil { - // Before returning an error, see if we're running the - // macsys variant at the normal location. - if port, token, err := localTCPPortAndTokenMacsys(); err == nil { + + if err == nil { + bs := bufio.NewScanner(bytes.NewReader(out)) + subStr := []byte(".tailscale.ipn.macos/sameuserproof-") + for bs.Scan() { + line := bs.Bytes() + i := bytes.Index(line, subStr) + if i == -1 { + continue + } + f := strings.SplitN(string(line[i+len(subStr):]), "-", 2) + if len(f) != 2 { + continue + } + portStr, token := f[0], f[1] + port, err := strconv.Atoi(portStr) + if err != nil { + return 0, "", fmt.Errorf("invalid port %q found in lsof", portStr) + } + return port, token, nil } - - return 0, "", fmt.Errorf("failed to run '%s' looking for IPNExtension: %w", cmd, err) } - bs := bufio.NewScanner(bytes.NewReader(out)) - subStr := []byte(".tailscale.ipn.macos/sameuserproof-") - for bs.Scan() { - line := bs.Bytes() - i := bytes.Index(line, subStr) - if i == -1 { - continue - } - f := strings.SplitN(string(line[i+len(subStr):]), "-", 2) - if len(f) != 2 { - continue - } - portStr, token := f[0], f[1] - port, err := strconv.Atoi(portStr) - if err != nil { - return 0, "", fmt.Errorf("invalid port %q found in lsof", portStr) - } + return 0, "", ErrTokenNotFound +} + +func portAndTokenFromSameUserProof() (port int, token string, err error) { + // When we're cmd/tailscale, we have no idea what tailscaled is, so we'll try + // macos, then macsys and finally, fallback to tailscaled via a unix socket + // if both of those return an error. You can run macos or macsys and + // tailscaled at the same time, but we are forced to choose one and the GUI + // clients are first in line here. You cannot run macos and macsys simultaneously. + if port, token, err := readMacosSameUserProof(); err == nil { return port, token, nil } - // Before returning an error, see if we're running the - // macsys variant at the normal location. - if port, token, err := localTCPPortAndTokenMacsys(); err == nil { + if port, token, err := readMacsysSameUserProof(); err == nil { return port, token, nil } + return 0, "", ErrTokenNotFound } diff --git a/safesocket/safesocket_darwin_test.go b/safesocket/safesocket_darwin_test.go new file mode 100644 index 0000000000000..e52959ad58dcf --- /dev/null +++ b/safesocket/safesocket_darwin_test.go @@ -0,0 +1,190 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package safesocket + +import ( + "os" + "strings" + "testing" + + "tailscale.com/tstest" +) + +// TestSetCredentials verifies that calling SetCredentials +// sets the port and token correctly and that LocalTCPPortAndToken +// returns the given values. +func TestSetCredentials(t *testing.T) { + const ( + wantToken = "token" + wantPort = 123 + ) + + tstest.Replace(t, &ssd.isMacGUIApp, func() bool { return false }) + SetCredentials(wantToken, wantPort) + + gotPort, gotToken, err := LocalTCPPortAndToken() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if gotPort != wantPort { + t.Errorf("port: got %d, want %d", gotPort, wantPort) + } + + if gotToken != wantToken { + t.Errorf("token: got %s, want %s", gotToken, wantToken) + } +} + +// TestFallbackToSameuserproof verifies that we fallback to the +// sameuserproof file via LocalTCPPortAndToken when we're running +// +// s cmd/tailscale +func TestFallbackToSameuserproof(t *testing.T) { + dir := t.TempDir() + const ( + wantToken = "token" + wantPort = 123 + ) + + // Mimics cmd/tailscale falling back to sameuserproof + tstest.Replace(t, &ssd.isMacGUIApp, func() bool { return false }) + tstest.Replace(t, &ssd.sharedDir, dir) + tstest.Replace(t, &ssd.checkConn, false) + + // Behave as macSysExt when initializing sameuserproof + tstest.Replace(t, &ssd.isMacSysExt, func() bool { return true }) + if err := initSameUserProofToken(dir, wantPort, wantToken); err != nil { + t.Fatalf("initSameUserProofToken: %v", err) + } + + gotPort, gotToken, err := LocalTCPPortAndToken() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if gotPort != wantPort { + t.Errorf("port: got %d, want %d", gotPort, wantPort) + } + + if gotToken != wantToken { + t.Errorf("token: got %s, want %s", gotToken, wantToken) + } +} + +// TestInitListenerDarwin verifies that InitListenerDarwin +// returns a listener and a non-zero port and non-empty token. +func TestInitListenerDarwin(t *testing.T) { + temp := t.TempDir() + tstest.Replace(t, &ssd.isMacGUIApp, func() bool { return false }) + + ln, err := InitListenerDarwin(temp) + if err != nil || ln == nil { + t.Fatalf("InitListenerDarwin failed: %v", err) + } + defer (*ln).Close() + + port, token, err := LocalTCPPortAndToken() + if err != nil { + t.Fatalf("LocalTCPPortAndToken failed: %v", err) + } + + if port == 0 { + t.Errorf("port: got %d, want non-zero", port) + } + + if token == "" { + t.Errorf("token: got %s, want non-empty", token) + } +} + +func TestTokenGeneration(t *testing.T) { + token, err := getToken() + if err != nil { + t.Fatalf("getToken: %v", err) + } + + // Verify token length (hex string is 2x byte length) + wantLen := sameUserProofTokenLength * 2 + if got := len(token); got != wantLen { + t.Errorf("token length: got %d, want %d", got, wantLen) + } + + // Verify token persistence + subsequentToken, err := getToken() + if err != nil { + t.Fatalf("subsequent getToken: %v", err) + } + if subsequentToken != token { + t.Errorf("subsequent token: got %q, want %q", subsequentToken, token) + } +} + +// TestSameUserProofToken verifies that the sameuserproof file +// is created and read correctly for the macsys variant +func TestMacsysSameuserproof(t *testing.T) { + dir := t.TempDir() + + tstest.Replace(t, &ssd.isMacSysExt, func() bool { return true }) + tstest.Replace(t, &ssd.checkConn, false) + tstest.Replace(t, &ssd.sharedDir, dir) + + const ( + wantToken = "token" + wantPort = 123 + ) + + if err := initSameUserProofToken(dir, wantPort, wantToken); err != nil { + t.Fatalf("initSameUserProofToken: %v", err) + } + + gotPort, gotToken, err := readMacsysSameUserProof() + if err != nil { + t.Fatalf("readMacOSSameUserProof: %v", err) + } + + if gotPort != wantPort { + t.Errorf("port: got %d, want %d", gotPort, wantPort) + } + if wantToken != gotToken { + t.Errorf("token: got %s, want %s", wantToken, gotToken) + } + assertFileCount(t, dir, 1, "sameuserproof-") +} + +// TestMacosSameuserproof verifies that the sameuserproof file +// is created correctly for the macos variant +func TestMacosSameuserproof(t *testing.T) { + dir := t.TempDir() + wantToken := "token" + wantPort := 123 + + initSameUserProofToken(dir, wantPort, wantToken) + + // initSameUserProofToken should never leave duplicates + initSameUserProofToken(dir, wantPort, wantToken) + + // we can't just call readMacosSameUserProof because it relies on lsof + // and makes some assumptions about the user. But we can make sure + // the file exists + assertFileCount(t, dir, 1, "sameuserproof-") +} + +func assertFileCount(t *testing.T, dir string, want int, prefix string) { + t.Helper() + + files, err := os.ReadDir(dir) + if err != nil { + t.Fatalf("[unexpected] error: %v", err) + } + count := 0 + for _, file := range files { + if strings.HasPrefix(file.Name(), prefix) { + count += 1 + } + } + if count != want { + t.Errorf("files: got %d, want 1", count) + } +} diff --git a/safesocket/safesocket_plan9.go b/safesocket/safesocket_plan9.go index 196c1df9ca4a7..c8a5e3b05bbef 100644 --- a/safesocket/safesocket_plan9.go +++ b/safesocket/safesocket_plan9.go @@ -7,119 +7,13 @@ package safesocket import ( "context" - "fmt" "net" - "os" - "syscall" - "time" - - "golang.org/x/sys/plan9" ) -// Plan 9's devsrv srv(3) is a server registry and -// it is conventionally bound to "/srv" in the default -// namespace. It is "a one level directory for holding -// already open channels to services". Post one end of -// a pipe to "/srv/tailscale.sock" and use the other -// end for communication with a requestor. Plan 9 pipes -// are bidirectional. - -type plan9SrvAddr string - -func (sl plan9SrvAddr) Network() string { - return "/srv" -} - -func (sl plan9SrvAddr) String() string { - return string(sl) -} - -// There is no net.FileListener for Plan 9 at this time -type plan9SrvListener struct { - name string - srvf *os.File - file *os.File -} - -func (sl *plan9SrvListener) Accept() (net.Conn, error) { - // sl.file is the server end of the pipe that's - // connected to /srv/tailscale.sock - return plan9FileConn{name: sl.name, file: sl.file}, nil -} - -func (sl *plan9SrvListener) Close() error { - sl.file.Close() - return sl.srvf.Close() -} - -func (sl *plan9SrvListener) Addr() net.Addr { - return plan9SrvAddr(sl.name) -} - -type plan9FileConn struct { - name string - file *os.File -} - -func (fc plan9FileConn) Read(b []byte) (n int, err error) { - return fc.file.Read(b) -} -func (fc plan9FileConn) Write(b []byte) (n int, err error) { - return fc.file.Write(b) -} -func (fc plan9FileConn) Close() error { - return fc.file.Close() -} -func (fc plan9FileConn) LocalAddr() net.Addr { - return plan9SrvAddr(fc.name) -} -func (fc plan9FileConn) RemoteAddr() net.Addr { - return plan9SrvAddr(fc.name) -} -func (fc plan9FileConn) SetDeadline(t time.Time) error { - return syscall.EPLAN9 -} -func (fc plan9FileConn) SetReadDeadline(t time.Time) error { - return syscall.EPLAN9 -} -func (fc plan9FileConn) SetWriteDeadline(t time.Time) error { - return syscall.EPLAN9 -} - func connect(_ context.Context, path string) (net.Conn, error) { - f, err := os.OpenFile(path, os.O_RDWR, 0666) - if err != nil { - return nil, err - } - - return plan9FileConn{name: path, file: f}, nil + return net.Dial("tcp", "localhost:5252") } -// Create an entry in /srv, open a pipe, write the -// client end to the entry and return the server -// end of the pipe to the caller. When the server -// end of the pipe is closed, /srv name associated -// with it will be removed (controlled by ORCLOSE flag) func listen(path string) (net.Listener, error) { - const O_RCLOSE = 64 // remove on close; should be in plan9 package - var pip [2]int - - err := plan9.Pipe(pip[:]) - if err != nil { - return nil, err - } - defer plan9.Close(pip[1]) - - srvfd, err := plan9.Create(path, plan9.O_WRONLY|plan9.O_CLOEXEC|O_RCLOSE, 0600) - if err != nil { - return nil, err - } - srv := os.NewFile(uintptr(srvfd), path) - - _, err = fmt.Fprintf(srv, "%d", pip[1]) - if err != nil { - return nil, err - } - - return &plan9SrvListener{name: path, srvf: srv, file: os.NewFile(uintptr(pip[0]), path)}, nil + return net.Listen("tcp", "localhost:5252") } diff --git a/safesocket/safesocket_ps.go b/safesocket/safesocket_ps.go index 18197846d307f..48a8dd483478b 100644 --- a/safesocket/safesocket_ps.go +++ b/safesocket/safesocket_ps.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || windows || (darwin && !ios) || freebsd +//go:build (linux && !android) || windows || (darwin && !ios) || freebsd package safesocket diff --git a/safeweb/http.go b/safeweb/http.go index 14c61336ac311..d085fcb8819d8 100644 --- a/safeweb/http.go +++ b/safeweb/http.go @@ -71,28 +71,78 @@ package safeweb import ( "cmp" + "context" crand "crypto/rand" "fmt" "log" + "maps" "net" "net/http" "net/url" "path" + "slices" "strings" "github.com/gorilla/csrf" ) -// The default Content-Security-Policy header. -var defaultCSP = strings.Join([]string{ - `default-src 'self'`, // origin is the only valid source for all content types - `script-src 'self'`, // disallow inline javascript - `frame-ancestors 'none'`, // disallow framing of the page - `form-action 'self'`, // disallow form submissions to other origins - `base-uri 'self'`, // disallow base URIs from other origins - `block-all-mixed-content`, // disallow mixed content when serving over HTTPS - `object-src 'self'`, // disallow embedding of resources from other origins -}, "; ") +// CSP is the value of a Content-Security-Policy header. Keys are CSP +// directives (like "default-src") and values are source expressions (like +// "'self'" or "https://tailscale.com"). A nil slice value is allowed for some +// directives like "upgrade-insecure-requests" that don't expect a list of +// source definitions. +type CSP map[string][]string + +// DefaultCSP is the recommended CSP to use when not loading resources from +// other domains and not embedding the current website. If you need to tweak +// the CSP, it is recommended to extend DefaultCSP instead of writing your own +// from scratch. +func DefaultCSP() CSP { + return CSP{ + "default-src": {"self"}, // origin is the only valid source for all content types + "frame-ancestors": {"none"}, // disallow framing of the page + "form-action": {"self"}, // disallow form submissions to other origins + "base-uri": {"self"}, // disallow base URIs from other origins + // TODO(awly): consider upgrade-insecure-requests in SecureContext + // instead, as this is deprecated. + "block-all-mixed-content": nil, // disallow mixed content when serving over HTTPS + } +} + +// Set sets the values for a given directive. Empty values are allowed, if the +// directive doesn't expect any (like "upgrade-insecure-requests"). +func (csp CSP) Set(directive string, values ...string) { + csp[directive] = values +} + +// Add adds a source expression to an existing directive. +func (csp CSP) Add(directive, value string) { + csp[directive] = append(csp[directive], value) +} + +// Del deletes a directive and all its values. +func (csp CSP) Del(directive string) { + delete(csp, directive) +} + +func (csp CSP) String() string { + keys := slices.Collect(maps.Keys(csp)) + slices.Sort(keys) + var s strings.Builder + for _, k := range keys { + s.WriteString(k) + for _, v := range csp[k] { + // Special values like 'self', 'none', 'unsafe-inline', etc., must + // be quoted. Do it implicitly as a convenience here. + if !strings.Contains(v, ".") && len(v) > 1 && v[0] != '\'' && v[len(v)-1] != '\'' { + v = "'" + v + "'" + } + s.WriteString(" " + v) + } + s.WriteString("; ") + } + return strings.TrimSpace(s.String()) +} // The default Strict-Transport-Security header. This header tells the browser // to exclusively use HTTPS for all requests to the origin for the next year. @@ -130,6 +180,9 @@ type Config struct { // startup. CSRFSecret []byte + // CSP is the Content-Security-Policy header to return with BrowserMux + // responses. + CSP CSP // CSPAllowInlineStyles specifies whether to include `style-src: // unsafe-inline` in the Content-Security-Policy header to permit the use of // inline CSS. @@ -168,6 +221,10 @@ func (c *Config) setDefaults() error { } } + if c.CSP == nil { + c.CSP = DefaultCSP() + } + return nil } @@ -199,16 +256,20 @@ func NewServer(config Config) (*Server, error) { if config.CookiesSameSiteLax { sameSite = csrf.SameSiteLaxMode } + if config.CSPAllowInlineStyles { + if _, ok := config.CSP["style-src"]; ok { + config.CSP.Add("style-src", "unsafe-inline") + } else { + config.CSP.Set("style-src", "self", "unsafe-inline") + } + } s := &Server{ Config: config, - csp: defaultCSP, + csp: config.CSP.String(), // only set Secure flag on CSRF cookies if we are in a secure context // as otherwise the browser will reject the cookie csrfProtect: csrf.Protect(config.CSRFSecret, csrf.Secure(config.SecureContext), csrf.SameSite(sameSite)), } - if config.CSPAllowInlineStyles { - s.csp = defaultCSP + `; style-src 'self' 'unsafe-inline'` - } s.h = cmp.Or(config.HTTPServer, &http.Server{}) if s.h.Handler != nil { return nil, fmt.Errorf("use safeweb.Config.APIMux and safeweb.Config.BrowserMux instead of http.Server.Handler") @@ -225,12 +286,27 @@ const ( browserHandler ) +func (h handlerType) String() string { + switch h { + case browserHandler: + return "browser" + case apiHandler: + return "api" + default: + return "unknown" + } +} + // checkHandlerType returns either apiHandler or browserHandler, depending on // whether apiPattern or browserPattern is more specific (i.e. which pattern // contains more pathname components). If they are equally specific, it returns // unknownHandler. func checkHandlerType(apiPattern, browserPattern string) handlerType { - c := cmp.Compare(strings.Count(path.Clean(apiPattern), "/"), strings.Count(path.Clean(browserPattern), "/")) + apiPattern, browserPattern = path.Clean(apiPattern), path.Clean(browserPattern) + c := cmp.Compare(strings.Count(apiPattern, "/"), strings.Count(browserPattern, "/")) + if apiPattern == "/" || browserPattern == "/" { + c = cmp.Compare(len(apiPattern), len(browserPattern)) + } switch { case c > 0: return apiHandler @@ -242,6 +318,12 @@ func checkHandlerType(apiPattern, browserPattern string) handlerType { } func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // if we are not in a secure context, signal to the CSRF middleware that + // TLS-only header checks should be skipped + if !s.Config.SecureContext { + r = csrf.PlaintextHTTPRequest(r) + } + _, bp := s.BrowserMux.Handler(r) _, ap := s.APIMux.Handler(r) switch { @@ -294,6 +376,7 @@ func (s *Server) serveBrowser(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Security-Policy", s.csp) w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("Referer-Policy", "same-origin") + w.Header().Set("Cross-Origin-Opener-Policy", "same-origin") if s.SecureContext { w.Header().Set("Strict-Transport-Security", cmp.Or(s.StrictTransportSecurityOptions, DefaultStrictTransportSecurityOptions)) } @@ -341,3 +424,7 @@ func (s *Server) ListenAndServe(addr string) error { func (s *Server) Close() error { return s.h.Close() } + +// Shutdown gracefully shuts down the server without interrupting any active +// connections. It has the same semantics as[http.Server.Shutdown]. +func (s *Server) Shutdown(ctx context.Context) error { return s.h.Shutdown(ctx) } diff --git a/safeweb/http_test.go b/safeweb/http_test.go index cec14b2b9bb8b..852ce326ba374 100644 --- a/safeweb/http_test.go +++ b/safeweb/http_test.go @@ -241,18 +241,26 @@ func TestCSRFProtection(t *testing.T) { func TestContentSecurityPolicyHeader(t *testing.T) { tests := []struct { name string + csp CSP apiRoute bool - wantCSP bool + wantCSP string }{ { - name: "default routes get CSP headers", - apiRoute: false, - wantCSP: true, + name: "default CSP", + wantCSP: `base-uri 'self'; block-all-mixed-content; default-src 'self'; form-action 'self'; frame-ancestors 'none';`, + }, + { + name: "custom CSP", + csp: CSP{ + "default-src": {"'self'", "https://tailscale.com"}, + "upgrade-insecure-requests": nil, + }, + wantCSP: `default-src 'self' https://tailscale.com; upgrade-insecure-requests;`, }, { name: "`/api/*` routes do not get CSP headers", apiRoute: true, - wantCSP: false, + wantCSP: "", }, } @@ -265,9 +273,9 @@ func TestContentSecurityPolicyHeader(t *testing.T) { var s *Server var err error if tt.apiRoute { - s, err = NewServer(Config{APIMux: h}) + s, err = NewServer(Config{APIMux: h, CSP: tt.csp}) } else { - s, err = NewServer(Config{BrowserMux: h}) + s, err = NewServer(Config{BrowserMux: h, CSP: tt.csp}) } if err != nil { t.Fatal(err) @@ -279,8 +287,8 @@ func TestContentSecurityPolicyHeader(t *testing.T) { s.h.Handler.ServeHTTP(w, req) resp := w.Result() - if (resp.Header.Get("Content-Security-Policy") == "") == tt.wantCSP { - t.Fatalf("content security policy want: %v; got: %v", tt.wantCSP, resp.Header.Get("Content-Security-Policy")) + if got := resp.Header.Get("Content-Security-Policy"); got != tt.wantCSP { + t.Fatalf("content security policy want: %q; got: %q", tt.wantCSP, got) } }) } @@ -397,7 +405,7 @@ func TestCSPAllowInlineStyles(t *testing.T) { csp := resp.Header.Get("Content-Security-Policy") allowsStyles := strings.Contains(csp, "style-src 'self' 'unsafe-inline'") if allowsStyles != allow { - t.Fatalf("CSP inline styles want: %v; got: %v", allow, allowsStyles) + t.Fatalf("CSP inline styles want: %v, got: %v in %q", allow, allowsStyles, csp) } }) } @@ -527,13 +535,13 @@ func TestGetMoreSpecificPattern(t *testing.T) { { desc: "same prefix", a: "/foo/bar/quux", - b: "/foo/bar/", + b: "/foo/bar/", // path.Clean will strip the trailing slash. want: apiHandler, }, { desc: "almost same prefix, but not a path component", a: "/goat/sheep/cheese", - b: "/goat/sheepcheese/", + b: "/goat/sheepcheese/", // path.Clean will strip the trailing slash. want: apiHandler, }, { @@ -554,6 +562,12 @@ func TestGetMoreSpecificPattern(t *testing.T) { b: "///////", want: unknownHandler, }, + { + desc: "root-level", + a: "/latest", + b: "/", // path.Clean will NOT strip the trailing slash. + want: apiHandler, + }, } { t.Run(tt.desc, func(t *testing.T) { got := checkHandlerType(tt.a, tt.b) diff --git a/scripts/installer.sh b/scripts/installer.sh index 19911ee23c8a7..f81ae529298aa 100755 --- a/scripts/installer.sh +++ b/scripts/installer.sh @@ -68,6 +68,14 @@ main() { if [ -z "${VERSION_ID:-}" ]; then # rolling release. If you haven't kept current, that's on you. APT_KEY_TYPE="keyring" + # Parrot Security is a special case that uses ID=debian + elif [ "$NAME" = "Parrot Security" ]; then + # All versions new enough to have this behaviour prefer keyring + # and their VERSION_ID is not consistent with Debian. + APT_KEY_TYPE="keyring" + # They don't specify the Debian version they're based off in os-release + # but Parrot 6 is based on Debian 12 Bookworm. + VERSION=bookworm elif [ "$VERSION_ID" -lt 11 ]; then APT_KEY_TYPE="legacy" else @@ -154,7 +162,7 @@ main() { APT_KEY_TYPE="keyring" fi ;; - Deepin) # https://github.com/tailscale/tailscale/issues/7862 + Deepin|deepin) # https://github.com/tailscale/tailscale/issues/7862 OS="debian" PACKAGETYPE="apt" if [ "$VERSION_ID" -lt 20 ]; then @@ -165,6 +173,25 @@ main() { VERSION="bullseye" fi ;; + pika) + PACKAGETYPE="apt" + # All versions of PikaOS are new enough to prefer keyring + APT_KEY_TYPE="keyring" + # Older versions of PikaOS are based on Ubuntu rather than Debian + if [ "$VERSION_ID" -lt 4 ]; then + OS="ubuntu" + VERSION="$UBUNTU_CODENAME" + else + OS="debian" + VERSION="$DEBIAN_CODENAME" + fi + ;; + sparky) + OS="debian" + PACKAGETYPE="apt" + VERSION="$DEBIAN_CODENAME" + APT_KEY_TYPE="keyring" + ;; centos) OS="$ID" VERSION="$VERSION_ID" @@ -181,8 +208,11 @@ main() { PACKAGETYPE="yum" fi ;; - rhel) + rhel|miraclelinux) OS="$ID" + if [ "$ID" = "miraclelinux" ]; then + OS="rhel" + fi VERSION="$(echo "$VERSION_ID" | cut -f1 -d.)" PACKAGETYPE="dnf" if [ "$VERSION" = "7" ]; then @@ -224,12 +254,12 @@ main() { VERSION="leap/15.4" PACKAGETYPE="zypper" ;; - arch|archarm|endeavouros|blendos|garuda) + arch|archarm|endeavouros|blendos|garuda|archcraft|cachyos) OS="arch" VERSION="" # rolling release PACKAGETYPE="pacman" ;; - manjaro|manjaro-arm) + manjaro|manjaro-arm|biglinux) OS="manjaro" VERSION="" # rolling release PACKAGETYPE="pacman" @@ -369,7 +399,8 @@ main() { ;; freebsd) if [ "$VERSION" != "12" ] && \ - [ "$VERSION" != "13" ] + [ "$VERSION" != "13" ] && \ + [ "$VERSION" != "14" ] then OS_UNSUPPORTED=1 fi @@ -465,10 +496,13 @@ main() { legacy) $CURL "https://pkgs.tailscale.com/$TRACK/$OS/$VERSION.asc" | $SUDO apt-key add - $CURL "https://pkgs.tailscale.com/$TRACK/$OS/$VERSION.list" | $SUDO tee /etc/apt/sources.list.d/tailscale.list + $SUDO chmod 0644 /etc/apt/sources.list.d/tailscale.list ;; keyring) $CURL "https://pkgs.tailscale.com/$TRACK/$OS/$VERSION.noarmor.gpg" | $SUDO tee /usr/share/keyrings/tailscale-archive-keyring.gpg >/dev/null + $SUDO chmod 0644 /usr/share/keyrings/tailscale-archive-keyring.gpg $CURL "https://pkgs.tailscale.com/$TRACK/$OS/$VERSION.tailscale-keyring.list" | $SUDO tee /etc/apt/sources.list.d/tailscale.list + $SUDO chmod 0644 /etc/apt/sources.list.d/tailscale.list ;; esac $SUDO apt-get update @@ -488,9 +522,41 @@ main() { set +x ;; dnf) + # DNF 5 has a different argument format; determine which one we have. + DNF_VERSION="3" + if LANG=C.UTF-8 dnf --version | grep -q '^dnf5 version'; then + DNF_VERSION="5" + fi + + # The 'config-manager' plugin wasn't implemented when + # DNF5 was released; detect that and use the old + # version if necessary. + if [ "$DNF_VERSION" = "5" ]; then + set -x + $SUDO dnf install -y 'dnf-command(config-manager)' && DNF_HAVE_CONFIG_MANAGER=1 || DNF_HAVE_CONFIG_MANAGER=0 + set +x + + if [ "$DNF_HAVE_CONFIG_MANAGER" != "1" ]; then + if type dnf-3 >/dev/null; then + DNF_VERSION="3" + else + echo "dnf 5 detected, but 'dnf-command(config-manager)' not available and dnf-3 not found" + exit 1 + fi + fi + fi + set -x - $SUDO dnf install -y 'dnf-command(config-manager)' - $SUDO dnf config-manager --add-repo "https://pkgs.tailscale.com/$TRACK/$OS/$VERSION/tailscale.repo" + if [ "$DNF_VERSION" = "3" ]; then + $SUDO dnf install -y 'dnf-command(config-manager)' + $SUDO dnf config-manager --add-repo "https://pkgs.tailscale.com/$TRACK/$OS/$VERSION/tailscale.repo" + elif [ "$DNF_VERSION" = "5" ]; then + # Already installed config-manager, above. + $SUDO dnf config-manager addrepo --from-repofile="https://pkgs.tailscale.com/$TRACK/$OS/$VERSION/tailscale.repo" + else + echo "unexpected: unknown dnf version $DNF_VERSION" + exit 1 + fi $SUDO dnf install -y tailscale $SUDO systemctl enable --now tailscaled set +x @@ -519,7 +585,7 @@ main() { ;; pkg) set -x - $SUDO pkg install tailscale + $SUDO pkg install --yes tailscale $SUDO service tailscaled enable $SUDO service tailscaled start set +x diff --git a/sessionrecording/connect.go b/sessionrecording/connect.go index db966ba2cdee2..dc697d071dad2 100644 --- a/sessionrecording/connect.go +++ b/sessionrecording/connect.go @@ -7,6 +7,8 @@ package sessionrecording import ( "context" + "crypto/tls" + "encoding/json" "errors" "fmt" "io" @@ -14,12 +16,31 @@ import ( "net/http" "net/http/httptrace" "net/netip" + "sync/atomic" "time" + "golang.org/x/net/http2" + "tailscale.com/net/netx" "tailscale.com/tailcfg" + "tailscale.com/util/httpm" "tailscale.com/util/multierr" ) +const ( + // Timeout for an individual DialFunc call for a single recorder address. + perDialAttemptTimeout = 5 * time.Second + // Timeout for the V2 API HEAD probe request (supportsV2). + http2ProbeTimeout = 10 * time.Second + // Maximum timeout for trying all available recorders, including V2 API + // probes and dial attempts. + allDialAttemptsTimeout = 30 * time.Second +) + +// uploadAckWindow is the period of time to wait for an ackFrame from recorder +// before terminating the connection. This is a variable to allow overriding it +// in tests. +var uploadAckWindow = 30 * time.Second + // ConnectToRecorder connects to the recorder at any of the provided addresses. // It returns the first successful response, or a multierr if all attempts fail. // @@ -32,19 +53,15 @@ import ( // attempts are in order the recorder(s) was attempted. If successful a // successful connection is made, the last attempt in the slice is the // attempt for connected recorder. -func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(context.Context, string, string) (net.Conn, error)) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) { +func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial netx.DialFunc) (io.WriteCloser, []*tailcfg.SSHRecordingAttempt, <-chan error, error) { if len(recs) == 0 { return nil, nil, nil, errors.New("no recorders configured") } // We use a special context for dialing the recorder, so that we can // limit the time we spend dialing to 30 seconds and still have an // unbounded context for the upload. - dialCtx, dialCancel := context.WithTimeout(ctx, 30*time.Second) + dialCtx, dialCancel := context.WithTimeout(ctx, allDialAttemptsTimeout) defer dialCancel() - hc, err := SessionRecordingClientForDialer(dialCtx, dial) - if err != nil { - return nil, nil, nil, err - } var errs []error var attempts []*tailcfg.SSHRecordingAttempt @@ -54,74 +71,230 @@ func ConnectToRecorder(ctx context.Context, recs []netip.AddrPort, dial func(con } attempts = append(attempts, attempt) - // We dial the recorder and wait for it to send a 100-continue - // response before returning from this function. This ensures that - // the recorder is ready to accept the recording. - - // got100 is closed when we receive the 100-continue response. - got100 := make(chan struct{}) - ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ - Got100Continue: func() { - close(got100) - }, - }) - - pr, pw := io.Pipe() - req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s:%d/record", ap.Addr(), ap.Port()), pr) + var pw io.WriteCloser + var errChan <-chan error + var err error + hc := clientHTTP2(dialCtx, dial) + // We need to probe V2 support using a separate HEAD request. Sending + // an HTTP/2 POST request to a HTTP/1 server will just "hang" until the + // request body is closed (instead of returning a 404 as one would + // expect). Sending a HEAD request without a body does not have that + // problem. + if supportsV2(ctx, hc, ap) { + pw, errChan, err = connectV2(ctx, hc, ap) + } else { + pw, errChan, err = connectV1(ctx, clientHTTP1(dialCtx, dial), ap) + } if err != nil { - err = fmt.Errorf("recording: error starting recording: %w", err) + err = fmt.Errorf("recording: error starting recording on %q: %w", ap, err) attempt.FailureMessage = err.Error() errs = append(errs, err) continue } - // We set the Expect header to 100-continue, so that the recorder - // will send a 100-continue response before it starts reading the - // request body. - req.Header.Set("Expect", "100-continue") + return pw, attempts, errChan, nil + } + return nil, attempts, nil, multierr.New(errs...) +} - // errChan is used to indicate the result of the request. - errChan := make(chan error, 1) - go func() { - resp, err := hc.Do(req) - if err != nil { - errChan <- fmt.Errorf("recording: error starting recording: %w", err) +// supportsV2 checks whether a recorder instance supports the /v2/record +// endpoint. +func supportsV2(ctx context.Context, hc *http.Client, ap netip.AddrPort) bool { + ctx, cancel := context.WithTimeout(ctx, http2ProbeTimeout) + defer cancel() + req, err := http.NewRequestWithContext(ctx, httpm.HEAD, fmt.Sprintf("http://%s/v2/record", ap), nil) + if err != nil { + return false + } + resp, err := hc.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + return resp.StatusCode == http.StatusOK && resp.ProtoMajor > 1 +} + +// connectV1 connects to the legacy /record endpoint on the recorder. It is +// used for backwards-compatibility with older tsrecorder instances. +// +// On success, it returns a WriteCloser that can be used to upload the +// recording, and a channel that will be sent an error (or nil) when the upload +// fails or completes. +func connectV1(ctx context.Context, hc *http.Client, ap netip.AddrPort) (io.WriteCloser, <-chan error, error) { + // We dial the recorder and wait for it to send a 100-continue + // response before returning from this function. This ensures that + // the recorder is ready to accept the recording. + + // got100 is closed when we receive the 100-continue response. + got100 := make(chan struct{}) + ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + Got100Continue: func() { + close(got100) + }, + }) + + pr, pw := io.Pipe() + req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s/record", ap), pr) + if err != nil { + return nil, nil, err + } + // We set the Expect header to 100-continue, so that the recorder + // will send a 100-continue response before it starts reading the + // request body. + req.Header.Set("Expect", "100-continue") + + // errChan is used to indicate the result of the request. + errChan := make(chan error, 1) + go func() { + defer close(errChan) + resp, err := hc.Do(req) + if err != nil { + errChan <- err + return + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status) + return + } + }() + select { + case <-got100: + return pw, errChan, nil + case err := <-errChan: + // If we get an error before we get the 100-continue response, + // we need to try another recorder. + if err == nil { + // If the error is nil, we got a 200 response, which + // is unexpected as we haven't sent any data yet. + err = errors.New("recording: unexpected EOF") + } + return nil, nil, err + } +} + +// connectV2 connects to the /v2/record endpoint on the recorder over HTTP/2. +// It explicitly tracks ack frames sent in the response and terminates the +// connection if sent recording data is un-acked for uploadAckWindow. +// +// On success, it returns a WriteCloser that can be used to upload the +// recording, and a channel that will be sent an error (or nil) when the upload +// fails or completes. +func connectV2(ctx context.Context, hc *http.Client, ap netip.AddrPort) (io.WriteCloser, <-chan error, error) { + pr, pw := io.Pipe() + upload := &readCounter{r: pr} + req, err := http.NewRequestWithContext(ctx, "POST", fmt.Sprintf("http://%s/v2/record", ap), upload) + if err != nil { + return nil, nil, err + } + + // With HTTP/2, hc.Do will not block while the request body is being sent. + // It will return immediately and allow us to consume the response body at + // the same time. + resp, err := hc.Do(req) + if err != nil { + return nil, nil, err + } + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + return nil, nil, fmt.Errorf("recording: unexpected status: %v", resp.Status) + } + + errChan := make(chan error, 1) + acks := make(chan int64) + // Read acks from the response and send them to the acks channel. + go func() { + defer close(errChan) + defer close(acks) + defer resp.Body.Close() + defer pw.Close() + dec := json.NewDecoder(resp.Body) + for { + var frame v2ResponseFrame + if err := dec.Decode(&frame); err != nil { + if !errors.Is(err, io.EOF) { + errChan <- fmt.Errorf("recording: unexpected error receiving acks: %w", err) + } return } - if resp.StatusCode != 200 { - errChan <- fmt.Errorf("recording: unexpected status: %v", resp.Status) + if frame.Error != "" { + errChan <- fmt.Errorf("recording: received error from the recorder: %q", frame.Error) return } - errChan <- nil - }() - select { - case <-got100: - case err := <-errChan: - // If we get an error before we get the 100-continue response, - // we need to try another recorder. - if err == nil { - // If the error is nil, we got a 200 response, which - // is unexpected as we haven't sent any data yet. - err = errors.New("recording: unexpected EOF") + select { + case acks <- frame.Ack: + case <-ctx.Done(): + return } - attempt.FailureMessage = err.Error() - errs = append(errs, err) - continue // try the next recorder } - return pw, attempts, errChan, nil - } - return nil, attempts, nil, multierr.New(errs...) + }() + // Track acks from the acks channel. + go func() { + // Hack for tests: some tests modify uploadAckWindow and reset it when + // the test ends. This can race with t.Reset call below. Making a copy + // here is a lazy workaround to not wait for this goroutine to exit in + // the test cases. + uploadAckWindow := uploadAckWindow + // This timer fires if we didn't receive an ack for too long. + t := time.NewTimer(uploadAckWindow) + defer t.Stop() + for { + select { + case <-t.C: + // Close the pipe which terminates the connection and cleans up + // other goroutines. Note that tsrecorder will send us ack + // frames even if there is no new data to ack. This helps + // detect broken recorder connection if the session is idle. + pr.CloseWithError(errNoAcks) + resp.Body.Close() + return + case _, ok := <-acks: + if !ok { + // acks channel closed means that the goroutine reading them + // finished, which means that the request has ended. + return + } + // TODO(awly): limit how far behind the received acks can be. This + // should handle scenarios where a session suddenly dumps a lot of + // output. + t.Reset(uploadAckWindow) + case <-ctx.Done(): + return + } + } + }() + + return pw, errChan, nil } -// SessionRecordingClientForDialer returns an http.Client that uses a clone of -// the provided Dialer's PeerTransport to dial connections. This is used to make -// requests to the session recording server to upload session recordings. It -// uses the provided dialCtx to dial connections, and limits a single dial to 5 -// seconds. -func SessionRecordingClientForDialer(dialCtx context.Context, dial func(context.Context, string, string) (net.Conn, error)) (*http.Client, error) { - tr := http.DefaultTransport.(*http.Transport).Clone() +var errNoAcks = errors.New("did not receive ack frames from the recorder in 30s") + +type v2ResponseFrame struct { + // Ack is the number of bytes received from the client so far. The bytes + // are not guaranteed to be durably stored yet. + Ack int64 `json:"ack,omitempty"` + // Error is an error encountered while storing the recording. Error is only + // ever set as the last frame in the response. + Error string `json:"error,omitempty"` +} +// readCounter is an io.Reader that counts how many bytes were read. +type readCounter struct { + r io.Reader + sent atomic.Int64 +} + +func (u *readCounter) Read(buf []byte) (int, error) { + n, err := u.r.Read(buf) + u.sent.Add(int64(n)) + return n, err +} + +// clientHTTP1 returns a claassic http.Client with a per-dial context. It uses +// dialCtx and adds a 5s timeout to it. +func clientHTTP1(dialCtx context.Context, dial netx.DialFunc) *http.Client { + tr := http.DefaultTransport.(*http.Transport).Clone() tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { - perAttemptCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + perAttemptCtx, cancel := context.WithTimeout(ctx, perDialAttemptTimeout) defer cancel() go func() { select { @@ -132,7 +305,32 @@ func SessionRecordingClientForDialer(dialCtx context.Context, dial func(context. }() return dial(perAttemptCtx, network, addr) } + return &http.Client{Transport: tr} +} + +// clientHTTP2 is like clientHTTP1 but returns an http.Client suitable for h2c +// requests (HTTP/2 over plaintext). Unfortunately the same client does not +// work for HTTP/1 so we need to split these up. +func clientHTTP2(dialCtx context.Context, dial netx.DialFunc) *http.Client { return &http.Client{ - Transport: tr, - }, nil + Transport: &http2.Transport{ + // Allow "http://" scheme in URLs. + AllowHTTP: true, + // Pretend like we're using TLS, but actually use the provided + // DialFunc underneath. This is necessary to convince the transport + // to actually dial. + DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { + perAttemptCtx, cancel := context.WithTimeout(ctx, perDialAttemptTimeout) + defer cancel() + go func() { + select { + case <-perAttemptCtx.Done(): + case <-dialCtx.Done(): + cancel() + } + }() + return dial(perAttemptCtx, network, addr) + }, + }, + } } diff --git a/sessionrecording/connect_test.go b/sessionrecording/connect_test.go new file mode 100644 index 0000000000000..c0fcf6d40c617 --- /dev/null +++ b/sessionrecording/connect_test.go @@ -0,0 +1,189 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package sessionrecording + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/sha256" + "encoding/json" + "io" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "testing" + "time" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +func TestConnectToRecorder(t *testing.T) { + tests := []struct { + desc string + http2 bool + // setup returns a recorder server mux, and a channel which sends the + // hash of the recording uploaded to it. The channel is expected to + // fire only once. + setup func(t *testing.T) (*http.ServeMux, <-chan []byte) + wantErr bool + }{ + { + desc: "v1 recorder", + setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { + uploadHash := make(chan []byte, 1) + mux := http.NewServeMux() + mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) { + hash := sha256.New() + if _, err := io.Copy(hash, r.Body); err != nil { + t.Error(err) + } + uploadHash <- hash.Sum(nil) + }) + return mux, uploadHash + }, + }, + { + desc: "v2 recorder", + http2: true, + setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { + uploadHash := make(chan []byte, 1) + mux := http.NewServeMux() + mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) { + t.Error("received request to v1 endpoint") + http.Error(w, "not found", http.StatusNotFound) + }) + mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) { + // Force the status to send to unblock the client waiting + // for it. + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + + body := &readCounter{r: r.Body} + hash := sha256.New() + ctx, cancel := context.WithCancel(r.Context()) + go func() { + defer cancel() + if _, err := io.Copy(hash, body); err != nil { + t.Error(err) + } + }() + + // Send acks for received bytes. + tick := time.NewTicker(time.Millisecond) + defer tick.Stop() + enc := json.NewEncoder(w) + outer: + for { + select { + case <-ctx.Done(): + break outer + case <-tick.C: + if err := enc.Encode(v2ResponseFrame{Ack: body.sent.Load()}); err != nil { + t.Errorf("writing ack frame: %v", err) + break outer + } + } + } + + uploadHash <- hash.Sum(nil) + }) + // Probing HEAD endpoint which always returns 200 OK. + mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {}) + return mux, uploadHash + }, + }, + { + desc: "v2 recorder no acks", + http2: true, + wantErr: true, + setup: func(t *testing.T) (*http.ServeMux, <-chan []byte) { + // Make the client no-ack timeout quick for the test. + oldAckWindow := uploadAckWindow + uploadAckWindow = 100 * time.Millisecond + t.Cleanup(func() { uploadAckWindow = oldAckWindow }) + + uploadHash := make(chan []byte, 1) + mux := http.NewServeMux() + mux.HandleFunc("POST /record", func(w http.ResponseWriter, r *http.Request) { + t.Error("received request to v1 endpoint") + http.Error(w, "not found", http.StatusNotFound) + }) + mux.HandleFunc("POST /v2/record", func(w http.ResponseWriter, r *http.Request) { + // Force the status to send to unblock the client waiting + // for it. + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + + // Consume the whole request body but don't send any acks + // back. + hash := sha256.New() + if _, err := io.Copy(hash, r.Body); err != nil { + t.Error(err) + } + // Goes in the channel buffer, non-blocking. + uploadHash <- hash.Sum(nil) + + // Block until the parent test case ends to prevent the + // request termination. We want to exercise the ack + // tracking logic specifically. + ctx, cancel := context.WithCancel(r.Context()) + t.Cleanup(cancel) + <-ctx.Done() + }) + mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {}) + return mux, uploadHash + }, + }, + } + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + mux, uploadHash := tt.setup(t) + + srv := httptest.NewUnstartedServer(mux) + if tt.http2 { + // Wire up h2c-compatible HTTP/2 server. This is optional + // because the v1 recorder didn't support HTTP/2 and we try to + // mimic that. + h2s := &http2.Server{} + srv.Config.Handler = h2c.NewHandler(mux, h2s) + if err := http2.ConfigureServer(srv.Config, h2s); err != nil { + t.Errorf("configuring HTTP/2 support in server: %v", err) + } + } + srv.Start() + t.Cleanup(srv.Close) + + d := new(net.Dialer) + + ctx := context.Background() + w, _, errc, err := ConnectToRecorder(ctx, []netip.AddrPort{netip.MustParseAddrPort(srv.Listener.Addr().String())}, d.DialContext) + if err != nil { + t.Fatalf("ConnectToRecorder: %v", err) + } + + // Send some random data and hash it to compare with the recorded + // data hash. + hash := sha256.New() + const numBytes = 1 << 20 // 1MB + if _, err := io.CopyN(io.MultiWriter(w, hash), rand.Reader, numBytes); err != nil { + t.Fatalf("writing recording data: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("closing recording stream: %v", err) + } + if err := <-errc; err != nil && !tt.wantErr { + t.Fatalf("error from the channel: %v", err) + } else if err == nil && tt.wantErr { + t.Fatalf("did not receive expected error from the channel") + } + + if recv, sent := <-uploadHash, hash.Sum(nil); !bytes.Equal(recv, sent) { + t.Errorf("mismatch in recording data hash, sent %x, received %x", sent, recv) + } + }) + } +} diff --git a/shell.nix b/shell.nix index 4d2e24366ae46..bb8eacb67ee18 100644 --- a/shell.nix +++ b/shell.nix @@ -16,4 +16,4 @@ ) { src = ./.; }).shellNix -# nix-direnv cache busting line: sha256-xO1DuLWi6/lpA9ubA2ZYVJM+CkVNA5IaVGZxX9my0j0= +# nix-direnv cache busting line: sha256-av4kr09rjNRmag94ziNjJuI/cg8b8lAD3Tk24t/ezH4= diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index 3ff676d519898..442fedcf22242 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -12,11 +12,13 @@ package tailssh import ( + "context" "encoding/json" "errors" "flag" "fmt" "io" + "io/fs" "log" "log/syslog" "os" @@ -29,6 +31,7 @@ import ( "strings" "sync/atomic" "syscall" + "time" "github.com/creack/pty" "github.com/pkg/sftp" @@ -43,6 +46,13 @@ import ( "tailscale.com/version/distro" ) +const ( + linux = "linux" + darwin = "darwin" + freebsd = "freebsd" + openbsd = "openbsd" +) + func init() { childproc.Add("ssh", beIncubator) childproc.Add("sftp", beSFTP) @@ -63,11 +73,36 @@ var maybeStartLoginSession = func(dlogf logger.Logf, ia incubatorArgs) (close fu return nil } +// tryExecInDir tries to run a command in dir and returns nil if it succeeds. +// Otherwise, it returns a filesystem error or a timeout error if the command +// took too long. +func tryExecInDir(ctx context.Context, dir string) error { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + // Assume that the following executables exist, are executable, and + // immediately return. + var name string + switch runtime.GOOS { + case "windows": + windir := os.Getenv("windir") + name = filepath.Join(windir, "system32", "doskey.exe") + default: + name = "/bin/true" + } + + cmd := exec.CommandContext(ctx, name) + cmd.Dir = dir + return cmd.Run() +} + // newIncubatorCommand returns a new exec.Cmd configured with // `tailscaled be-child ssh` as the entrypoint. // -// If ss.srv.tailscaledPath is empty, this method is equivalent to -// exec.CommandContext. +// If ss.srv.tailscaledPath is empty, this method is almost equivalent to +// exec.CommandContext. It will refuse to run in SFTP-mode. It will simulate the +// behavior of SSHD when by falling back to the root directory if it cannot run +// a command in the user’s home directory. // // The returned Cmd.Env is guaranteed to be nil; the caller populates it. func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err error) { @@ -97,7 +132,35 @@ func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err loginShell := ss.conn.localUser.LoginShell() args := shellArgs(isShell, ss.RawCommand()) logf("directly running %s %q", loginShell, args) - return exec.CommandContext(ss.ctx, loginShell, args...), nil + cmd = exec.CommandContext(ss.ctx, loginShell, args...) + + // While running directly instead of using `tailscaled be-child`, + // do what sshd does by running inside the home directory, + // falling back to the root directory it doesn't have permissions. + // This can happen if the system has networked home directories, + // i.e. NFS or SMB, which enable root-squashing by default. + cmd.Dir = ss.conn.localUser.HomeDir + err := tryExecInDir(ss.ctx, cmd.Dir) + switch { + case errors.Is(err, exec.ErrNotFound): + // /bin/true might not be installed on a barebones system, + // so we assume that the home directory does not exist. + cmd.Dir = "/" + case errors.Is(err, fs.ErrPermission) || errors.Is(err, fs.ErrNotExist): + // Ensure that cmd.Dir is the source of the error. + var pathErr *fs.PathError + if errors.As(err, &pathErr) && pathErr.Path == cmd.Dir { + // If we cannot run loginShell in localUser.HomeDir, + // we will try to run this command in the root directory. + cmd.Dir = "/" + } else { + return nil, err + } + case err != nil: + return nil, err + } + + return cmd, nil } lu := ss.conn.localUser @@ -126,7 +189,7 @@ func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err // We have to check the below outside of the incubator process, because it // relies on the "getenforce" command being on the PATH, which it is not // when in the incubator. - if runtime.GOOS == "linux" && hostinfo.IsSELinuxEnforcing() { + if runtime.GOOS == linux && hostinfo.IsSELinuxEnforcing() { incubatorArgs = append(incubatorArgs, "--is-selinux-enforcing") } @@ -171,7 +234,10 @@ func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err } } - return exec.CommandContext(ss.ctx, ss.conn.srv.tailscaledPath, incubatorArgs...), nil + cmd = exec.CommandContext(ss.ctx, ss.conn.srv.tailscaledPath, incubatorArgs...) + // The incubator will chdir into the home directory after it drops privileges. + cmd.Dir = "/" + return cmd, nil } var debugIncubator bool @@ -247,32 +313,44 @@ func parseIncubatorArgs(args []string) (incubatorArgs, error) { return ia, nil } -func (ia incubatorArgs) forwadedEnviron() ([]string, string, error) { +// forwardedEnviron returns the concatenation of the current environment with +// any environment variables specified in ia.encodedEnv. +// +// It also returns allowedExtraKeys, containing the env keys that were passed in +// to ia.encodedEnv. +func (ia incubatorArgs) forwardedEnviron() (env, allowedExtraKeys []string, err error) { environ := os.Environ() + // pass through SSH_AUTH_SOCK environment variable to support ssh agent forwarding - allowListKeys := "SSH_AUTH_SOCK" + // TODO(bradfitz,percy): why is this listed specially? If the parent wanted to included + // it, couldn't it have just passed it to the incubator in encodedEnv? + // If it didn't, no reason for us to pass it to "su -w ..." if it's not in our env + // anyway? (Surely we don't want to inherit the tailscaled parent SSH_AUTH_SOCK, if any) + allowedExtraKeys = []string{"SSH_AUTH_SOCK"} if ia.encodedEnv != "" { unquoted, err := strconv.Unquote(ia.encodedEnv) if err != nil { - return nil, "", fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) + return nil, nil, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) } var extraEnviron []string err = json.Unmarshal([]byte(unquoted), &extraEnviron) if err != nil { - return nil, "", fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) + return nil, nil, fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) } environ = append(environ, extraEnviron...) - for _, v := range extraEnviron { - allowListKeys = fmt.Sprintf("%s,%s", allowListKeys, strings.Split(v, "=")[0]) + for _, kv := range extraEnviron { + if k, _, ok := strings.Cut(kv, "="); ok { + allowedExtraKeys = append(allowedExtraKeys, k) + } } } - return environ, allowListKeys, nil + return environ, allowedExtraKeys, nil } // beIncubator is the entrypoint to the `tailscaled be-child ssh` subcommand. @@ -428,13 +506,13 @@ func tryExecLogin(dlogf logger.Logf, ia incubatorArgs) error { // Only the macOS version of the login command supports executing a // command, all other versions only support launching a shell without // taking any arguments. - if !ia.isShell && runtime.GOOS != "darwin" { + if !ia.isShell && runtime.GOOS != darwin { dlogf("won't use login because we're not in a shell or on macOS") return nil } switch runtime.GOOS { - case "linux", "freebsd", "openbsd": + case linux, freebsd, openbsd: if !ia.hasTTY { dlogf("can't use login because of missing TTY") // We can only use the login command if a shell was requested with @@ -452,7 +530,7 @@ func tryExecLogin(dlogf logger.Logf, ia incubatorArgs) error { loginArgs := ia.loginArgs(loginCmdPath) dlogf("logging in with %+v", loginArgs) - environ, _, err := ia.forwadedEnviron() + environ, _, err := ia.forwardedEnviron() if err != nil { return err } @@ -491,14 +569,14 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) { defer sessionCloser() } - environ, allowListEnvKeys, err := ia.forwadedEnviron() + environ, allowListEnvKeys, err := ia.forwardedEnviron() if err != nil { return false, err } loginArgs := []string{ su, - "-w", allowListEnvKeys, + "-w", strings.Join(allowListEnvKeys, ","), "-l", ia.localUser, } @@ -523,7 +601,7 @@ func trySU(dlogf logger.Logf, ia incubatorArgs) (handled bool, err error) { func findSU(dlogf logger.Logf, ia incubatorArgs) string { // Currently, we only support falling back to su on Linux. This // potentially could work on BSDs as well, but requires testing. - if runtime.GOOS != "linux" { + if runtime.GOOS != linux { return "" } @@ -539,7 +617,7 @@ func findSU(dlogf logger.Logf, ia incubatorArgs) string { return "" } - _, allowListEnvKeys, err := ia.forwadedEnviron() + _, allowListEnvKeys, err := ia.forwardedEnviron() if err != nil { return "" } @@ -548,7 +626,7 @@ func findSU(dlogf logger.Logf, ia incubatorArgs) string { // to make sure su supports the necessary arguments. err = exec.Command( su, - "-w", allowListEnvKeys, + "-w", strings.Join(allowListEnvKeys, ","), "-l", ia.localUser, "-c", "true", @@ -575,7 +653,7 @@ func handleSSHInProcess(dlogf logger.Logf, ia incubatorArgs) error { return err } - environ, _, err := ia.forwadedEnviron() + environ, _, err := ia.forwardedEnviron() if err != nil { return err } @@ -659,7 +737,7 @@ func doDropPrivileges(dlogf logger.Logf, wantUid, wantGid int, supplementaryGrou euid := os.Geteuid() egid := os.Getegid() - if runtime.GOOS == "darwin" || runtime.GOOS == "freebsd" { + if runtime.GOOS == darwin || runtime.GOOS == freebsd { // On FreeBSD and Darwin, the first entry returned from the // getgroups(2) syscall is the egid, and changing it with // setgroups(2) changes the egid of the process. This is @@ -758,7 +836,6 @@ func (ss *sshSession) launchProcess() error { } cmd := ss.cmd - cmd.Dir = "/" cmd.Env = envForUser(ss.conn.localUser) for _, kv := range ss.Environ() { if acceptEnvPair(kv) { @@ -1014,10 +1091,10 @@ func (ss *sshSession) startWithStdPipes() (err error) { func envForUser(u *userMeta) []string { return []string{ - fmt.Sprintf("SHELL=" + u.LoginShell()), - fmt.Sprintf("USER=" + u.Username), - fmt.Sprintf("HOME=" + u.HomeDir), - fmt.Sprintf("PATH=" + defaultPathForUser(&u.User)), + fmt.Sprintf("SHELL=%s", u.LoginShell()), + fmt.Sprintf("USER=%s", u.Username), + fmt.Sprintf("HOME=%s", u.HomeDir), + fmt.Sprintf("PATH=%s", defaultPathForUser(&u.User)), } } @@ -1051,7 +1128,7 @@ func fileExists(path string) bool { // loginArgs returns the arguments to use to exec the login binary. func (ia *incubatorArgs) loginArgs(loginCmdPath string) []string { switch runtime.GOOS { - case "darwin": + case darwin: args := []string{ loginCmdPath, "-f", // already authenticated @@ -1071,7 +1148,7 @@ func (ia *incubatorArgs) loginArgs(loginCmdPath string) []string { } return args - case "linux": + case linux: if distro.Get() == distro.Arch && !fileExists("/etc/pam.d/remote") { // See https://github.com/tailscale/tailscale/issues/4924 // @@ -1081,7 +1158,7 @@ func (ia *incubatorArgs) loginArgs(loginCmdPath string) []string { return []string{loginCmdPath, "-f", ia.localUser, "-p"} } return []string{loginCmdPath, "-f", ia.localUser, "-h", ia.remoteIP, "-p"} - case "freebsd", "openbsd": + case freebsd, openbsd: return []string{loginCmdPath, "-fp", "-h", ia.remoteIP, ia.localUser} } panic("unimplemented") @@ -1089,6 +1166,10 @@ func (ia *incubatorArgs) loginArgs(loginCmdPath string) []string { func shellArgs(isShell bool, cmd string) []string { if isShell { + if runtime.GOOS == freebsd || runtime.GOOS == openbsd { + // bsd shells don't support the "-l" option, so we can't run as a login shell + return []string{} + } return []string{"-l"} } else { return []string{"-c", cmd} @@ -1096,7 +1177,7 @@ func shellArgs(isShell bool, cmd string) []string { } func setGroups(groupIDs []int) error { - if runtime.GOOS == "darwin" && len(groupIDs) > 16 { + if runtime.GOOS == darwin && len(groupIDs) > 16 { // darwin returns "invalid argument" if more than 16 groups are passed to syscall.Setgroups // some info can be found here: // https://opensource.apple.com/source/samba/samba-187.8/patches/support-darwin-initgroups-syscall.auto.html diff --git a/ssh/tailssh/incubator_plan9.go b/ssh/tailssh/incubator_plan9.go new file mode 100644 index 0000000000000..61b6a54ebdc94 --- /dev/null +++ b/ssh/tailssh/incubator_plan9.go @@ -0,0 +1,421 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This file contains the plan9-specific version of the incubator. Tailscaled +// launches the incubator as the same user as it was launched as. The +// incubator then registers a new session with the OS, sets its UID +// and groups to the specified `--uid`, `--gid` and `--groups`, and +// then launches the requested `--cmd`. + +package tailssh + +import ( + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "log" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "sync/atomic" + + "github.com/go4org/plan9netshell" + "github.com/pkg/sftp" + "tailscale.com/cmd/tailscaled/childproc" + "tailscale.com/tailcfg" + "tailscale.com/types/logger" +) + +func init() { + childproc.Add("ssh", beIncubator) + childproc.Add("sftp", beSFTP) + childproc.Add("plan9-netshell", beNetshell) +} + +// newIncubatorCommand returns a new exec.Cmd configured with +// `tailscaled be-child ssh` as the entrypoint. +// +// If ss.srv.tailscaledPath is empty, this method is equivalent to +// exec.CommandContext. +// +// The returned Cmd.Env is guaranteed to be nil; the caller populates it. +func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err error) { + defer func() { + if cmd.Env != nil { + panic("internal error") + } + }() + + var isSFTP, isShell bool + switch ss.Subsystem() { + case "sftp": + isSFTP = true + case "": + isShell = ss.RawCommand() == "" + default: + panic(fmt.Sprintf("unexpected subsystem: %v", ss.Subsystem())) + } + + if ss.conn.srv.tailscaledPath == "" { + if isSFTP { + // SFTP relies on the embedded Go-based SFTP server in tailscaled, + // so without tailscaled, we can't serve SFTP. + return nil, errors.New("no tailscaled found on path, can't serve SFTP") + } + + loginShell := ss.conn.localUser.LoginShell() + logf("directly running /bin/rc -c %q", ss.RawCommand()) + return exec.CommandContext(ss.ctx, loginShell, "-c", ss.RawCommand()), nil + } + + lu := ss.conn.localUser + ci := ss.conn.info + remoteUser := ci.uprof.LoginName + if ci.node.IsTagged() { + remoteUser = strings.Join(ci.node.Tags().AsSlice(), ",") + } + + incubatorArgs := []string{ + "be-child", + "ssh", + // TODO: "--uid=" + lu.Uid, + // TODO: "--gid=" + lu.Gid, + "--local-user=" + lu.Username, + "--home-dir=" + lu.HomeDir, + "--remote-user=" + remoteUser, + "--remote-ip=" + ci.src.Addr().String(), + "--has-tty=false", // updated in-place by startWithPTY + "--tty-name=", // updated in-place by startWithPTY + } + + nm := ss.conn.srv.lb.NetMap() + forceV1Behavior := nm.HasCap(tailcfg.NodeAttrSSHBehaviorV1) && !nm.HasCap(tailcfg.NodeAttrSSHBehaviorV2) + if forceV1Behavior { + incubatorArgs = append(incubatorArgs, "--force-v1-behavior") + } + + if debugTest.Load() { + incubatorArgs = append(incubatorArgs, "--debug-test") + } + + switch { + case isSFTP: + // Note that we include both the `--sftp` flag and a command to launch + // tailscaled as `be-child sftp`. If login or su is available, and + // we're not running with tailcfg.NodeAttrSSHBehaviorV1, this will + // result in serving SFTP within a login shell, with full PAM + // integration. Otherwise, we'll serve SFTP in the incubator process + // with no PAM integration. + incubatorArgs = append(incubatorArgs, "--sftp", fmt.Sprintf("--cmd=%s be-child sftp", ss.conn.srv.tailscaledPath)) + case isShell: + incubatorArgs = append(incubatorArgs, "--shell") + default: + incubatorArgs = append(incubatorArgs, "--cmd="+ss.RawCommand()) + } + + allowSendEnv := nm.HasCap(tailcfg.NodeAttrSSHEnvironmentVariables) + if allowSendEnv { + env, err := filterEnv(ss.conn.acceptEnv, ss.Session.Environ()) + if err != nil { + return nil, err + } + + if len(env) > 0 { + encoded, err := json.Marshal(env) + if err != nil { + return nil, fmt.Errorf("failed to encode environment: %w", err) + } + incubatorArgs = append(incubatorArgs, fmt.Sprintf("--encoded-env=%q", encoded)) + } + } + + return exec.CommandContext(ss.ctx, ss.conn.srv.tailscaledPath, incubatorArgs...), nil +} + +var debugTest atomic.Bool + +type stdRWC struct{} + +func (stdRWC) Read(p []byte) (n int, err error) { + return os.Stdin.Read(p) +} + +func (stdRWC) Write(b []byte) (n int, err error) { + return os.Stdout.Write(b) +} + +func (stdRWC) Close() error { + os.Exit(0) + return nil +} + +type incubatorArgs struct { + localUser string + homeDir string + remoteUser string + remoteIP string + ttyName string + hasTTY bool + cmd string + isSFTP bool + isShell bool + forceV1Behavior bool + debugTest bool + isSELinuxEnforcing bool + encodedEnv string +} + +func parseIncubatorArgs(args []string) (incubatorArgs, error) { + var ia incubatorArgs + + flags := flag.NewFlagSet("", flag.ExitOnError) + flags.StringVar(&ia.localUser, "local-user", "", "the user to run as") + flags.StringVar(&ia.homeDir, "home-dir", "/", "the user's home directory") + flags.StringVar(&ia.remoteUser, "remote-user", "", "the remote user/tags") + flags.StringVar(&ia.remoteIP, "remote-ip", "", "the remote Tailscale IP") + flags.StringVar(&ia.ttyName, "tty-name", "", "the tty name (pts/3)") + flags.BoolVar(&ia.hasTTY, "has-tty", false, "is the output attached to a tty") + flags.StringVar(&ia.cmd, "cmd", "", "the cmd to launch, including all arguments (ignored in sftp mode)") + flags.BoolVar(&ia.isShell, "shell", false, "is launching a shell (with no cmds)") + flags.BoolVar(&ia.isSFTP, "sftp", false, "run sftp server (cmd is ignored)") + flags.BoolVar(&ia.forceV1Behavior, "force-v1-behavior", false, "allow falling back to the su command if login is unavailable") + flags.BoolVar(&ia.debugTest, "debug-test", false, "should debug in test mode") + flags.BoolVar(&ia.isSELinuxEnforcing, "is-selinux-enforcing", false, "whether SELinux is in enforcing mode") + flags.StringVar(&ia.encodedEnv, "encoded-env", "", "JSON encoded array of environment variables in '['key=value']' format") + flags.Parse(args) + return ia, nil +} + +func (ia incubatorArgs) forwardedEnviron() ([]string, string, error) { + environ := os.Environ() + // pass through SSH_AUTH_SOCK environment variable to support ssh agent forwarding + allowListKeys := "SSH_AUTH_SOCK" + + if ia.encodedEnv != "" { + unquoted, err := strconv.Unquote(ia.encodedEnv) + if err != nil { + return nil, "", fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) + } + + var extraEnviron []string + + err = json.Unmarshal([]byte(unquoted), &extraEnviron) + if err != nil { + return nil, "", fmt.Errorf("unable to parse encodedEnv %q: %w", ia.encodedEnv, err) + } + + environ = append(environ, extraEnviron...) + + for _, v := range extraEnviron { + allowListKeys = fmt.Sprintf("%s,%s", allowListKeys, strings.Split(v, "=")[0]) + } + } + + return environ, allowListKeys, nil +} + +func beNetshell(args []string) error { + plan9netshell.Main() + return nil +} + +// beIncubator is the entrypoint to the `tailscaled be-child ssh` subcommand. +// It is responsible for informing the system of a new login session for the +// user. This is sometimes necessary for mounting home directories and +// decrypting file systems. +// +// Tailscaled launches the incubator as the same user as it was launched as. +func beIncubator(args []string) error { + // To defend against issues like https://golang.org/issue/1435, + // defensively lock our current goroutine's thread to the current + // system thread before we start making any UID/GID/group changes. + // + // This shouldn't matter on Linux because syscall.AllThreadsSyscall is + // used to invoke syscalls on all OS threads, but (as of 2023-03-23) + // that function is not implemented on all platforms. + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + ia, err := parseIncubatorArgs(args) + if err != nil { + return err + } + if ia.isSFTP && ia.isShell { + return fmt.Errorf("--sftp and --shell are mutually exclusive") + } + + if ia.isShell { + plan9netshell.Main() + return nil + } + + dlogf := logger.Discard + if ia.debugTest { + // In testing, we don't always have syslog, so log to a temp file. + if logFile, err := os.OpenFile("/tmp/tailscalessh.log", os.O_APPEND|os.O_WRONLY, 0666); err == nil { + lf := log.New(logFile, "", 0) + dlogf = func(msg string, args ...any) { + lf.Printf(msg, args...) + logFile.Sync() + } + defer logFile.Close() + } + } + + return handleInProcess(dlogf, ia) +} + +func handleInProcess(dlogf logger.Logf, ia incubatorArgs) error { + if ia.isSFTP { + return handleSFTPInProcess(dlogf, ia) + } + return handleSSHInProcess(dlogf, ia) +} + +func handleSFTPInProcess(dlogf logger.Logf, ia incubatorArgs) error { + dlogf("handling sftp") + + return serveSFTP() +} + +// beSFTP serves SFTP in-process. +func beSFTP(args []string) error { + return serveSFTP() +} + +func serveSFTP() error { + server, err := sftp.NewServer(stdRWC{}) + if err != nil { + return err + } + // TODO(https://github.com/pkg/sftp/pull/554): Revert the check for io.EOF, + // when sftp is patched to report clean termination. + if err := server.Serve(); err != nil && err != io.EOF { + return err + } + return nil +} + +// handleSSHInProcess is a last resort if we couldn't use login or su. It +// registers a new session with the OS, sets its UID, GID and groups to the +// specified values, and then launches the requested `--cmd` in the user's +// login shell. +func handleSSHInProcess(dlogf logger.Logf, ia incubatorArgs) error { + + environ, _, err := ia.forwardedEnviron() + if err != nil { + return err + } + + dlogf("running /bin/rc -c %q", ia.cmd) + cmd := newCommand("/bin/rc", environ, []string{"-c", ia.cmd}) + err = cmd.Run() + if ee, ok := err.(*exec.ExitError); ok { + ps := ee.ProcessState + code := ps.ExitCode() + if code < 0 { + // TODO(bradfitz): do we need to also check the syscall.WaitStatus + // and make our process look like it also died by signal/same signal + // as our child process? For now we just do the exit code. + fmt.Fprintf(os.Stderr, "[tailscale-ssh: process died: %v]\n", ps.String()) + code = 1 // for now. so we don't exit with negative + } + os.Exit(code) + } + return err +} + +func newCommand(cmdPath string, cmdEnviron []string, cmdArgs []string) *exec.Cmd { + cmd := exec.Command(cmdPath, cmdArgs...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Env = cmdEnviron + + return cmd +} + +// launchProcess launches an incubator process for the provided session. +// It is responsible for configuring the process execution environment. +// The caller can wait for the process to exit by calling cmd.Wait(). +// +// It sets ss.cmd, stdin, stdout, and stderr. +func (ss *sshSession) launchProcess() error { + var err error + ss.cmd, err = ss.newIncubatorCommand(ss.logf) + if err != nil { + return err + } + + cmd := ss.cmd + cmd.Dir = "/" + cmd.Env = append(os.Environ(), envForUser(ss.conn.localUser)...) + for _, kv := range ss.Environ() { + if acceptEnvPair(kv) { + cmd.Env = append(cmd.Env, kv) + } + } + + ci := ss.conn.info + cmd.Env = append(cmd.Env, + fmt.Sprintf("SSH_CLIENT=%s %d %d", ci.src.Addr(), ci.src.Port(), ci.dst.Port()), + fmt.Sprintf("SSH_CONNECTION=%s %d %s %d", ci.src.Addr(), ci.src.Port(), ci.dst.Addr(), ci.dst.Port()), + ) + + if ss.agentListener != nil { + cmd.Env = append(cmd.Env, fmt.Sprintf("SSH_AUTH_SOCK=%s", ss.agentListener.Addr())) + } + + return ss.startWithStdPipes() +} + +// startWithStdPipes starts cmd with os.Pipe for Stdin, Stdout and Stderr. +func (ss *sshSession) startWithStdPipes() (err error) { + var rdStdin, wrStdout, wrStderr io.ReadWriteCloser + defer func() { + if err != nil { + closeAll(rdStdin, ss.wrStdin, ss.rdStdout, wrStdout, ss.rdStderr, wrStderr) + } + }() + if ss.cmd == nil { + return errors.New("nil cmd") + } + if rdStdin, ss.wrStdin, err = os.Pipe(); err != nil { + return err + } + if ss.rdStdout, wrStdout, err = os.Pipe(); err != nil { + return err + } + if ss.rdStderr, wrStderr, err = os.Pipe(); err != nil { + return err + } + ss.cmd.Stdin = rdStdin + ss.cmd.Stdout = wrStdout + ss.cmd.Stderr = wrStderr + ss.childPipes = []io.Closer{rdStdin, wrStdout, wrStderr} + return ss.cmd.Start() +} + +func envForUser(u *userMeta) []string { + return []string{ + fmt.Sprintf("user=%s", u.Username), + fmt.Sprintf("home=%s", u.HomeDir), + fmt.Sprintf("path=%s", defaultPathForUser(&u.User)), + } +} + +// acceptEnvPair reports whether the environment variable key=value pair +// should be accepted from the client. It uses the same default as OpenSSH +// AcceptEnv. +func acceptEnvPair(kv string) bool { + k, _, ok := strings.Cut(kv, "=") + if !ok { + return false + } + _ = k + return true // permit anything on plan9 during bringup, for debugging at least +} diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 9ade1847e6b27..e42f09bdfb4e4 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || (darwin && !ios) || freebsd || openbsd +//go:build linux || (darwin && !ios) || freebsd || openbsd || plan9 // Package tailssh is an SSH server integrated into Tailscale. package tailssh @@ -10,7 +10,6 @@ import ( "bytes" "context" "crypto/rand" - "encoding/base64" "encoding/json" "errors" "fmt" @@ -30,7 +29,7 @@ import ( "syscall" "time" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" "tailscale.com/envknob" "tailscale.com/ipn/ipnlocal" "tailscale.com/logtail/backoff" @@ -45,7 +44,6 @@ import ( "tailscale.com/util/clientmetric" "tailscale.com/util/httpm" "tailscale.com/util/mak" - "tailscale.com/util/slicesx" ) var ( @@ -53,6 +51,11 @@ var ( sshDisableSFTP = envknob.RegisterBool("TS_SSH_DISABLE_SFTP") sshDisableForwarding = envknob.RegisterBool("TS_SSH_DISABLE_FORWARDING") sshDisablePTY = envknob.RegisterBool("TS_SSH_DISABLE_PTY") + + // errTerminal is an empty gossh.PartialSuccessError (with no 'Next' + // authentication methods that may proceed), which results in the SSH + // server immediately disconnecting the client. + errTerminal = &gossh.PartialSuccessError{} ) const ( @@ -80,16 +83,14 @@ type server struct { logf logger.Logf tailscaledPath string - pubKeyHTTPClient *http.Client // or nil for http.DefaultClient - timeNow func() time.Time // or nil for time.Now + timeNow func() time.Time // or nil for time.Now sessionWaitGroup sync.WaitGroup // mu protects the following - mu sync.Mutex - activeConns map[*conn]bool // set; value is always true - fetchPublicKeysCache map[string]pubKeyCacheEntry // by https URL - shutdownCalled bool + mu sync.Mutex + activeConns map[*conn]bool // set; value is always true + shutdownCalled bool } func (srv *server) now() time.Time { @@ -202,9 +203,11 @@ func (srv *server) OnPolicyChange() { // Setup and discover server info // - ServerConfigCallback // -// Do the user auth -// - NoClientAuthHandler -// - PublicKeyHandler (only if NoClientAuthHandler returns errPubKeyRequired) +// Get access to a ServerPreAuthConn (useful for sending banners) +// +// Do the user auth with a NoClientAuthCallback. If user specified +// a username ending in "+password", follow this with password auth +// (to work around buggy SSH clients that don't work with noauth). // // Once auth is done, the conn can be multiplexed with multiple sessions and // channels concurrently. At which point any of the following can be called @@ -224,20 +227,16 @@ type conn struct { idH string connID string // ID that's shared with control - // anyPasswordIsOkay is whether the client is authorized but has requested - // password-based auth to work around their buggy SSH client. When set, we - // accept any password in the PasswordHandler. - anyPasswordIsOkay bool // set by NoClientAuthCallback + // spac is a [gossh.ServerPreAuthConn] used for sending auth banners. + // Banners cannot be sent after auth completes. + spac gossh.ServerPreAuthConn - action0 *tailcfg.SSHAction // set by doPolicyAuth; first matching action - currentAction *tailcfg.SSHAction // set by doPolicyAuth, updated by resolveNextAction - finalAction *tailcfg.SSHAction // set by doPolicyAuth or resolveNextAction - finalActionErr error // set by doPolicyAuth or resolveNextAction + action0 *tailcfg.SSHAction // set by clientAuth + finalAction *tailcfg.SSHAction // set by clientAuth - info *sshConnInfo // set by setInfo - localUser *userMeta // set by doPolicyAuth - userGroupIDs []string // set by doPolicyAuth - pubKey gossh.PublicKey // set by doPolicyAuth + info *sshConnInfo // set by setInfo + localUser *userMeta // set by clientAuth + userGroupIDs []string // set by clientAuth acceptEnv []string // mu protects the following fields. @@ -260,172 +259,183 @@ func (c *conn) vlogf(format string, args ...any) { } } -// isAuthorized walks through the action chain and returns nil if the connection -// is authorized. If the connection is not authorized, it returns -// errDenied. If the action chain resolution fails, it returns the -// resolution error. -func (c *conn) isAuthorized(ctx ssh.Context) error { - action := c.currentAction - for { - if action.Accept { - if c.pubKey != nil { - metricPublicKeyAccepts.Add(1) - } - return nil - } - if action.Reject || action.HoldAndDelegate == "" { - return errDenied - } - var err error - action, err = c.resolveNextAction(ctx) - if err != nil { - return err - } - if action.Message != "" { - if err := ctx.SendAuthBanner(action.Message); err != nil { - return err - } - } - } -} - // errDenied is returned by auth callbacks when a connection is denied by the -// policy. -var errDenied = errors.New("ssh: access denied") - -// errPubKeyRequired is returned by NoClientAuthCallback to make the client -// resort to public-key auth; not user visible. -var errPubKeyRequired = errors.New("ssh publickey required") - -// NoClientAuthCallback implements gossh.NoClientAuthCallback and is called by -// the ssh.Server when the client first connects with the "none" -// authentication method. -// -// It is responsible for continuing policy evaluation from BannerCallback (or -// starting it afresh). It returns an error if the policy evaluation fails, or -// if the decision is "reject" -// -// It either returns nil (accept) or errPubKeyRequired or errDenied -// (reject). The errors may be wrapped. -func (c *conn) NoClientAuthCallback(ctx ssh.Context) error { - if c.insecureSkipTailscaleAuth { - return nil +// policy. It writes the message to an auth banner and then returns an empty +// gossh.PartialSuccessError in order to stop processing authentication +// attempts and immediately disconnect the client. +func (c *conn) errDenied(message string) error { + if message == "" { + message = "tailscale: access denied" } - if err := c.doPolicyAuth(ctx, nil /* no pub key */); err != nil { - return err + if err := c.spac.SendAuthBanner(message); err != nil { + c.logf("failed to send auth banner: %s", err) } - if err := c.isAuthorized(ctx); err != nil { - return err - } - - // Let users specify a username ending in +password to force password auth. - // This exists for buggy SSH clients that get confused by success from - // "none" auth. - if strings.HasSuffix(ctx.User(), forcePasswordSuffix) { - c.anyPasswordIsOkay = true - return errors.New("any password please") // not shown to users - } - return nil + return errTerminal } -func (c *conn) nextAuthMethodCallback(cm gossh.ConnMetadata, prevErrors []error) (nextMethod []string) { - switch { - case c.anyPasswordIsOkay: - nextMethod = append(nextMethod, "password") - case slicesx.LastEqual(prevErrors, errPubKeyRequired): - nextMethod = append(nextMethod, "publickey") +// errBanner writes the given message to an auth banner and then returns an +// empty gossh.PartialSuccessError in order to stop processing authentication +// attempts and immediately disconnect the client. The contents of err is not +// leaked in the auth banner, but it is logged to the server's log. +func (c *conn) errBanner(message string, err error) error { + if err != nil { + c.logf("%s: %s", message, err) } - - // The fake "tailscale" method is always appended to next so OpenSSH renders - // that in parens as the final failure. (It also shows up in "ssh -v", etc) - nextMethod = append(nextMethod, "tailscale") - return + if err := c.spac.SendAuthBanner("tailscale: " + message); err != nil { + c.logf("failed to send auth banner: %s", err) + } + return errTerminal } -// fakePasswordHandler is our implementation of the PasswordHandler hook that -// checks whether the user's password is correct. But we don't actually use -// passwords. This exists only for when the user's username ends in "+password" -// to signal that their SSH client is buggy and gets confused by auth type -// "none" succeeding and they want our SSH server to require a dummy password -// prompt instead. We then accept any password since we've already authenticated -// & authorized them. -func (c *conn) fakePasswordHandler(ctx ssh.Context, password string) bool { - return c.anyPasswordIsOkay +// errUnexpected is returned by auth callbacks that encounter an unexpected +// error, such as being unable to send an auth banner. It sends an empty +// gossh.PartialSuccessError to tell gossh.Server to stop processing +// authentication attempts and instead disconnect immediately. +func (c *conn) errUnexpected(err error) error { + c.logf("terminal error: %s", err) + return errTerminal } -// PublicKeyHandler implements ssh.PublicKeyHandler is called by the -// ssh.Server when the client presents a public key. -func (c *conn) PublicKeyHandler(ctx ssh.Context, pubKey ssh.PublicKey) error { - if err := c.doPolicyAuth(ctx, pubKey); err != nil { - // TODO(maisem/bradfitz): surface the error here. - c.logf("rejecting SSH public key %s: %v", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey)), err) - return err +// clientAuth is responsible for performing client authentication. +// +// If policy evaluation fails, it returns an error. +// If access is denied, it returns an error. This must always be an empty +// gossh.PartialSuccessError to prevent further authentication methods from +// being tried. +func (c *conn) clientAuth(cm gossh.ConnMetadata) (perms *gossh.Permissions, retErr error) { + defer func() { + if pse, ok := retErr.(*gossh.PartialSuccessError); ok { + if pse.Next.GSSAPIWithMICConfig != nil || + pse.Next.KeyboardInteractiveCallback != nil || + pse.Next.PasswordCallback != nil || + pse.Next.PublicKeyCallback != nil { + panic("clientAuth attempted to return a non-empty PartialSuccessError") + } + } else if retErr != nil { + panic(fmt.Sprintf("clientAuth attempted to return a non-PartialSuccessError error of type: %t", retErr)) + } + }() + + if c.insecureSkipTailscaleAuth { + return &gossh.Permissions{}, nil } - if err := c.isAuthorized(ctx); err != nil { - return err + + if err := c.setInfo(cm); err != nil { + return nil, c.errBanner("failed to get connection info", err) } - c.logf("accepting SSH public key %s", bytes.TrimSpace(gossh.MarshalAuthorizedKey(pubKey))) - return nil -} -// doPolicyAuth verifies that conn can proceed with the specified (optional) -// pubKey. It returns nil if the matching policy action is Accept or -// HoldAndDelegate. If pubKey is nil, there was no policy match but there is a -// policy that might match a public key it returns errPubKeyRequired. Otherwise, -// it returns errDenied. -func (c *conn) doPolicyAuth(ctx ssh.Context, pubKey ssh.PublicKey) error { - if err := c.setInfo(ctx); err != nil { - c.logf("failed to get conninfo: %v", err) - return errDenied - } - a, localUser, acceptEnv, err := c.evaluatePolicy(pubKey) + action, localUser, acceptEnv, err := c.evaluatePolicy() if err != nil { - if pubKey == nil && c.havePubKeyPolicy() { - return errPubKeyRequired - } - return fmt.Errorf("%w: %v", errDenied, err) - } - c.action0 = a - c.currentAction = a - c.pubKey = pubKey - c.acceptEnv = acceptEnv - if a.Message != "" { - if err := ctx.SendAuthBanner(a.Message); err != nil { - return fmt.Errorf("SendBanner: %w", err) - } + return nil, c.errBanner("failed to evaluate SSH policy", err) } - if a.Accept || a.HoldAndDelegate != "" { - if a.Accept { - c.finalAction = a - } + + c.action0 = action + + if action.Accept || action.HoldAndDelegate != "" { + // Immediately look up user information for purposes of generating + // hold and delegate URL (if necessary). lu, err := userLookup(localUser) if err != nil { - c.logf("failed to look up %v: %v", localUser, err) - ctx.SendAuthBanner(fmt.Sprintf("failed to look up %v\r\n", localUser)) - return err + return nil, c.errBanner(fmt.Sprintf("failed to look up local user %q ", localUser), err) } gids, err := lu.GroupIds() if err != nil { - c.logf("failed to look up local user's group IDs: %v", err) - return err + return nil, c.errBanner("failed to look up local user's group IDs", err) } c.userGroupIDs = gids c.localUser = lu - return nil + c.acceptEnv = acceptEnv } - if a.Reject { - c.finalAction = a - return errDenied + + for { + switch { + case action.Accept: + metricTerminalAccept.Add(1) + if action.Message != "" { + if err := c.spac.SendAuthBanner(action.Message); err != nil { + return nil, c.errUnexpected(fmt.Errorf("error sending auth welcome message: %w", err)) + } + } + c.finalAction = action + return &gossh.Permissions{}, nil + case action.Reject: + metricTerminalReject.Add(1) + c.finalAction = action + return nil, c.errDenied(action.Message) + case action.HoldAndDelegate != "": + if action.Message != "" { + if err := c.spac.SendAuthBanner(action.Message); err != nil { + return nil, c.errUnexpected(fmt.Errorf("error sending hold and delegate message: %w", err)) + } + } + + url := action.HoldAndDelegate + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + defer cancel() + + metricHolds.Add(1) + url = c.expandDelegateURLLocked(url) + + var err error + action, err = c.fetchSSHAction(ctx, url) + if err != nil { + metricTerminalFetchError.Add(1) + return nil, c.errBanner("failed to fetch next SSH action", fmt.Errorf("fetch failed from %s: %w", url, err)) + } + default: + metricTerminalMalformed.Add(1) + return nil, c.errBanner("reached Action that had neither Accept, Reject, nor HoldAndDelegate", nil) + } } - // Shouldn't get here, but: - return errDenied } // ServerConfig implements ssh.ServerConfigCallback. func (c *conn) ServerConfig(ctx ssh.Context) *gossh.ServerConfig { return &gossh.ServerConfig{ - NoClientAuth: true, // required for the NoClientAuthCallback to run - NextAuthMethodCallback: c.nextAuthMethodCallback, + PreAuthConnCallback: func(spac gossh.ServerPreAuthConn) { + c.spac = spac + }, + NoClientAuth: true, // required for the NoClientAuthCallback to run + NoClientAuthCallback: func(cm gossh.ConnMetadata) (*gossh.Permissions, error) { + // First perform client authentication, which can potentially + // involve multiple steps (for example prompting user to log in to + // Tailscale admin panel to confirm identity). + perms, err := c.clientAuth(cm) + if err != nil { + return nil, err + } + + // Authentication succeeded. Buggy SSH clients get confused by + // success from the "none" auth method. As a workaround, let users + // specify a username ending in "+password" to force password auth. + // The actual value of the password doesn't matter. + if strings.HasSuffix(cm.User(), forcePasswordSuffix) { + return nil, &gossh.PartialSuccessError{ + Next: gossh.ServerAuthCallbacks{ + PasswordCallback: func(_ gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) { + return &gossh.Permissions{}, nil + }, + }, + } + } + + return perms, nil + }, + PasswordCallback: func(cm gossh.ConnMetadata, pword []byte) (*gossh.Permissions, error) { + // Some clients don't request 'none' authentication. Instead, they + // immediately supply a password. We humor them by accepting the + // password, but authenticate as usual, ignoring the actual value of + // the password. + return c.clientAuth(cm) + }, + PublicKeyCallback: func(cm gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) { + // Some clients don't request 'none' authentication. Instead, they + // immediately supply a public key. We humor them by accepting the + // key, but authenticate as usual, ignoring the actual content of + // the key. + return c.clientAuth(cm) + }, } } @@ -436,7 +446,7 @@ func (srv *server) newConn() (*conn, error) { // Stop accepting new connections. // Connections in the auth phase are handled in handleConnPostSSHAuth. // Existing sessions are terminated by Shutdown. - return nil, errDenied + return nil, errors.New("server is shutting down") } srv.mu.Unlock() c := &conn{srv: srv} @@ -447,10 +457,6 @@ func (srv *server) newConn() (*conn, error) { Version: "Tailscale", ServerConfigCallback: c.ServerConfig, - NoClientAuthHandler: c.NoClientAuthCallback, - PublicKeyHandler: c.PublicKeyHandler, - PasswordHandler: c.fakePasswordHandler, - Handler: c.handleSessionPostSSHAuth, LocalPortForwardingCallback: c.mayForwardLocalPortTo, ReversePortForwardingCallback: c.mayReversePortForwardTo, @@ -516,34 +522,6 @@ func (c *conn) mayForwardLocalPortTo(ctx ssh.Context, destinationHost string, de return false } -// havePubKeyPolicy reports whether any policy rule may provide access by means -// of a ssh.PublicKey. -func (c *conn) havePubKeyPolicy() bool { - if c.info == nil { - panic("havePubKeyPolicy called before setInfo") - } - // Is there any rule that looks like it'd require a public key for this - // sshUser? - pol, ok := c.sshPolicy() - if !ok { - return false - } - for _, r := range pol.Rules { - if c.ruleExpired(r) { - continue - } - if mapLocalUser(r.SSHUsers, c.info.sshUser) == "" { - continue - } - for _, p := range r.Principals { - if len(p.PubKeys) > 0 && c.principalMatchesTailscaleIdentity(p) { - return true - } - } - } - return false -} - // sshPolicy returns the SSHPolicy for current node. // If there is no SSHPolicy in the netmap, it returns a debugPolicy // if one is defined. @@ -589,16 +567,16 @@ func toIPPort(a net.Addr) (ipp netip.AddrPort) { return netip.AddrPortFrom(tanetaddr.Unmap(), uint16(ta.Port)) } -// connInfo returns a populated sshConnInfo from the provided arguments, +// connInfo populates the sshConnInfo from the provided arguments, // validating only that they represent a known Tailscale identity. -func (c *conn) setInfo(ctx ssh.Context) error { +func (c *conn) setInfo(cm gossh.ConnMetadata) error { if c.info != nil { return nil } ci := &sshConnInfo{ - sshUser: strings.TrimSuffix(ctx.User(), forcePasswordSuffix), - src: toIPPort(ctx.RemoteAddr()), - dst: toIPPort(ctx.LocalAddr()), + sshUser: strings.TrimSuffix(cm.User(), forcePasswordSuffix), + src: toIPPort(cm.RemoteAddr()), + dst: toIPPort(cm.LocalAddr()), } if !tsaddr.IsTailscaleIP(ci.dst.Addr()) { return fmt.Errorf("tailssh: rejecting non-Tailscale local address %v", ci.dst) @@ -613,124 +591,26 @@ func (c *conn) setInfo(ctx ssh.Context) error { ci.node = node ci.uprof = uprof - c.idH = ctx.SessionID() + c.idH = string(cm.SessionID()) c.info = ci c.logf("handling conn: %v", ci.String()) return nil } // evaluatePolicy returns the SSHAction and localUser after evaluating -// the SSHPolicy for this conn. The pubKey may be nil for "none" auth. -func (c *conn) evaluatePolicy(pubKey gossh.PublicKey) (_ *tailcfg.SSHAction, localUser string, acceptEnv []string, _ error) { +// the SSHPolicy for this conn. +func (c *conn) evaluatePolicy() (_ *tailcfg.SSHAction, localUser string, acceptEnv []string, _ error) { pol, ok := c.sshPolicy() if !ok { return nil, "", nil, fmt.Errorf("tailssh: rejecting connection; no SSH policy") } - a, localUser, acceptEnv, ok := c.evalSSHPolicy(pol, pubKey) + a, localUser, acceptEnv, ok := c.evalSSHPolicy(pol) if !ok { return nil, "", nil, fmt.Errorf("tailssh: rejecting connection; no matching policy") } return a, localUser, acceptEnv, nil } -// pubKeyCacheEntry is the cache value for an HTTPS URL of public keys (like -// "https://github.com/foo.keys") -type pubKeyCacheEntry struct { - lines []string - etag string // if sent by server - at time.Time -} - -const ( - pubKeyCacheDuration = time.Minute // how long to cache non-empty public keys - pubKeyCacheEmptyDuration = 15 * time.Second // how long to cache empty responses -) - -func (srv *server) fetchPublicKeysURLCached(url string) (ce pubKeyCacheEntry, ok bool) { - srv.mu.Lock() - defer srv.mu.Unlock() - // Mostly don't care about the size of this cache. Clean rarely. - if m := srv.fetchPublicKeysCache; len(m) > 50 { - tooOld := srv.now().Add(pubKeyCacheDuration * 10) - for k, ce := range m { - if ce.at.Before(tooOld) { - delete(m, k) - } - } - } - ce, ok = srv.fetchPublicKeysCache[url] - if !ok { - return ce, false - } - maxAge := pubKeyCacheDuration - if len(ce.lines) == 0 { - maxAge = pubKeyCacheEmptyDuration - } - return ce, srv.now().Sub(ce.at) < maxAge -} - -func (srv *server) pubKeyClient() *http.Client { - if srv.pubKeyHTTPClient != nil { - return srv.pubKeyHTTPClient - } - return http.DefaultClient -} - -// fetchPublicKeysURL fetches the public keys from a URL. The strings are in the -// the typical public key "type base64-string [comment]" format seen at e.g. -// https://github.com/USER.keys -func (srv *server) fetchPublicKeysURL(url string) ([]string, error) { - if !strings.HasPrefix(url, "https://") { - return nil, errors.New("invalid URL scheme") - } - - ce, ok := srv.fetchPublicKeysURLCached(url) - if ok { - return ce.lines, nil - } - - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) - if err != nil { - return nil, err - } - if ce.etag != "" { - req.Header.Add("If-None-Match", ce.etag) - } - res, err := srv.pubKeyClient().Do(req) - if err != nil { - return nil, err - } - defer res.Body.Close() - var lines []string - var etag string - switch res.StatusCode { - default: - err = fmt.Errorf("unexpected status %v", res.Status) - srv.logf("fetching public keys from %s: %v", url, err) - case http.StatusNotModified: - lines = ce.lines - etag = ce.etag - case http.StatusOK: - var all []byte - all, err = io.ReadAll(io.LimitReader(res.Body, 4<<10)) - if s := strings.TrimSpace(string(all)); s != "" { - lines = strings.Split(s, "\n") - } - etag = res.Header.Get("Etag") - } - - srv.mu.Lock() - defer srv.mu.Unlock() - mak.Set(&srv.fetchPublicKeysCache, url, pubKeyCacheEntry{ - at: srv.now(), - lines: lines, - etag: etag, - }) - return lines, err -} - // handleSessionPostSSHAuth runs an SSH session after the SSH-level authentication, // but not necessarily before all the Tailscale-level extra verification has // completed. It also handles SFTP requests. @@ -758,62 +638,6 @@ func (c *conn) handleSessionPostSSHAuth(s ssh.Session) { ss.run() } -// resolveNextAction starts at c.currentAction and makes it way through the -// action chain one step at a time. An action without a HoldAndDelegate is -// considered the final action. Once a final action is reached, this function -// will keep returning that action. It updates c.currentAction to the next -// action in the chain. When the final action is reached, it also sets -// c.finalAction to the final action. -func (c *conn) resolveNextAction(sctx ssh.Context) (action *tailcfg.SSHAction, err error) { - if c.finalAction != nil || c.finalActionErr != nil { - return c.finalAction, c.finalActionErr - } - - defer func() { - if action != nil { - c.currentAction = action - if action.Accept || action.Reject { - c.finalAction = action - } - } - if err != nil { - c.finalActionErr = err - } - }() - - ctx, cancel := context.WithCancel(sctx) - defer cancel() - - // Loop processing/fetching Actions until one reaches a - // terminal state (Accept, Reject, or invalid Action), or - // until fetchSSHAction times out due to the context being - // done (client disconnect) or its 30 minute timeout passes. - // (Which is a long time for somebody to see login - // instructions and go to a URL to do something.) - action = c.currentAction - if action.Accept || action.Reject { - if action.Reject { - metricTerminalReject.Add(1) - } else { - metricTerminalAccept.Add(1) - } - return action, nil - } - url := action.HoldAndDelegate - if url == "" { - metricTerminalMalformed.Add(1) - return nil, errors.New("reached Action that lacked Accept, Reject, and HoldAndDelegate") - } - metricHolds.Add(1) - url = c.expandDelegateURLLocked(url) - nextAction, err := c.fetchSSHAction(ctx, url) - if err != nil { - metricTerminalFetchError.Add(1) - return nil, fmt.Errorf("fetching SSHAction from %s: %w", url, err) - } - return nextAction, nil -} - func (c *conn) expandDelegateURLLocked(actionURL string) string { nm := c.srv.lb.NetMap() ci := c.info @@ -832,18 +656,6 @@ func (c *conn) expandDelegateURLLocked(actionURL string) string { ).Replace(actionURL) } -func (c *conn) expandPublicKeyURL(pubKeyURL string) string { - if !strings.Contains(pubKeyURL, "$") { - return pubKeyURL - } - loginName := c.info.uprof.LoginName - localPart, _, _ := strings.Cut(loginName, "@") - return strings.NewReplacer( - "$LOGINNAME_EMAIL", loginName, - "$LOGINNAME_LOCALPART", localPart, - ).Replace(pubKeyURL) -} - // sshSession is an accepted Tailscale SSH session. type sshSession struct { ssh.Session @@ -894,7 +706,7 @@ func (c *conn) newSSHSession(s ssh.Session) *sshSession { // isStillValid reports whether the conn is still valid. func (c *conn) isStillValid() bool { - a, localUser, _, err := c.evaluatePolicy(c.pubKey) + a, localUser, _, err := c.evaluatePolicy() c.vlogf("stillValid: %+v %v %v", a, localUser, err) if err != nil { return false @@ -1091,7 +903,7 @@ func (ss *sshSession) run() { defer t.Stop() } - if euid := os.Geteuid(); euid != 0 { + if euid := os.Geteuid(); euid != 0 && runtime.GOOS != "plan9" { if lu.Uid != fmt.Sprint(euid) { ss.logf("can't switch to user %q from process euid %v", lu.Username, euid) fmt.Fprintf(ss, "can't switch user\r\n") @@ -1170,7 +982,7 @@ func (ss *sshSession) run() { if err != nil && !errors.Is(err, io.EOF) { isErrBecauseProcessExited := processDone.Load() && errors.Is(err, syscall.EIO) if !isErrBecauseProcessExited { - logf("stdout copy: %v, %T", err) + logf("stdout copy: %v", err) ss.cancelCtx(err) } } @@ -1277,9 +1089,9 @@ func (c *conn) ruleExpired(r *tailcfg.SSHRule) bool { return r.RuleExpires.Before(c.srv.now()) } -func (c *conn) evalSSHPolicy(pol *tailcfg.SSHPolicy, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, acceptEnv []string, ok bool) { +func (c *conn) evalSSHPolicy(pol *tailcfg.SSHPolicy) (a *tailcfg.SSHAction, localUser string, acceptEnv []string, ok bool) { for _, r := range pol.Rules { - if a, localUser, acceptEnv, err := c.matchRule(r, pubKey); err == nil { + if a, localUser, acceptEnv, err := c.matchRule(r); err == nil { return a, localUser, acceptEnv, true } } @@ -1296,7 +1108,7 @@ var ( errInvalidConn = errors.New("invalid connection state") ) -func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg.SSHAction, localUser string, acceptEnv []string, err error) { +func (c *conn) matchRule(r *tailcfg.SSHRule) (a *tailcfg.SSHAction, localUser string, acceptEnv []string, err error) { defer func() { c.vlogf("matchRule(%+v): %v", r, err) }() @@ -1326,9 +1138,7 @@ func (c *conn) matchRule(r *tailcfg.SSHRule, pubKey gossh.PublicKey) (a *tailcfg return nil, "", nil, errUserMatch } } - if ok, err := c.anyPrincipalMatches(r.Principals, pubKey); err != nil { - return nil, "", nil, err - } else if !ok { + if !c.anyPrincipalMatches(r.Principals) { return nil, "", nil, errPrincipalMatch } return r.Action, localUser, r.AcceptEnv, nil @@ -1345,30 +1155,20 @@ func mapLocalUser(ruleSSHUsers map[string]string, reqSSHUser string) (localUser return v } -func (c *conn) anyPrincipalMatches(ps []*tailcfg.SSHPrincipal, pubKey gossh.PublicKey) (bool, error) { +func (c *conn) anyPrincipalMatches(ps []*tailcfg.SSHPrincipal) bool { for _, p := range ps { if p == nil { continue } - if ok, err := c.principalMatches(p, pubKey); err != nil { - return false, err - } else if ok { - return true, nil + if c.principalMatchesTailscaleIdentity(p) { + return true } } - return false, nil -} - -func (c *conn) principalMatches(p *tailcfg.SSHPrincipal, pubKey gossh.PublicKey) (bool, error) { - if !c.principalMatchesTailscaleIdentity(p) { - return false, nil - } - return c.principalMatchesPubKey(p, pubKey) + return false } // principalMatchesTailscaleIdentity reports whether one of p's four fields // that match the Tailscale identity match (Node, NodeIP, UserLogin, Any). -// This function does not consider PubKeys. func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool { ci := c.info if p.Any { @@ -1388,42 +1188,6 @@ func (c *conn) principalMatchesTailscaleIdentity(p *tailcfg.SSHPrincipal) bool { return false } -func (c *conn) principalMatchesPubKey(p *tailcfg.SSHPrincipal, clientPubKey gossh.PublicKey) (bool, error) { - if len(p.PubKeys) == 0 { - return true, nil - } - if clientPubKey == nil { - return false, nil - } - knownKeys := p.PubKeys - if len(knownKeys) == 1 && strings.HasPrefix(knownKeys[0], "https://") { - var err error - knownKeys, err = c.srv.fetchPublicKeysURL(c.expandPublicKeyURL(knownKeys[0])) - if err != nil { - return false, err - } - } - for _, knownKey := range knownKeys { - if pubKeyMatchesAuthorizedKey(clientPubKey, knownKey) { - return true, nil - } - } - return false, nil -} - -func pubKeyMatchesAuthorizedKey(pubKey ssh.PublicKey, wantKey string) bool { - wantKeyType, rest, ok := strings.Cut(wantKey, " ") - if !ok { - return false - } - if pubKey.Type() != wantKeyType { - return false - } - wantKeyB64, _, _ := strings.Cut(rest, " ") - wantKeyData, _ := base64.StdEncoding.DecodeString(wantKeyB64) - return len(wantKeyData) > 0 && bytes.Equal(pubKey.Marshal(), wantKeyData) -} - func randBytes(n int) []byte { b := make([]byte, n) if _, err := rand.Read(b); err != nil { @@ -1520,9 +1284,14 @@ func (ss *sshSession) startNewRecording() (_ *recording, err error) { go func() { err := <-errChan if err == nil { - // Success. - ss.logf("recording: finished uploading recording") - return + select { + case <-ss.ctx.Done(): + // Success. + ss.logf("recording: finished uploading recording") + return + default: + err = errors.New("recording upload ended before the SSH session") + } } if onFailure != nil && onFailure.NotifyURL != "" && len(attempts) > 0 { lastAttempt := attempts[len(attempts)-1] @@ -1744,7 +1513,6 @@ func envEq(a, b string) bool { var ( metricActiveSessions = clientmetric.NewGauge("ssh_active_sessions") metricIncomingConnections = clientmetric.NewCounter("ssh_incoming_connections") - metricPublicKeyAccepts = clientmetric.NewCounter("ssh_publickey_accepts") // accepted subset of ssh_publickey_connections metricTerminalAccept = clientmetric.NewCounter("ssh_terminalaction_accept") metricTerminalReject = clientmetric.NewCounter("ssh_terminalaction_reject") metricTerminalMalformed = clientmetric.NewCounter("ssh_terminalaction_malformed") diff --git a/ssh/tailssh/tailssh_integration_test.go b/ssh/tailssh/tailssh_integration_test.go index 1799d340019cb..9ab26e169665b 100644 --- a/ssh/tailssh/tailssh_integration_test.go +++ b/ssh/tailssh/tailssh_integration_test.go @@ -2,7 +2,6 @@ // SPDX-License-Identifier: BSD-3-Clause //go:build integrationtest -// +build integrationtest package tailssh @@ -32,8 +31,8 @@ import ( "github.com/bramvdbogaerde/go-scp" "github.com/google/go-cmp/cmp" "github.com/pkg/sftp" - gossh "github.com/tailscale/golang-x-crypto/ssh" "golang.org/x/crypto/ssh" + gossh "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" "tailscale.com/net/tsdial" "tailscale.com/tailcfg" @@ -410,6 +409,48 @@ func TestSSHAgentForwarding(t *testing.T) { } } +// TestIntegrationParamiko attempts to connect to Tailscale SSH using the +// paramiko Python library. This library does not request 'none' auth. This +// test ensures that Tailscale SSH can correctly handle clients that don't +// request 'none' auth and instead immediately authenticate with a public key +// or password. +func TestIntegrationParamiko(t *testing.T) { + debugTest.Store(true) + t.Cleanup(func() { + debugTest.Store(false) + }) + + addr := testServer(t, "testuser", true, false) + host, port, err := net.SplitHostPort(addr) + if err != nil { + t.Fatalf("Failed to split addr %q: %s", addr, err) + } + + out, err := exec.Command("python3", "-c", fmt.Sprintf(` +import paramiko.client as pm +from paramiko.ecdsakey import ECDSAKey +client = pm.SSHClient() +client.set_missing_host_key_policy(pm.AutoAddPolicy) +client.connect('%s', port=%s, username='testuser', pkey=ECDSAKey.generate(), allow_agent=False, look_for_keys=False) +client.exec_command('pwd') +`, host, port)).CombinedOutput() + if err != nil { + t.Fatalf("failed to connect with Paramiko using public key auth: %s\n%q", err, string(out)) + } + + out, err = exec.Command("python3", "-c", fmt.Sprintf(` +import paramiko.client as pm +from paramiko.ecdsakey import ECDSAKey +client = pm.SSHClient() +client.set_missing_host_key_policy(pm.AutoAddPolicy) +client.connect('%s', port=%s, username='testuser', password='doesntmatter', allow_agent=False, look_for_keys=False) +client.exec_command('pwd') +`, host, port)).CombinedOutput() + if err != nil { + t.Fatalf("failed to connect with Paramiko using password auth: %s\n%q", err, string(out)) + } +} + func fallbackToSUAvailable() bool { if runtime.GOOS != "linux" { return false diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 9e4f5ffd3d481..79479d7fbf5c7 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -8,9 +8,10 @@ package tailssh import ( "bytes" "context" + "crypto/ecdsa" "crypto/ed25519" + "crypto/elliptic" "crypto/rand" - "crypto/sha256" "encoding/json" "errors" "fmt" @@ -32,7 +33,9 @@ import ( "testing" "time" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/store/mem" "tailscale.com/net/memnet" @@ -40,15 +43,15 @@ import ( "tailscale.com/sessionrecording" "tailscale.com/tailcfg" "tailscale.com/tempfork/gliderlabs/ssh" + testssh "tailscale.com/tempfork/sshtest/ssh" "tailscale.com/tsd" "tailscale.com/tstest" "tailscale.com/types/key" - "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/netmap" "tailscale.com/types/ptr" "tailscale.com/util/cibuild" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/util/must" "tailscale.com/version/distro" "tailscale.com/wgengine" @@ -225,9 +228,9 @@ func TestMatchRule(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := &conn{ info: tt.ci, - srv: &server{logf: t.Logf}, + srv: &server{logf: tstest.WhileTestRunningLogger(t)}, } - got, gotUser, gotAcceptEnv, err := c.matchRule(tt.rule, nil) + got, gotUser, gotAcceptEnv, err := c.matchRule(tt.rule) if err != tt.wantErr { t.Errorf("err = %v; want %v", err, tt.wantErr) } @@ -344,9 +347,9 @@ func TestEvalSSHPolicy(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := &conn{ info: tt.ci, - srv: &server{logf: t.Logf}, + srv: &server{logf: tstest.WhileTestRunningLogger(t)}, } - got, gotUser, gotAcceptEnv, match := c.evalSSHPolicy(tt.policy, nil) + got, gotUser, gotAcceptEnv, match := c.evalSSHPolicy(tt.policy) if match != tt.wantMatch { t.Errorf("match = %v; want %v", match, tt.wantMatch) } @@ -481,13 +484,12 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { } var handler http.HandlerFunc - recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { handler(w, r) - })) - defer recordingServer.Close() + }) s := &server{ - logf: t.Logf, + logf: tstest.WhileTestRunningLogger(t), lb: &localState{ sshEnabled: true, matchingRule: newSSHRule( @@ -507,9 +509,9 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { defer s.Shutdown() const sshUser = "alice" - cfg := &gossh.ClientConfig{ + cfg := &testssh.ClientConfig{ User: sshUser, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), + HostKeyCallback: testssh.InsecureIgnoreHostKey(), } tests := []struct { @@ -533,9 +535,10 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { { name: "upload-fails-after-starting", handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() r.Body.Read(make([]byte, 1)) time.Sleep(100 * time.Millisecond) - w.WriteHeader(http.StatusInternalServerError) }, sshCommand: "echo hello && sleep 1 && echo world", wantClientOutput: "\r\n\r\nsession terminated\r\n\r\n", @@ -548,18 +551,19 @@ func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + s.logf = tstest.WhileTestRunningLogger(t) tstest.Replace(t, &handler, tt.handler) sc, dc := memnet.NewTCPConn(src, dst, 1024) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) + c, chans, reqs, err := testssh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) if err != nil { t.Errorf("client: %v", err) return } - client := gossh.NewClient(c, chans, reqs) + client := testssh.NewClient(c, chans, reqs) defer client.Close() session, err := client.NewSession() if err != nil { @@ -597,12 +601,12 @@ func TestMultipleRecorders(t *testing.T) { t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS) } done := make(chan struct{}) - recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { defer close(done) - io.ReadAll(r.Body) w.WriteHeader(http.StatusOK) - })) - defer recordingServer.Close() + w.(http.Flusher).Flush() + io.ReadAll(r.Body) + }) badRecorder, err := net.Listen("tcp", ":0") if err != nil { t.Fatal(err) @@ -610,18 +614,12 @@ func TestMultipleRecorders(t *testing.T) { badRecorderAddr := badRecorder.Addr().String() badRecorder.Close() - badRecordingServer500 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(500) - })) - defer badRecordingServer500.Close() - - badRecordingServer200 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - })) - defer badRecordingServer200.Close() + badRecordingServer500 := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + }) s := &server{ - logf: t.Logf, + logf: tstest.WhileTestRunningLogger(t), lb: &localState{ sshEnabled: true, matchingRule: newSSHRule( @@ -630,7 +628,6 @@ func TestMultipleRecorders(t *testing.T) { Recorders: []netip.AddrPort{ netip.MustParseAddrPort(badRecorderAddr), netip.MustParseAddrPort(badRecordingServer500.Listener.Addr().String()), - netip.MustParseAddrPort(badRecordingServer200.Listener.Addr().String()), netip.MustParseAddrPort(recordingServer.Listener.Addr().String()), }, OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{ @@ -647,21 +644,21 @@ func TestMultipleRecorders(t *testing.T) { sc, dc := memnet.NewTCPConn(src, dst, 1024) const sshUser = "alice" - cfg := &gossh.ClientConfig{ + cfg := &testssh.ClientConfig{ User: sshUser, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), + HostKeyCallback: testssh.InsecureIgnoreHostKey(), } var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) + c, chans, reqs, err := testssh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) if err != nil { t.Errorf("client: %v", err) return } - client := gossh.NewClient(c, chans, reqs) + client := testssh.NewClient(c, chans, reqs) defer client.Close() session, err := client.NewSession() if err != nil { @@ -701,19 +698,21 @@ func TestSSHRecordingNonInteractive(t *testing.T) { } var recording []byte ctx, cancel := context.WithTimeout(context.Background(), time.Second) - recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recordingServer := mockRecordingServer(t, func(w http.ResponseWriter, r *http.Request) { defer cancel() + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + var err error recording, err = io.ReadAll(r.Body) if err != nil { t.Error(err) return } - })) - defer recordingServer.Close() + }) s := &server{ - logf: logger.Discard, + logf: tstest.WhileTestRunningLogger(t), lb: &localState{ sshEnabled: true, matchingRule: newSSHRule( @@ -736,21 +735,21 @@ func TestSSHRecordingNonInteractive(t *testing.T) { sc, dc := memnet.NewTCPConn(src, dst, 1024) const sshUser = "alice" - cfg := &gossh.ClientConfig{ + cfg := &testssh.ClientConfig{ User: sshUser, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), + HostKeyCallback: testssh.InsecureIgnoreHostKey(), } var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() - c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) + c, chans, reqs, err := testssh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) if err != nil { t.Errorf("client: %v", err) return } - client := gossh.NewClient(c, chans, reqs) + client := testssh.NewClient(c, chans, reqs) defer client.Close() session, err := client.NewSession() if err != nil { @@ -808,7 +807,8 @@ func TestSSHAuthFlow(t *testing.T) { state: &localState{ sshEnabled: true, }, - authErr: true, + authErr: true, + wantBanners: []string{"tailscale: failed to evaluate SSH policy"}, }, { name: "accept", @@ -885,87 +885,160 @@ func TestSSHAuthFlow(t *testing.T) { }, } s := &server{ - logf: logger.Discard, + logf: tstest.WhileTestRunningLogger(t), } defer s.Shutdown() src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22")) for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - sc, dc := memnet.NewTCPConn(src, dst, 1024) - s.lb = tc.state - sshUser := "alice" - if tc.sshUser != "" { - sshUser = tc.sshUser - } - var passwordUsed atomic.Bool - cfg := &gossh.ClientConfig{ - User: sshUser, - HostKeyCallback: gossh.InsecureIgnoreHostKey(), - Auth: []gossh.AuthMethod{ - gossh.PasswordCallback(func() (secret string, err error) { - if !tc.usesPassword { - t.Error("unexpected use of PasswordCallback") - return "", errors.New("unexpected use of PasswordCallback") - } + for _, authMethods := range [][]string{nil, {"publickey", "password"}, {"password", "publickey"}} { + t.Run(fmt.Sprintf("%s-skip-none-auth-%v", tc.name, strings.Join(authMethods, "-then-")), func(t *testing.T) { + s.logf = tstest.WhileTestRunningLogger(t) + + sc, dc := memnet.NewTCPConn(src, dst, 1024) + s.lb = tc.state + sshUser := "alice" + if tc.sshUser != "" { + sshUser = tc.sshUser + } + + wantBanners := slices.Clone(tc.wantBanners) + noneAuthEnabled := len(authMethods) == 0 + + var publicKeyUsed atomic.Bool + var passwordUsed atomic.Bool + var methods []testssh.AuthMethod + + for _, authMethod := range authMethods { + switch authMethod { + case "publickey": + methods = append(methods, + testssh.PublicKeysCallback(func() (signers []testssh.Signer, err error) { + publicKeyUsed.Store(true) + key, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + if err != nil { + return nil, err + } + sig, err := testssh.NewSignerFromKey(key) + if err != nil { + return nil, err + } + return []testssh.Signer{sig}, nil + })) + case "password": + methods = append(methods, testssh.PasswordCallback(func() (secret string, err error) { + passwordUsed.Store(true) + return "any-pass", nil + })) + } + } + + if noneAuthEnabled && tc.usesPassword { + methods = append(methods, testssh.PasswordCallback(func() (secret string, err error) { passwordUsed.Store(true) return "any-pass", nil - }), - }, - BannerCallback: func(message string) error { - if len(tc.wantBanners) == 0 { - t.Errorf("unexpected banner: %q", message) - } else if message != tc.wantBanners[0] { - t.Errorf("banner = %q; want %q", message, tc.wantBanners[0]) - } else { - t.Logf("banner = %q", message) - tc.wantBanners = tc.wantBanners[1:] + })) + } + + cfg := &testssh.ClientConfig{ + User: sshUser, + HostKeyCallback: testssh.InsecureIgnoreHostKey(), + SkipNoneAuth: !noneAuthEnabled, + Auth: methods, + BannerCallback: func(message string) error { + if len(wantBanners) == 0 { + t.Errorf("unexpected banner: %q", message) + } else if message != wantBanners[0] { + t.Errorf("banner = %q; want %q", message, wantBanners[0]) + } else { + t.Logf("banner = %q", message) + wantBanners = wantBanners[1:] + } + return nil + }, + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + c, chans, reqs, err := testssh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) + if err != nil { + if !tc.authErr { + t.Errorf("client: %v", err) + } + return + } else if tc.authErr { + c.Close() + t.Errorf("client: expected error, got nil") + return } - return nil - }, - } - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) - if err != nil { - if !tc.authErr { + client := testssh.NewClient(c, chans, reqs) + defer client.Close() + session, err := client.NewSession() + if err != nil { t.Errorf("client: %v", err) + return } - return - } else if tc.authErr { - c.Close() - t.Errorf("client: expected error, got nil") - return + defer session.Close() + _, err = session.CombinedOutput("echo Ran echo!") + if err != nil { + t.Errorf("client: %v", err) + } + }() + if err := s.HandleSSHConn(dc); err != nil { + t.Errorf("unexpected error: %v", err) } - client := gossh.NewClient(c, chans, reqs) - defer client.Close() - session, err := client.NewSession() - if err != nil { - t.Errorf("client: %v", err) - return + wg.Wait() + if len(wantBanners) > 0 { + t.Errorf("missing banners: %v", wantBanners) } - defer session.Close() - _, err = session.CombinedOutput("echo Ran echo!") - if err != nil { - t.Errorf("client: %v", err) + + // Check to see which callbacks were invoked. + // + // When `none` auth is enabled, the public key callback should + // never fire, and the password callback should only fire if + // authentication succeeded and the client was trying to force + // password authentication by connecting with the '-password' + // username suffix. + // + // When skipping `none` auth, the first callback should always + // fire, and the 2nd callback should fire only if + // authentication failed. + wantPublicKey := false + wantPassword := false + if noneAuthEnabled { + wantPassword = !tc.authErr && tc.usesPassword + } else { + for i, authMethod := range authMethods { + switch authMethod { + case "publickey": + wantPublicKey = i == 0 || tc.authErr + case "password": + wantPassword = i == 0 || tc.authErr + } + } } - }() - if err := s.HandleSSHConn(dc); err != nil { - t.Errorf("unexpected error: %v", err) - } - wg.Wait() - if len(tc.wantBanners) > 0 { - t.Errorf("missing banners: %v", tc.wantBanners) - } - }) + + if wantPublicKey && !publicKeyUsed.Load() { + t.Error("public key should have been attempted") + } else if !wantPublicKey && publicKeyUsed.Load() { + t.Errorf("public key should not have been attempted") + } + + if wantPassword && !passwordUsed.Load() { + t.Error("password should have been attempted") + } else if !wantPassword && passwordUsed.Load() { + t.Error("password should not have been attempted") + } + }) + } } } func TestSSH(t *testing.T) { - var logf logger.Logf = t.Logf - sys := &tsd.System{} - eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry()) + logf := tstest.WhileTestRunningLogger(t) + sys := tsd.NewSystem() + eng, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker(), sys.UserMetricsRegistry(), sys.Bus.Get()) if err != nil { t.Fatal(err) } @@ -1123,98 +1196,12 @@ func TestSSH(t *testing.T) { func parseEnv(out []byte) map[string]string { e := map[string]string{} - lineread.Reader(bytes.NewReader(out), func(line []byte) error { - i := bytes.IndexByte(line, '=') - if i == -1 { - return nil - } - e[string(line[:i])] = string(line[i+1:]) - return nil - }) - return e -} - -func TestPublicKeyFetching(t *testing.T) { - var reqsTotal, reqsIfNoneMatchHit, reqsIfNoneMatchMiss int32 - ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt32((&reqsTotal), 1) - etag := fmt.Sprintf("W/%q", sha256.Sum256([]byte(r.URL.Path))) - w.Header().Set("Etag", etag) - if v := r.Header.Get("If-None-Match"); v != "" { - if v == etag { - atomic.AddInt32(&reqsIfNoneMatchHit, 1) - w.WriteHeader(304) - return - } - atomic.AddInt32(&reqsIfNoneMatchMiss, 1) - } - io.WriteString(w, "foo\nbar\n"+string(r.URL.Path)+"\n") - })) - ts.StartTLS() - defer ts.Close() - keys := ts.URL - - clock := &tstest.Clock{} - srv := &server{ - pubKeyHTTPClient: ts.Client(), - timeNow: clock.Now, - } - for range 2 { - got, err := srv.fetchPublicKeysURL(keys + "/alice.keys") - if err != nil { - t.Fatal(err) - } - if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) { - t.Errorf("got %q; want %q", got, want) + for line := range lineiter.Bytes(out) { + if i := bytes.IndexByte(line, '='); i != -1 { + e[string(line[:i])] = string(line[i+1:]) } } - if got, want := atomic.LoadInt32(&reqsTotal), int32(1); got != want { - t.Errorf("got %d requests; want %d", got, want) - } - if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(0); got != want { - t.Errorf("got %d etag hits; want %d", got, want) - } - clock.Advance(5 * time.Minute) - got, err := srv.fetchPublicKeysURL(keys + "/alice.keys") - if err != nil { - t.Fatal(err) - } - if want := []string{"foo", "bar", "/alice.keys"}; !reflect.DeepEqual(got, want) { - t.Errorf("got %q; want %q", got, want) - } - if got, want := atomic.LoadInt32(&reqsTotal), int32(2); got != want { - t.Errorf("got %d requests; want %d", got, want) - } - if got, want := atomic.LoadInt32(&reqsIfNoneMatchHit), int32(1); got != want { - t.Errorf("got %d etag hits; want %d", got, want) - } - if got, want := atomic.LoadInt32(&reqsIfNoneMatchMiss), int32(0); got != want { - t.Errorf("got %d etag misses; want %d", got, want) - } - -} - -func TestExpandPublicKeyURL(t *testing.T) { - c := &conn{ - info: &sshConnInfo{ - uprof: tailcfg.UserProfile{ - LoginName: "bar@baz.tld", - }, - }, - } - if got, want := c.expandPublicKeyURL("foo"), "foo"; got != want { - t.Errorf("basic: got %q; want %q", got, want) - } - if got, want := c.expandPublicKeyURL("https://example.com/$LOGINNAME_LOCALPART.keys"), "https://example.com/bar.keys"; got != want { - t.Errorf("localpart: got %q; want %q", got, want) - } - if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/keys?email=bar@baz.tld"; got != want { - t.Errorf("email: got %q; want %q", got, want) - } - c.info = new(sshConnInfo) - if got, want := c.expandPublicKeyURL("https://example.com/keys?email=$LOGINNAME_EMAIL"), "https://example.com/keys?email="; got != want { - t.Errorf("on empty: got %q; want %q", got, want) - } + return e } func TestAcceptEnvPair(t *testing.T) { @@ -1302,3 +1289,22 @@ func TestStdOsUserUserAssumptions(t *testing.T) { t.Errorf("os/user.User has %v fields; this package assumes %v", got, want) } } + +func mockRecordingServer(t *testing.T, handleRecord http.HandlerFunc) *httptest.Server { + t.Helper() + mux := http.NewServeMux() + mux.HandleFunc("POST /record", func(http.ResponseWriter, *http.Request) { + t.Errorf("v1 recording endpoint called") + }) + mux.HandleFunc("HEAD /v2/record", func(http.ResponseWriter, *http.Request) {}) + mux.HandleFunc("POST /v2/record", handleRecord) + + h2s := &http2.Server{} + srv := httptest.NewUnstartedServer(h2c.NewHandler(mux, h2s)) + if err := http2.ConfigureServer(srv.Config, h2s); err != nil { + t.Errorf("configuring HTTP/2 support in recording server: %v", err) + } + srv.Start() + t.Cleanup(srv.Close) + return srv +} diff --git a/ssh/tailssh/testcontainers/Dockerfile b/ssh/tailssh/testcontainers/Dockerfile index c94c961d37c61..4ef1c1eb0bb7c 100644 --- a/ssh/tailssh/testcontainers/Dockerfile +++ b/ssh/tailssh/testcontainers/Dockerfile @@ -3,9 +3,12 @@ FROM ${BASE} ARG BASE -RUN echo "Install openssh, needed for scp." -RUN if echo "$BASE" | grep "ubuntu:"; then apt-get update -y && apt-get install -y openssh-client; fi -RUN if echo "$BASE" | grep "alpine:"; then apk add openssh; fi +RUN echo "Install openssh, needed for scp. Also install python3" +RUN if echo "$BASE" | grep "ubuntu:"; then apt-get update -y && apt-get install -y openssh-client python3 python3-pip; fi +RUN if echo "$BASE" | grep "alpine:"; then apk add openssh python3 py3-pip; fi + +RUN echo "Install paramiko" +RUN pip3 install paramiko==3.5.1 || pip3 install --break-system-packages paramiko==3.5.1 # Note - on Ubuntu, we do not create the user's home directory, pam_mkhomedir will do that # for us, and we want to test that PAM gets triggered by Tailscale SSH. @@ -33,6 +36,8 @@ RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationSCP RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationSSH +RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi +RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationParamiko RUN echo "Then run tests as non-root user testuser and make sure tests still pass." RUN touch /tmp/tailscalessh.log diff --git a/ssh/tailssh/user.go b/ssh/tailssh/user.go index 33ebb4db729de..097f0d296e92c 100644 --- a/ssh/tailssh/user.go +++ b/ssh/tailssh/user.go @@ -1,12 +1,11 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux || (darwin && !ios) || freebsd || openbsd +//go:build linux || (darwin && !ios) || freebsd || openbsd || plan9 package tailssh import ( - "io" "os" "os/exec" "os/user" @@ -18,7 +17,7 @@ import ( "go4.org/mem" "tailscale.com/envknob" "tailscale.com/hostinfo" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" "tailscale.com/util/osuser" "tailscale.com/version/distro" ) @@ -49,6 +48,9 @@ func userLookup(username string) (*userMeta, error) { } func (u *userMeta) LoginShell() string { + if runtime.GOOS == "plan9" { + return "/bin/rc" + } if u.loginShellCached != "" { // This field should be populated on Linux, at least, because // func userLookup on Linux uses "getent" to look up the user @@ -86,6 +88,9 @@ func defaultPathForUser(u *user.User) string { if s := defaultPathTmpl(); s != "" { return expandDefaultPathTmpl(s, u) } + if runtime.GOOS == "plan9" { + return "/bin" + } isRoot := u.Uid == "0" switch distro.Get() { case distro.Debian: @@ -110,15 +115,16 @@ func defaultPathForUser(u *user.User) string { } func defaultPathForUserOnNixOS(u *user.User) string { - var path string - lineread.File("/etc/pam/environment", func(lineb []byte) error { + for lr := range lineiter.File("/etc/pam/environment") { + lineb, err := lr.Value() + if err != nil { + return "" + } if v := pathFromPAMEnvLine(lineb, u); v != "" { - path = v - return io.EOF // stop iteration + return v } - return nil - }) - return path + } + return "" } func pathFromPAMEnvLine(line []byte, u *user.User) (path string) { diff --git a/syncs/shardedint.go b/syncs/shardedint.go new file mode 100644 index 0000000000000..28c4168d54c79 --- /dev/null +++ b/syncs/shardedint.go @@ -0,0 +1,69 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import ( + "encoding/json" + "strconv" + "sync/atomic" + + "golang.org/x/sys/cpu" +) + +// ShardedInt provides a sharded atomic int64 value that optimizes high +// frequency (Mhz range and above) writes in highly parallel workloads. +// The zero value is not safe for use; use [NewShardedInt]. +// ShardedInt implements the expvar.Var interface. +type ShardedInt struct { + sv *ShardValue[intShard] +} + +// NewShardedInt returns a new [ShardedInt]. +func NewShardedInt() *ShardedInt { + return &ShardedInt{ + sv: NewShardValue[intShard](), + } +} + +// Add adds delta to the value. +func (m *ShardedInt) Add(delta int64) { + m.sv.One(func(v *intShard) { + v.Add(delta) + }) +} + +type intShard struct { + atomic.Int64 + _ cpu.CacheLinePad // avoid false sharing of neighboring shards +} + +// Value returns the current value. +func (m *ShardedInt) Value() int64 { + var v int64 + for s := range m.sv.All { + v += s.Load() + } + return v +} + +// GetDistribution returns the current value in each shard. +// This is intended for observability/debugging only. +func (m *ShardedInt) GetDistribution() []int64 { + v := make([]int64, 0, m.sv.Len()) + for s := range m.sv.All { + v = append(v, s.Load()) + } + return v +} + +// String implements the expvar.Var interface +func (m *ShardedInt) String() string { + v, _ := json.Marshal(m.Value()) + return string(v) +} + +// AppendText implements the encoding.TextAppender interface +func (m *ShardedInt) AppendText(b []byte) ([]byte, error) { + return strconv.AppendInt(b, m.Value(), 10), nil +} diff --git a/syncs/shardedint_test.go b/syncs/shardedint_test.go new file mode 100644 index 0000000000000..d355a15400a90 --- /dev/null +++ b/syncs/shardedint_test.go @@ -0,0 +1,119 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import ( + "expvar" + "sync" + "testing" + + "tailscale.com/tstest" +) + +var ( + _ expvar.Var = (*ShardedInt)(nil) + // TODO(raggi): future go version: + // _ encoding.TextAppender = (*ShardedInt)(nil) +) + +func BenchmarkShardedInt(b *testing.B) { + b.ReportAllocs() + + b.Run("expvar", func(b *testing.B) { + var m expvar.Int + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + m.Add(1) + } + }) + }) + + b.Run("sharded int", func(b *testing.B) { + m := NewShardedInt() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + m.Add(1) + } + }) + }) +} + +func TestShardedInt(t *testing.T) { + t.Run("basics", func(t *testing.T) { + m := NewShardedInt() + if got, want := m.Value(), int64(0); got != want { + t.Errorf("got %v, want %v", got, want) + } + m.Add(1) + if got, want := m.Value(), int64(1); got != want { + t.Errorf("got %v, want %v", got, want) + } + m.Add(2) + if got, want := m.Value(), int64(3); got != want { + t.Errorf("got %v, want %v", got, want) + } + m.Add(-1) + if got, want := m.Value(), int64(2); got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + + t.Run("high concurrency", func(t *testing.T) { + m := NewShardedInt() + wg := sync.WaitGroup{} + numWorkers := 1000 + numIncrements := 1000 + wg.Add(numWorkers) + for i := 0; i < numWorkers; i++ { + go func() { + defer wg.Done() + for i := 0; i < numIncrements; i++ { + m.Add(1) + } + }() + } + wg.Wait() + if got, want := m.Value(), int64(numWorkers*numIncrements); got != want { + t.Errorf("got %v, want %v", got, want) + } + for i, shard := range m.GetDistribution() { + t.Logf("shard %d: %d", i, shard) + } + }) + + t.Run("encoding.TextAppender", func(t *testing.T) { + m := NewShardedInt() + m.Add(1) + b := make([]byte, 0, 10) + b, err := m.AppendText(b) + if err != nil { + t.Fatal(err) + } + if got, want := string(b), "1"; got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + + t.Run("allocs", func(t *testing.T) { + m := NewShardedInt() + tstest.MinAllocsPerRun(t, 0, func() { + m.Add(1) + _ = m.Value() + }) + + // TODO(raggi): fix access to expvar's internal append based + // interface, unfortunately it's not currently closed for external + // use, this will alloc when it escapes. + tstest.MinAllocsPerRun(t, 0, func() { + m.Add(1) + _ = m.String() + }) + + b := make([]byte, 0, 10) + tstest.MinAllocsPerRun(t, 0, func() { + m.Add(1) + m.AppendText(b) + }) + }) +} diff --git a/syncs/shardvalue.go b/syncs/shardvalue.go new file mode 100644 index 0000000000000..b1474477c7082 --- /dev/null +++ b/syncs/shardvalue.go @@ -0,0 +1,36 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +// TODO(raggi): this implementation is still imperfect as it will still result +// in cross CPU sharing periodically, we instead really want a per-CPU shard +// key, but the limitations of calling platform code make reaching for even the +// getcpu vdso very painful. See https://github.com/golang/go/issues/18802, and +// hopefully one day we can replace with a primitive that falls out of that +// work. + +// ShardValue contains a value sharded over a set of shards. +// In order to be useful, T should be aligned to cache lines. +// Users must organize that usage in One and All is concurrency safe. +// The zero value is not safe for use; use [NewShardValue]. +type ShardValue[T any] struct { + shards []T + + //lint:ignore U1000 unused under tailscale_go builds. + pool shardValuePool +} + +// Len returns the number of shards. +func (sp *ShardValue[T]) Len() int { + return len(sp.shards) +} + +// All yields a pointer to the value in each shard. +func (sp *ShardValue[T]) All(yield func(*T) bool) { + for i := range sp.shards { + if !yield(&sp.shards[i]) { + return + } + } +} diff --git a/syncs/shardvalue_go.go b/syncs/shardvalue_go.go new file mode 100644 index 0000000000000..9b9d252a796d4 --- /dev/null +++ b/syncs/shardvalue_go.go @@ -0,0 +1,36 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !tailscale_go + +package syncs + +import ( + "runtime" + "sync" + "sync/atomic" +) + +type shardValuePool struct { + atomic.Int64 + sync.Pool +} + +// NewShardValue constructs a new ShardValue[T] with a shard per CPU. +func NewShardValue[T any]() *ShardValue[T] { + sp := &ShardValue[T]{ + shards: make([]T, runtime.NumCPU()), + } + sp.pool.New = func() any { + i := sp.pool.Add(1) - 1 + return &sp.shards[i%int64(len(sp.shards))] + } + return sp +} + +// One yields a pointer to a single shard value with best-effort P-locality. +func (sp *ShardValue[T]) One(yield func(*T)) { + v := sp.pool.Get().(*T) + yield(v) + sp.pool.Put(v) +} diff --git a/syncs/shardvalue_tailscale.go b/syncs/shardvalue_tailscale.go new file mode 100644 index 0000000000000..8ef778ff3e669 --- /dev/null +++ b/syncs/shardvalue_tailscale.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// TODO(raggi): update build tag after toolchain update +//go:build tailscale_go + +package syncs + +import ( + "runtime" +) + +//lint:ignore U1000 unused under tailscale_go builds. +type shardValuePool struct{} + +// NewShardValue constructs a new ShardValue[T] with a shard per CPU. +func NewShardValue[T any]() *ShardValue[T] { + return &ShardValue[T]{shards: make([]T, runtime.NumCPU())} +} + +// One yields a pointer to a single shard value with best-effort P-locality. +func (sp *ShardValue[T]) One(f func(*T)) { + f(&sp.shards[runtime.TailscaleCurrentP()%len(sp.shards)]) +} diff --git a/syncs/shardvalue_test.go b/syncs/shardvalue_test.go new file mode 100644 index 0000000000000..8f6ac6414dee7 --- /dev/null +++ b/syncs/shardvalue_test.go @@ -0,0 +1,119 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syncs + +import ( + "math" + "runtime" + "sync" + "sync/atomic" + "testing" + + "golang.org/x/sys/cpu" +) + +func TestShardValue(t *testing.T) { + type intVal struct { + atomic.Int64 + _ cpu.CacheLinePad + } + + t.Run("One", func(t *testing.T) { + sv := NewShardValue[intVal]() + sv.One(func(v *intVal) { + v.Store(10) + }) + + var v int64 + for i := range sv.shards { + v += sv.shards[i].Load() + } + if v != 10 { + t.Errorf("got %v, want 10", v) + } + }) + + t.Run("All", func(t *testing.T) { + sv := NewShardValue[intVal]() + for i := range sv.shards { + sv.shards[i].Store(int64(i)) + } + + var total int64 + sv.All(func(v *intVal) bool { + total += v.Load() + return true + }) + // triangle coefficient lower one order due to 0 index + want := int64(len(sv.shards) * (len(sv.shards) - 1) / 2) + if total != want { + t.Errorf("got %v, want %v", total, want) + } + }) + + t.Run("Len", func(t *testing.T) { + sv := NewShardValue[intVal]() + if got, want := sv.Len(), runtime.NumCPU(); got != want { + t.Errorf("got %v, want %v", got, want) + } + }) + + t.Run("distribution", func(t *testing.T) { + sv := NewShardValue[intVal]() + + goroutines := 1000 + iterations := 10000 + var wg sync.WaitGroup + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + sv.One(func(v *intVal) { + v.Add(1) + }) + } + }() + } + wg.Wait() + + var ( + total int64 + distribution []int64 + ) + t.Logf("distribution:") + sv.All(func(v *intVal) bool { + total += v.Load() + distribution = append(distribution, v.Load()) + t.Logf("%d", v.Load()) + return true + }) + + if got, want := total, int64(goroutines*iterations); got != want { + t.Errorf("got %v, want %v", got, want) + } + if got, want := len(distribution), runtime.NumCPU(); got != want { + t.Errorf("got %v, want %v", got, want) + } + + mean := total / int64(len(distribution)) + for _, v := range distribution { + if v < mean/10 || v > mean*10 { + t.Logf("distribution is very unbalanced: %v", distribution) + } + } + t.Logf("mean: %d", mean) + + var standardDev int64 + for _, v := range distribution { + standardDev += ((v - mean) * (v - mean)) + } + standardDev = int64(math.Sqrt(float64(standardDev / int64(len(distribution))))) + t.Logf("stdev: %d", standardDev) + + if standardDev > mean/3 { + t.Logf("standard deviation is too high: %v", standardDev) + } + }) +} diff --git a/syncs/syncs.go b/syncs/syncs.go index acc0c88f2991e..337fca7557f34 100644 --- a/syncs/syncs.go +++ b/syncs/syncs.go @@ -25,6 +25,7 @@ func initClosedChan() <-chan struct{} { } // AtomicValue is the generic version of [atomic.Value]. +// See [MutexValue] for guidance on whether to use this type. type AtomicValue[T any] struct { v atomic.Value } @@ -74,6 +75,67 @@ func (v *AtomicValue[T]) CompareAndSwap(oldV, newV T) (swapped bool) { return v.v.CompareAndSwap(wrappedValue[T]{oldV}, wrappedValue[T]{newV}) } +// MutexValue is a value protected by a mutex. +// +// AtomicValue, [MutexValue], [atomic.Pointer] are similar and +// overlap in their use cases. +// +// - Use [atomic.Pointer] if the value being stored is a pointer and +// you only ever need load and store operations. +// An atomic pointer only occupies 1 word of memory. +// +// - Use [MutexValue] if the value being stored is not a pointer or +// you need the ability for a mutex to protect a set of operations +// performed on the value. +// A mutex-guarded value occupies 1 word of memory plus +// the memory representation of T. +// +// - AtomicValue is useful for non-pointer types that happen to +// have the memory layout of a single pointer. +// Examples include a map, channel, func, or a single field struct +// that contains any prior types. +// An atomic value occupies 2 words of memory. +// Consequently, Storing of non-pointer types always allocates. +// +// Note that [AtomicValue] has the ability to report whether it was set +// while [MutexValue] lacks the ability to detect if the value was set +// and it happens to be the zero value of T. If such a use case is +// necessary, then you could consider wrapping T in [opt.Value]. +type MutexValue[T any] struct { + mu sync.Mutex + v T +} + +// WithLock calls f with a pointer to the value while holding the lock. +// The provided pointer must not leak beyond the scope of the call. +func (m *MutexValue[T]) WithLock(f func(p *T)) { + m.mu.Lock() + defer m.mu.Unlock() + f(&m.v) +} + +// Load returns a shallow copy of the underlying value. +func (m *MutexValue[T]) Load() T { + m.mu.Lock() + defer m.mu.Unlock() + return m.v +} + +// Store stores a shallow copy of the provided value. +func (m *MutexValue[T]) Store(v T) { + m.mu.Lock() + defer m.mu.Unlock() + m.v = v +} + +// Swap stores new into m and returns the previous value. +func (m *MutexValue[T]) Swap(new T) (old T) { + m.mu.Lock() + defer m.mu.Unlock() + old, m.v = m.v, new + return old +} + // WaitGroupChan is like a sync.WaitGroup, but has a chan that closes // on completion that you can wait on. (This, you can only use the // value once) diff --git a/syncs/syncs_test.go b/syncs/syncs_test.go index ee3711e76587b..901d429486d13 100644 --- a/syncs/syncs_test.go +++ b/syncs/syncs_test.go @@ -8,6 +8,7 @@ import ( "io" "os" "testing" + "time" "github.com/google/go-cmp/cmp" ) @@ -65,6 +66,39 @@ func TestAtomicValue(t *testing.T) { } } +func TestMutexValue(t *testing.T) { + var v MutexValue[time.Time] + if n := int(testing.AllocsPerRun(1000, func() { + v.Store(v.Load()) + v.WithLock(func(*time.Time) {}) + })); n != 0 { + t.Errorf("AllocsPerRun = %d, want 0", n) + } + + now := time.Now() + v.Store(now) + if !v.Load().Equal(now) { + t.Errorf("Load = %v, want %v", v.Load(), now) + } + + var group WaitGroup + var v2 MutexValue[int] + var sum int + for i := range 10 { + group.Go(func() { + old1 := v2.Load() + old2 := v2.Swap(old1 + i) + delta := old2 - old1 + v2.WithLock(func(p *int) { *p += delta }) + }) + sum += i + } + group.Wait() + if v2.Load() != sum { + t.Errorf("Load = %v, want %v", v2.Load(), sum) + } +} + func TestWaitGroupChan(t *testing.T) { wg := NewWaitGroupChan() diff --git a/tailcfg/c2ntypes.go b/tailcfg/c2ntypes.go index 54efb736e6db9..66f95785c4a83 100644 --- a/tailcfg/c2ntypes.go +++ b/tailcfg/c2ntypes.go @@ -102,3 +102,18 @@ type C2NTLSCertInfo struct { // TODO(bradfitz): add fields for whether an ACME fetch is currently in // process and when it started, etc. } + +// C2NVIPServicesResponse is the response (from node to control) from the +// /vip-services handler. +// +// It returns the list of VIPServices that the node is currently serving with +// their port info and whether they are active or not. It also returns a hash of +// the response to allow the control server to detect changes. +type C2NVIPServicesResponse struct { + // VIPServices is the list of VIP services that the node is currently serving. + VIPServices []*VIPService `json:",omitempty"` + + // ServicesHash is the hash of VIPServices to allow the control server to detect + // changes. This value matches what is reported in latest [Hostinfo.ServicesHash]. + ServicesHash string +} diff --git a/tailcfg/derpmap.go b/tailcfg/derpmap.go index 056152157fede..e05559f3ed7f1 100644 --- a/tailcfg/derpmap.go +++ b/tailcfg/derpmap.go @@ -96,12 +96,32 @@ type DERPRegion struct { Latitude float64 `json:",omitempty"` Longitude float64 `json:",omitempty"` - // Avoid is whether the client should avoid picking this as its home - // region. The region should only be used if a peer is there. - // Clients already using this region as their home should migrate - // away to a new region without Avoid set. + // Avoid is whether the client should avoid picking this as its home region. + // The region should only be used if a peer is there. Clients already using + // this region as their home should migrate away to a new region without + // Avoid set. + // + // Deprecated: because of bugs in past implementations combined with unclear + // docs that caused people to think the bugs were intentional, this field is + // deprecated. It was never supposed to cause STUN/DERP measurement probes, + // but due to bugs, it sometimes did. And then some parts of the code began + // to rely on that property. But then we were unable to use this field for + // its original purpose, nor its later imagined purpose, because various + // parts of the codebase thought it meant one thing and others thought it + // meant another. But it did something in the middle instead. So we're retiring + // it. Use NoMeasureNoHome instead. Avoid bool `json:",omitempty"` + // NoMeasureNoHome says that this regions should not be measured for its + // latency distance (STUN, HTTPS, etc) or availability (e.g. captive portal + // checks) and should never be selected as the node's home region. However, + // if a peer declares this region as its home, then this client is allowed + // to connect to it for the purpose of communicating with that peer. + // + // This is what the now deprecated Avoid bool was supposed to mean + // originally but had implementation bugs and documentation omissions. + NoMeasureNoHome bool `json:",omitempty"` + // Nodes are the DERP nodes running in this region, in // priority order for the current client. Client TLS // connections should ideally only go to the first entry @@ -139,6 +159,12 @@ type DERPNode struct { // name. If empty, HostName is used. If CertName is non-empty, // HostName is only used for the TCP dial (if IPv4/IPv6 are // not present) + TLS ClientHello. + // + // As a special case, if CertName starts with "sha256-raw:", + // then the rest of the string is a hex-encoded SHA256 of the + // cert to expect. This is used for self-signed certs. + // In this case, the HostName field will typically be an IP + // address literal. CertName string `json:",omitempty"` // IPv4 optionally forces an IPv4 address to use, instead of using DNS. diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 92bf2cd95da15..0a58d8f0cc229 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -5,7 +5,7 @@ // the node and the coordination server. package tailcfg -//go:generate go run tailscale.com/cmd/viewer --type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile --clonefunc +//go:generate go run tailscale.com/cmd/viewer --type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile,VIPService --clonefunc import ( "bytes" @@ -25,8 +25,10 @@ import ( "tailscale.com/types/opt" "tailscale.com/types/structs" "tailscale.com/types/tkatype" + "tailscale.com/types/views" "tailscale.com/util/dnsname" "tailscale.com/util/slicesx" + "tailscale.com/util/vizerror" ) // CapabilityVersion represents the client's capability level. That @@ -142,44 +144,89 @@ type CapabilityVersion int // - 97: 2024-06-06: Client understands NodeAttrDisableSplitDNSWhenNoCustomResolvers // - 98: 2024-06-13: iOS/tvOS clients may provide serial number as part of posture information // - 99: 2024-06-14: Client understands NodeAttrDisableLocalDNSOverrideViaNRPT -// - 100: 2024-06-18: Client supports filtertype.Match.SrcCaps (issue #12542) +// - 100: 2024-06-18: Initial support for filtertype.Match.SrcCaps - actually usable in capver 109 (issue #12542) // - 101: 2024-07-01: Client supports SSH agent forwarding when handling connections with /bin/su // - 102: 2024-07-12: NodeAttrDisableMagicSockCryptoRouting support // - 103: 2024-07-24: Client supports NodeAttrDisableCaptivePortalDetection // - 104: 2024-08-03: SelfNodeV6MasqAddrForThisPeer now works // - 105: 2024-08-05: Fixed SSH behavior on systems that use busybox (issue #12849) // - 106: 2024-09-03: fix panic regression from cryptokey routing change (65fe0ba7b5) -const CurrentCapabilityVersion CapabilityVersion = 106 - -type StableID string - +// - 107: 2024-10-30: add App Connector to conffile (PR #13942) +// - 108: 2024-11-08: Client sends ServicesHash in Hostinfo, understands c2n GET /vip-services. +// - 109: 2024-11-18: Client supports filtertype.Match.SrcCaps (issue #12542) +// - 110: 2024-12-12: removed never-before-used Tailscale SSH public key support (#14373) +// - 111: 2025-01-14: Client supports a peer having Node.HomeDERP (issue #14636) +// - 112: 2025-01-14: Client interprets AllowedIPs of nil as meaning same as Addresses +// - 113: 2025-01-20: Client communicates to control whether funnel is enabled by sending Hostinfo.IngressEnabled (#14688) +// - 114: 2025-01-30: NodeAttrMaxKeyDuration CapMap defined, clients might use it (no tailscaled code change) (#14829) +// - 115: 2025-03-07: Client understands DERPRegion.NoMeasureNoHome. +// - 116: 2025-05-05: Client serves MagicDNS "AAAA" if NodeAttrMagicDNSPeerAAAA set on self node +const CurrentCapabilityVersion CapabilityVersion = 116 + +// ID is an integer ID for a user, node, or login allocated by the +// control plane. +// +// To be nice, control plane servers should not use int64s that are too large to +// fit in a JavaScript number (see JavaScript's Number.MAX_SAFE_INTEGER). +// The Tailscale-hosted control plane stopped allocating large integers in +// March 2023 but nodes prior to that may have IDs larger than +// MAX_SAFE_INTEGER (2^53 – 1). +// +// IDs must not be zero or negative. type ID int64 +// UserID is an [ID] for a [User]. type UserID ID func (u UserID) IsZero() bool { return u == 0 } +// LoginID is an [ID] for a [Login]. +// +// It is not used in the Tailscale client, but is used in the control plane. type LoginID ID func (u LoginID) IsZero() bool { return u == 0 } +// NodeID is a unique integer ID for a node. +// +// It's global within a control plane URL ("tailscale up --login-server") and is +// (as of 2025-01-06) never re-used even after a node is deleted. +// +// To be nice, control plane servers should not use int64s that are too large to +// fit in a JavaScript number (see JavaScript's Number.MAX_SAFE_INTEGER). +// The Tailscale-hosted control plane stopped allocating large integers in +// March 2023 but nodes prior to that may have node IDs larger than +// MAX_SAFE_INTEGER (2^53 – 1). +// +// NodeIDs are not stable across control plane URLs. For more stable URLs, +// see [StableNodeID]. type NodeID ID func (u NodeID) IsZero() bool { return u == 0 } -type StableNodeID StableID +// StableNodeID is a string form of [NodeID]. +// +// Different control plane servers should ideally have different StableNodeID +// suffixes for different sites or regions. +// +// Being a string, it's safer to use in JavaScript without worrying about the +// size of the integer, as documented on [NodeID]. +// +// But in general, Tailscale APIs can accept either a [NodeID] integer or a +// [StableNodeID] string when referring to a node. +type StableNodeID string func (u StableNodeID) IsZero() bool { return u == "" } -// User is an IPN user. +// User is a Tailscale user. // // A user can have multiple logins associated with it (e.g. gmail and github oauth). // (Note: none of our UIs support this yet.) @@ -192,34 +239,30 @@ func (u StableNodeID) IsZero() bool { // have a general gmail address login associated with the user. type User struct { ID UserID - LoginName string `json:"-"` // not stored, filled from Login // TODO REMOVE DisplayName string // if non-empty overrides Login field ProfilePicURL string // if non-empty overrides Login field - Logins []LoginID Created time.Time } +// Login is a user from a specific identity provider, not associated with any +// particular tailnet. type Login struct { _ structs.Incomparable - ID LoginID - Provider string - LoginName string - DisplayName string - ProfilePicURL string + ID LoginID // unused in the Tailscale client + Provider string // "google", "github", "okta_foo", etc. + LoginName string // an email address or "email-ish" string (like alice@github) + DisplayName string // from the IdP + ProfilePicURL string // from the IdP } -// A UserProfile is display-friendly data for a user. +// A UserProfile is display-friendly data for a [User]. // It includes the LoginName for display purposes but *not* the Provider. // It also includes derived data from one of the user's logins. type UserProfile struct { ID UserID LoginName string // "alice@smith.com"; for display purposes only (provider is not listed) DisplayName string // "Alice Smith" - ProfilePicURL string - - // Roles exists for legacy reasons, to keep old macOS clients - // happy. It JSON marshals as []. - Roles emptyStructJSONSlice + ProfilePicURL string `json:",omitempty"` } func (p *UserProfile) Equal(p2 *UserProfile) bool { @@ -235,16 +278,6 @@ func (p *UserProfile) Equal(p2 *UserProfile) bool { p.ProfilePicURL == p2.ProfilePicURL } -type emptyStructJSONSlice struct{} - -var emptyJSONSliceBytes = []byte("[]") - -func (emptyStructJSONSlice) MarshalJSON() ([]byte, error) { - return emptyJSONSliceBytes, nil -} - -func (emptyStructJSONSlice) UnmarshalJSON([]byte) error { return nil } - // RawMessage is a raw encoded JSON value. It implements Marshaler and // Unmarshaler and can be used to delay JSON decoding or precompute a JSON // encoding. @@ -279,6 +312,7 @@ func MarshalCapJSON[T any](capRule T) (RawMessage, error) { return RawMessage(string(bs)), nil } +// Node is a Tailscale device in a tailnet. type Node struct { ID NodeID StableID StableNodeID @@ -302,19 +336,37 @@ type Node struct { KeySignature tkatype.MarshaledSignature `json:",omitempty"` Machine key.MachinePublic DiscoKey key.DiscoPublic - Addresses []netip.Prefix // IP addresses of this Node directly - AllowedIPs []netip.Prefix // range of IP addresses to route to this node - Endpoints []netip.AddrPort `json:",omitempty"` // IP+port (public via STUN, and local LANs) - // DERP is this node's home DERP region ID integer, but shoved into an + // Addresses are the IP addresses of this Node directly. + Addresses []netip.Prefix + + // AllowedIPs are the IP ranges to route to this node. + // + // As of CapabilityVersion 112, this may be nil (null or undefined) on the wire + // to mean the same as Addresses. Internally, it is always filled in with + // its possibly-implicit value. + AllowedIPs []netip.Prefix + + Endpoints []netip.AddrPort `json:",omitempty"` // IP+port (public via STUN, and local LANs) + + // LegacyDERPString is this node's home LegacyDERPString region ID integer, but shoved into an // IP:port string for legacy reasons. The IP address is always "127.3.3.40" // (a loopback address (127) followed by the digits over the letters DERP on - // a QWERTY keyboard (3.3.40)). The "port number" is the home DERP region ID + // a QWERTY keyboard (3.3.40)). The "port number" is the home LegacyDERPString region ID // integer. // - // TODO(bradfitz): simplify this legacy mess; add a new HomeDERPRegionID int - // field behind a new capver bump. - DERP string `json:",omitempty"` // DERP-in-IP:port ("127.3.3.40:N") endpoint + // Deprecated: HomeDERP has replaced this, but old servers might still send + // this field. See tailscale/tailscale#14636. Do not use this field in code + // other than in the upgradeNode func, which canonicalizes it to HomeDERP + // if it arrives as a LegacyDERPString string on the wire. + LegacyDERPString string `json:"DERP,omitempty"` // DERP-in-IP:port ("127.3.3.40:N") endpoint + + // HomeDERP is the modern version of the DERP string field, with just an + // integer. The client advertises support for this as of capver 111. + // + // HomeDERP may be zero if not (yet) known, but ideally always be non-zero + // for magicsock connectivity to function normally. + HomeDERP int `json:",omitempty"` // DERP region ID of the node's home DERP Hostinfo HostinfoView Created time.Time @@ -559,6 +611,11 @@ func (n *Node) InitDisplayNames(networkMagicDNSSuffix string) { n.ComputedNameWithHost = nameWithHost } +// MachineStatus is the state of a [Node]'s approval into a tailnet. +// +// A "node" and a "machine" are often 1:1, but technically a Tailscale +// daemon has one machine key and can have multiple nodes (e.g. different +// users on Windows) for that one machine key. type MachineStatus int const ( @@ -780,15 +837,23 @@ type Hostinfo struct { // App is used to disambiguate Tailscale clients that run using tsnet. App string `json:",omitempty"` // "k8s-operator", "golinks", ... - Desktop opt.Bool `json:",omitempty"` // if a desktop was detected on Linux - Package string `json:",omitempty"` // Tailscale package to disambiguate ("choco", "appstore", etc; "" for unknown) - DeviceModel string `json:",omitempty"` // mobile phone model ("Pixel 3a", "iPhone12,3") - PushDeviceToken string `json:",omitempty"` // macOS/iOS APNs device token for notifications (and Android in the future) - Hostname string `json:",omitempty"` // name of the host the client runs on - ShieldsUp bool `json:",omitempty"` // indicates whether the host is blocking incoming connections - ShareeNode bool `json:",omitempty"` // indicates this node exists in netmap because it's owned by a shared-to user - NoLogsNoSupport bool `json:",omitempty"` // indicates that the user has opted out of sending logs and support - WireIngress bool `json:",omitempty"` // indicates that the node wants the option to receive ingress connections + Desktop opt.Bool `json:",omitempty"` // if a desktop was detected on Linux + Package string `json:",omitempty"` // Tailscale package to disambiguate ("choco", "appstore", etc; "" for unknown) + DeviceModel string `json:",omitempty"` // mobile phone model ("Pixel 3a", "iPhone12,3") + PushDeviceToken string `json:",omitempty"` // macOS/iOS APNs device token for notifications (and Android in the future) + Hostname string `json:",omitempty"` // name of the host the client runs on + ShieldsUp bool `json:",omitempty"` // indicates whether the host is blocking incoming connections + ShareeNode bool `json:",omitempty"` // indicates this node exists in netmap because it's owned by a shared-to user + NoLogsNoSupport bool `json:",omitempty"` // indicates that the user has opted out of sending logs and support + // WireIngress indicates that the node would like to be wired up server-side + // (DNS, etc) to be able to use Tailscale Funnel, even if it's not currently + // enabled. For example, the user might only use it for intermittent + // foreground CLI serve sessions, for which they'd like it to work right + // away, even if it's disabled most of the time. As an optimization, this is + // only sent if IngressEnabled is false, as IngressEnabled implies that this + // option is true. + WireIngress bool `json:",omitempty"` + IngressEnabled bool `json:",omitempty"` // if the node has any funnel endpoint enabled AllowsUpdate bool `json:",omitempty"` // indicates that the node has opted-in to admin-console-drive remote updates Machine string `json:",omitempty"` // the current host's machine type (uname -m) GoArch string `json:",omitempty"` // GOARCH value (of the built binary) @@ -804,16 +869,99 @@ type Hostinfo struct { Userspace opt.Bool `json:",omitempty"` // if the client is running in userspace (netstack) mode UserspaceRouter opt.Bool `json:",omitempty"` // if the client's subnet router is running in userspace (netstack) mode AppConnector opt.Bool `json:",omitempty"` // if the client is running the app-connector service + ServicesHash string `json:",omitempty"` // opaque hash of the most recent list of tailnet services, change in hash indicates config should be fetched via c2n // Location represents geographical location data about a // Tailscale host. Location is optional and only set if // explicitly declared by a node. Location *Location `json:",omitempty"` + TPM *TPMInfo `json:",omitempty"` // TPM device metadata, if available + // NOTE: any new fields containing pointers in this type // require changes to Hostinfo.Equal. } +// TPMInfo contains information about a TPM 2.0 device present on a node. +// All fields are read from TPM_CAP_TPM_PROPERTIES, see Part 2, section 6.13 of +// https://trustedcomputinggroup.org/resource/tpm-library-specification/. +type TPMInfo struct { + // Manufacturer is a 4-letter code from section 4.1 of + // https://trustedcomputinggroup.org/resource/vendor-id-registry/, + // for example "MSFT" for Microsoft. + // Read from TPM_PT_MANUFACTURER. + Manufacturer string `json:",omitempty"` + // Vendor is a vendor ID string, up to 16 characters. + // Read from TPM_PT_VENDOR_STRING_*. + Vendor string `json:",omitempty"` + // Model is a vendor-defined TPM model. + // Read from TPM_PT_VENDOR_TPM_TYPE. + Model int `json:",omitempty"` + // FirmwareVersion is the version number of the firmware. + // Read from TPM_PT_FIRMWARE_VERSION_*. + FirmwareVersion uint64 `json:",omitempty"` + // SpecRevision is the TPM 2.0 spec revision encoded as a single number. All + // revisions can be found at + // https://trustedcomputinggroup.org/resource/tpm-library-specification/. + // Before revision 184, TCG used the "01.83" format for revision 183. + SpecRevision int `json:",omitempty"` +} + +// ServiceName is the name of a service, of the form `svc:dns-label`. Services +// represent some kind of application provided for users of the tailnet with a +// MagicDNS name and possibly dedicated IP addresses. Currently (2024-01-21), +// the only type of service is [VIPService]. +// This is not related to the older [Service] used in [Hostinfo.Services]. +type ServiceName string + +// Validate validates if the service name is formatted correctly. +// We only allow valid DNS labels, since the expectation is that these will be +// used as parts of domain names. All errors are [vizerror.Error]. +func (sn ServiceName) Validate() error { + bareName, ok := strings.CutPrefix(string(sn), "svc:") + if !ok { + return vizerror.Errorf("%q is not a valid service name: must start with 'svc:'", sn) + } + if bareName == "" { + return vizerror.Errorf("%q is not a valid service name: must not be empty after the 'svc:' prefix", sn) + } + return dnsname.ValidLabel(bareName) +} + +// String implements [fmt.Stringer]. +func (sn ServiceName) String() string { + return string(sn) +} + +// WithoutPrefix is the name of the service without the `svc:` prefix, used for +// DNS names. If the name does not include the prefix (which means +// [ServiceName.Validate] would return an error) then it returns "". +func (sn ServiceName) WithoutPrefix() string { + bareName, ok := strings.CutPrefix(string(sn), "svc:") + if !ok { + return "" + } + return bareName +} + +// VIPService represents a service created on a tailnet from the +// perspective of a node providing that service. These services +// have an virtual IP (VIP) address pair distinct from the node's IPs. +type VIPService struct { + // Name is the name of the service. The Name uniquely identifies a service + // on a particular tailnet, and so also corresponds uniquely to the pair of + // IP addresses belonging to the VIP service. + Name ServiceName + + // Ports specify which ProtoPorts are made available by this node + // on the service's IPs. + Ports []ProtoPortRange + + // Active specifies whether new requests for the service should be + // sent to this node by control. + Active bool +} + // TailscaleSSHEnabled reports whether or not this node is acting as a // Tailscale SSH server. func (hi *Hostinfo) TailscaleSSHEnabled() bool { @@ -824,14 +972,6 @@ func (hi *Hostinfo) TailscaleSSHEnabled() bool { func (v HostinfoView) TailscaleSSHEnabled() bool { return v.Đļ.TailscaleSSHEnabled() } -// TailscaleFunnelEnabled reports whether or not this node has explicitly -// enabled Funnel. -func (hi *Hostinfo) TailscaleFunnelEnabled() bool { - return hi != nil && hi.WireIngress -} - -func (v HostinfoView) TailscaleFunnelEnabled() bool { return v.Đļ.TailscaleFunnelEnabled() } - // NetInfo contains information about the host's network state. type NetInfo struct { // MappingVariesByDestIP says whether the host's NAT mappings @@ -975,68 +1115,6 @@ func (h *Hostinfo) Equal(h2 *Hostinfo) bool { return reflect.DeepEqual(h, h2) } -// HowUnequal returns a list of paths through Hostinfo where h and h2 differ. -// If they differ in nil-ness, the path is "nil", otherwise the path is like -// "ShieldsUp" or "NetInfo.nil" or "NetInfo.PCP". -func (h *Hostinfo) HowUnequal(h2 *Hostinfo) (path []string) { - return appendStructPtrDiff(nil, "", reflect.ValueOf(h), reflect.ValueOf(h2)) -} - -func appendStructPtrDiff(base []string, pfx string, p1, p2 reflect.Value) (ret []string) { - ret = base - if p1.IsNil() && p2.IsNil() { - return base - } - mkPath := func(b string) string { - if pfx == "" { - return b - } - return pfx + "." + b - } - if p1.IsNil() || p2.IsNil() { - return append(base, mkPath("nil")) - } - v1, v2 := p1.Elem(), p2.Elem() - t := v1.Type() - for i, n := 0, t.NumField(); i < n; i++ { - sf := t.Field(i) - switch sf.Type.Kind() { - case reflect.String: - if v1.Field(i).String() != v2.Field(i).String() { - ret = append(ret, mkPath(sf.Name)) - } - continue - case reflect.Bool: - if v1.Field(i).Bool() != v2.Field(i).Bool() { - ret = append(ret, mkPath(sf.Name)) - } - continue - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if v1.Field(i).Int() != v2.Field(i).Int() { - ret = append(ret, mkPath(sf.Name)) - } - continue - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - if v1.Field(i).Uint() != v2.Field(i).Uint() { - ret = append(ret, mkPath(sf.Name)) - } - continue - case reflect.Slice, reflect.Map: - if !reflect.DeepEqual(v1.Field(i).Interface(), v2.Field(i).Interface()) { - ret = append(ret, mkPath(sf.Name)) - } - continue - case reflect.Ptr: - if sf.Type.Elem().Kind() == reflect.Struct { - ret = appendStructPtrDiff(ret, sf.Name, v1.Field(i), v2.Field(i)) - continue - } - } - panic(fmt.Sprintf("unsupported type at %s: %s", mkPath(sf.Name), sf.Type.String())) - } - return ret -} - // SignatureType specifies a scheme for signing RegisterRequest messages. It // specifies the crypto algorithms to use, the contents of what is signed, and // any other relevant details. Historically, requests were unsigned so the zero @@ -1117,11 +1195,11 @@ type RegisterResponseAuth struct { AuthKey string `json:",omitempty"` } -// RegisterRequest is sent by a client to register the key for a node. -// It is encoded to JSON, encrypted with golang.org/x/crypto/nacl/box, -// using the local machine key, and sent to: +// RegisterRequest is a request to register a key for a node. +// +// This is JSON-encoded and sent over the control plane connection to: // -// https://login.tailscale.com/machine/ +// POST https:///machine/register. type RegisterRequest struct { _ structs.Incomparable @@ -1237,10 +1315,9 @@ type Endpoint struct { // The request includes a copy of the client's current set of WireGuard // endpoints and general host information. // -// The request is encoded to JSON, encrypted with golang.org/x/crypto/nacl/box, -// using the local machine key, and sent to: +// This is JSON-encoded and sent over the control plane connection to: // -// https://login.tailscale.com/machine//map +// POST https:///machine/map type MapRequest struct { // Version is incremented whenever the client code changes enough that // we want to signal to the control server that we're capable of something @@ -1336,6 +1413,12 @@ type MapRequest struct { // * "warn-router-unhealthy": client's Router implementation is // having problems. DebugFlags []string `json:",omitempty"` + + // ConnectionHandleForTest, if non-empty, is an opaque string sent by the client that + // identifies this specific connection to the server. The server may choose to + // use this handle to identify the connection for debugging or testing + // purposes. It has no semantic meaning. + ConnectionHandleForTest string `json:",omitempty"` } // PortRange represents a range of UDP or TCP port numbers. @@ -1355,7 +1438,7 @@ var PortRangeAny = PortRange{0, 65535} type NetPortRange struct { _ structs.Incomparable IP string // IP, CIDR, Range, or "*" (same formats as FilterRule.SrcIPs) - Bits *int // deprecated; the 2020 way to turn IP into a CIDR. See FilterRule.SrcBits. + Bits *int `json:",omitempty"` // deprecated; the 2020 way to turn IP into a CIDR. See FilterRule.SrcBits. Ports PortRange } @@ -1413,11 +1496,23 @@ const ( // user groups as Kubernetes user groups. This capability is read by // peers that are Tailscale Kubernetes operator instances. PeerCapabilityKubernetes PeerCapability = "tailscale.com/cap/kubernetes" + + // PeerCapabilityRelay grants the ability for a peer to allocate relay + // endpoints. + PeerCapabilityRelay PeerCapability = "tailscale.com/cap/relay" + // PeerCapabilityRelayTarget grants the current node the ability to allocate + // relay endpoints to the peer which has this capability. + PeerCapabilityRelayTarget PeerCapability = "tailscale.com/cap/relay-target" + + // PeerCapabilityTsIDP grants a peer tsidp-specific + // capabilities, such as the ability to add user groups to the OIDC + // claim + PeerCapabilityTsIDP PeerCapability = "tailscale.com/cap/tsidp" ) // NodeCapMap is a map of capabilities to their optional values. It is valid for // a capability to have no values (nil slice); such capabilities can be tested -// for by using the Contains method. +// for by using the [NodeCapMap.Contains] method. // // See [NodeCapability] for more information on keys. type NodeCapMap map[NodeCapability][]RawMessage @@ -1431,12 +1526,19 @@ func (c NodeCapMap) Equal(c2 NodeCapMap) bool { // If cap does not exist in cm, it returns (nil, nil). // It returns an error if the values cannot be unmarshaled into the provided type. func UnmarshalNodeCapJSON[T any](cm NodeCapMap, cap NodeCapability) ([]T, error) { - vals, ok := cm[cap] + return UnmarshalNodeCapViewJSON[T](views.MapSliceOf(cm), cap) +} + +// UnmarshalNodeCapViewJSON unmarshals each JSON value in cm.Get(cap) as T. +// If cap does not exist in cm, it returns (nil, nil). +// It returns an error if the values cannot be unmarshaled into the provided type. +func UnmarshalNodeCapViewJSON[T any](cm views.MapSlice[NodeCapability, RawMessage], cap NodeCapability) ([]T, error) { + vals, ok := cm.GetOk(cap) if !ok { return nil, nil } - out := make([]T, 0, len(vals)) - for _, v := range vals { + out := make([]T, 0, vals.Len()) + for _, v := range vals.All() { var t T if err := json.Unmarshal([]byte(v), &t); err != nil { return nil, err @@ -1466,12 +1568,19 @@ type PeerCapMap map[PeerCapability][]RawMessage // If cap does not exist in cm, it returns (nil, nil). // It returns an error if the values cannot be unmarshaled into the provided type. func UnmarshalCapJSON[T any](cm PeerCapMap, cap PeerCapability) ([]T, error) { - vals, ok := cm[cap] + return UnmarshalCapViewJSON[T](views.MapSliceOf(cm), cap) +} + +// UnmarshalCapViewJSON unmarshals each JSON value in cm.Get(cap) as T. +// If cap does not exist in cm, it returns (nil, nil). +// It returns an error if the values cannot be unmarshaled into the provided type. +func UnmarshalCapViewJSON[T any](cm views.MapSlice[PeerCapability, RawMessage], cap PeerCapability) ([]T, error) { + vals, ok := cm.GetOk(cap) if !ok { return nil, nil } - out := make([]T, 0, len(vals)) - for _, v := range vals { + out := make([]T, 0, vals.Len()) + for _, v := range vals.All() { var t T if err := json.Unmarshal([]byte(v), &t); err != nil { return nil, err @@ -1666,9 +1775,14 @@ const ( PingPeerAPI PingType = "peerapi" ) -// PingRequest with no IP and Types is a request to send an HTTP request to prove the -// long-polling client is still connected. -// PingRequest with Types and IP, will send a ping to the IP and send a POST +// PingRequest is a request from the control plane to the local node to probe +// something. +// +// A PingRequest with no IP and Types is a request from the control plane to the +// local node to send an HTTP request to a URL to prove the long-polling client +// is still connected. +// +// A PingRequest with Types and IP, will send a ping to the IP and send a POST // request containing a PingResponse to the URL containing results. type PingRequest struct { // URL is the URL to reply to the PingRequest to. @@ -1833,7 +1947,7 @@ type MapResponse struct { // PeersChangedPatch, if non-nil, means that node(s) have changed. // This is a lighter version of the older PeersChanged support that - // only supports certain types of updates + // only supports certain types of updates. // // These are applied after Peers* above, but in practice the // control server should only send these on their own, without @@ -1962,10 +2076,6 @@ type MapResponse struct { // auto-update setting doesn't change if the tailnet admin flips the // default after the node registered. DefaultAutoUpdate opt.Bool `json:",omitempty"` - - // MaxKeyDuration describes the MaxKeyDuration setting for the tailnet. - // If zero, the value is unchanged. - MaxKeyDuration time.Duration `json:",omitempty"` } // ClientVersion is information about the latest client version that's available @@ -2081,7 +2191,8 @@ func (n *Node) Equal(n2 *Node) bool { slicesx.EqualSameNil(n.AllowedIPs, n2.AllowedIPs) && slicesx.EqualSameNil(n.PrimaryRoutes, n2.PrimaryRoutes) && slicesx.EqualSameNil(n.Endpoints, n2.Endpoints) && - n.DERP == n2.DERP && + n.LegacyDERPString == n2.LegacyDERPString && + n.HomeDERP == n2.HomeDERP && n.Cap == n2.Cap && n.Hostinfo.Equal(n2.Hostinfo) && n.Created.Equal(n2.Created) && @@ -2353,20 +2464,57 @@ const ( // automatically when the network state changes. NodeAttrDisableCaptivePortalDetection NodeCapability = "disable-captive-portal-detection" + // NodeAttrDisableSkipStatusQueue is set when the node should disable skipping + // of queued netmap.NetworkMap between the controlclient and LocalBackend. + // See tailscale/tailscale#14768. + NodeAttrDisableSkipStatusQueue NodeCapability = "disable-skip-status-queue" + // NodeAttrSSHEnvironmentVariables enables logic for handling environment variables sent // via SendEnv in the SSH server and applying them to the SSH session. NodeAttrSSHEnvironmentVariables NodeCapability = "ssh-env-vars" + + // NodeAttrServiceHost indicates the VIP Services for which the client is + // approved to act as a service host, and which IP addresses are assigned + // to those VIP Services. Any VIP Services that the client is not + // advertising can be ignored. + // Each value of this key in [NodeCapMap] is of type [ServiceIPMappings]. + // If multiple values of this key exist, they should be merged in sequence + // (replace conflicting keys). + NodeAttrServiceHost NodeCapability = "service-host" + + // NodeAttrMaxKeyDuration represents the MaxKeyDuration setting on the + // tailnet. The value of this key in [NodeCapMap] will be only one entry of + // type float64 representing the duration in seconds. This cap will be + // omitted if the tailnet's MaxKeyDuration is the default. + NodeAttrMaxKeyDuration NodeCapability = "tailnet.maxKeyDuration" + + // NodeAttrNativeIPV4 contains the IPV4 address of the node in its + // native tailnet. This is currently only sent to Hello, in its + // peer node list. + NodeAttrNativeIPV4 NodeCapability = "native-ipv4" + + // NodeAttrRelayServer permits the node to act as an underlay UDP relay + // server. There are no expected values for this key in NodeCapMap. + NodeAttrRelayServer NodeCapability = "relay:server" + + // NodeAttrRelayClient permits the node to act as an underlay UDP relay + // client. There are no expected values for this key in NodeCapMap. + NodeAttrRelayClient NodeCapability = "relay:client" + + // NodeAttrMagicDNSPeerAAAA is a capability that tells the node's MagicDNS + // server to answer AAAA queries about its peers. See tailscale/tailscale#1152. + NodeAttrMagicDNSPeerAAAA NodeCapability = "magicdns-aaaa" ) // SetDNSRequest is a request to add a DNS record. // -// This is used for ACME DNS-01 challenges (so people can use -// LetsEncrypt, etc). +// This is used to let tailscaled clients complete their ACME DNS-01 challenges +// (so people can use LetsEncrypt, etc) to get TLS certificates for +// their foo.bar.ts.net MagicDNS names. // -// The request is encoded to JSON, encrypted with golang.org/x/crypto/nacl/box, -// using the local machine key, and sent to: +// This is JSON-encoded and sent over the control plane connection to: // -// https://login.tailscale.com/machine//set-dns +// POST https:///machine/set-dns type SetDNSRequest struct { // Version is the client's capabilities // (CurrentCapabilityVersion) when using the Noise transport. @@ -2396,7 +2544,9 @@ type SetDNSRequest struct { type SetDNSResponse struct{} // HealthChangeRequest is the JSON request body type used to report -// node health changes to https:///machine//update-health. +// node health changes to: +// +// POST https:///machine/update-health. type HealthChangeRequest struct { Subsys string // a health.Subsystem value in string form Error string // or empty if cleared @@ -2406,6 +2556,38 @@ type HealthChangeRequest struct { NodeKey key.NodePublic } +// SetDeviceAttributesRequest is a request to update the +// current node's device posture attributes. +// +// As of 2024-12-30, this is an experimental dev feature +// for internal testing. See tailscale/corp#24690. +// +// This is JSON-encoded and sent over the control plane connection to: +// +// PATCH https:///machine/set-device-attr +type SetDeviceAttributesRequest struct { + // Version is the current binary's [CurrentCapabilityVersion]. + Version CapabilityVersion + + // NodeKey identifies the node to modify. It should be the currently active + // node and is an error if not. + NodeKey key.NodePublic + + // Update is a map of device posture attributes to update. + // Attributes not in the map are left unchanged. + Update AttrUpdate +} + +// AttrUpdate is a map of attributes to update. +// Attributes not in the map are left unchanged. +// The value can be a string, float64, bool, or nil to delete. +// +// See https://tailscale.com/s/api-device-posture-attrs. +// +// TODO(bradfitz): add struct type for specifying optional associated data +// for each attribute value, like an expiry time? +type AttrUpdate map[string]any + // SSHPolicy is the policy for how to handle incoming SSH connections // over Tailscale. type SSHPolicy struct { @@ -2481,16 +2663,13 @@ type SSHPrincipal struct { Any bool `json:"any,omitempty"` // if true, match any connection // TODO(bradfitz): add StableUserID, once that exists - // PubKeys, if non-empty, means that this SSHPrincipal only - // matches if one of these public keys is presented by the user. + // UnusedPubKeys was public key support. It never became an official product + // feature and so as of 2024-12-12 is being removed. + // This stub exists to remind us not to re-use the JSON field name "pubKeys" + // in the future if we bring it back with different semantics. // - // As a special case, if len(PubKeys) == 1 and PubKeys[0] starts - // with "https://", then it's fetched (like https://github.com/username.keys). - // In that case, the following variable expansions are also supported - // in the URL: - // * $LOGINNAME_EMAIL ("foo@bar.com" or "foo@github") - // * $LOGINNAME_LOCALPART (the "foo" from either of the above) - PubKeys []string `json:"pubKeys,omitempty"` + // Deprecated: do not use. It does nothing. + UnusedPubKeys []string `json:"pubKeys,omitempty"` } // SSHAction is how to handle an incoming connection. @@ -2575,6 +2754,8 @@ type SSHRecorderFailureAction struct { // SSHEventNotifyRequest is the JSON payload sent to the NotifyURL // for an SSH event. +// +// POST https:///[...varies, sent in SSH policy...] type SSHEventNotifyRequest struct { // EventType is the type of notify request being sent. EventType SSHEventType @@ -2635,9 +2816,9 @@ type SSHRecordingAttempt struct { FailureMessage string } -// QueryFeatureRequest is a request sent to "/machine/feature/query" -// to get instructions on how to enable a feature, such as Funnel, -// for the node's tailnet. +// QueryFeatureRequest is a request sent to "POST /machine/feature/query" to get +// instructions on how to enable a feature, such as Funnel, for the node's +// tailnet. // // See QueryFeatureResponse for response structure. type QueryFeatureRequest struct { @@ -2726,7 +2907,7 @@ type OverTLSPublicKeyResponse struct { // The token can be presented to any resource provider which offers OIDC // Federation. // -// It is JSON-encoded and sent over Noise to "/machine/id-token". +// It is JSON-encoded and sent over Noise to "POST /machine/id-token". type TokenRequest struct { // CapVersion is the client's current CapabilityVersion. CapVersion CapabilityVersion @@ -2841,3 +3022,51 @@ type EarlyNoise struct { // For some request types, the header may have multiple values. (e.g. OldNodeKey // vs NodeKey) const LBHeader = "Ts-Lb" + +// ServiceIPMappings maps ServiceName to lists of IP addresses. This is used +// as the value of the [NodeAttrServiceHost] capability, to inform service hosts +// what IP addresses they need to listen on for each service that they are +// advertising. +// +// This is of the form: +// +// { +// "svc:samba": ["100.65.32.1", "fd7a:115c:a1e0::1234"], +// "svc:web": ["100.102.42.3", "fd7a:115c:a1e0::abcd"], +// } +// +// where the IP addresses are the IPs of the VIP services. These IPs are also +// provided in AllowedIPs, but this lets the client know which services +// correspond to those IPs. Any services that don't correspond to a service +// this client is hosting can be ignored. +type ServiceIPMappings map[ServiceName][]netip.Addr + +// ClientAuditAction represents an auditable action that a client can report to the +// control plane. These actions must correspond to the supported actions +// in the control plane. +type ClientAuditAction string + +const ( + // AuditNodeDisconnect action is sent when a node has disconnected + // from the control plane. The details must include a reason in the Details + // field, either generated, or entered by the user. + AuditNodeDisconnect = ClientAuditAction("DISCONNECT_NODE") +) + +// AuditLogRequest represents an audit log request to be sent to the control plane. +// +// This is JSON-encoded and sent over the control plane connection to: +// POST https:///machine/audit-log +type AuditLogRequest struct { + // Version is the client's current CapabilityVersion. + Version CapabilityVersion `json:",omitempty"` + // NodeKey is the client's current node key. + NodeKey key.NodePublic `json:",omitzero"` + // Action is the action to be logged. It must correspond to a known action in the control plane. + Action ClientAuditAction `json:",omitempty"` + // Details is an opaque string, specific to the action being logged. Empty strings may not + // be valid depending on the action being logged. + Details string `json:",omitempty"` + // Timestamp is the time at which the audit log was generated on the node. + Timestamp time.Time `json:",omitzero"` +} diff --git a/tailcfg/tailcfg_clone.go b/tailcfg/tailcfg_clone.go index 61564f3f8bfd4..2c7941d51d7e3 100644 --- a/tailcfg/tailcfg_clone.go +++ b/tailcfg/tailcfg_clone.go @@ -26,17 +26,14 @@ func (src *User) Clone() *User { } dst := new(User) *dst = *src - dst.Logins = append(src.Logins[:0:0], src.Logins...) return dst } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _UserCloneNeedsRegeneration = User(struct { ID UserID - LoginName string DisplayName string ProfilePicURL string - Logins []LoginID Created time.Time }{}) @@ -102,7 +99,8 @@ var _NodeCloneNeedsRegeneration = Node(struct { Addresses []netip.Prefix AllowedIPs []netip.Prefix Endpoints []netip.AddrPort - DERP string + LegacyDERPString string + HomeDERP int Hostinfo HostinfoView Created time.Time Cap CapabilityVersion @@ -143,6 +141,9 @@ func (src *Hostinfo) Clone() *Hostinfo { if dst.Location != nil { dst.Location = ptr.To(*src.Location) } + if dst.TPM != nil { + dst.TPM = ptr.To(*src.TPM) + } return dst } @@ -168,6 +169,7 @@ var _HostinfoCloneNeedsRegeneration = Hostinfo(struct { ShareeNode bool NoLogsNoSupport bool WireIngress bool + IngressEnabled bool AllowsUpdate bool Machine string GoArch string @@ -183,7 +185,9 @@ var _HostinfoCloneNeedsRegeneration = Hostinfo(struct { Userspace opt.Bool UserspaceRouter opt.Bool AppConnector opt.Bool + ServicesHash string Location *Location + TPM *TPMInfo }{}) // Clone makes a deep copy of NetInfo. @@ -301,7 +305,6 @@ func (src *RegisterResponse) Clone() *RegisterResponse { } dst := new(RegisterResponse) *dst = *src - dst.User = *src.User.Clone() dst.NodeKeySignature = append(src.NodeKeySignature[:0:0], src.NodeKeySignature...) return dst } @@ -417,13 +420,14 @@ func (src *DERPRegion) Clone() *DERPRegion { // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _DERPRegionCloneNeedsRegeneration = DERPRegion(struct { - RegionID int - RegionCode string - RegionName string - Latitude float64 - Longitude float64 - Avoid bool - Nodes []*DERPNode + RegionID int + RegionCode string + RegionName string + Latitude float64 + Longitude float64 + Avoid bool + NoMeasureNoHome bool + Nodes []*DERPNode }{}) // Clone makes a deep copy of DERPMap. @@ -555,17 +559,17 @@ func (src *SSHPrincipal) Clone() *SSHPrincipal { } dst := new(SSHPrincipal) *dst = *src - dst.PubKeys = append(src.PubKeys[:0:0], src.PubKeys...) + dst.UnusedPubKeys = append(src.UnusedPubKeys[:0:0], src.UnusedPubKeys...) return dst } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _SSHPrincipalCloneNeedsRegeneration = SSHPrincipal(struct { - Node StableNodeID - NodeIP string - UserLogin string - Any bool - PubKeys []string + Node StableNodeID + NodeIP string + UserLogin string + Any bool + UnusedPubKeys []string }{}) // Clone makes a deep copy of ControlDialPlan. @@ -624,12 +628,30 @@ var _UserProfileCloneNeedsRegeneration = UserProfile(struct { LoginName string DisplayName string ProfilePicURL string - Roles emptyStructJSONSlice +}{}) + +// Clone makes a deep copy of VIPService. +// The result aliases no memory with the original. +func (src *VIPService) Clone() *VIPService { + if src == nil { + return nil + } + dst := new(VIPService) + *dst = *src + dst.Ports = append(src.Ports[:0:0], src.Ports...) + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _VIPServiceCloneNeedsRegeneration = VIPService(struct { + Name ServiceName + Ports []ProtoPortRange + Active bool }{}) // Clone duplicates src into dst and reports whether it succeeded. // To succeed, must be of types <*T, *T> or <*T, **T>, -// where T is one of User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile. +// where T is one of User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile,VIPService. func Clone(dst, src any) bool { switch src := src.(type) { case *User: @@ -803,6 +825,15 @@ func Clone(dst, src any) bool { *dst = src.Clone() return true } + case *VIPService: + switch dst := dst.(type) { + case *VIPService: + *dst = *src.Clone() + return true + case **VIPService: + *dst = src.Clone() + return true + } } return false } diff --git a/tailcfg/tailcfg_test.go b/tailcfg/tailcfg_test.go index 0d06366771d6e..079162a150191 100644 --- a/tailcfg/tailcfg_test.go +++ b/tailcfg/tailcfg_test.go @@ -10,7 +10,6 @@ import ( "reflect" "regexp" "strconv" - "strings" "testing" "time" @@ -51,6 +50,7 @@ func TestHostinfoEqual(t *testing.T) { "ShareeNode", "NoLogsNoSupport", "WireIngress", + "IngressEnabled", "AllowsUpdate", "Machine", "GoArch", @@ -66,7 +66,9 @@ func TestHostinfoEqual(t *testing.T) { "Userspace", "UserspaceRouter", "AppConnector", + "ServicesHash", "Location", + "TPM", } if have := fieldsOf(reflect.TypeFor[Hostinfo]()); !reflect.DeepEqual(have, hiHandles) { t.Errorf("Hostinfo.Equal check might be out of sync\nfields: %q\nhandled: %q\n", @@ -240,87 +242,41 @@ func TestHostinfoEqual(t *testing.T) { &Hostinfo{AppConnector: opt.Bool("false")}, false, }, - } - for i, tt := range tests { - got := tt.a.Equal(tt.b) - if got != tt.want { - t.Errorf("%d. Equal = %v; want %v", i, got, tt.want) - } - } -} - -func TestHostinfoHowEqual(t *testing.T) { - tests := []struct { - a, b *Hostinfo - want []string - }{ { - a: nil, - b: nil, - want: nil, + &Hostinfo{ServicesHash: "73475cb40a568e8da8a045ced110137e159f890ac4da883b6b17dc651b3a8049"}, + &Hostinfo{ServicesHash: "73475cb40a568e8da8a045ced110137e159f890ac4da883b6b17dc651b3a8049"}, + true, }, { - a: new(Hostinfo), - b: nil, - want: []string{"nil"}, + &Hostinfo{ServicesHash: "084c799cd551dd1d8d5c5f9a5d593b2e931f5e36122ee5c793c1d08a19839cc0"}, + &Hostinfo{}, + false, }, { - a: nil, - b: new(Hostinfo), - want: []string{"nil"}, + &Hostinfo{IngressEnabled: true}, + &Hostinfo{}, + false, }, { - a: new(Hostinfo), - b: new(Hostinfo), - want: nil, + &Hostinfo{IngressEnabled: true}, + &Hostinfo{IngressEnabled: true}, + true, }, { - a: &Hostinfo{ - IPNVersion: "1", - ShieldsUp: false, - RoutableIPs: []netip.Prefix{netip.MustParsePrefix("1.2.3.0/24")}, - }, - b: &Hostinfo{ - IPNVersion: "2", - ShieldsUp: true, - RoutableIPs: []netip.Prefix{netip.MustParsePrefix("1.2.3.0/25")}, - }, - want: []string{"IPNVersion", "ShieldsUp", "RoutableIPs"}, + &Hostinfo{IngressEnabled: false}, + &Hostinfo{}, + true, }, { - a: &Hostinfo{ - IPNVersion: "1", - }, - b: &Hostinfo{ - IPNVersion: "2", - NetInfo: new(NetInfo), - }, - want: []string{"IPNVersion", "NetInfo.nil"}, - }, - { - a: &Hostinfo{ - IPNVersion: "1", - NetInfo: &NetInfo{ - WorkingIPv6: "true", - HavePortMap: true, - LinkType: "foo", - PreferredDERP: 123, - DERPLatency: map[string]float64{ - "foo": 1.0, - }, - }, - }, - b: &Hostinfo{ - IPNVersion: "2", - NetInfo: &NetInfo{}, - }, - want: []string{"IPNVersion", "NetInfo.WorkingIPv6", "NetInfo.HavePortMap", "NetInfo.PreferredDERP", "NetInfo.LinkType", "NetInfo.DERPLatency"}, + &Hostinfo{IngressEnabled: false}, + &Hostinfo{IngressEnabled: true}, + false, }, } for i, tt := range tests { - got := tt.a.HowUnequal(tt.b) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("%d. got %q; want %q", i, got, tt.want) + got := tt.a.Equal(tt.b) + if got != tt.want { + t.Errorf("%d. Equal = %v; want %v", i, got, tt.want) } } } @@ -356,7 +312,7 @@ func TestNodeEqual(t *testing.T) { nodeHandles := []string{ "ID", "StableID", "Name", "User", "Sharer", "Key", "KeyExpiry", "KeySignature", "Machine", "DiscoKey", - "Addresses", "AllowedIPs", "Endpoints", "DERP", "Hostinfo", + "Addresses", "AllowedIPs", "Endpoints", "LegacyDERPString", "HomeDERP", "Hostinfo", "Created", "Cap", "Tags", "PrimaryRoutes", "LastSeen", "Online", "MachineAuthorized", "Capabilities", "CapMap", @@ -519,8 +475,13 @@ func TestNodeEqual(t *testing.T) { true, }, { - &Node{DERP: "foo"}, - &Node{DERP: "bar"}, + &Node{LegacyDERPString: "foo"}, + &Node{LegacyDERPString: "bar"}, + false, + }, + { + &Node{HomeDERP: 1}, + &Node{HomeDERP: 2}, false, }, { @@ -655,7 +616,6 @@ func TestCloneUser(t *testing.T) { u *User }{ {"nil_logins", &User{}}, - {"zero_logins", &User{Logins: make([]LoginID, 0)}}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -689,28 +649,6 @@ func TestCloneNode(t *testing.T) { } } -func TestUserProfileJSONMarshalForMac(t *testing.T) { - // Old macOS clients had a bug where they required - // UserProfile.Roles to be non-null. Lock that in - // 1.0.x/1.2.x clients are gone in the wild. - // See mac commit 0242c08a2ca496958027db1208f44251bff8488b (Sep 30). - // It was fixed in at least 1.4.x, and perhaps 1.2.x. - j, err := json.Marshal(UserProfile{}) - if err != nil { - t.Fatal(err) - } - const wantSub = `"Roles":[]` - if !strings.Contains(string(j), wantSub) { - t.Fatalf("didn't contain %#q; got: %s", wantSub, j) - } - - // And back: - var up UserProfile - if err := json.Unmarshal(j, &up); err != nil { - t.Fatalf("Unmarshal: %v", err) - } -} - func TestEndpointTypeMarshal(t *testing.T) { eps := []EndpointType{ EndpointUnknownType, diff --git a/tailcfg/tailcfg_view.go b/tailcfg/tailcfg_view.go index a3e19b0dcec7a..c76654887f8ab 100644 --- a/tailcfg/tailcfg_view.go +++ b/tailcfg/tailcfg_view.go @@ -19,9 +19,9 @@ import ( "tailscale.com/types/views" ) -//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile +//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type=User,Node,Hostinfo,NetInfo,Login,DNSConfig,RegisterResponse,RegisterResponseAuth,RegisterRequest,DERPHomeParams,DERPRegion,DERPMap,DERPNode,SSHRule,SSHAction,SSHPrincipal,ControlDialPlan,Location,UserProfile,VIPService -// View returns a readonly view of User. +// View returns a read-only view of User. func (p *User) View() UserView { return UserView{Đļ: p} } @@ -37,7 +37,7 @@ type UserView struct { Đļ *User } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v UserView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -66,24 +66,20 @@ func (v *UserView) UnmarshalJSON(b []byte) error { return nil } -func (v UserView) ID() UserID { return v.Đļ.ID } -func (v UserView) LoginName() string { return v.Đļ.LoginName } -func (v UserView) DisplayName() string { return v.Đļ.DisplayName } -func (v UserView) ProfilePicURL() string { return v.Đļ.ProfilePicURL } -func (v UserView) Logins() views.Slice[LoginID] { return views.SliceOf(v.Đļ.Logins) } -func (v UserView) Created() time.Time { return v.Đļ.Created } +func (v UserView) ID() UserID { return v.Đļ.ID } +func (v UserView) DisplayName() string { return v.Đļ.DisplayName } +func (v UserView) ProfilePicURL() string { return v.Đļ.ProfilePicURL } +func (v UserView) Created() time.Time { return v.Đļ.Created } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _UserViewNeedsRegeneration = User(struct { ID UserID - LoginName string DisplayName string ProfilePicURL string - Logins []LoginID Created time.Time }{}) -// View returns a readonly view of Node. +// View returns a read-only view of Node. func (p *Node) View() NodeView { return NodeView{Đļ: p} } @@ -99,7 +95,7 @@ type NodeView struct { Đļ *Node } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v NodeView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -143,27 +139,18 @@ func (v NodeView) DiscoKey() key.DiscoPublic { return v.Đļ.DiscoK func (v NodeView) Addresses() views.Slice[netip.Prefix] { return views.SliceOf(v.Đļ.Addresses) } func (v NodeView) AllowedIPs() views.Slice[netip.Prefix] { return views.SliceOf(v.Đļ.AllowedIPs) } func (v NodeView) Endpoints() views.Slice[netip.AddrPort] { return views.SliceOf(v.Đļ.Endpoints) } -func (v NodeView) DERP() string { return v.Đļ.DERP } +func (v NodeView) LegacyDERPString() string { return v.Đļ.LegacyDERPString } +func (v NodeView) HomeDERP() int { return v.Đļ.HomeDERP } func (v NodeView) Hostinfo() HostinfoView { return v.Đļ.Hostinfo } func (v NodeView) Created() time.Time { return v.Đļ.Created } func (v NodeView) Cap() CapabilityVersion { return v.Đļ.Cap } func (v NodeView) Tags() views.Slice[string] { return views.SliceOf(v.Đļ.Tags) } func (v NodeView) PrimaryRoutes() views.Slice[netip.Prefix] { return views.SliceOf(v.Đļ.PrimaryRoutes) } -func (v NodeView) LastSeen() *time.Time { - if v.Đļ.LastSeen == nil { - return nil - } - x := *v.Đļ.LastSeen - return &x +func (v NodeView) LastSeen() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.Đļ.LastSeen) } -func (v NodeView) Online() *bool { - if v.Đļ.Online == nil { - return nil - } - x := *v.Đļ.Online - return &x -} +func (v NodeView) Online() views.ValuePointer[bool] { return views.ValuePointerOf(v.Đļ.Online) } func (v NodeView) MachineAuthorized() bool { return v.Đļ.MachineAuthorized } func (v NodeView) Capabilities() views.Slice[NodeCapability] { return views.SliceOf(v.Đļ.Capabilities) } @@ -176,20 +163,12 @@ func (v NodeView) ComputedName() string { return v.Đļ.ComputedName } func (v NodeView) ComputedNameWithHost() string { return v.Đļ.ComputedNameWithHost } func (v NodeView) DataPlaneAuditLogID() string { return v.Đļ.DataPlaneAuditLogID } func (v NodeView) Expired() bool { return v.Đļ.Expired } -func (v NodeView) SelfNodeV4MasqAddrForThisPeer() *netip.Addr { - if v.Đļ.SelfNodeV4MasqAddrForThisPeer == nil { - return nil - } - x := *v.Đļ.SelfNodeV4MasqAddrForThisPeer - return &x +func (v NodeView) SelfNodeV4MasqAddrForThisPeer() views.ValuePointer[netip.Addr] { + return views.ValuePointerOf(v.Đļ.SelfNodeV4MasqAddrForThisPeer) } -func (v NodeView) SelfNodeV6MasqAddrForThisPeer() *netip.Addr { - if v.Đļ.SelfNodeV6MasqAddrForThisPeer == nil { - return nil - } - x := *v.Đļ.SelfNodeV6MasqAddrForThisPeer - return &x +func (v NodeView) SelfNodeV6MasqAddrForThisPeer() views.ValuePointer[netip.Addr] { + return views.ValuePointerOf(v.Đļ.SelfNodeV6MasqAddrForThisPeer) } func (v NodeView) IsWireGuardOnly() bool { return v.Đļ.IsWireGuardOnly } @@ -214,7 +193,8 @@ var _NodeViewNeedsRegeneration = Node(struct { Addresses []netip.Prefix AllowedIPs []netip.Prefix Endpoints []netip.AddrPort - DERP string + LegacyDERPString string + HomeDERP int Hostinfo HostinfoView Created time.Time Cap CapabilityVersion @@ -238,7 +218,7 @@ var _NodeViewNeedsRegeneration = Node(struct { ExitNodeDNSResolvers []*dnstype.Resolver }{}) -// View returns a readonly view of Hostinfo. +// View returns a read-only view of Hostinfo. func (p *Hostinfo) View() HostinfoView { return HostinfoView{Đļ: p} } @@ -254,7 +234,7 @@ type HostinfoView struct { Đļ *Hostinfo } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v HostinfoView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -303,6 +283,7 @@ func (v HostinfoView) ShieldsUp() bool { return v.Đļ.Shie func (v HostinfoView) ShareeNode() bool { return v.Đļ.ShareeNode } func (v HostinfoView) NoLogsNoSupport() bool { return v.Đļ.NoLogsNoSupport } func (v HostinfoView) WireIngress() bool { return v.Đļ.WireIngress } +func (v HostinfoView) IngressEnabled() bool { return v.Đļ.IngressEnabled } func (v HostinfoView) AllowsUpdate() bool { return v.Đļ.AllowsUpdate } func (v HostinfoView) Machine() string { return v.Đļ.Machine } func (v HostinfoView) GoArch() string { return v.Đļ.GoArch } @@ -318,13 +299,9 @@ func (v HostinfoView) Cloud() string { return v.Đļ.Clou func (v HostinfoView) Userspace() opt.Bool { return v.Đļ.Userspace } func (v HostinfoView) UserspaceRouter() opt.Bool { return v.Đļ.UserspaceRouter } func (v HostinfoView) AppConnector() opt.Bool { return v.Đļ.AppConnector } -func (v HostinfoView) Location() *Location { - if v.Đļ.Location == nil { - return nil - } - x := *v.Đļ.Location - return &x -} +func (v HostinfoView) ServicesHash() string { return v.Đļ.ServicesHash } +func (v HostinfoView) Location() LocationView { return v.Đļ.Location.View() } +func (v HostinfoView) TPM() views.ValuePointer[TPMInfo] { return views.ValuePointerOf(v.Đļ.TPM) } func (v HostinfoView) Equal(v2 HostinfoView) bool { return v.Đļ.Equal(v2.Đļ) } @@ -350,6 +327,7 @@ var _HostinfoViewNeedsRegeneration = Hostinfo(struct { ShareeNode bool NoLogsNoSupport bool WireIngress bool + IngressEnabled bool AllowsUpdate bool Machine string GoArch string @@ -365,10 +343,12 @@ var _HostinfoViewNeedsRegeneration = Hostinfo(struct { Userspace opt.Bool UserspaceRouter opt.Bool AppConnector opt.Bool + ServicesHash string Location *Location + TPM *TPMInfo }{}) -// View returns a readonly view of NetInfo. +// View returns a read-only view of NetInfo. func (p *NetInfo) View() NetInfoView { return NetInfoView{Đļ: p} } @@ -384,7 +364,7 @@ type NetInfoView struct { Đļ *NetInfo } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v NetInfoView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -448,7 +428,7 @@ var _NetInfoViewNeedsRegeneration = NetInfo(struct { FirewallMode string }{}) -// View returns a readonly view of Login. +// View returns a read-only view of Login. func (p *Login) View() LoginView { return LoginView{Đļ: p} } @@ -464,7 +444,7 @@ type LoginView struct { Đļ *Login } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v LoginView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -509,7 +489,7 @@ var _LoginViewNeedsRegeneration = Login(struct { ProfilePicURL string }{}) -// View returns a readonly view of DNSConfig. +// View returns a read-only view of DNSConfig. func (p *DNSConfig) View() DNSConfigView { return DNSConfigView{Đļ: p} } @@ -525,7 +505,7 @@ type DNSConfigView struct { Đļ *DNSConfig } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v DNSConfigView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -590,7 +570,7 @@ var _DNSConfigViewNeedsRegeneration = DNSConfig(struct { TempCorpIssue13969 string }{}) -// View returns a readonly view of RegisterResponse. +// View returns a read-only view of RegisterResponse. func (p *RegisterResponse) View() RegisterResponseView { return RegisterResponseView{Đļ: p} } @@ -606,7 +586,7 @@ type RegisterResponseView struct { Đļ *RegisterResponse } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v RegisterResponseView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -635,7 +615,7 @@ func (v *RegisterResponseView) UnmarshalJSON(b []byte) error { return nil } -func (v RegisterResponseView) User() UserView { return v.Đļ.User.View() } +func (v RegisterResponseView) User() User { return v.Đļ.User } func (v RegisterResponseView) Login() Login { return v.Đļ.Login } func (v RegisterResponseView) NodeKeyExpired() bool { return v.Đļ.NodeKeyExpired } func (v RegisterResponseView) MachineAuthorized() bool { return v.Đļ.MachineAuthorized } @@ -656,7 +636,7 @@ var _RegisterResponseViewNeedsRegeneration = RegisterResponse(struct { Error string }{}) -// View returns a readonly view of RegisterResponseAuth. +// View returns a read-only view of RegisterResponseAuth. func (p *RegisterResponseAuth) View() RegisterResponseAuthView { return RegisterResponseAuthView{Đļ: p} } @@ -672,7 +652,7 @@ type RegisterResponseAuthView struct { Đļ *RegisterResponseAuth } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v RegisterResponseAuthView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -701,12 +681,8 @@ func (v *RegisterResponseAuthView) UnmarshalJSON(b []byte) error { return nil } -func (v RegisterResponseAuthView) Oauth2Token() *Oauth2Token { - if v.Đļ.Oauth2Token == nil { - return nil - } - x := *v.Đļ.Oauth2Token - return &x +func (v RegisterResponseAuthView) Oauth2Token() views.ValuePointer[Oauth2Token] { + return views.ValuePointerOf(v.Đļ.Oauth2Token) } func (v RegisterResponseAuthView) AuthKey() string { return v.Đļ.AuthKey } @@ -718,7 +694,7 @@ var _RegisterResponseAuthViewNeedsRegeneration = RegisterResponseAuth(struct { AuthKey string }{}) -// View returns a readonly view of RegisterRequest. +// View returns a read-only view of RegisterRequest. func (p *RegisterRequest) View() RegisterRequestView { return RegisterRequestView{Đļ: p} } @@ -734,7 +710,7 @@ type RegisterRequestView struct { Đļ *RegisterRequest } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v RegisterRequestView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -776,12 +752,8 @@ func (v RegisterRequestView) NodeKeySignature() views.ByteSlice[tkatype.Marshale return views.ByteSliceOf(v.Đļ.NodeKeySignature) } func (v RegisterRequestView) SignatureType() SignatureType { return v.Đļ.SignatureType } -func (v RegisterRequestView) Timestamp() *time.Time { - if v.Đļ.Timestamp == nil { - return nil - } - x := *v.Đļ.Timestamp - return &x +func (v RegisterRequestView) Timestamp() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.Đļ.Timestamp) } func (v RegisterRequestView) DeviceCert() views.ByteSlice[[]byte] { @@ -812,7 +784,7 @@ var _RegisterRequestViewNeedsRegeneration = RegisterRequest(struct { Tailnet string }{}) -// View returns a readonly view of DERPHomeParams. +// View returns a read-only view of DERPHomeParams. func (p *DERPHomeParams) View() DERPHomeParamsView { return DERPHomeParamsView{Đļ: p} } @@ -828,7 +800,7 @@ type DERPHomeParamsView struct { Đļ *DERPHomeParams } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v DERPHomeParamsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -866,7 +838,7 @@ var _DERPHomeParamsViewNeedsRegeneration = DERPHomeParams(struct { RegionScore map[int]float64 }{}) -// View returns a readonly view of DERPRegion. +// View returns a read-only view of DERPRegion. func (p *DERPRegion) View() DERPRegionView { return DERPRegionView{Đļ: p} } @@ -882,7 +854,7 @@ type DERPRegionView struct { Đļ *DERPRegion } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v DERPRegionView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -911,28 +883,30 @@ func (v *DERPRegionView) UnmarshalJSON(b []byte) error { return nil } -func (v DERPRegionView) RegionID() int { return v.Đļ.RegionID } -func (v DERPRegionView) RegionCode() string { return v.Đļ.RegionCode } -func (v DERPRegionView) RegionName() string { return v.Đļ.RegionName } -func (v DERPRegionView) Latitude() float64 { return v.Đļ.Latitude } -func (v DERPRegionView) Longitude() float64 { return v.Đļ.Longitude } -func (v DERPRegionView) Avoid() bool { return v.Đļ.Avoid } +func (v DERPRegionView) RegionID() int { return v.Đļ.RegionID } +func (v DERPRegionView) RegionCode() string { return v.Đļ.RegionCode } +func (v DERPRegionView) RegionName() string { return v.Đļ.RegionName } +func (v DERPRegionView) Latitude() float64 { return v.Đļ.Latitude } +func (v DERPRegionView) Longitude() float64 { return v.Đļ.Longitude } +func (v DERPRegionView) Avoid() bool { return v.Đļ.Avoid } +func (v DERPRegionView) NoMeasureNoHome() bool { return v.Đļ.NoMeasureNoHome } func (v DERPRegionView) Nodes() views.SliceView[*DERPNode, DERPNodeView] { return views.SliceOfViews[*DERPNode, DERPNodeView](v.Đļ.Nodes) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _DERPRegionViewNeedsRegeneration = DERPRegion(struct { - RegionID int - RegionCode string - RegionName string - Latitude float64 - Longitude float64 - Avoid bool - Nodes []*DERPNode + RegionID int + RegionCode string + RegionName string + Latitude float64 + Longitude float64 + Avoid bool + NoMeasureNoHome bool + Nodes []*DERPNode }{}) -// View returns a readonly view of DERPMap. +// View returns a read-only view of DERPMap. func (p *DERPMap) View() DERPMapView { return DERPMapView{Đļ: p} } @@ -948,7 +922,7 @@ type DERPMapView struct { Đļ *DERPMap } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v DERPMapView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -993,7 +967,7 @@ var _DERPMapViewNeedsRegeneration = DERPMap(struct { OmitDefaultRegions bool }{}) -// View returns a readonly view of DERPNode. +// View returns a read-only view of DERPNode. func (p *DERPNode) View() DERPNodeView { return DERPNodeView{Đļ: p} } @@ -1009,7 +983,7 @@ type DERPNodeView struct { Đļ *DERPNode } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v DERPNodeView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -1067,7 +1041,7 @@ var _DERPNodeViewNeedsRegeneration = DERPNode(struct { CanPort80 bool }{}) -// View returns a readonly view of SSHRule. +// View returns a read-only view of SSHRule. func (p *SSHRule) View() SSHRuleView { return SSHRuleView{Đļ: p} } @@ -1083,7 +1057,7 @@ type SSHRuleView struct { Đļ *SSHRule } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v SSHRuleView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -1112,12 +1086,8 @@ func (v *SSHRuleView) UnmarshalJSON(b []byte) error { return nil } -func (v SSHRuleView) RuleExpires() *time.Time { - if v.Đļ.RuleExpires == nil { - return nil - } - x := *v.Đļ.RuleExpires - return &x +func (v SSHRuleView) RuleExpires() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.Đļ.RuleExpires) } func (v SSHRuleView) Principals() views.SliceView[*SSHPrincipal, SSHPrincipalView] { @@ -1137,7 +1107,7 @@ var _SSHRuleViewNeedsRegeneration = SSHRule(struct { AcceptEnv []string }{}) -// View returns a readonly view of SSHAction. +// View returns a read-only view of SSHAction. func (p *SSHAction) View() SSHActionView { return SSHActionView{Đļ: p} } @@ -1153,7 +1123,7 @@ type SSHActionView struct { Đļ *SSHAction } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v SSHActionView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -1191,12 +1161,8 @@ func (v SSHActionView) HoldAndDelegate() string { return v.Đļ.Hol func (v SSHActionView) AllowLocalPortForwarding() bool { return v.Đļ.AllowLocalPortForwarding } func (v SSHActionView) AllowRemotePortForwarding() bool { return v.Đļ.AllowRemotePortForwarding } func (v SSHActionView) Recorders() views.Slice[netip.AddrPort] { return views.SliceOf(v.Đļ.Recorders) } -func (v SSHActionView) OnRecordingFailure() *SSHRecorderFailureAction { - if v.Đļ.OnRecordingFailure == nil { - return nil - } - x := *v.Đļ.OnRecordingFailure - return &x +func (v SSHActionView) OnRecordingFailure() views.ValuePointer[SSHRecorderFailureAction] { + return views.ValuePointerOf(v.Đļ.OnRecordingFailure) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -1213,7 +1179,7 @@ var _SSHActionViewNeedsRegeneration = SSHAction(struct { OnRecordingFailure *SSHRecorderFailureAction }{}) -// View returns a readonly view of SSHPrincipal. +// View returns a read-only view of SSHPrincipal. func (p *SSHPrincipal) View() SSHPrincipalView { return SSHPrincipalView{Đļ: p} } @@ -1229,7 +1195,7 @@ type SSHPrincipalView struct { Đļ *SSHPrincipal } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v SSHPrincipalView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -1258,22 +1224,24 @@ func (v *SSHPrincipalView) UnmarshalJSON(b []byte) error { return nil } -func (v SSHPrincipalView) Node() StableNodeID { return v.Đļ.Node } -func (v SSHPrincipalView) NodeIP() string { return v.Đļ.NodeIP } -func (v SSHPrincipalView) UserLogin() string { return v.Đļ.UserLogin } -func (v SSHPrincipalView) Any() bool { return v.Đļ.Any } -func (v SSHPrincipalView) PubKeys() views.Slice[string] { return views.SliceOf(v.Đļ.PubKeys) } +func (v SSHPrincipalView) Node() StableNodeID { return v.Đļ.Node } +func (v SSHPrincipalView) NodeIP() string { return v.Đļ.NodeIP } +func (v SSHPrincipalView) UserLogin() string { return v.Đļ.UserLogin } +func (v SSHPrincipalView) Any() bool { return v.Đļ.Any } +func (v SSHPrincipalView) UnusedPubKeys() views.Slice[string] { + return views.SliceOf(v.Đļ.UnusedPubKeys) +} // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _SSHPrincipalViewNeedsRegeneration = SSHPrincipal(struct { - Node StableNodeID - NodeIP string - UserLogin string - Any bool - PubKeys []string + Node StableNodeID + NodeIP string + UserLogin string + Any bool + UnusedPubKeys []string }{}) -// View returns a readonly view of ControlDialPlan. +// View returns a read-only view of ControlDialPlan. func (p *ControlDialPlan) View() ControlDialPlanView { return ControlDialPlanView{Đļ: p} } @@ -1289,7 +1257,7 @@ type ControlDialPlanView struct { Đļ *ControlDialPlan } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v ControlDialPlanView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -1327,7 +1295,7 @@ var _ControlDialPlanViewNeedsRegeneration = ControlDialPlan(struct { Candidates []ControlIPCandidate }{}) -// View returns a readonly view of Location. +// View returns a read-only view of Location. func (p *Location) View() LocationView { return LocationView{Đļ: p} } @@ -1343,7 +1311,7 @@ type LocationView struct { Đļ *Location } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v LocationView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -1391,7 +1359,7 @@ var _LocationViewNeedsRegeneration = Location(struct { Priority int }{}) -// View returns a readonly view of UserProfile. +// View returns a read-only view of UserProfile. func (p *UserProfile) View() UserProfileView { return UserProfileView{Đļ: p} } @@ -1407,7 +1375,7 @@ type UserProfileView struct { Đļ *UserProfile } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v UserProfileView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -1440,7 +1408,6 @@ func (v UserProfileView) ID() UserID { return v.Đļ.ID } func (v UserProfileView) LoginName() string { return v.Đļ.LoginName } func (v UserProfileView) DisplayName() string { return v.Đļ.DisplayName } func (v UserProfileView) ProfilePicURL() string { return v.Đļ.ProfilePicURL } -func (v UserProfileView) Roles() emptyStructJSONSlice { return v.Đļ.Roles } func (v UserProfileView) Equal(v2 UserProfileView) bool { return v.Đļ.Equal(v2.Đļ) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -1449,5 +1416,60 @@ var _UserProfileViewNeedsRegeneration = UserProfile(struct { LoginName string DisplayName string ProfilePicURL string - Roles emptyStructJSONSlice +}{}) + +// View returns a read-only view of VIPService. +func (p *VIPService) View() VIPServiceView { + return VIPServiceView{Đļ: p} +} + +// VIPServiceView provides a read-only view over VIPService. +// +// Its methods should only be called if `Valid()` returns true. +type VIPServiceView struct { + // Đļ is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + Đļ *VIPService +} + +// Valid reports whether v's underlying value is non-nil. +func (v VIPServiceView) Valid() bool { return v.Đļ != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v VIPServiceView) AsStruct() *VIPService { + if v.Đļ == nil { + return nil + } + return v.Đļ.Clone() +} + +func (v VIPServiceView) MarshalJSON() ([]byte, error) { return json.Marshal(v.Đļ) } + +func (v *VIPServiceView) UnmarshalJSON(b []byte) error { + if v.Đļ != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x VIPService + if err := json.Unmarshal(b, &x); err != nil { + return err + } + v.Đļ = &x + return nil +} + +func (v VIPServiceView) Name() ServiceName { return v.Đļ.Name } +func (v VIPServiceView) Ports() views.Slice[ProtoPortRange] { return views.SliceOf(v.Đļ.Ports) } +func (v VIPServiceView) Active() bool { return v.Đļ.Active } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _VIPServiceViewNeedsRegeneration = VIPService(struct { + Name ServiceName + Ports []ProtoPortRange + Active bool }{}) diff --git a/taildrop/send.go b/taildrop/send.go deleted file mode 100644 index 0dff71b2467e2..0000000000000 --- a/taildrop/send.go +++ /dev/null @@ -1,252 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package taildrop - -import ( - "crypto/sha256" - "errors" - "io" - "os" - "path/filepath" - "sync" - "time" - - "tailscale.com/envknob" - "tailscale.com/ipn" - "tailscale.com/tstime" - "tailscale.com/version/distro" -) - -type incomingFileKey struct { - id ClientID - name string // e.g., "foo.jpeg" -} - -type incomingFile struct { - clock tstime.DefaultClock - - started time.Time - size int64 // or -1 if unknown; never 0 - w io.Writer // underlying writer - sendFileNotify func() // called when done - partialPath string // non-empty in direct mode - finalPath string // not used in direct mode - - mu sync.Mutex - copied int64 - done bool - lastNotify time.Time -} - -func (f *incomingFile) Write(p []byte) (n int, err error) { - n, err = f.w.Write(p) - - var needNotify bool - defer func() { - if needNotify { - f.sendFileNotify() - } - }() - if n > 0 { - f.mu.Lock() - defer f.mu.Unlock() - f.copied += int64(n) - now := f.clock.Now() - if f.lastNotify.IsZero() || now.Sub(f.lastNotify) > time.Second { - f.lastNotify = now - needNotify = true - } - } - return n, err -} - -// PutFile stores a file into [Manager.Dir] from a given client id. -// The baseName must be a base filename without any slashes. -// The length is the expected length of content to read from r, -// it may be negative to indicate that it is unknown. -// It returns the length of the entire file. -// -// If there is a failure reading from r, then the partial file is not deleted -// for some period of time. The [Manager.PartialFiles] and [Manager.HashPartialFile] -// methods may be used to list all partial files and to compute the hash for a -// specific partial file. This allows the client to determine whether to resume -// a partial file. While resuming, PutFile may be called again with a non-zero -// offset to specify where to resume receiving data at. -func (m *Manager) PutFile(id ClientID, baseName string, r io.Reader, offset, length int64) (int64, error) { - switch { - case m == nil || m.opts.Dir == "": - return 0, ErrNoTaildrop - case !envknob.CanTaildrop(): - return 0, ErrNoTaildrop - case distro.Get() == distro.Unraid && !m.opts.DirectFileMode: - return 0, ErrNotAccessible - } - dstPath, err := joinDir(m.opts.Dir, baseName) - if err != nil { - return 0, err - } - - redactAndLogError := func(action string, err error) error { - err = redactError(err) - m.opts.Logf("put %v error: %v", action, err) - return err - } - - // Check whether there is an in-progress transfer for the file. - partialPath := dstPath + id.partialSuffix() - inFileKey := incomingFileKey{id, baseName} - inFile, loaded := m.incomingFiles.LoadOrInit(inFileKey, func() *incomingFile { - inFile := &incomingFile{ - clock: m.opts.Clock, - started: m.opts.Clock.Now(), - size: length, - sendFileNotify: m.opts.SendFileNotify, - } - if m.opts.DirectFileMode { - inFile.partialPath = partialPath - inFile.finalPath = dstPath - } - return inFile - }) - if loaded { - return 0, ErrFileExists - } - defer m.incomingFiles.Delete(inFileKey) - m.deleter.Remove(filepath.Base(partialPath)) // avoid deleting the partial file while receiving - - // Create (if not already) the partial file with read-write permissions. - f, err := os.OpenFile(partialPath, os.O_CREATE|os.O_RDWR, 0666) - if err != nil { - return 0, redactAndLogError("Create", err) - } - defer func() { - f.Close() // best-effort to cleanup dangling file handles - if err != nil { - m.deleter.Insert(filepath.Base(partialPath)) // mark partial file for eventual deletion - } - }() - inFile.w = f - - // Record that we have started to receive at least one file. - // This is used by the deleter upon a cold-start to scan the directory - // for any files that need to be deleted. - if m.opts.State != nil { - if b, _ := m.opts.State.ReadState(ipn.TaildropReceivedKey); len(b) == 0 { - if err := m.opts.State.WriteState(ipn.TaildropReceivedKey, []byte{1}); err != nil { - m.opts.Logf("WriteState error: %v", err) // non-fatal error - } - } - } - - // A positive offset implies that we are resuming an existing file. - // Seek to the appropriate offset and truncate the file. - if offset != 0 { - currLength, err := f.Seek(0, io.SeekEnd) - if err != nil { - return 0, redactAndLogError("Seek", err) - } - if offset < 0 || offset > currLength { - return 0, redactAndLogError("Seek", err) - } - if _, err := f.Seek(offset, io.SeekStart); err != nil { - return 0, redactAndLogError("Seek", err) - } - if err := f.Truncate(offset); err != nil { - return 0, redactAndLogError("Truncate", err) - } - } - - // Copy the contents of the file. - copyLength, err := io.Copy(inFile, r) - if err != nil { - return 0, redactAndLogError("Copy", err) - } - if length >= 0 && copyLength != length { - return 0, redactAndLogError("Copy", errors.New("copied an unexpected number of bytes")) - } - if err := f.Close(); err != nil { - return 0, redactAndLogError("Close", err) - } - fileLength := offset + copyLength - - inFile.mu.Lock() - inFile.done = true - inFile.mu.Unlock() - - // File has been successfully received, rename the partial file - // to the final destination filename. If a file of that name already exists, - // then try multiple times with variations of the filename. - computePartialSum := sync.OnceValues(func() ([sha256.Size]byte, error) { - return sha256File(partialPath) - }) - maxRetries := 10 - for ; maxRetries > 0; maxRetries-- { - // Atomically rename the partial file as the destination file if it doesn't exist. - // Otherwise, it returns the length of the current destination file. - // The operation is atomic. - dstLength, err := func() (int64, error) { - m.renameMu.Lock() - defer m.renameMu.Unlock() - switch fi, err := os.Stat(dstPath); { - case os.IsNotExist(err): - return -1, os.Rename(partialPath, dstPath) - case err != nil: - return -1, err - default: - return fi.Size(), nil - } - }() - if err != nil { - return 0, redactAndLogError("Rename", err) - } - if dstLength < 0 { - break // we successfully renamed; so stop - } - - // Avoid the final rename if a destination file has the same contents. - // - // Note: this is best effort and copying files from iOS from the Media Library - // results in processing on the iOS side which means the size and shas of the - // same file can be different. - if dstLength == fileLength { - partialSum, err := computePartialSum() - if err != nil { - return 0, redactAndLogError("Rename", err) - } - dstSum, err := sha256File(dstPath) - if err != nil { - return 0, redactAndLogError("Rename", err) - } - if dstSum == partialSum { - if err := os.Remove(partialPath); err != nil { - return 0, redactAndLogError("Remove", err) - } - break // we successfully found a content match; so stop - } - } - - // Choose a new destination filename and try again. - dstPath = NextFilename(dstPath) - inFile.finalPath = dstPath - } - if maxRetries <= 0 { - return 0, errors.New("too many retries trying to rename partial file") - } - m.totalReceived.Add(1) - m.opts.SendFileNotify() - return fileLength, nil -} - -func sha256File(file string) (out [sha256.Size]byte, err error) { - h := sha256.New() - f, err := os.Open(file) - if err != nil { - return out, err - } - defer f.Close() - if _, err := io.Copy(h, f); err != nil { - return out, err - } - return [sha256.Size]byte(h.Sum(nil)), nil -} diff --git a/tempfork/acme/README.md b/tempfork/acme/README.md new file mode 100644 index 0000000000000..def357fc1e500 --- /dev/null +++ b/tempfork/acme/README.md @@ -0,0 +1,14 @@ +# tempfork/acme + +This is a vendored copy of Tailscale's https://github.com/tailscale/golang-x-crypto, +which is a fork of golang.org/x/crypto/acme. + +See https://github.com/tailscale/tailscale/issues/10238 for unforking +status. + +The https://github.com/tailscale/golang-x-crypto location exists to +let us do rebases from upstream easily, and then we update tempfork/acme +in the same commit we go get github.com/tailscale/golang-x-crypto@main. +See the comment on the TestSyncedToUpstream test for details. That +test should catch that forgotten step. + diff --git a/tempfork/acme/acme.go b/tempfork/acme/acme.go new file mode 100644 index 0000000000000..bbddb95519ed6 --- /dev/null +++ b/tempfork/acme/acme.go @@ -0,0 +1,866 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package acme provides an implementation of the +// Automatic Certificate Management Environment (ACME) spec, +// most famously used by Let's Encrypt. +// +// The initial implementation of this package was based on an early version +// of the spec. The current implementation supports only the modern +// RFC 8555 but some of the old API surface remains for compatibility. +// While code using the old API will still compile, it will return an error. +// Note the deprecation comments to update your code. +// +// See https://tools.ietf.org/html/rfc8555 for the spec. +// +// Most common scenarios will want to use autocert subdirectory instead, +// which provides automatic access to certificates from Let's Encrypt +// and any other ACME-based CA. +package acme + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/base64" + "encoding/hex" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "math/big" + "net/http" + "strings" + "sync" + "time" +) + +const ( + // LetsEncryptURL is the Directory endpoint of Let's Encrypt CA. + LetsEncryptURL = "https://acme-v02.api.letsencrypt.org/directory" + + // ALPNProto is the ALPN protocol name used by a CA server when validating + // tls-alpn-01 challenges. + // + // Package users must ensure their servers can negotiate the ACME ALPN in + // order for tls-alpn-01 challenge verifications to succeed. + // See the crypto/tls package's Config.NextProtos field. + ALPNProto = "acme-tls/1" +) + +// idPeACMEIdentifier is the OID for the ACME extension for the TLS-ALPN challenge. +// https://tools.ietf.org/html/draft-ietf-acme-tls-alpn-05#section-5.1 +var idPeACMEIdentifier = asn1.ObjectIdentifier{1, 3, 6, 1, 5, 5, 7, 1, 31} + +const ( + maxChainLen = 5 // max depth and breadth of a certificate chain + maxCertSize = 1 << 20 // max size of a certificate, in DER bytes + // Used for decoding certs from application/pem-certificate-chain response, + // the default when in RFC mode. + maxCertChainSize = maxCertSize * maxChainLen + + // Max number of collected nonces kept in memory. + // Expect usual peak of 1 or 2. + maxNonces = 100 +) + +// Client is an ACME client. +// +// The only required field is Key. An example of creating a client with a new key +// is as follows: +// +// key, err := rsa.GenerateKey(rand.Reader, 2048) +// if err != nil { +// log.Fatal(err) +// } +// client := &Client{Key: key} +type Client struct { + // Key is the account key used to register with a CA and sign requests. + // Key.Public() must return a *rsa.PublicKey or *ecdsa.PublicKey. + // + // The following algorithms are supported: + // RS256, ES256, ES384 and ES512. + // See RFC 7518 for more details about the algorithms. + Key crypto.Signer + + // HTTPClient optionally specifies an HTTP client to use + // instead of http.DefaultClient. + HTTPClient *http.Client + + // DirectoryURL points to the CA directory endpoint. + // If empty, LetsEncryptURL is used. + // Mutating this value after a successful call of Client's Discover method + // will have no effect. + DirectoryURL string + + // RetryBackoff computes the duration after which the nth retry of a failed request + // should occur. The value of n for the first call on failure is 1. + // The values of r and resp are the request and response of the last failed attempt. + // If the returned value is negative or zero, no more retries are done and an error + // is returned to the caller of the original method. + // + // Requests which result in a 4xx client error are not retried, + // except for 400 Bad Request due to "bad nonce" errors and 429 Too Many Requests. + // + // If RetryBackoff is nil, a truncated exponential backoff algorithm + // with the ceiling of 10 seconds is used, where each subsequent retry n + // is done after either ("Retry-After" + jitter) or (2^n seconds + jitter), + // preferring the former if "Retry-After" header is found in the resp. + // The jitter is a random value up to 1 second. + RetryBackoff func(n int, r *http.Request, resp *http.Response) time.Duration + + // UserAgent is prepended to the User-Agent header sent to the ACME server, + // which by default is this package's name and version. + // + // Reusable libraries and tools in particular should set this value to be + // identifiable by the server, in case they are causing issues. + UserAgent string + + cacheMu sync.Mutex + dir *Directory // cached result of Client's Discover method + // KID is the key identifier provided by the CA. If not provided it will be + // retrieved from the CA by making a call to the registration endpoint. + KID KeyID + + noncesMu sync.Mutex + nonces map[string]struct{} // nonces collected from previous responses +} + +// accountKID returns a key ID associated with c.Key, the account identity +// provided by the CA during RFC based registration. +// It assumes c.Discover has already been called. +// +// accountKID requires at most one network roundtrip. +// It caches only successful result. +// +// When in pre-RFC mode or when c.getRegRFC responds with an error, accountKID +// returns noKeyID. +func (c *Client) accountKID(ctx context.Context) KeyID { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + if c.KID != noKeyID { + return c.KID + } + a, err := c.getRegRFC(ctx) + if err != nil { + return noKeyID + } + c.KID = KeyID(a.URI) + return c.KID +} + +var errPreRFC = errors.New("acme: server does not support the RFC 8555 version of ACME") + +// Discover performs ACME server discovery using c.DirectoryURL. +// +// It caches successful result. So, subsequent calls will not result in +// a network round-trip. This also means mutating c.DirectoryURL after successful call +// of this method will have no effect. +func (c *Client) Discover(ctx context.Context) (Directory, error) { + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + if c.dir != nil { + return *c.dir, nil + } + + res, err := c.get(ctx, c.directoryURL(), wantStatus(http.StatusOK)) + if err != nil { + return Directory{}, err + } + defer res.Body.Close() + c.addNonce(res.Header) + + var v struct { + Reg string `json:"newAccount"` + Authz string `json:"newAuthz"` + Order string `json:"newOrder"` + Revoke string `json:"revokeCert"` + Nonce string `json:"newNonce"` + KeyChange string `json:"keyChange"` + RenewalInfo string `json:"renewalInfo"` + Meta struct { + Terms string `json:"termsOfService"` + Website string `json:"website"` + CAA []string `json:"caaIdentities"` + ExternalAcct bool `json:"externalAccountRequired"` + } + } + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return Directory{}, err + } + if v.Order == "" { + return Directory{}, errPreRFC + } + c.dir = &Directory{ + RegURL: v.Reg, + AuthzURL: v.Authz, + OrderURL: v.Order, + RevokeURL: v.Revoke, + NonceURL: v.Nonce, + KeyChangeURL: v.KeyChange, + RenewalInfoURL: v.RenewalInfo, + Terms: v.Meta.Terms, + Website: v.Meta.Website, + CAA: v.Meta.CAA, + ExternalAccountRequired: v.Meta.ExternalAcct, + } + return *c.dir, nil +} + +func (c *Client) directoryURL() string { + if c.DirectoryURL != "" { + return c.DirectoryURL + } + return LetsEncryptURL +} + +// CreateCert was part of the old version of ACME. It is incompatible with RFC 8555. +// +// Deprecated: this was for the pre-RFC 8555 version of ACME. Callers should use CreateOrderCert. +func (c *Client) CreateCert(ctx context.Context, csr []byte, exp time.Duration, bundle bool) (der [][]byte, certURL string, err error) { + return nil, "", errPreRFC +} + +// FetchCert retrieves already issued certificate from the given url, in DER format. +// It retries the request until the certificate is successfully retrieved, +// context is cancelled by the caller or an error response is received. +// +// If the bundle argument is true, the returned value also contains the CA (issuer) +// certificate chain. +// +// FetchCert returns an error if the CA's response or chain was unreasonably large. +// Callers are encouraged to parse the returned value to ensure the certificate is valid +// and has expected features. +func (c *Client) FetchCert(ctx context.Context, url string, bundle bool) ([][]byte, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + return c.fetchCertRFC(ctx, url, bundle) +} + +// RevokeCert revokes a previously issued certificate cert, provided in DER format. +// +// The key argument, used to sign the request, must be authorized +// to revoke the certificate. It's up to the CA to decide which keys are authorized. +// For instance, the key pair of the certificate may be authorized. +// If the key is nil, c.Key is used instead. +func (c *Client) RevokeCert(ctx context.Context, key crypto.Signer, cert []byte, reason CRLReasonCode) error { + if _, err := c.Discover(ctx); err != nil { + return err + } + return c.revokeCertRFC(ctx, key, cert, reason) +} + +// FetchRenewalInfo retrieves the RenewalInfo from Directory.RenewalInfoURL. +func (c *Client) FetchRenewalInfo(ctx context.Context, leaf []byte) (*RenewalInfo, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + parsedLeaf, err := x509.ParseCertificate(leaf) + if err != nil { + return nil, fmt.Errorf("parsing leaf certificate: %w", err) + } + + renewalURL := c.getRenewalURL(parsedLeaf) + + res, err := c.get(ctx, renewalURL, wantStatus(http.StatusOK)) + if err != nil { + return nil, fmt.Errorf("fetching renewal info: %w", err) + } + defer res.Body.Close() + + var info RenewalInfo + if err := json.NewDecoder(res.Body).Decode(&info); err != nil { + return nil, fmt.Errorf("parsing renewal info response: %w", err) + } + return &info, nil +} + +func (c *Client) getRenewalURL(cert *x509.Certificate) string { + // See https://www.ietf.org/archive/id/draft-ietf-acme-ari-04.html#name-the-renewalinfo-resource + // for how the request URL is built. + url := c.dir.RenewalInfoURL + if !strings.HasSuffix(url, "/") { + url += "/" + } + return url + certRenewalIdentifier(cert) +} + +func certRenewalIdentifier(cert *x509.Certificate) string { + aki := base64.RawURLEncoding.EncodeToString(cert.AuthorityKeyId) + serial := base64.RawURLEncoding.EncodeToString(cert.SerialNumber.Bytes()) + return aki + "." + serial +} + +// AcceptTOS always returns true to indicate the acceptance of a CA's Terms of Service +// during account registration. See Register method of Client for more details. +func AcceptTOS(tosURL string) bool { return true } + +// Register creates a new account with the CA using c.Key. +// It returns the registered account. The account acct is not modified. +// +// The registration may require the caller to agree to the CA's Terms of Service (TOS). +// If so, and the account has not indicated the acceptance of the terms (see Account for details), +// Register calls prompt with a TOS URL provided by the CA. Prompt should report +// whether the caller agrees to the terms. To always accept the terms, the caller can use AcceptTOS. +// +// When interfacing with an RFC-compliant CA, non-RFC 8555 fields of acct are ignored +// and prompt is called if Directory's Terms field is non-zero. +// Also see Error's Instance field for when a CA requires already registered accounts to agree +// to an updated Terms of Service. +func (c *Client) Register(ctx context.Context, acct *Account, prompt func(tosURL string) bool) (*Account, error) { + if c.Key == nil { + return nil, errors.New("acme: client.Key must be set to Register") + } + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + return c.registerRFC(ctx, acct, prompt) +} + +// GetReg retrieves an existing account associated with c.Key. +// +// The url argument is a legacy artifact of the pre-RFC 8555 API +// and is ignored. +func (c *Client) GetReg(ctx context.Context, url string) (*Account, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + return c.getRegRFC(ctx) +} + +// UpdateReg updates an existing registration. +// It returns an updated account copy. The provided account is not modified. +// +// The account's URI is ignored and the account URL associated with +// c.Key is used instead. +func (c *Client) UpdateReg(ctx context.Context, acct *Account) (*Account, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + return c.updateRegRFC(ctx, acct) +} + +// AccountKeyRollover attempts to transition a client's account key to a new key. +// On success client's Key is updated which is not concurrency safe. +// On failure an error will be returned. +// The new key is already registered with the ACME provider if the following is true: +// - error is of type acme.Error +// - StatusCode should be 409 (Conflict) +// - Location header will have the KID of the associated account +// +// More about account key rollover can be found at +// https://tools.ietf.org/html/rfc8555#section-7.3.5. +func (c *Client) AccountKeyRollover(ctx context.Context, newKey crypto.Signer) error { + return c.accountKeyRollover(ctx, newKey) +} + +// Authorize performs the initial step in the pre-authorization flow, +// as opposed to order-based flow. +// The caller will then need to choose from and perform a set of returned +// challenges using c.Accept in order to successfully complete authorization. +// +// Once complete, the caller can use AuthorizeOrder which the CA +// should provision with the already satisfied authorization. +// For pre-RFC CAs, the caller can proceed directly to requesting a certificate +// using CreateCert method. +// +// If an authorization has been previously granted, the CA may return +// a valid authorization which has its Status field set to StatusValid. +// +// More about pre-authorization can be found at +// https://tools.ietf.org/html/rfc8555#section-7.4.1. +func (c *Client) Authorize(ctx context.Context, domain string) (*Authorization, error) { + return c.authorize(ctx, "dns", domain) +} + +// AuthorizeIP is the same as Authorize but requests IP address authorization. +// Clients which successfully obtain such authorization may request to issue +// a certificate for IP addresses. +// +// See the ACME spec extension for more details about IP address identifiers: +// https://tools.ietf.org/html/draft-ietf-acme-ip. +func (c *Client) AuthorizeIP(ctx context.Context, ipaddr string) (*Authorization, error) { + return c.authorize(ctx, "ip", ipaddr) +} + +func (c *Client) authorize(ctx context.Context, typ, val string) (*Authorization, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + type authzID struct { + Type string `json:"type"` + Value string `json:"value"` + } + req := struct { + Resource string `json:"resource"` + Identifier authzID `json:"identifier"` + }{ + Resource: "new-authz", + Identifier: authzID{Type: typ, Value: val}, + } + res, err := c.post(ctx, nil, c.dir.AuthzURL, req, wantStatus(http.StatusCreated)) + if err != nil { + return nil, err + } + defer res.Body.Close() + + var v wireAuthz + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: invalid response: %v", err) + } + if v.Status != StatusPending && v.Status != StatusValid { + return nil, fmt.Errorf("acme: unexpected status: %s", v.Status) + } + return v.authorization(res.Header.Get("Location")), nil +} + +// GetAuthorization retrieves an authorization identified by the given URL. +// +// If a caller needs to poll an authorization until its status is final, +// see the WaitAuthorization method. +func (c *Client) GetAuthorization(ctx context.Context, url string) (*Authorization, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + defer res.Body.Close() + var v wireAuthz + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: invalid response: %v", err) + } + return v.authorization(url), nil +} + +// RevokeAuthorization relinquishes an existing authorization identified +// by the given URL. +// The url argument is an Authorization.URI value. +// +// If successful, the caller will be required to obtain a new authorization +// using the Authorize or AuthorizeOrder methods before being able to request +// a new certificate for the domain associated with the authorization. +// +// It does not revoke existing certificates. +func (c *Client) RevokeAuthorization(ctx context.Context, url string) error { + if _, err := c.Discover(ctx); err != nil { + return err + } + + req := struct { + Resource string `json:"resource"` + Status string `json:"status"` + Delete bool `json:"delete"` + }{ + Resource: "authz", + Status: "deactivated", + Delete: true, + } + res, err := c.post(ctx, nil, url, req, wantStatus(http.StatusOK)) + if err != nil { + return err + } + defer res.Body.Close() + return nil +} + +// WaitAuthorization polls an authorization at the given URL +// until it is in one of the final states, StatusValid or StatusInvalid, +// the ACME CA responded with a 4xx error code, or the context is done. +// +// It returns a non-nil Authorization only if its Status is StatusValid. +// In all other cases WaitAuthorization returns an error. +// If the Status is StatusInvalid, the returned error is of type *AuthorizationError. +func (c *Client) WaitAuthorization(ctx context.Context, url string) (*Authorization, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + for { + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted)) + if err != nil { + return nil, err + } + + var raw wireAuthz + err = json.NewDecoder(res.Body).Decode(&raw) + res.Body.Close() + switch { + case err != nil: + // Skip and retry. + case raw.Status == StatusValid: + return raw.authorization(url), nil + case raw.Status == StatusInvalid: + return nil, raw.error(url) + } + + // Exponential backoff is implemented in c.get above. + // This is just to prevent continuously hitting the CA + // while waiting for a final authorization status. + d := retryAfter(res.Header.Get("Retry-After")) + if d == 0 { + // Given that the fastest challenges TLS-SNI and HTTP-01 + // require a CA to make at least 1 network round trip + // and most likely persist a challenge state, + // this default delay seems reasonable. + d = time.Second + } + t := time.NewTimer(d) + select { + case <-ctx.Done(): + t.Stop() + return nil, ctx.Err() + case <-t.C: + // Retry. + } + } +} + +// GetChallenge retrieves the current status of an challenge. +// +// A client typically polls a challenge status using this method. +func (c *Client) GetChallenge(ctx context.Context, url string) (*Challenge, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK, http.StatusAccepted)) + if err != nil { + return nil, err + } + + defer res.Body.Close() + v := wireChallenge{URI: url} + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: invalid response: %v", err) + } + return v.challenge(), nil +} + +// Accept informs the server that the client accepts one of its challenges +// previously obtained with c.Authorize. +// +// The server will then perform the validation asynchronously. +func (c *Client) Accept(ctx context.Context, chal *Challenge) (*Challenge, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + payload := json.RawMessage("{}") + if len(chal.Payload) != 0 { + payload = chal.Payload + } + res, err := c.post(ctx, nil, chal.URI, payload, wantStatus( + http.StatusOK, // according to the spec + http.StatusAccepted, // Let's Encrypt: see https://goo.gl/WsJ7VT (acme-divergences.md) + )) + if err != nil { + return nil, err + } + defer res.Body.Close() + + var v wireChallenge + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: invalid response: %v", err) + } + return v.challenge(), nil +} + +// DNS01ChallengeRecord returns a DNS record value for a dns-01 challenge response. +// A TXT record containing the returned value must be provisioned under +// "_acme-challenge" name of the domain being validated. +// +// The token argument is a Challenge.Token value. +func (c *Client) DNS01ChallengeRecord(token string) (string, error) { + ka, err := keyAuth(c.Key.Public(), token) + if err != nil { + return "", err + } + b := sha256.Sum256([]byte(ka)) + return base64.RawURLEncoding.EncodeToString(b[:]), nil +} + +// HTTP01ChallengeResponse returns the response for an http-01 challenge. +// Servers should respond with the value to HTTP requests at the URL path +// provided by HTTP01ChallengePath to validate the challenge and prove control +// over a domain name. +// +// The token argument is a Challenge.Token value. +func (c *Client) HTTP01ChallengeResponse(token string) (string, error) { + return keyAuth(c.Key.Public(), token) +} + +// HTTP01ChallengePath returns the URL path at which the response for an http-01 challenge +// should be provided by the servers. +// The response value can be obtained with HTTP01ChallengeResponse. +// +// The token argument is a Challenge.Token value. +func (c *Client) HTTP01ChallengePath(token string) string { + return "/.well-known/acme-challenge/" + token +} + +// TLSSNI01ChallengeCert creates a certificate for TLS-SNI-01 challenge response. +// +// Deprecated: This challenge type is unused in both draft-02 and RFC versions of the ACME spec. +func (c *Client) TLSSNI01ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) { + ka, err := keyAuth(c.Key.Public(), token) + if err != nil { + return tls.Certificate{}, "", err + } + b := sha256.Sum256([]byte(ka)) + h := hex.EncodeToString(b[:]) + name = fmt.Sprintf("%s.%s.acme.invalid", h[:32], h[32:]) + cert, err = tlsChallengeCert([]string{name}, opt) + if err != nil { + return tls.Certificate{}, "", err + } + return cert, name, nil +} + +// TLSSNI02ChallengeCert creates a certificate for TLS-SNI-02 challenge response. +// +// Deprecated: This challenge type is unused in both draft-02 and RFC versions of the ACME spec. +func (c *Client) TLSSNI02ChallengeCert(token string, opt ...CertOption) (cert tls.Certificate, name string, err error) { + b := sha256.Sum256([]byte(token)) + h := hex.EncodeToString(b[:]) + sanA := fmt.Sprintf("%s.%s.token.acme.invalid", h[:32], h[32:]) + + ka, err := keyAuth(c.Key.Public(), token) + if err != nil { + return tls.Certificate{}, "", err + } + b = sha256.Sum256([]byte(ka)) + h = hex.EncodeToString(b[:]) + sanB := fmt.Sprintf("%s.%s.ka.acme.invalid", h[:32], h[32:]) + + cert, err = tlsChallengeCert([]string{sanA, sanB}, opt) + if err != nil { + return tls.Certificate{}, "", err + } + return cert, sanA, nil +} + +// TLSALPN01ChallengeCert creates a certificate for TLS-ALPN-01 challenge response. +// Servers can present the certificate to validate the challenge and prove control +// over a domain name. For more details on TLS-ALPN-01 see +// https://tools.ietf.org/html/draft-shoemaker-acme-tls-alpn-00#section-3 +// +// The token argument is a Challenge.Token value. +// If a WithKey option is provided, its private part signs the returned cert, +// and the public part is used to specify the signee. +// If no WithKey option is provided, a new ECDSA key is generated using P-256 curve. +// +// The returned certificate is valid for the next 24 hours and must be presented only when +// the server name in the TLS ClientHello matches the domain, and the special acme-tls/1 ALPN protocol +// has been specified. +func (c *Client) TLSALPN01ChallengeCert(token, domain string, opt ...CertOption) (cert tls.Certificate, err error) { + ka, err := keyAuth(c.Key.Public(), token) + if err != nil { + return tls.Certificate{}, err + } + shasum := sha256.Sum256([]byte(ka)) + extValue, err := asn1.Marshal(shasum[:]) + if err != nil { + return tls.Certificate{}, err + } + acmeExtension := pkix.Extension{ + Id: idPeACMEIdentifier, + Critical: true, + Value: extValue, + } + + tmpl := defaultTLSChallengeCertTemplate() + + var newOpt []CertOption + for _, o := range opt { + switch o := o.(type) { + case *certOptTemplate: + t := *(*x509.Certificate)(o) // shallow copy is ok + tmpl = &t + default: + newOpt = append(newOpt, o) + } + } + tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, acmeExtension) + newOpt = append(newOpt, WithTemplate(tmpl)) + return tlsChallengeCert([]string{domain}, newOpt) +} + +// popNonce returns a nonce value previously stored with c.addNonce +// or fetches a fresh one from c.dir.NonceURL. +// If NonceURL is empty, it first tries c.directoryURL() and, failing that, +// the provided url. +func (c *Client) popNonce(ctx context.Context, url string) (string, error) { + c.noncesMu.Lock() + defer c.noncesMu.Unlock() + if len(c.nonces) == 0 { + if c.dir != nil && c.dir.NonceURL != "" { + return c.fetchNonce(ctx, c.dir.NonceURL) + } + dirURL := c.directoryURL() + v, err := c.fetchNonce(ctx, dirURL) + if err != nil && url != dirURL { + v, err = c.fetchNonce(ctx, url) + } + return v, err + } + var nonce string + for nonce = range c.nonces { + delete(c.nonces, nonce) + break + } + return nonce, nil +} + +// clearNonces clears any stored nonces +func (c *Client) clearNonces() { + c.noncesMu.Lock() + defer c.noncesMu.Unlock() + c.nonces = make(map[string]struct{}) +} + +// addNonce stores a nonce value found in h (if any) for future use. +func (c *Client) addNonce(h http.Header) { + v := nonceFromHeader(h) + if v == "" { + return + } + c.noncesMu.Lock() + defer c.noncesMu.Unlock() + if len(c.nonces) >= maxNonces { + return + } + if c.nonces == nil { + c.nonces = make(map[string]struct{}) + } + c.nonces[v] = struct{}{} +} + +func (c *Client) fetchNonce(ctx context.Context, url string) (string, error) { + r, err := http.NewRequest("HEAD", url, nil) + if err != nil { + return "", err + } + resp, err := c.doNoRetry(ctx, r) + if err != nil { + return "", err + } + defer resp.Body.Close() + nonce := nonceFromHeader(resp.Header) + if nonce == "" { + if resp.StatusCode > 299 { + return "", responseError(resp) + } + return "", errors.New("acme: nonce not found") + } + return nonce, nil +} + +func nonceFromHeader(h http.Header) string { + return h.Get("Replay-Nonce") +} + +// linkHeader returns URI-Reference values of all Link headers +// with relation-type rel. +// See https://tools.ietf.org/html/rfc5988#section-5 for details. +func linkHeader(h http.Header, rel string) []string { + var links []string + for _, v := range h["Link"] { + parts := strings.Split(v, ";") + for _, p := range parts { + p = strings.TrimSpace(p) + if !strings.HasPrefix(p, "rel=") { + continue + } + if v := strings.Trim(p[4:], `"`); v == rel { + links = append(links, strings.Trim(parts[0], "<>")) + } + } + } + return links +} + +// keyAuth generates a key authorization string for a given token. +func keyAuth(pub crypto.PublicKey, token string) (string, error) { + th, err := JWKThumbprint(pub) + if err != nil { + return "", err + } + return fmt.Sprintf("%s.%s", token, th), nil +} + +// defaultTLSChallengeCertTemplate is a template used to create challenge certs for TLS challenges. +func defaultTLSChallengeCertTemplate() *x509.Certificate { + return &x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } +} + +// tlsChallengeCert creates a temporary certificate for TLS-SNI challenges +// with the given SANs and auto-generated public/private key pair. +// The Subject Common Name is set to the first SAN to aid debugging. +// To create a cert with a custom key pair, specify WithKey option. +func tlsChallengeCert(san []string, opt []CertOption) (tls.Certificate, error) { + var key crypto.Signer + tmpl := defaultTLSChallengeCertTemplate() + for _, o := range opt { + switch o := o.(type) { + case *certOptKey: + if key != nil { + return tls.Certificate{}, errors.New("acme: duplicate key option") + } + key = o.key + case *certOptTemplate: + t := *(*x509.Certificate)(o) // shallow copy is ok + tmpl = &t + default: + // package's fault, if we let this happen: + panic(fmt.Sprintf("unsupported option type %T", o)) + } + } + if key == nil { + var err error + if key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader); err != nil { + return tls.Certificate{}, err + } + } + tmpl.DNSNames = san + if len(san) > 0 { + tmpl.Subject.CommonName = san[0] + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, key.Public(), key) + if err != nil { + return tls.Certificate{}, err + } + return tls.Certificate{ + Certificate: [][]byte{der}, + PrivateKey: key, + }, nil +} + +// encodePEM returns b encoded as PEM with block of type typ. +func encodePEM(typ string, b []byte) []byte { + pb := &pem.Block{Type: typ, Bytes: b} + return pem.EncodeToMemory(pb) +} + +// timeNow is time.Now, except in tests which can mess with it. +var timeNow = time.Now diff --git a/tempfork/acme/acme_test.go b/tempfork/acme/acme_test.go new file mode 100644 index 0000000000000..f0c45aea9ab03 --- /dev/null +++ b/tempfork/acme/acme_test.go @@ -0,0 +1,970 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/hex" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "math/big" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "sort" + "strings" + "testing" + "time" +) + +// newTestClient creates a client with a non-nil Directory so that it skips +// the discovery which is otherwise done on the first call of almost every +// exported method. +func newTestClient() *Client { + return &Client{ + Key: testKeyEC, + dir: &Directory{}, // skip discovery + } +} + +// newTestClientWithMockDirectory creates a client with a non-nil Directory +// that contains mock field values. +func newTestClientWithMockDirectory() *Client { + return &Client{ + Key: testKeyEC, + dir: &Directory{ + RenewalInfoURL: "https://example.com/acme/renewal-info/", + }, + } +} + +// Decodes a JWS-encoded request and unmarshals the decoded JSON into a provided +// interface. +func decodeJWSRequest(t *testing.T, v interface{}, r io.Reader) { + // Decode request + var req struct{ Payload string } + if err := json.NewDecoder(r).Decode(&req); err != nil { + t.Fatal(err) + } + payload, err := base64.RawURLEncoding.DecodeString(req.Payload) + if err != nil { + t.Fatal(err) + } + err = json.Unmarshal(payload, v) + if err != nil { + t.Fatal(err) + } +} + +type jwsHead struct { + Alg string + Nonce string + URL string `json:"url"` + KID string `json:"kid"` + JWK map[string]string `json:"jwk"` +} + +func decodeJWSHead(r io.Reader) (*jwsHead, error) { + var req struct{ Protected string } + if err := json.NewDecoder(r).Decode(&req); err != nil { + return nil, err + } + b, err := base64.RawURLEncoding.DecodeString(req.Protected) + if err != nil { + return nil, err + } + var head jwsHead + if err := json.Unmarshal(b, &head); err != nil { + return nil, err + } + return &head, nil +} + +func TestRegisterWithoutKey(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Replay-Nonce", "test-nonce") + return + } + w.WriteHeader(http.StatusCreated) + fmt.Fprint(w, `{}`) + })) + defer ts.Close() + // First verify that using a complete client results in success. + c := Client{ + Key: testKeyEC, + DirectoryURL: ts.URL, + dir: &Directory{RegURL: ts.URL}, + } + if _, err := c.Register(context.Background(), &Account{}, AcceptTOS); err != nil { + t.Fatalf("c.Register() = %v; want success with a complete test client", err) + } + c.Key = nil + if _, err := c.Register(context.Background(), &Account{}, AcceptTOS); err == nil { + t.Error("c.Register() from client without key succeeded, wanted error") + } +} + +func TestAuthorize(t *testing.T) { + tt := []struct{ typ, value string }{ + {"dns", "example.com"}, + {"ip", "1.2.3.4"}, + } + for _, test := range tt { + t.Run(test.typ, func(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Replay-Nonce", "test-nonce") + return + } + if r.Method != "POST" { + t.Errorf("r.Method = %q; want POST", r.Method) + } + + var j struct { + Resource string + Identifier struct { + Type string + Value string + } + } + decodeJWSRequest(t, &j, r.Body) + + // Test request + if j.Resource != "new-authz" { + t.Errorf("j.Resource = %q; want new-authz", j.Resource) + } + if j.Identifier.Type != test.typ { + t.Errorf("j.Identifier.Type = %q; want %q", j.Identifier.Type, test.typ) + } + if j.Identifier.Value != test.value { + t.Errorf("j.Identifier.Value = %q; want %q", j.Identifier.Value, test.value) + } + + w.Header().Set("Location", "https://ca.tld/acme/auth/1") + w.WriteHeader(http.StatusCreated) + fmt.Fprintf(w, `{ + "identifier": {"type":%q,"value":%q}, + "status":"pending", + "challenges":[ + { + "type":"http-01", + "status":"pending", + "uri":"https://ca.tld/acme/challenge/publickey/id1", + "token":"token1" + }, + { + "type":"tls-sni-01", + "status":"pending", + "uri":"https://ca.tld/acme/challenge/publickey/id2", + "token":"token2" + } + ], + "combinations":[[0],[1]] + }`, test.typ, test.value) + })) + defer ts.Close() + + var ( + auth *Authorization + err error + ) + cl := Client{ + Key: testKeyEC, + DirectoryURL: ts.URL, + dir: &Directory{AuthzURL: ts.URL}, + } + switch test.typ { + case "dns": + auth, err = cl.Authorize(context.Background(), test.value) + case "ip": + auth, err = cl.AuthorizeIP(context.Background(), test.value) + default: + t.Fatalf("unknown identifier type: %q", test.typ) + } + if err != nil { + t.Fatal(err) + } + + if auth.URI != "https://ca.tld/acme/auth/1" { + t.Errorf("URI = %q; want https://ca.tld/acme/auth/1", auth.URI) + } + if auth.Status != "pending" { + t.Errorf("Status = %q; want pending", auth.Status) + } + if auth.Identifier.Type != test.typ { + t.Errorf("Identifier.Type = %q; want %q", auth.Identifier.Type, test.typ) + } + if auth.Identifier.Value != test.value { + t.Errorf("Identifier.Value = %q; want %q", auth.Identifier.Value, test.value) + } + + if n := len(auth.Challenges); n != 2 { + t.Fatalf("len(auth.Challenges) = %d; want 2", n) + } + + c := auth.Challenges[0] + if c.Type != "http-01" { + t.Errorf("c.Type = %q; want http-01", c.Type) + } + if c.URI != "https://ca.tld/acme/challenge/publickey/id1" { + t.Errorf("c.URI = %q; want https://ca.tld/acme/challenge/publickey/id1", c.URI) + } + if c.Token != "token1" { + t.Errorf("c.Token = %q; want token1", c.Token) + } + + c = auth.Challenges[1] + if c.Type != "tls-sni-01" { + t.Errorf("c.Type = %q; want tls-sni-01", c.Type) + } + if c.URI != "https://ca.tld/acme/challenge/publickey/id2" { + t.Errorf("c.URI = %q; want https://ca.tld/acme/challenge/publickey/id2", c.URI) + } + if c.Token != "token2" { + t.Errorf("c.Token = %q; want token2", c.Token) + } + + combs := [][]int{{0}, {1}} + if !reflect.DeepEqual(auth.Combinations, combs) { + t.Errorf("auth.Combinations: %+v\nwant: %+v\n", auth.Combinations, combs) + } + + }) + } +} + +func TestAuthorizeValid(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Replay-Nonce", "nonce") + return + } + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{"status":"valid"}`)) + })) + defer ts.Close() + client := Client{ + Key: testKey, + DirectoryURL: ts.URL, + dir: &Directory{AuthzURL: ts.URL}, + } + _, err := client.Authorize(context.Background(), "example.com") + if err != nil { + t.Errorf("err = %v", err) + } +} + +func TestWaitAuthorization(t *testing.T) { + t.Run("wait loop", func(t *testing.T) { + var count int + authz, err := runWaitAuthorization(context.Background(), t, func(w http.ResponseWriter, r *http.Request) { + count++ + w.Header().Set("Retry-After", "0") + if count > 1 { + fmt.Fprintf(w, `{"status":"valid"}`) + return + } + fmt.Fprintf(w, `{"status":"pending"}`) + }) + if err != nil { + t.Fatalf("non-nil error: %v", err) + } + if authz == nil { + t.Fatal("authz is nil") + } + }) + t.Run("invalid status", func(t *testing.T) { + _, err := runWaitAuthorization(context.Background(), t, func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, `{"status":"invalid"}`) + }) + if _, ok := err.(*AuthorizationError); !ok { + t.Errorf("err is %v (%T); want non-nil *AuthorizationError", err, err) + } + }) + t.Run("invalid status with error returns the authorization error", func(t *testing.T) { + _, err := runWaitAuthorization(context.Background(), t, func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, `{ + "type": "dns-01", + "status": "invalid", + "error": { + "type": "urn:ietf:params:acme:error:caa", + "detail": "CAA record for prevents issuance", + "status": 403 + }, + "url": "https://acme-v02.api.letsencrypt.org/acme/chall-v3/xxx/xxx", + "token": "xxx", + "validationRecord": [ + { + "hostname": "" + } + ] + }`) + }) + + want := &AuthorizationError{ + Errors: []error{ + (&wireError{ + Status: 403, + Type: "urn:ietf:params:acme:error:caa", + Detail: "CAA record for prevents issuance", + }).error(nil), + }, + } + + _, ok := err.(*AuthorizationError) + if !ok { + t.Errorf("err is %T; want non-nil *AuthorizationError", err) + } + + if err.Error() != want.Error() { + t.Errorf("err is %v; want %v", err, want) + } + }) + t.Run("non-retriable error", func(t *testing.T) { + const code = http.StatusBadRequest + _, err := runWaitAuthorization(context.Background(), t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(code) + }) + res, ok := err.(*Error) + if !ok { + t.Fatalf("err is %v (%T); want a non-nil *Error", err, err) + } + if res.StatusCode != code { + t.Errorf("res.StatusCode = %d; want %d", res.StatusCode, code) + } + }) + for _, code := range []int{http.StatusTooManyRequests, http.StatusInternalServerError} { + t.Run(fmt.Sprintf("retriable %d error", code), func(t *testing.T) { + var count int + authz, err := runWaitAuthorization(context.Background(), t, func(w http.ResponseWriter, r *http.Request) { + count++ + w.Header().Set("Retry-After", "0") + if count > 1 { + fmt.Fprintf(w, `{"status":"valid"}`) + return + } + w.WriteHeader(code) + }) + if err != nil { + t.Fatalf("non-nil error: %v", err) + } + if authz == nil { + t.Fatal("authz is nil") + } + }) + } + t.Run("context cancel", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := runWaitAuthorization(ctx, t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Retry-After", "60") + fmt.Fprintf(w, `{"status":"pending"}`) + time.AfterFunc(1*time.Millisecond, cancel) + }) + if err == nil { + t.Error("err is nil") + } + }) +} + +func runWaitAuthorization(ctx context.Context, t *testing.T, h http.HandlerFunc) (*Authorization, error) { + t.Helper() + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Replay-Nonce", fmt.Sprintf("bad-test-nonce-%v", time.Now().UnixNano())) + h(w, r) + })) + defer ts.Close() + + client := &Client{ + Key: testKey, + DirectoryURL: ts.URL, + dir: &Directory{}, + KID: "some-key-id", // set to avoid lookup attempt + } + return client.WaitAuthorization(ctx, ts.URL) +} + +func TestRevokeAuthorization(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "HEAD" { + w.Header().Set("Replay-Nonce", "nonce") + return + } + switch r.URL.Path { + case "/1": + var req struct { + Resource string + Status string + Delete bool + } + decodeJWSRequest(t, &req, r.Body) + if req.Resource != "authz" { + t.Errorf("req.Resource = %q; want authz", req.Resource) + } + if req.Status != "deactivated" { + t.Errorf("req.Status = %q; want deactivated", req.Status) + } + if !req.Delete { + t.Errorf("req.Delete is false") + } + case "/2": + w.WriteHeader(http.StatusBadRequest) + } + })) + defer ts.Close() + client := &Client{ + Key: testKey, + DirectoryURL: ts.URL, // don't dial outside of localhost + dir: &Directory{}, // don't do discovery + } + ctx := context.Background() + if err := client.RevokeAuthorization(ctx, ts.URL+"/1"); err != nil { + t.Errorf("err = %v", err) + } + if client.RevokeAuthorization(ctx, ts.URL+"/2") == nil { + t.Error("nil error") + } +} + +func TestFetchCertCancel(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + <-r.Context().Done() + w.Header().Set("Retry-After", "0") + w.WriteHeader(http.StatusBadRequest) + })) + defer ts.Close() + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + var err error + go func() { + cl := newTestClient() + _, err = cl.FetchCert(ctx, ts.URL, false) + close(done) + }() + cancel() + <-done + if err != context.Canceled { + t.Errorf("err = %v; want %v", err, context.Canceled) + } +} + +func TestFetchCertDepth(t *testing.T) { + var count byte + var ts *httptest.Server + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count++ + if count > maxChainLen+1 { + t.Errorf("count = %d; want at most %d", count, maxChainLen+1) + w.WriteHeader(http.StatusInternalServerError) + } + w.Header().Set("Link", fmt.Sprintf("<%s>;rel=up", ts.URL)) + w.Write([]byte{count}) + })) + defer ts.Close() + cl := newTestClient() + _, err := cl.FetchCert(context.Background(), ts.URL, true) + if err == nil { + t.Errorf("err is nil") + } +} + +func TestFetchCertBreadth(t *testing.T) { + var ts *httptest.Server + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for i := 0; i < maxChainLen+1; i++ { + w.Header().Add("Link", fmt.Sprintf("<%s>;rel=up", ts.URL)) + } + w.Write([]byte{1}) + })) + defer ts.Close() + cl := newTestClient() + _, err := cl.FetchCert(context.Background(), ts.URL, true) + if err == nil { + t.Errorf("err is nil") + } +} + +func TestFetchCertSize(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + b := bytes.Repeat([]byte{1}, maxCertSize+1) + w.Write(b) + })) + defer ts.Close() + cl := newTestClient() + _, err := cl.FetchCert(context.Background(), ts.URL, false) + if err == nil { + t.Errorf("err is nil") + } +} + +const ( + leafPEM = `-----BEGIN CERTIFICATE----- +MIIEizCCAvOgAwIBAgIRAITApw7R8HSs7GU7cj8dEyUwDQYJKoZIhvcNAQELBQAw +gYUxHjAcBgNVBAoTFW1rY2VydCBkZXZlbG9wbWVudCBDQTEtMCsGA1UECwwkY3Bh +bG1lckBwdW1wa2luLmxvY2FsIChDaHJpcyBQYWxtZXIpMTQwMgYDVQQDDCtta2Nl +cnQgY3BhbG1lckBwdW1wa2luLmxvY2FsIChDaHJpcyBQYWxtZXIpMB4XDTIzMDcx +MjE4MjIxNloXDTI1MTAxMjE4MjIxNlowWDEnMCUGA1UEChMebWtjZXJ0IGRldmVs +b3BtZW50IGNlcnRpZmljYXRlMS0wKwYDVQQLDCRjcGFsbWVyQHB1bXBraW4ubG9j +YWwgKENocmlzIFBhbG1lcikwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIB +AQDNDO8P4MI9jaqVcPtF8C4GgHnTP5EK3U9fgyGApKGxTpicMQkA6z4GXwUP/Fvq +7RuCU9Wg7By5VetKIHF7FxkxWkUMrssr7mV8v6mRCh/a5GqDs14aj5ucjLQAJV74 +tLAdrCiijQ1fkPWc82fob+LkfKWGCWw7Cxf6ZtEyC8jz/DnfQXUvOiZS729ndGF7 +FobKRfIoirD+GI2NTYIp3LAUFSPR6HXTe7HAg8J81VoUKli8z504+FebfMmHePm/ +zIfiI0njAj4czOlZD56/oLsV0WRUizFjafHHUFz1HVdfFw8Qf9IOOTydYOe8M5i0 +lVbVO5G+HP+JDn3cr9MT41B9AgMBAAGjgaEwgZ4wDgYDVR0PAQH/BAQDAgWgMBMG +A1UdJQQMMAoGCCsGAQUFBwMBMB8GA1UdIwQYMBaAFPpL4Q0O7Z7voTkjn2rrFCsf +s8TbMFYGA1UdEQRPME2CC2V4YW1wbGUuY29tgg0qLmV4YW1wbGUuY29tggxleGFt +cGxlLnRlc3SCCWxvY2FsaG9zdIcEfwAAAYcQAAAAAAAAAAAAAAAAAAAAATANBgkq +hkiG9w0BAQsFAAOCAYEAMlOb7lrHuSxwcnAu7mL1ysTGqKn1d2TyDJAN5W8YFY+4 +XLpofNkK2UzZ0t9LQRnuFUcjmfqmfplh5lpC7pKmtL4G5Qcdc+BczQWcopbxd728 +sht9BKRkH+Bo1I+1WayKKNXW+5bsMv4CH641zxaMBlzjEnPvwKkNaGLMH3x5lIeX +GGgkKNXwVtINmyV+lTNVtu2IlHprxJGCjRfEuX7mEv6uRnqz3Wif+vgyh3MBgM/1 +dUOsTBNH4a6Jl/9VPSOfRdQOStqIlwTa/J1bhTvivsYt1+eWjLnsQJLgZQqwKvYH +BJ30gAk1oNnuSkx9dHbx4mO+4mB9oIYUALXUYakb8JHTOnuMSj9qelVj5vjVxl9q +KRitptU+kLYRA4HSgUXrhDIm4Q6D/w8/ascPqQ3HxPIDFLe+gTofEjqnnsnQB29L +gWpI8l5/MtXAOMdW69eEovnADc2pgaiif0T+v9nNKBc5xfDZHnrnqIqVzQEwL5Qv +niQI8IsWD5LcQ1Eg7kCq +-----END CERTIFICATE-----` +) + +func TestGetRenewalURL(t *testing.T) { + leaf, _ := pem.Decode([]byte(leafPEM)) + + parsedLeaf, err := x509.ParseCertificate(leaf.Bytes) + if err != nil { + t.Fatal(err) + } + + client := newTestClientWithMockDirectory() + urlString := client.getRenewalURL(parsedLeaf) + + parsedURL, err := url.Parse(urlString) + if err != nil { + t.Fatal(err) + } + if scheme := parsedURL.Scheme; scheme == "" { + t.Fatalf("malformed URL scheme: %q from %q", scheme, urlString) + } + if host := parsedURL.Host; host == "" { + t.Fatalf("malformed URL host: %q from %q", host, urlString) + } + if parsedURL.RawQuery != "" { + t.Fatalf("malformed URL: should not have a query") + } + path := parsedURL.EscapedPath() + slash := strings.LastIndex(path, "/") + if slash == -1 { + t.Fatalf("malformed URL path: %q from %q", path, urlString) + } + certID := path[slash+1:] + if certID == "" { + t.Fatalf("missing certificate identifier in URL path: %q from %q", path, urlString) + } + certIDParts := strings.Split(certID, ".") + if len(certIDParts) != 2 { + t.Fatalf("certificate identifier should consist of 2 base64-encoded values separated by a dot: %q from %q", certID, urlString) + } + if _, err := base64.RawURLEncoding.DecodeString(certIDParts[0]); err != nil { + t.Fatalf("malformed AKI part in certificate identifier: %q from %q: %v", certIDParts[0], urlString, err) + } + if _, err := base64.RawURLEncoding.DecodeString(certIDParts[1]); err != nil { + t.Fatalf("malformed Serial part in certificate identifier: %q from %q: %v", certIDParts[1], urlString, err) + } + +} + +func TestUnmarshalRenewalInfo(t *testing.T) { + renewalInfoJSON := `{ + "suggestedWindow": { + "start": "2021-01-03T00:00:00Z", + "end": "2021-01-07T00:00:00Z" + }, + "explanationURL": "https://example.com/docs/example-mass-reissuance-event" + }` + expectedStart := time.Date(2021, time.January, 3, 0, 0, 0, 0, time.UTC) + expectedEnd := time.Date(2021, time.January, 7, 0, 0, 0, 0, time.UTC) + + var info RenewalInfo + if err := json.Unmarshal([]byte(renewalInfoJSON), &info); err != nil { + t.Fatal(err) + } + if _, err := url.Parse(info.ExplanationURL); err != nil { + t.Fatal(err) + } + if !info.SuggestedWindow.Start.Equal(expectedStart) { + t.Fatalf("%v != %v", expectedStart, info.SuggestedWindow.Start) + } + if !info.SuggestedWindow.End.Equal(expectedEnd) { + t.Fatalf("%v != %v", expectedEnd, info.SuggestedWindow.End) + } +} + +func TestNonce_add(t *testing.T) { + var c Client + c.addNonce(http.Header{"Replay-Nonce": {"nonce"}}) + c.addNonce(http.Header{"Replay-Nonce": {}}) + c.addNonce(http.Header{"Replay-Nonce": {"nonce"}}) + + nonces := map[string]struct{}{"nonce": {}} + if !reflect.DeepEqual(c.nonces, nonces) { + t.Errorf("c.nonces = %q; want %q", c.nonces, nonces) + } +} + +func TestNonce_addMax(t *testing.T) { + c := &Client{nonces: make(map[string]struct{})} + for i := 0; i < maxNonces; i++ { + c.nonces[fmt.Sprintf("%d", i)] = struct{}{} + } + c.addNonce(http.Header{"Replay-Nonce": {"nonce"}}) + if n := len(c.nonces); n != maxNonces { + t.Errorf("len(c.nonces) = %d; want %d", n, maxNonces) + } +} + +func TestNonce_fetch(t *testing.T) { + tests := []struct { + code int + nonce string + }{ + {http.StatusOK, "nonce1"}, + {http.StatusBadRequest, "nonce2"}, + {http.StatusOK, ""}, + } + var i int + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "HEAD" { + t.Errorf("%d: r.Method = %q; want HEAD", i, r.Method) + } + w.Header().Set("Replay-Nonce", tests[i].nonce) + w.WriteHeader(tests[i].code) + })) + defer ts.Close() + for ; i < len(tests); i++ { + test := tests[i] + c := newTestClient() + n, err := c.fetchNonce(context.Background(), ts.URL) + if n != test.nonce { + t.Errorf("%d: n=%q; want %q", i, n, test.nonce) + } + switch { + case err == nil && test.nonce == "": + t.Errorf("%d: n=%q, err=%v; want non-nil error", i, n, err) + case err != nil && test.nonce != "": + t.Errorf("%d: n=%q, err=%v; want %q", i, n, err, test.nonce) + } + } +} + +func TestNonce_fetchError(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + })) + defer ts.Close() + c := newTestClient() + _, err := c.fetchNonce(context.Background(), ts.URL) + e, ok := err.(*Error) + if !ok { + t.Fatalf("err is %T; want *Error", err) + } + if e.StatusCode != http.StatusTooManyRequests { + t.Errorf("e.StatusCode = %d; want %d", e.StatusCode, http.StatusTooManyRequests) + } +} + +func TestNonce_popWhenEmpty(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "HEAD" { + t.Errorf("r.Method = %q; want HEAD", r.Method) + } + switch r.URL.Path { + case "/dir-with-nonce": + w.Header().Set("Replay-Nonce", "dirnonce") + case "/new-nonce": + w.Header().Set("Replay-Nonce", "newnonce") + case "/dir-no-nonce", "/empty": + // No nonce in the header. + default: + t.Errorf("Unknown URL: %s", r.URL) + } + })) + defer ts.Close() + ctx := context.Background() + + tt := []struct { + dirURL, popURL, nonce string + wantOK bool + }{ + {ts.URL + "/dir-with-nonce", ts.URL + "/new-nonce", "dirnonce", true}, + {ts.URL + "/dir-no-nonce", ts.URL + "/new-nonce", "newnonce", true}, + {ts.URL + "/dir-no-nonce", ts.URL + "/empty", "", false}, + } + for _, test := range tt { + t.Run(fmt.Sprintf("nonce:%s wantOK:%v", test.nonce, test.wantOK), func(t *testing.T) { + c := Client{DirectoryURL: test.dirURL} + v, err := c.popNonce(ctx, test.popURL) + if !test.wantOK { + if err == nil { + t.Fatalf("c.popNonce(%q) returned nil error", test.popURL) + } + return + } + if err != nil { + t.Fatalf("c.popNonce(%q): %v", test.popURL, err) + } + if v != test.nonce { + t.Errorf("c.popNonce(%q) = %q; want %q", test.popURL, v, test.nonce) + } + }) + } +} + +func TestLinkHeader(t *testing.T) { + h := http.Header{"Link": { + `;rel="next"`, + `; rel=recover`, + `; foo=bar; rel="terms-of-service"`, + `;rel="next"`, + }} + tests := []struct { + rel string + out []string + }{ + {"next", []string{"https://example.com/acme/new-authz", "dup"}}, + {"recover", []string{"https://example.com/acme/recover-reg"}}, + {"terms-of-service", []string{"https://example.com/acme/terms"}}, + {"empty", nil}, + } + for i, test := range tests { + if v := linkHeader(h, test.rel); !reflect.DeepEqual(v, test.out) { + t.Errorf("%d: linkHeader(%q): %v; want %v", i, test.rel, v, test.out) + } + } +} + +func TestTLSSNI01ChallengeCert(t *testing.T) { + const ( + token = "evaGxfADs6pSRb2LAv9IZf17Dt3juxGJ-PCt92wr-oA" + // echo -n | shasum -a 256 + san = "dbbd5eefe7b4d06eb9d1d9f5acb4c7cd.a27d320e4b30332f0b6cb441734ad7b0.acme.invalid" + ) + + tlscert, name, err := newTestClient().TLSSNI01ChallengeCert(token) + if err != nil { + t.Fatal(err) + } + + if n := len(tlscert.Certificate); n != 1 { + t.Fatalf("len(tlscert.Certificate) = %d; want 1", n) + } + cert, err := x509.ParseCertificate(tlscert.Certificate[0]) + if err != nil { + t.Fatal(err) + } + if len(cert.DNSNames) != 1 || cert.DNSNames[0] != san { + t.Fatalf("cert.DNSNames = %v; want %q", cert.DNSNames, san) + } + if cert.DNSNames[0] != name { + t.Errorf("cert.DNSNames[0] != name: %q vs %q", cert.DNSNames[0], name) + } + if cn := cert.Subject.CommonName; cn != san { + t.Errorf("cert.Subject.CommonName = %q; want %q", cn, san) + } +} + +func TestTLSSNI02ChallengeCert(t *testing.T) { + const ( + token = "evaGxfADs6pSRb2LAv9IZf17Dt3juxGJ-PCt92wr-oA" + // echo -n evaGxfADs6pSRb2LAv9IZf17Dt3juxGJ-PCt92wr-oA | shasum -a 256 + sanA = "7ea0aaa69214e71e02cebb18bb867736.09b730209baabf60e43d4999979ff139.token.acme.invalid" + // echo -n | shasum -a 256 + sanB = "dbbd5eefe7b4d06eb9d1d9f5acb4c7cd.a27d320e4b30332f0b6cb441734ad7b0.ka.acme.invalid" + ) + + tlscert, name, err := newTestClient().TLSSNI02ChallengeCert(token) + if err != nil { + t.Fatal(err) + } + + if n := len(tlscert.Certificate); n != 1 { + t.Fatalf("len(tlscert.Certificate) = %d; want 1", n) + } + cert, err := x509.ParseCertificate(tlscert.Certificate[0]) + if err != nil { + t.Fatal(err) + } + names := []string{sanA, sanB} + if !reflect.DeepEqual(cert.DNSNames, names) { + t.Fatalf("cert.DNSNames = %v;\nwant %v", cert.DNSNames, names) + } + sort.Strings(cert.DNSNames) + i := sort.SearchStrings(cert.DNSNames, name) + if i >= len(cert.DNSNames) || cert.DNSNames[i] != name { + t.Errorf("%v doesn't have %q", cert.DNSNames, name) + } + if cn := cert.Subject.CommonName; cn != sanA { + t.Errorf("CommonName = %q; want %q", cn, sanA) + } +} + +func TestTLSALPN01ChallengeCert(t *testing.T) { + const ( + token = "evaGxfADs6pSRb2LAv9IZf17Dt3juxGJ-PCt92wr-oA" + keyAuth = "evaGxfADs6pSRb2LAv9IZf17Dt3juxGJ-PCt92wr-oA." + testKeyECThumbprint + // echo -n | shasum -a 256 + h = "0420dbbd5eefe7b4d06eb9d1d9f5acb4c7cda27d320e4b30332f0b6cb441734ad7b0" + domain = "example.com" + ) + + extValue, err := hex.DecodeString(h) + if err != nil { + t.Fatal(err) + } + + tlscert, err := newTestClient().TLSALPN01ChallengeCert(token, domain) + if err != nil { + t.Fatal(err) + } + + if n := len(tlscert.Certificate); n != 1 { + t.Fatalf("len(tlscert.Certificate) = %d; want 1", n) + } + cert, err := x509.ParseCertificate(tlscert.Certificate[0]) + if err != nil { + t.Fatal(err) + } + names := []string{domain} + if !reflect.DeepEqual(cert.DNSNames, names) { + t.Fatalf("cert.DNSNames = %v;\nwant %v", cert.DNSNames, names) + } + if cn := cert.Subject.CommonName; cn != domain { + t.Errorf("CommonName = %q; want %q", cn, domain) + } + acmeExts := []pkix.Extension{} + for _, ext := range cert.Extensions { + if idPeACMEIdentifier.Equal(ext.Id) { + acmeExts = append(acmeExts, ext) + } + } + if len(acmeExts) != 1 { + t.Errorf("acmeExts = %v; want exactly one", acmeExts) + } + if !acmeExts[0].Critical { + t.Errorf("acmeExt.Critical = %v; want true", acmeExts[0].Critical) + } + if bytes.Compare(acmeExts[0].Value, extValue) != 0 { + t.Errorf("acmeExt.Value = %v; want %v", acmeExts[0].Value, extValue) + } + +} + +func TestTLSChallengeCertOpt(t *testing.T) { + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + t.Fatal(err) + } + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{Organization: []string{"Test"}}, + DNSNames: []string{"should-be-overwritten"}, + } + opts := []CertOption{WithKey(key), WithTemplate(tmpl)} + + client := newTestClient() + cert1, _, err := client.TLSSNI01ChallengeCert("token", opts...) + if err != nil { + t.Fatal(err) + } + cert2, _, err := client.TLSSNI02ChallengeCert("token", opts...) + if err != nil { + t.Fatal(err) + } + + for i, tlscert := range []tls.Certificate{cert1, cert2} { + // verify generated cert private key + tlskey, ok := tlscert.PrivateKey.(*rsa.PrivateKey) + if !ok { + t.Errorf("%d: tlscert.PrivateKey is %T; want *rsa.PrivateKey", i, tlscert.PrivateKey) + continue + } + if tlskey.D.Cmp(key.D) != 0 { + t.Errorf("%d: tlskey.D = %v; want %v", i, tlskey.D, key.D) + } + // verify generated cert public key + x509Cert, err := x509.ParseCertificate(tlscert.Certificate[0]) + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + tlspub, ok := x509Cert.PublicKey.(*rsa.PublicKey) + if !ok { + t.Errorf("%d: x509Cert.PublicKey is %T; want *rsa.PublicKey", i, x509Cert.PublicKey) + continue + } + if tlspub.N.Cmp(key.N) != 0 { + t.Errorf("%d: tlspub.N = %v; want %v", i, tlspub.N, key.N) + } + // verify template option + sn := big.NewInt(2) + if x509Cert.SerialNumber.Cmp(sn) != 0 { + t.Errorf("%d: SerialNumber = %v; want %v", i, x509Cert.SerialNumber, sn) + } + org := []string{"Test"} + if !reflect.DeepEqual(x509Cert.Subject.Organization, org) { + t.Errorf("%d: Subject.Organization = %+v; want %+v", i, x509Cert.Subject.Organization, org) + } + for _, v := range x509Cert.DNSNames { + if !strings.HasSuffix(v, ".acme.invalid") { + t.Errorf("%d: invalid DNSNames element: %q", i, v) + } + } + } +} + +func TestHTTP01Challenge(t *testing.T) { + const ( + token = "xxx" + // thumbprint is precomputed for testKeyEC in jws_test.go + value = token + "." + testKeyECThumbprint + urlpath = "/.well-known/acme-challenge/" + token + ) + client := newTestClient() + val, err := client.HTTP01ChallengeResponse(token) + if err != nil { + t.Fatal(err) + } + if val != value { + t.Errorf("val = %q; want %q", val, value) + } + if path := client.HTTP01ChallengePath(token); path != urlpath { + t.Errorf("path = %q; want %q", path, urlpath) + } +} + +func TestDNS01ChallengeRecord(t *testing.T) { + // echo -n xxx. | \ + // openssl dgst -binary -sha256 | \ + // base64 | tr -d '=' | tr '/+' '_-' + const value = "8DERMexQ5VcdJ_prpPiA0mVdp7imgbCgjsG4SqqNMIo" + + val, err := newTestClient().DNS01ChallengeRecord("xxx") + if err != nil { + t.Fatal(err) + } + if val != value { + t.Errorf("val = %q; want %q", val, value) + } +} diff --git a/tempfork/acme/http.go b/tempfork/acme/http.go new file mode 100644 index 0000000000000..d92ff232fe983 --- /dev/null +++ b/tempfork/acme/http.go @@ -0,0 +1,344 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "bytes" + "context" + "crypto" + "crypto/rand" + "encoding/json" + "errors" + "fmt" + "io" + "math/big" + "net/http" + "runtime/debug" + "strconv" + "strings" + "time" +) + +// retryTimer encapsulates common logic for retrying unsuccessful requests. +// It is not safe for concurrent use. +type retryTimer struct { + // backoffFn provides backoff delay sequence for retries. + // See Client.RetryBackoff doc comment. + backoffFn func(n int, r *http.Request, res *http.Response) time.Duration + // n is the current retry attempt. + n int +} + +func (t *retryTimer) inc() { + t.n++ +} + +// backoff pauses the current goroutine as described in Client.RetryBackoff. +func (t *retryTimer) backoff(ctx context.Context, r *http.Request, res *http.Response) error { + d := t.backoffFn(t.n, r, res) + if d <= 0 { + return fmt.Errorf("acme: no more retries for %s; tried %d time(s)", r.URL, t.n) + } + wakeup := time.NewTimer(d) + defer wakeup.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-wakeup.C: + return nil + } +} + +func (c *Client) retryTimer() *retryTimer { + f := c.RetryBackoff + if f == nil { + f = defaultBackoff + } + return &retryTimer{backoffFn: f} +} + +// defaultBackoff provides default Client.RetryBackoff implementation +// using a truncated exponential backoff algorithm, +// as described in Client.RetryBackoff. +// +// The n argument is always bounded between 1 and 30. +// The returned value is always greater than 0. +func defaultBackoff(n int, r *http.Request, res *http.Response) time.Duration { + const max = 10 * time.Second + var jitter time.Duration + if x, err := rand.Int(rand.Reader, big.NewInt(1000)); err == nil { + // Set the minimum to 1ms to avoid a case where + // an invalid Retry-After value is parsed into 0 below, + // resulting in the 0 returned value which would unintentionally + // stop the retries. + jitter = (1 + time.Duration(x.Int64())) * time.Millisecond + } + if v, ok := res.Header["Retry-After"]; ok { + return retryAfter(v[0]) + jitter + } + + if n < 1 { + n = 1 + } + if n > 30 { + n = 30 + } + d := time.Duration(1< max { + return max + } + return d +} + +// retryAfter parses a Retry-After HTTP header value, +// trying to convert v into an int (seconds) or use http.ParseTime otherwise. +// It returns zero value if v cannot be parsed. +func retryAfter(v string) time.Duration { + if i, err := strconv.Atoi(v); err == nil { + return time.Duration(i) * time.Second + } + t, err := http.ParseTime(v) + if err != nil { + return 0 + } + return t.Sub(timeNow()) +} + +// resOkay is a function that reports whether the provided response is okay. +// It is expected to keep the response body unread. +type resOkay func(*http.Response) bool + +// wantStatus returns a function which reports whether the code +// matches the status code of a response. +func wantStatus(codes ...int) resOkay { + return func(res *http.Response) bool { + for _, code := range codes { + if code == res.StatusCode { + return true + } + } + return false + } +} + +// get issues an unsigned GET request to the specified URL. +// It returns a non-error value only when ok reports true. +// +// get retries unsuccessful attempts according to c.RetryBackoff +// until the context is done or a non-retriable error is received. +func (c *Client) get(ctx context.Context, url string, ok resOkay) (*http.Response, error) { + retry := c.retryTimer() + for { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + res, err := c.doNoRetry(ctx, req) + switch { + case err != nil: + return nil, err + case ok(res): + return res, nil + case isRetriable(res.StatusCode): + retry.inc() + resErr := responseError(res) + res.Body.Close() + // Ignore the error value from retry.backoff + // and return the one from last retry, as received from the CA. + if retry.backoff(ctx, req, res) != nil { + return nil, resErr + } + default: + defer res.Body.Close() + return nil, responseError(res) + } + } +} + +// postAsGet is POST-as-GET, a replacement for GET in RFC 8555 +// as described in https://tools.ietf.org/html/rfc8555#section-6.3. +// It makes a POST request in KID form with zero JWS payload. +// See nopayload doc comments in jws.go. +func (c *Client) postAsGet(ctx context.Context, url string, ok resOkay) (*http.Response, error) { + return c.post(ctx, nil, url, noPayload, ok) +} + +// post issues a signed POST request in JWS format using the provided key +// to the specified URL. If key is nil, c.Key is used instead. +// It returns a non-error value only when ok reports true. +// +// post retries unsuccessful attempts according to c.RetryBackoff +// until the context is done or a non-retriable error is received. +// It uses postNoRetry to make individual requests. +func (c *Client) post(ctx context.Context, key crypto.Signer, url string, body interface{}, ok resOkay) (*http.Response, error) { + retry := c.retryTimer() + for { + res, req, err := c.postNoRetry(ctx, key, url, body) + if err != nil { + return nil, err + } + if ok(res) { + return res, nil + } + resErr := responseError(res) + res.Body.Close() + switch { + // Check for bad nonce before isRetriable because it may have been returned + // with an unretriable response code such as 400 Bad Request. + case isBadNonce(resErr): + // Consider any previously stored nonce values to be invalid. + c.clearNonces() + case !isRetriable(res.StatusCode): + return nil, resErr + } + retry.inc() + // Ignore the error value from retry.backoff + // and return the one from last retry, as received from the CA. + if err := retry.backoff(ctx, req, res); err != nil { + return nil, resErr + } + } +} + +// postNoRetry signs the body with the given key and POSTs it to the provided url. +// It is used by c.post to retry unsuccessful attempts. +// The body argument must be JSON-serializable. +// +// If key argument is nil, c.Key is used to sign the request. +// If key argument is nil and c.accountKID returns a non-zero keyID, +// the request is sent in KID form. Otherwise, JWK form is used. +// +// In practice, when interfacing with RFC-compliant CAs most requests are sent in KID form +// and JWK is used only when KID is unavailable: new account endpoint and certificate +// revocation requests authenticated by a cert key. +// See jwsEncodeJSON for other details. +func (c *Client) postNoRetry(ctx context.Context, key crypto.Signer, url string, body interface{}) (*http.Response, *http.Request, error) { + kid := noKeyID + if key == nil { + if c.Key == nil { + return nil, nil, errors.New("acme: Client.Key must be populated to make POST requests") + } + key = c.Key + kid = c.accountKID(ctx) + } + nonce, err := c.popNonce(ctx, url) + if err != nil { + return nil, nil, err + } + b, err := jwsEncodeJSON(body, key, kid, nonce, url) + if err != nil { + return nil, nil, err + } + req, err := http.NewRequest("POST", url, bytes.NewReader(b)) + if err != nil { + return nil, nil, err + } + req.Header.Set("Content-Type", "application/jose+json") + res, err := c.doNoRetry(ctx, req) + if err != nil { + return nil, nil, err + } + c.addNonce(res.Header) + return res, req, nil +} + +// doNoRetry issues a request req, replacing its context (if any) with ctx. +func (c *Client) doNoRetry(ctx context.Context, req *http.Request) (*http.Response, error) { + req.Header.Set("User-Agent", c.userAgent()) + res, err := c.httpClient().Do(req.WithContext(ctx)) + if err != nil { + select { + case <-ctx.Done(): + // Prefer the unadorned context error. + // (The acme package had tests assuming this, previously from ctxhttp's + // behavior, predating net/http supporting contexts natively) + // TODO(bradfitz): reconsider this in the future. But for now this + // requires no test updates. + return nil, ctx.Err() + default: + return nil, err + } + } + return res, nil +} + +func (c *Client) httpClient() *http.Client { + if c.HTTPClient != nil { + return c.HTTPClient + } + return http.DefaultClient +} + +// packageVersion is the version of the module that contains this package, for +// sending as part of the User-Agent header. +var packageVersion string + +func init() { + // Set packageVersion if the binary was built in modules mode and x/crypto + // was not replaced with a different module. + info, ok := debug.ReadBuildInfo() + if !ok { + return + } + for _, m := range info.Deps { + if m.Path != "golang.org/x/crypto" { + continue + } + if m.Replace == nil { + packageVersion = m.Version + } + break + } +} + +// userAgent returns the User-Agent header value. It includes the package name, +// the module version (if available), and the c.UserAgent value (if set). +func (c *Client) userAgent() string { + ua := "golang.org/x/crypto/acme" + if packageVersion != "" { + ua += "@" + packageVersion + } + if c.UserAgent != "" { + ua = c.UserAgent + " " + ua + } + return ua +} + +// isBadNonce reports whether err is an ACME "badnonce" error. +func isBadNonce(err error) bool { + // According to the spec badNonce is urn:ietf:params:acme:error:badNonce. + // However, ACME servers in the wild return their versions of the error. + // See https://tools.ietf.org/html/draft-ietf-acme-acme-02#section-5.4 + // and https://github.com/letsencrypt/boulder/blob/0e07eacb/docs/acme-divergences.md#section-66. + ae, ok := err.(*Error) + return ok && strings.HasSuffix(strings.ToLower(ae.ProblemType), ":badnonce") +} + +// isRetriable reports whether a request can be retried +// based on the response status code. +// +// Note that a "bad nonce" error is returned with a non-retriable 400 Bad Request code. +// Callers should parse the response and check with isBadNonce. +func isRetriable(code int) bool { + return code <= 399 || code >= 500 || code == http.StatusTooManyRequests +} + +// responseError creates an error of Error type from resp. +func responseError(resp *http.Response) error { + // don't care if ReadAll returns an error: + // json.Unmarshal will fail in that case anyway + b, _ := io.ReadAll(resp.Body) + e := &wireError{Status: resp.StatusCode} + if err := json.Unmarshal(b, e); err != nil { + // this is not a regular error response: + // populate detail with anything we received, + // e.Status will already contain HTTP response code value + e.Detail = string(b) + if e.Detail == "" { + e.Detail = resp.Status + } + } + return e.error(resp.Header) +} diff --git a/tempfork/acme/http_test.go b/tempfork/acme/http_test.go new file mode 100644 index 0000000000000..d124e4e219abe --- /dev/null +++ b/tempfork/acme/http_test.go @@ -0,0 +1,255 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + "time" +) + +func TestDefaultBackoff(t *testing.T) { + tt := []struct { + nretry int + retryAfter string // Retry-After header + out time.Duration // expected min; max = min + jitter + }{ + {-1, "", time.Second}, // verify the lower bound is 1 + {0, "", time.Second}, // verify the lower bound is 1 + {100, "", 10 * time.Second}, // verify the ceiling + {1, "3600", time.Hour}, // verify the header value is used + {1, "", 1 * time.Second}, + {2, "", 2 * time.Second}, + {3, "", 4 * time.Second}, + {4, "", 8 * time.Second}, + } + for i, test := range tt { + r := httptest.NewRequest("GET", "/", nil) + resp := &http.Response{Header: http.Header{}} + if test.retryAfter != "" { + resp.Header.Set("Retry-After", test.retryAfter) + } + d := defaultBackoff(test.nretry, r, resp) + max := test.out + time.Second // + max jitter + if d < test.out || max < d { + t.Errorf("%d: defaultBackoff(%v) = %v; want between %v and %v", i, test.nretry, d, test.out, max) + } + } +} + +func TestErrorResponse(t *testing.T) { + s := `{ + "status": 400, + "type": "urn:acme:error:xxx", + "detail": "text" + }` + res := &http.Response{ + StatusCode: 400, + Status: "400 Bad Request", + Body: io.NopCloser(strings.NewReader(s)), + Header: http.Header{"X-Foo": {"bar"}}, + } + err := responseError(res) + v, ok := err.(*Error) + if !ok { + t.Fatalf("err = %+v (%T); want *Error type", err, err) + } + if v.StatusCode != 400 { + t.Errorf("v.StatusCode = %v; want 400", v.StatusCode) + } + if v.ProblemType != "urn:acme:error:xxx" { + t.Errorf("v.ProblemType = %q; want urn:acme:error:xxx", v.ProblemType) + } + if v.Detail != "text" { + t.Errorf("v.Detail = %q; want text", v.Detail) + } + if !reflect.DeepEqual(v.Header, res.Header) { + t.Errorf("v.Header = %+v; want %+v", v.Header, res.Header) + } +} + +func TestPostWithRetries(t *testing.T) { + var count int + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + count++ + w.Header().Set("Replay-Nonce", fmt.Sprintf("nonce%d", count)) + if r.Method == "HEAD" { + // We expect the client to do 2 head requests to fetch + // nonces, one to start and another after getting badNonce + return + } + + head, err := decodeJWSHead(r.Body) + switch { + case err != nil: + t.Errorf("decodeJWSHead: %v", err) + case head.Nonce == "": + t.Error("head.Nonce is empty") + case head.Nonce == "nonce1": + // Return a badNonce error to force the call to retry. + w.Header().Set("Retry-After", "0") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"type":"urn:ietf:params:acme:error:badNonce"}`)) + return + } + // Make client.Authorize happy; we're not testing its result. + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{"status":"valid"}`)) + })) + defer ts.Close() + + client := &Client{ + Key: testKey, + DirectoryURL: ts.URL, + dir: &Directory{AuthzURL: ts.URL}, + } + // This call will fail with badNonce, causing a retry + if _, err := client.Authorize(context.Background(), "example.com"); err != nil { + t.Errorf("client.Authorize 1: %v", err) + } + if count != 3 { + t.Errorf("total requests count: %d; want 3", count) + } +} + +func TestRetryErrorType(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Replay-Nonce", "nonce") + w.WriteHeader(http.StatusTooManyRequests) + w.Write([]byte(`{"type":"rateLimited"}`)) + })) + defer ts.Close() + + client := &Client{ + Key: testKey, + RetryBackoff: func(n int, r *http.Request, res *http.Response) time.Duration { + // Do no retries. + return 0 + }, + dir: &Directory{AuthzURL: ts.URL}, + } + + t.Run("post", func(t *testing.T) { + testRetryErrorType(t, func() error { + _, err := client.Authorize(context.Background(), "example.com") + return err + }) + }) + t.Run("get", func(t *testing.T) { + testRetryErrorType(t, func() error { + _, err := client.GetAuthorization(context.Background(), ts.URL) + return err + }) + }) +} + +func testRetryErrorType(t *testing.T, callClient func() error) { + t.Helper() + err := callClient() + if err == nil { + t.Fatal("client.Authorize returned nil error") + } + acmeErr, ok := err.(*Error) + if !ok { + t.Fatalf("err is %v (%T); want *Error", err, err) + } + if acmeErr.StatusCode != http.StatusTooManyRequests { + t.Errorf("acmeErr.StatusCode = %d; want %d", acmeErr.StatusCode, http.StatusTooManyRequests) + } + if acmeErr.ProblemType != "rateLimited" { + t.Errorf("acmeErr.ProblemType = %q; want 'rateLimited'", acmeErr.ProblemType) + } +} + +func TestRetryBackoffArgs(t *testing.T) { + const resCode = http.StatusInternalServerError + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Replay-Nonce", "test-nonce") + w.WriteHeader(resCode) + })) + defer ts.Close() + + // Canceled in backoff. + ctx, cancel := context.WithCancel(context.Background()) + + var nretry int + backoff := func(n int, r *http.Request, res *http.Response) time.Duration { + nretry++ + if n != nretry { + t.Errorf("n = %d; want %d", n, nretry) + } + if nretry == 3 { + cancel() + } + + if r == nil { + t.Error("r is nil") + } + if res.StatusCode != resCode { + t.Errorf("res.StatusCode = %d; want %d", res.StatusCode, resCode) + } + return time.Millisecond + } + + client := &Client{ + Key: testKey, + RetryBackoff: backoff, + dir: &Directory{AuthzURL: ts.URL}, + } + if _, err := client.Authorize(ctx, "example.com"); err == nil { + t.Error("err is nil") + } + if nretry != 3 { + t.Errorf("nretry = %d; want 3", nretry) + } +} + +func TestUserAgent(t *testing.T) { + for _, custom := range []string{"", "CUSTOM_UA"} { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Log(r.UserAgent()) + if s := "golang.org/x/crypto/acme"; !strings.Contains(r.UserAgent(), s) { + t.Errorf("expected User-Agent to contain %q, got %q", s, r.UserAgent()) + } + if !strings.Contains(r.UserAgent(), custom) { + t.Errorf("expected User-Agent to contain %q, got %q", custom, r.UserAgent()) + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"newOrder": "sure"}`)) + })) + defer ts.Close() + + client := &Client{ + Key: testKey, + DirectoryURL: ts.URL, + UserAgent: custom, + } + if _, err := client.Discover(context.Background()); err != nil { + t.Errorf("client.Discover: %v", err) + } + } +} + +func TestAccountKidLoop(t *testing.T) { + // if Client.postNoRetry is called with a nil key argument + // then Client.Key must be set, otherwise we fall into an + // infinite loop (which also causes a deadlock). + client := &Client{dir: &Directory{OrderURL: ":)"}} + _, _, err := client.postNoRetry(context.Background(), nil, "", nil) + if err == nil { + t.Fatal("Client.postNoRetry didn't fail with a nil key") + } + expected := "acme: Client.Key must be populated to make POST requests" + if err.Error() != expected { + t.Fatalf("Unexpected error returned: wanted %q, got %q", expected, err.Error()) + } +} diff --git a/tempfork/acme/jws.go b/tempfork/acme/jws.go new file mode 100644 index 0000000000000..b38828d85935c --- /dev/null +++ b/tempfork/acme/jws.go @@ -0,0 +1,257 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "crypto" + "crypto/ecdsa" + "crypto/hmac" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + _ "crypto/sha512" // need for EC keys + "encoding/asn1" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "math/big" +) + +// KeyID is the account key identity provided by a CA during registration. +type KeyID string + +// noKeyID indicates that jwsEncodeJSON should compute and use JWK instead of a KID. +// See jwsEncodeJSON for details. +const noKeyID = KeyID("") + +// noPayload indicates jwsEncodeJSON will encode zero-length octet string +// in a JWS request. This is called POST-as-GET in RFC 8555 and is used to make +// authenticated GET requests via POSTing with an empty payload. +// See https://tools.ietf.org/html/rfc8555#section-6.3 for more details. +const noPayload = "" + +// noNonce indicates that the nonce should be omitted from the protected header. +// See jwsEncodeJSON for details. +const noNonce = "" + +// jsonWebSignature can be easily serialized into a JWS following +// https://tools.ietf.org/html/rfc7515#section-3.2. +type jsonWebSignature struct { + Protected string `json:"protected"` + Payload string `json:"payload"` + Sig string `json:"signature"` +} + +// jwsEncodeJSON signs claimset using provided key and a nonce. +// The result is serialized in JSON format containing either kid or jwk +// fields based on the provided KeyID value. +// +// The claimset is marshalled using json.Marshal unless it is a string. +// In which case it is inserted directly into the message. +// +// If kid is non-empty, its quoted value is inserted in the protected header +// as "kid" field value. Otherwise, JWK is computed using jwkEncode and inserted +// as "jwk" field value. The "jwk" and "kid" fields are mutually exclusive. +// +// If nonce is non-empty, its quoted value is inserted in the protected header. +// +// See https://tools.ietf.org/html/rfc7515#section-7. +func jwsEncodeJSON(claimset interface{}, key crypto.Signer, kid KeyID, nonce, url string) ([]byte, error) { + if key == nil { + return nil, errors.New("nil key") + } + alg, sha := jwsHasher(key.Public()) + if alg == "" || !sha.Available() { + return nil, ErrUnsupportedKey + } + headers := struct { + Alg string `json:"alg"` + KID string `json:"kid,omitempty"` + JWK json.RawMessage `json:"jwk,omitempty"` + Nonce string `json:"nonce,omitempty"` + URL string `json:"url"` + }{ + Alg: alg, + Nonce: nonce, + URL: url, + } + switch kid { + case noKeyID: + jwk, err := jwkEncode(key.Public()) + if err != nil { + return nil, err + } + headers.JWK = json.RawMessage(jwk) + default: + headers.KID = string(kid) + } + phJSON, err := json.Marshal(headers) + if err != nil { + return nil, err + } + phead := base64.RawURLEncoding.EncodeToString([]byte(phJSON)) + var payload string + if val, ok := claimset.(string); ok { + payload = val + } else { + cs, err := json.Marshal(claimset) + if err != nil { + return nil, err + } + payload = base64.RawURLEncoding.EncodeToString(cs) + } + hash := sha.New() + hash.Write([]byte(phead + "." + payload)) + sig, err := jwsSign(key, sha, hash.Sum(nil)) + if err != nil { + return nil, err + } + enc := jsonWebSignature{ + Protected: phead, + Payload: payload, + Sig: base64.RawURLEncoding.EncodeToString(sig), + } + return json.Marshal(&enc) +} + +// jwsWithMAC creates and signs a JWS using the given key and the HS256 +// algorithm. kid and url are included in the protected header. rawPayload +// should not be base64-URL-encoded. +func jwsWithMAC(key []byte, kid, url string, rawPayload []byte) (*jsonWebSignature, error) { + if len(key) == 0 { + return nil, errors.New("acme: cannot sign JWS with an empty MAC key") + } + header := struct { + Algorithm string `json:"alg"` + KID string `json:"kid"` + URL string `json:"url,omitempty"` + }{ + // Only HMAC-SHA256 is supported. + Algorithm: "HS256", + KID: kid, + URL: url, + } + rawProtected, err := json.Marshal(header) + if err != nil { + return nil, err + } + protected := base64.RawURLEncoding.EncodeToString(rawProtected) + payload := base64.RawURLEncoding.EncodeToString(rawPayload) + + h := hmac.New(sha256.New, key) + if _, err := h.Write([]byte(protected + "." + payload)); err != nil { + return nil, err + } + mac := h.Sum(nil) + + return &jsonWebSignature{ + Protected: protected, + Payload: payload, + Sig: base64.RawURLEncoding.EncodeToString(mac), + }, nil +} + +// jwkEncode encodes public part of an RSA or ECDSA key into a JWK. +// The result is also suitable for creating a JWK thumbprint. +// https://tools.ietf.org/html/rfc7517 +func jwkEncode(pub crypto.PublicKey) (string, error) { + switch pub := pub.(type) { + case *rsa.PublicKey: + // https://tools.ietf.org/html/rfc7518#section-6.3.1 + n := pub.N + e := big.NewInt(int64(pub.E)) + // Field order is important. + // See https://tools.ietf.org/html/rfc7638#section-3.3 for details. + return fmt.Sprintf(`{"e":"%s","kty":"RSA","n":"%s"}`, + base64.RawURLEncoding.EncodeToString(e.Bytes()), + base64.RawURLEncoding.EncodeToString(n.Bytes()), + ), nil + case *ecdsa.PublicKey: + // https://tools.ietf.org/html/rfc7518#section-6.2.1 + p := pub.Curve.Params() + n := p.BitSize / 8 + if p.BitSize%8 != 0 { + n++ + } + x := pub.X.Bytes() + if n > len(x) { + x = append(make([]byte, n-len(x)), x...) + } + y := pub.Y.Bytes() + if n > len(y) { + y = append(make([]byte, n-len(y)), y...) + } + // Field order is important. + // See https://tools.ietf.org/html/rfc7638#section-3.3 for details. + return fmt.Sprintf(`{"crv":"%s","kty":"EC","x":"%s","y":"%s"}`, + p.Name, + base64.RawURLEncoding.EncodeToString(x), + base64.RawURLEncoding.EncodeToString(y), + ), nil + } + return "", ErrUnsupportedKey +} + +// jwsSign signs the digest using the given key. +// The hash is unused for ECDSA keys. +func jwsSign(key crypto.Signer, hash crypto.Hash, digest []byte) ([]byte, error) { + switch pub := key.Public().(type) { + case *rsa.PublicKey: + return key.Sign(rand.Reader, digest, hash) + case *ecdsa.PublicKey: + sigASN1, err := key.Sign(rand.Reader, digest, hash) + if err != nil { + return nil, err + } + + var rs struct{ R, S *big.Int } + if _, err := asn1.Unmarshal(sigASN1, &rs); err != nil { + return nil, err + } + + rb, sb := rs.R.Bytes(), rs.S.Bytes() + size := pub.Params().BitSize / 8 + if size%8 > 0 { + size++ + } + sig := make([]byte, size*2) + copy(sig[size-len(rb):], rb) + copy(sig[size*2-len(sb):], sb) + return sig, nil + } + return nil, ErrUnsupportedKey +} + +// jwsHasher indicates suitable JWS algorithm name and a hash function +// to use for signing a digest with the provided key. +// It returns ("", 0) if the key is not supported. +func jwsHasher(pub crypto.PublicKey) (string, crypto.Hash) { + switch pub := pub.(type) { + case *rsa.PublicKey: + return "RS256", crypto.SHA256 + case *ecdsa.PublicKey: + switch pub.Params().Name { + case "P-256": + return "ES256", crypto.SHA256 + case "P-384": + return "ES384", crypto.SHA384 + case "P-521": + return "ES512", crypto.SHA512 + } + } + return "", 0 +} + +// JWKThumbprint creates a JWK thumbprint out of pub +// as specified in https://tools.ietf.org/html/rfc7638. +func JWKThumbprint(pub crypto.PublicKey) (string, error) { + jwk, err := jwkEncode(pub) + if err != nil { + return "", err + } + b := sha256.Sum256([]byte(jwk)) + return base64.RawURLEncoding.EncodeToString(b[:]), nil +} diff --git a/tempfork/acme/jws_test.go b/tempfork/acme/jws_test.go new file mode 100644 index 0000000000000..d5f00ba2d3245 --- /dev/null +++ b/tempfork/acme/jws_test.go @@ -0,0 +1,550 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "math/big" + "testing" +) + +// The following shell command alias is used in the comments +// throughout this file: +// alias b64raw="base64 -w0 | tr -d '=' | tr '/+' '_-'" + +const ( + // Modulus in raw base64: + // 4xgZ3eRPkwoRvy7qeRUbmMDe0V-xH9eWLdu0iheeLlrmD2mqWXfP9IeSKApbn34 + // g8TuAS9g5zhq8ELQ3kmjr-KV86GAMgI6VAcGlq3QrzpTCf_30Ab7-zawrfRaFON + // a1HwEzPY1KHnGVkxJc85gNkwYI9SY2RHXtvln3zs5wITNrdosqEXeaIkVYBEhbh + // Nu54pp3kxo6TuWLi9e6pXeWetEwmlBwtWZlPoib2j3TxLBksKZfoyFyek380mHg + // JAumQ_I2fjj98_97mk3ihOY4AgVdCDj1z_GCoZkG5Rq7nbCGyosyKWyDX00Zs-n + // NqVhoLeIvXC4nnWdJMZ6rogxyQQ + testKeyPEM = ` +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA4xgZ3eRPkwoRvy7qeRUbmMDe0V+xH9eWLdu0iheeLlrmD2mq +WXfP9IeSKApbn34g8TuAS9g5zhq8ELQ3kmjr+KV86GAMgI6VAcGlq3QrzpTCf/30 +Ab7+zawrfRaFONa1HwEzPY1KHnGVkxJc85gNkwYI9SY2RHXtvln3zs5wITNrdosq +EXeaIkVYBEhbhNu54pp3kxo6TuWLi9e6pXeWetEwmlBwtWZlPoib2j3TxLBksKZf +oyFyek380mHgJAumQ/I2fjj98/97mk3ihOY4AgVdCDj1z/GCoZkG5Rq7nbCGyosy +KWyDX00Zs+nNqVhoLeIvXC4nnWdJMZ6rogxyQQIDAQABAoIBACIEZTOI1Kao9nmV +9IeIsuaR1Y61b9neOF/MLmIVIZu+AAJFCMB4Iw11FV6sFodwpEyeZhx2WkpWVN+H +r19eGiLX3zsL0DOdqBJoSIHDWCCMxgnYJ6nvS0nRxX3qVrBp8R2g12Ub+gNPbmFm +ecf/eeERIVxfifd9VsyRu34eDEvcmKFuLYbElFcPh62xE3x12UZvV/sN7gXbawpP +G+w255vbE5MoaKdnnO83cTFlcHvhn24M/78qP7Te5OAeelr1R89kYxQLpuGe4fbS +zc6E3ym5Td6urDetGGrSY1Eu10/8sMusX+KNWkm+RsBRbkyKq72ks/qKpOxOa+c6 +9gm+Y8ECgYEA/iNUyg1ubRdH11p82l8KHtFC1DPE0V1gSZsX29TpM5jS4qv46K+s +8Ym1zmrORM8x+cynfPx1VQZQ34EYeCMIX212ryJ+zDATl4NE0I4muMvSiH9vx6Xc +7FmhNnaYzPsBL5Tm9nmtQuP09YEn8poiOJFiDs/4olnD5ogA5O4THGkCgYEA5MIL +qWYBUuqbEWLRtMruUtpASclrBqNNsJEsMGbeqBJmoMxdHeSZckbLOrqm7GlMyNRJ +Ne/5uWRGSzaMYuGmwsPpERzqEvYFnSrpjW5YtXZ+JtxFXNVfm9Z1gLLgvGpOUCIU +RbpoDckDe1vgUuk3y5+DjZihs+rqIJ45XzXTzBkCgYBWuf3segruJZy5rEKhTv+o +JqeUvRn0jNYYKFpLBeyTVBrbie6GkbUGNIWbrK05pC+c3K9nosvzuRUOQQL1tJbd +4gA3oiD9U4bMFNr+BRTHyZ7OQBcIXdz3t1qhuHVKtnngIAN1p25uPlbRFUNpshnt +jgeVoHlsBhApcs5DUc+pyQKBgDzeHPg/+g4z+nrPznjKnktRY1W+0El93kgi+J0Q +YiJacxBKEGTJ1MKBb8X6sDurcRDm22wMpGfd9I5Cv2v4GsUsF7HD/cx5xdih+G73 +c4clNj/k0Ff5Nm1izPUno4C+0IOl7br39IPmfpSuR6wH/h6iHQDqIeybjxyKvT1G +N0rRAoGBAKGD+4ZI/E1MoJ5CXB8cDDMHagbE3cq/DtmYzE2v1DFpQYu5I4PCm5c7 +EQeIP6dZtv8IMgtGIb91QX9pXvP0aznzQKwYIA8nZgoENCPfiMTPiEDT9e/0lObO +9XWsXpbSTsRPj0sv1rB+UzBJ0PgjK4q2zOF0sNo7b1+6nlM3BWPx +-----END RSA PRIVATE KEY----- +` + + // This thumbprint is for the testKey defined above. + testKeyThumbprint = "6nicxzh6WETQlrvdchkz-U3e3DOQZ4heJKU63rfqMqQ" + + // openssl ecparam -name secp256k1 -genkey -noout + testKeyECPEM = ` +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIK07hGLr0RwyUdYJ8wbIiBS55CjnkMD23DWr+ccnypWLoAoGCCqGSM49 +AwEHoUQDQgAE5lhEug5xK4xBDZ2nAbaxLtaLiv85bxJ7ePd1dkO23HThqIrvawF5 +QAaS/RNouybCiRhRjI3EaxLkQwgrCw0gqQ== +-----END EC PRIVATE KEY----- +` + // openssl ecparam -name secp384r1 -genkey -noout + testKeyEC384PEM = ` +-----BEGIN EC PRIVATE KEY----- +MIGkAgEBBDAQ4lNtXRORWr1bgKR1CGysr9AJ9SyEk4jiVnlUWWUChmSNL+i9SLSD +Oe/naPqXJ6CgBwYFK4EEACKhZANiAAQzKtj+Ms0vHoTX5dzv3/L5YMXOWuI5UKRj +JigpahYCqXD2BA1j0E/2xt5vlPf+gm0PL+UHSQsCokGnIGuaHCsJAp3ry0gHQEke +WYXapUUFdvaK1R2/2hn5O+eiQM8YzCg= +-----END EC PRIVATE KEY----- +` + // openssl ecparam -name secp521r1 -genkey -noout + testKeyEC512PEM = ` +-----BEGIN EC PRIVATE KEY----- +MIHcAgEBBEIBSNZKFcWzXzB/aJClAb305ibalKgtDA7+70eEkdPt28/3LZMM935Z +KqYHh/COcxuu3Kt8azRAUz3gyr4zZKhlKUSgBwYFK4EEACOhgYkDgYYABAHUNKbx +7JwC7H6pa2sV0tERWhHhB3JmW+OP6SUgMWryvIKajlx73eS24dy4QPGrWO9/ABsD +FqcRSkNVTXnIv6+0mAF25knqIBIg5Q8M9BnOu9GGAchcwt3O7RDHmqewnJJDrbjd +GGnm6rb+NnWR9DIopM0nKNkToWoF/hzopxu4Ae/GsQ== +-----END EC PRIVATE KEY----- +` + // 1. openssl ec -in key.pem -noout -text + // 2. remove first byte, 04 (the header); the rest is X and Y + // 3. convert each with: echo | xxd -r -p | b64raw + testKeyECPubX = "5lhEug5xK4xBDZ2nAbaxLtaLiv85bxJ7ePd1dkO23HQ" + testKeyECPubY = "4aiK72sBeUAGkv0TaLsmwokYUYyNxGsS5EMIKwsNIKk" + testKeyEC384PubX = "MyrY_jLNLx6E1-Xc79_y-WDFzlriOVCkYyYoKWoWAqlw9gQNY9BP9sbeb5T3_oJt" + testKeyEC384PubY = "Dy_lB0kLAqJBpyBrmhwrCQKd68tIB0BJHlmF2qVFBXb2itUdv9oZ-TvnokDPGMwo" + testKeyEC512PubX = "AdQ0pvHsnALsfqlraxXS0RFaEeEHcmZb44_pJSAxavK8gpqOXHvd5Lbh3LhA8atY738AGwMWpxFKQ1VNeci_r7SY" + testKeyEC512PubY = "AXbmSeogEiDlDwz0Gc670YYByFzC3c7tEMeap7CckkOtuN0Yaebqtv42dZH0MiikzSco2ROhagX-HOinG7gB78ax" + + // echo -n '{"crv":"P-256","kty":"EC","x":"","y":""}' | \ + // openssl dgst -binary -sha256 | b64raw + testKeyECThumbprint = "zedj-Bd1Zshp8KLePv2MB-lJ_Hagp7wAwdkA0NUTniU" +) + +var ( + testKey *rsa.PrivateKey + testKeyEC *ecdsa.PrivateKey + testKeyEC384 *ecdsa.PrivateKey + testKeyEC512 *ecdsa.PrivateKey +) + +func init() { + testKey = parseRSA(testKeyPEM, "testKeyPEM") + testKeyEC = parseEC(testKeyECPEM, "testKeyECPEM") + testKeyEC384 = parseEC(testKeyEC384PEM, "testKeyEC384PEM") + testKeyEC512 = parseEC(testKeyEC512PEM, "testKeyEC512PEM") +} + +func decodePEM(s, name string) []byte { + d, _ := pem.Decode([]byte(s)) + if d == nil { + panic("no block found in " + name) + } + return d.Bytes +} + +func parseRSA(s, name string) *rsa.PrivateKey { + b := decodePEM(s, name) + k, err := x509.ParsePKCS1PrivateKey(b) + if err != nil { + panic(fmt.Sprintf("%s: %v", name, err)) + } + return k +} + +func parseEC(s, name string) *ecdsa.PrivateKey { + b := decodePEM(s, name) + k, err := x509.ParseECPrivateKey(b) + if err != nil { + panic(fmt.Sprintf("%s: %v", name, err)) + } + return k +} + +func TestJWSEncodeJSON(t *testing.T) { + claims := struct{ Msg string }{"Hello JWS"} + // JWS signed with testKey and "nonce" as the nonce value + // JSON-serialized JWS fields are split for easier testing + const ( + // {"alg":"RS256","jwk":{"e":"AQAB","kty":"RSA","n":"..."},"nonce":"nonce","url":"url"} + protected = "eyJhbGciOiJSUzI1NiIsImp3ayI6eyJlIjoiQVFBQiIsImt0eSI6" + + "IlJTQSIsIm4iOiI0eGdaM2VSUGt3b1J2eTdxZVJVYm1NRGUwVi14" + + "SDllV0xkdTBpaGVlTGxybUQybXFXWGZQOUllU0tBcGJuMzRnOFR1" + + "QVM5ZzV6aHE4RUxRM2ttanItS1Y4NkdBTWdJNlZBY0dscTNRcnpw" + + "VENmXzMwQWI3LXphd3JmUmFGT05hMUh3RXpQWTFLSG5HVmt4SmM4" + + "NWdOa3dZSTlTWTJSSFh0dmxuM3pzNXdJVE5yZG9zcUVYZWFJa1ZZ" + + "QkVoYmhOdTU0cHAza3hvNlR1V0xpOWU2cFhlV2V0RXdtbEJ3dFda" + + "bFBvaWIyajNUeExCa3NLWmZveUZ5ZWszODBtSGdKQXVtUV9JMmZq" + + "ajk4Xzk3bWszaWhPWTRBZ1ZkQ0RqMXpfR0NvWmtHNVJxN25iQ0d5" + + "b3N5S1d5RFgwMFpzLW5OcVZob0xlSXZYQzRubldkSk1aNnJvZ3h5" + + "UVEifSwibm9uY2UiOiJub25jZSIsInVybCI6InVybCJ9" + // {"Msg":"Hello JWS"} + payload = "eyJNc2ciOiJIZWxsbyBKV1MifQ" + // printf '.' | openssl dgst -binary -sha256 -sign testKey | b64raw + signature = "YFyl_xz1E7TR-3E1bIuASTr424EgCvBHjt25WUFC2VaDjXYV0Rj_" + + "Hd3dJ_2IRqBrXDZZ2n4ZeA_4mm3QFwmwyeDwe2sWElhb82lCZ8iX" + + "uFnjeOmSOjx-nWwPa5ibCXzLq13zZ-OBV1Z4oN_TuailQeRoSfA3" + + "nO8gG52mv1x2OMQ5MAFtt8jcngBLzts4AyhI6mBJ2w7Yaj3ZCriq" + + "DWA3GLFvvHdW1Ba9Z01wtGT2CuZI7DUk_6Qj1b3BkBGcoKur5C9i" + + "bUJtCkABwBMvBQNyD3MmXsrRFRTgvVlyU_yMaucYm7nmzEr_2PaQ" + + "50rFt_9qOfJ4sfbLtG1Wwae57BQx1g" + ) + + b, err := jwsEncodeJSON(claims, testKey, noKeyID, "nonce", "url") + if err != nil { + t.Fatal(err) + } + var jws struct{ Protected, Payload, Signature string } + if err := json.Unmarshal(b, &jws); err != nil { + t.Fatal(err) + } + if jws.Protected != protected { + t.Errorf("protected:\n%s\nwant:\n%s", jws.Protected, protected) + } + if jws.Payload != payload { + t.Errorf("payload:\n%s\nwant:\n%s", jws.Payload, payload) + } + if jws.Signature != signature { + t.Errorf("signature:\n%s\nwant:\n%s", jws.Signature, signature) + } +} + +func TestJWSEncodeNoNonce(t *testing.T) { + kid := KeyID("https://example.org/account/1") + claims := "RawString" + const ( + // {"alg":"ES256","kid":"https://example.org/account/1","nonce":"nonce","url":"url"} + protected = "eyJhbGciOiJFUzI1NiIsImtpZCI6Imh0dHBzOi8vZXhhbXBsZS5vcmcvYWNjb3VudC8xIiwidXJsIjoidXJsIn0" + // "Raw String" + payload = "RawString" + ) + + b, err := jwsEncodeJSON(claims, testKeyEC, kid, "", "url") + if err != nil { + t.Fatal(err) + } + var jws struct{ Protected, Payload, Signature string } + if err := json.Unmarshal(b, &jws); err != nil { + t.Fatal(err) + } + if jws.Protected != protected { + t.Errorf("protected:\n%s\nwant:\n%s", jws.Protected, protected) + } + if jws.Payload != payload { + t.Errorf("payload:\n%s\nwant:\n%s", jws.Payload, payload) + } + + sig, err := base64.RawURLEncoding.DecodeString(jws.Signature) + if err != nil { + t.Fatalf("jws.Signature: %v", err) + } + r, s := big.NewInt(0), big.NewInt(0) + r.SetBytes(sig[:len(sig)/2]) + s.SetBytes(sig[len(sig)/2:]) + h := sha256.Sum256([]byte(protected + "." + payload)) + if !ecdsa.Verify(testKeyEC.Public().(*ecdsa.PublicKey), h[:], r, s) { + t.Error("invalid signature") + } +} + +func TestJWSEncodeKID(t *testing.T) { + kid := KeyID("https://example.org/account/1") + claims := struct{ Msg string }{"Hello JWS"} + // JWS signed with testKeyEC + const ( + // {"alg":"ES256","kid":"https://example.org/account/1","nonce":"nonce","url":"url"} + protected = "eyJhbGciOiJFUzI1NiIsImtpZCI6Imh0dHBzOi8vZXhhbXBsZS5" + + "vcmcvYWNjb3VudC8xIiwibm9uY2UiOiJub25jZSIsInVybCI6InVybCJ9" + // {"Msg":"Hello JWS"} + payload = "eyJNc2ciOiJIZWxsbyBKV1MifQ" + ) + + b, err := jwsEncodeJSON(claims, testKeyEC, kid, "nonce", "url") + if err != nil { + t.Fatal(err) + } + var jws struct{ Protected, Payload, Signature string } + if err := json.Unmarshal(b, &jws); err != nil { + t.Fatal(err) + } + if jws.Protected != protected { + t.Errorf("protected:\n%s\nwant:\n%s", jws.Protected, protected) + } + if jws.Payload != payload { + t.Errorf("payload:\n%s\nwant:\n%s", jws.Payload, payload) + } + + sig, err := base64.RawURLEncoding.DecodeString(jws.Signature) + if err != nil { + t.Fatalf("jws.Signature: %v", err) + } + r, s := big.NewInt(0), big.NewInt(0) + r.SetBytes(sig[:len(sig)/2]) + s.SetBytes(sig[len(sig)/2:]) + h := sha256.Sum256([]byte(protected + "." + payload)) + if !ecdsa.Verify(testKeyEC.Public().(*ecdsa.PublicKey), h[:], r, s) { + t.Error("invalid signature") + } +} + +func TestJWSEncodeJSONEC(t *testing.T) { + tt := []struct { + key *ecdsa.PrivateKey + x, y string + alg, crv string + }{ + {testKeyEC, testKeyECPubX, testKeyECPubY, "ES256", "P-256"}, + {testKeyEC384, testKeyEC384PubX, testKeyEC384PubY, "ES384", "P-384"}, + {testKeyEC512, testKeyEC512PubX, testKeyEC512PubY, "ES512", "P-521"}, + } + for i, test := range tt { + claims := struct{ Msg string }{"Hello JWS"} + b, err := jwsEncodeJSON(claims, test.key, noKeyID, "nonce", "url") + if err != nil { + t.Errorf("%d: %v", i, err) + continue + } + var jws struct{ Protected, Payload, Signature string } + if err := json.Unmarshal(b, &jws); err != nil { + t.Errorf("%d: %v", i, err) + continue + } + + b, err = base64.RawURLEncoding.DecodeString(jws.Protected) + if err != nil { + t.Errorf("%d: jws.Protected: %v", i, err) + } + var head struct { + Alg string + Nonce string + URL string `json:"url"` + KID string `json:"kid"` + JWK struct { + Crv string + Kty string + X string + Y string + } `json:"jwk"` + } + if err := json.Unmarshal(b, &head); err != nil { + t.Errorf("%d: jws.Protected: %v", i, err) + } + if head.Alg != test.alg { + t.Errorf("%d: head.Alg = %q; want %q", i, head.Alg, test.alg) + } + if head.Nonce != "nonce" { + t.Errorf("%d: head.Nonce = %q; want nonce", i, head.Nonce) + } + if head.URL != "url" { + t.Errorf("%d: head.URL = %q; want 'url'", i, head.URL) + } + if head.KID != "" { + // We used noKeyID in jwsEncodeJSON: expect no kid value. + t.Errorf("%d: head.KID = %q; want empty", i, head.KID) + } + if head.JWK.Crv != test.crv { + t.Errorf("%d: head.JWK.Crv = %q; want %q", i, head.JWK.Crv, test.crv) + } + if head.JWK.Kty != "EC" { + t.Errorf("%d: head.JWK.Kty = %q; want EC", i, head.JWK.Kty) + } + if head.JWK.X != test.x { + t.Errorf("%d: head.JWK.X = %q; want %q", i, head.JWK.X, test.x) + } + if head.JWK.Y != test.y { + t.Errorf("%d: head.JWK.Y = %q; want %q", i, head.JWK.Y, test.y) + } + } +} + +type customTestSigner struct { + sig []byte + pub crypto.PublicKey +} + +func (s *customTestSigner) Public() crypto.PublicKey { return s.pub } +func (s *customTestSigner) Sign(io.Reader, []byte, crypto.SignerOpts) ([]byte, error) { + return s.sig, nil +} + +func TestJWSEncodeJSONCustom(t *testing.T) { + claims := struct{ Msg string }{"hello"} + const ( + // printf '{"Msg":"hello"}' | b64raw + payload = "eyJNc2ciOiJoZWxsbyJ9" + // printf 'testsig' | b64raw + testsig = "dGVzdHNpZw" + + // the example P256 curve point from https://tools.ietf.org/html/rfc7515#appendix-A.3.1 + // encoded as ASN.1â€Ļ + es256stdsig = "MEUCIA7RIVN5Y2xIPC9/FVgH1AKjsigDOvl8fheBmsMWnqZlAiEA" + + "xQoH04w8cOXY8S2vCEpUgKZlkMXyk1Cajz9/ioOjVNU" + // â€Ļand RFC7518 (https://tools.ietf.org/html/rfc7518#section-3.4) + es256jwsig = "DtEhU3ljbEg8L38VWAfUAqOyKAM6-Xx-F4GawxaepmXFCgfTjDxw" + + "5djxLa8ISlSApmWQxfKTUJqPP3-Kg6NU1Q" + + // printf '{"alg":"ES256","jwk":{"crv":"P-256","kty":"EC","x":,"y":},"nonce":"nonce","url":"url"}' | b64raw + es256phead = "eyJhbGciOiJFUzI1NiIsImp3ayI6eyJjcnYiOiJQLTI1NiIsImt0" + + "eSI6IkVDIiwieCI6IjVsaEV1ZzV4SzR4QkRaMm5BYmF4THRhTGl2" + + "ODVieEo3ZVBkMWRrTzIzSFEiLCJ5IjoiNGFpSzcyc0JlVUFHa3Yw" + + "VGFMc213b2tZVVl5TnhHc1M1RU1JS3dzTklLayJ9LCJub25jZSI6" + + "Im5vbmNlIiwidXJsIjoidXJsIn0" + + // {"alg":"RS256","jwk":{"e":"AQAB","kty":"RSA","n":"..."},"nonce":"nonce","url":"url"} + rs256phead = "eyJhbGciOiJSUzI1NiIsImp3ayI6eyJlIjoiQVFBQiIsImt0eSI6" + + "IlJTQSIsIm4iOiI0eGdaM2VSUGt3b1J2eTdxZVJVYm1NRGUwVi14" + + "SDllV0xkdTBpaGVlTGxybUQybXFXWGZQOUllU0tBcGJuMzRnOFR1" + + "QVM5ZzV6aHE4RUxRM2ttanItS1Y4NkdBTWdJNlZBY0dscTNRcnpw" + + "VENmXzMwQWI3LXphd3JmUmFGT05hMUh3RXpQWTFLSG5HVmt4SmM4" + + "NWdOa3dZSTlTWTJSSFh0dmxuM3pzNXdJVE5yZG9zcUVYZWFJa1ZZ" + + "QkVoYmhOdTU0cHAza3hvNlR1V0xpOWU2cFhlV2V0RXdtbEJ3dFda" + + "bFBvaWIyajNUeExCa3NLWmZveUZ5ZWszODBtSGdKQXVtUV9JMmZq" + + "ajk4Xzk3bWszaWhPWTRBZ1ZkQ0RqMXpfR0NvWmtHNVJxN25iQ0d5" + + "b3N5S1d5RFgwMFpzLW5OcVZob0xlSXZYQzRubldkSk1aNnJvZ3h5" + + "UVEifSwibm9uY2UiOiJub25jZSIsInVybCI6InVybCJ9" + ) + + tt := []struct { + alg, phead string + pub crypto.PublicKey + stdsig, jwsig string + }{ + {"ES256", es256phead, testKeyEC.Public(), es256stdsig, es256jwsig}, + {"RS256", rs256phead, testKey.Public(), testsig, testsig}, + } + for _, tc := range tt { + tc := tc + t.Run(tc.alg, func(t *testing.T) { + stdsig, err := base64.RawStdEncoding.DecodeString(tc.stdsig) + if err != nil { + t.Errorf("couldn't decode test vector: %v", err) + } + signer := &customTestSigner{ + sig: stdsig, + pub: tc.pub, + } + + b, err := jwsEncodeJSON(claims, signer, noKeyID, "nonce", "url") + if err != nil { + t.Fatal(err) + } + var j jsonWebSignature + if err := json.Unmarshal(b, &j); err != nil { + t.Fatal(err) + } + if j.Protected != tc.phead { + t.Errorf("j.Protected = %q\nwant %q", j.Protected, tc.phead) + } + if j.Payload != payload { + t.Errorf("j.Payload = %q\nwant %q", j.Payload, payload) + } + if j.Sig != tc.jwsig { + t.Errorf("j.Sig = %q\nwant %q", j.Sig, tc.jwsig) + } + }) + } +} + +func TestJWSWithMAC(t *testing.T) { + // Example from RFC 7520 Section 4.4.3. + // https://tools.ietf.org/html/rfc7520#section-4.4.3 + b64Key := "hJtXIZ2uSN5kbQfbtTNWbpdmhkV8FJG-Onbc6mxCcYg" + rawPayload := []byte("It\xe2\x80\x99s a dangerous business, Frodo, going out your " + + "door. You step onto the road, and if you don't keep your feet, " + + "there\xe2\x80\x99s no knowing where you might be swept off " + + "to.") + protected := "eyJhbGciOiJIUzI1NiIsImtpZCI6IjAxOGMwYWU1LTRkOWItNDcxYi1iZmQ2LW" + + "VlZjMxNGJjNzAzNyJ9" + payload := "SXTigJlzIGEgZGFuZ2Vyb3VzIGJ1c2luZXNzLCBGcm9kbywg" + + "Z29pbmcgb3V0IHlvdXIgZG9vci4gWW91IHN0ZXAgb250byB0aGUgcm9h" + + "ZCwgYW5kIGlmIHlvdSBkb24ndCBrZWVwIHlvdXIgZmVldCwgdGhlcmXi" + + "gJlzIG5vIGtub3dpbmcgd2hlcmUgeW91IG1pZ2h0IGJlIHN3ZXB0IG9m" + + "ZiB0by4" + sig := "s0h6KThzkfBBBkLspW1h84VsJZFTsPPqMDA7g1Md7p0" + + key, err := base64.RawURLEncoding.DecodeString(b64Key) + if err != nil { + t.Fatalf("unable to decode key: %q", b64Key) + } + got, err := jwsWithMAC(key, "018c0ae5-4d9b-471b-bfd6-eef314bc7037", "", rawPayload) + if err != nil { + t.Fatalf("jwsWithMAC() = %q", err) + } + if got.Protected != protected { + t.Errorf("got.Protected = %q\nwant %q", got.Protected, protected) + } + if got.Payload != payload { + t.Errorf("got.Payload = %q\nwant %q", got.Payload, payload) + } + if got.Sig != sig { + t.Errorf("got.Signature = %q\nwant %q", got.Sig, sig) + } +} + +func TestJWSWithMACError(t *testing.T) { + p := "{}" + if _, err := jwsWithMAC(nil, "", "", []byte(p)); err == nil { + t.Errorf("jwsWithMAC(nil, ...) = success; want err") + } +} + +func TestJWKThumbprintRSA(t *testing.T) { + // Key example from RFC 7638 + const base64N = "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAt" + + "VT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn6" + + "4tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FD" + + "W2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n9" + + "1CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINH" + + "aQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw" + const base64E = "AQAB" + const expected = "NzbLsXh8uDCcd-6MNwXF4W_7noWXFZAfHkxZsRGC9Xs" + + b, err := base64.RawURLEncoding.DecodeString(base64N) + if err != nil { + t.Fatalf("Error parsing example key N: %v", err) + } + n := new(big.Int).SetBytes(b) + + b, err = base64.RawURLEncoding.DecodeString(base64E) + if err != nil { + t.Fatalf("Error parsing example key E: %v", err) + } + e := new(big.Int).SetBytes(b) + + pub := &rsa.PublicKey{N: n, E: int(e.Uint64())} + th, err := JWKThumbprint(pub) + if err != nil { + t.Error(err) + } + if th != expected { + t.Errorf("thumbprint = %q; want %q", th, expected) + } +} + +func TestJWKThumbprintEC(t *testing.T) { + // Key example from RFC 7520 + // expected was computed with + // printf '{"crv":"P-521","kty":"EC","x":"","y":""}' | \ + // openssl dgst -binary -sha256 | b64raw + const ( + base64X = "AHKZLLOsCOzz5cY97ewNUajB957y-C-U88c3v13nmGZx6sYl_oJXu9A5RkT" + + "KqjqvjyekWF-7ytDyRXYgCF5cj0Kt" + base64Y = "AdymlHvOiLxXkEhayXQnNCvDX4h9htZaCJN34kfmC6pV5OhQHiraVySsUda" + + "QkAgDPrwQrJmbnX9cwlGfP-HqHZR1" + expected = "dHri3SADZkrush5HU_50AoRhcKFryN-PI6jPBtPL55M" + ) + + b, err := base64.RawURLEncoding.DecodeString(base64X) + if err != nil { + t.Fatalf("Error parsing example key X: %v", err) + } + x := new(big.Int).SetBytes(b) + + b, err = base64.RawURLEncoding.DecodeString(base64Y) + if err != nil { + t.Fatalf("Error parsing example key Y: %v", err) + } + y := new(big.Int).SetBytes(b) + + pub := &ecdsa.PublicKey{Curve: elliptic.P521(), X: x, Y: y} + th, err := JWKThumbprint(pub) + if err != nil { + t.Error(err) + } + if th != expected { + t.Errorf("thumbprint = %q; want %q", th, expected) + } +} + +func TestJWKThumbprintErrUnsupportedKey(t *testing.T) { + _, err := JWKThumbprint(struct{}{}) + if err != ErrUnsupportedKey { + t.Errorf("err = %q; want %q", err, ErrUnsupportedKey) + } +} diff --git a/tempfork/acme/rfc8555.go b/tempfork/acme/rfc8555.go new file mode 100644 index 0000000000000..3eaf935fdd77b --- /dev/null +++ b/tempfork/acme/rfc8555.go @@ -0,0 +1,486 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "context" + "crypto" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io" + "net/http" + "time" +) + +// DeactivateReg permanently disables an existing account associated with c.Key. +// A deactivated account can no longer request certificate issuance or access +// resources related to the account, such as orders or authorizations. +// +// It only works with CAs implementing RFC 8555. +func (c *Client) DeactivateReg(ctx context.Context) error { + if _, err := c.Discover(ctx); err != nil { // required by c.accountKID + return err + } + url := string(c.accountKID(ctx)) + if url == "" { + return ErrNoAccount + } + req := json.RawMessage(`{"status": "deactivated"}`) + res, err := c.post(ctx, nil, url, req, wantStatus(http.StatusOK)) + if err != nil { + return err + } + res.Body.Close() + return nil +} + +// registerRFC is equivalent to c.Register but for CAs implementing RFC 8555. +// It expects c.Discover to have already been called. +func (c *Client) registerRFC(ctx context.Context, acct *Account, prompt func(tosURL string) bool) (*Account, error) { + c.cacheMu.Lock() // guard c.kid access + defer c.cacheMu.Unlock() + + req := struct { + TermsAgreed bool `json:"termsOfServiceAgreed,omitempty"` + Contact []string `json:"contact,omitempty"` + ExternalAccountBinding *jsonWebSignature `json:"externalAccountBinding,omitempty"` + }{ + Contact: acct.Contact, + } + if c.dir.Terms != "" { + req.TermsAgreed = prompt(c.dir.Terms) + } + + // set 'externalAccountBinding' field if requested + if acct.ExternalAccountBinding != nil { + eabJWS, err := c.encodeExternalAccountBinding(acct.ExternalAccountBinding) + if err != nil { + return nil, fmt.Errorf("acme: failed to encode external account binding: %v", err) + } + req.ExternalAccountBinding = eabJWS + } + + res, err := c.post(ctx, c.Key, c.dir.RegURL, req, wantStatus( + http.StatusOK, // account with this key already registered + http.StatusCreated, // new account created + )) + if err != nil { + return nil, err + } + + defer res.Body.Close() + a, err := responseAccount(res) + if err != nil { + return nil, err + } + // Cache Account URL even if we return an error to the caller. + // It is by all means a valid and usable "kid" value for future requests. + c.KID = KeyID(a.URI) + if res.StatusCode == http.StatusOK { + return nil, ErrAccountAlreadyExists + } + return a, nil +} + +// encodeExternalAccountBinding will encode an external account binding stanza +// as described in https://tools.ietf.org/html/rfc8555#section-7.3.4. +func (c *Client) encodeExternalAccountBinding(eab *ExternalAccountBinding) (*jsonWebSignature, error) { + jwk, err := jwkEncode(c.Key.Public()) + if err != nil { + return nil, err + } + return jwsWithMAC(eab.Key, eab.KID, c.dir.RegURL, []byte(jwk)) +} + +// updateRegRFC is equivalent to c.UpdateReg but for CAs implementing RFC 8555. +// It expects c.Discover to have already been called. +func (c *Client) updateRegRFC(ctx context.Context, a *Account) (*Account, error) { + url := string(c.accountKID(ctx)) + if url == "" { + return nil, ErrNoAccount + } + req := struct { + Contact []string `json:"contact,omitempty"` + }{ + Contact: a.Contact, + } + res, err := c.post(ctx, nil, url, req, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + defer res.Body.Close() + return responseAccount(res) +} + +// getRegRFC is equivalent to c.GetReg but for CAs implementing RFC 8555. +// It expects c.Discover to have already been called. +func (c *Client) getRegRFC(ctx context.Context) (*Account, error) { + req := json.RawMessage(`{"onlyReturnExisting": true}`) + res, err := c.post(ctx, c.Key, c.dir.RegURL, req, wantStatus(http.StatusOK)) + if e, ok := err.(*Error); ok && e.ProblemType == "urn:ietf:params:acme:error:accountDoesNotExist" { + return nil, ErrNoAccount + } + if err != nil { + return nil, err + } + + defer res.Body.Close() + return responseAccount(res) +} + +func responseAccount(res *http.Response) (*Account, error) { + var v struct { + Status string + Contact []string + Orders string + } + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: invalid account response: %v", err) + } + return &Account{ + URI: res.Header.Get("Location"), + Status: v.Status, + Contact: v.Contact, + OrdersURL: v.Orders, + }, nil +} + +// accountKeyRollover attempts to perform account key rollover. +// On success it will change client.Key to the new key. +func (c *Client) accountKeyRollover(ctx context.Context, newKey crypto.Signer) error { + dir, err := c.Discover(ctx) // Also required by c.accountKID + if err != nil { + return err + } + kid := c.accountKID(ctx) + if kid == noKeyID { + return ErrNoAccount + } + oldKey, err := jwkEncode(c.Key.Public()) + if err != nil { + return err + } + payload := struct { + Account string `json:"account"` + OldKey json.RawMessage `json:"oldKey"` + }{ + Account: string(kid), + OldKey: json.RawMessage(oldKey), + } + inner, err := jwsEncodeJSON(payload, newKey, noKeyID, noNonce, dir.KeyChangeURL) + if err != nil { + return err + } + + res, err := c.post(ctx, nil, dir.KeyChangeURL, base64.RawURLEncoding.EncodeToString(inner), wantStatus(http.StatusOK)) + if err != nil { + return err + } + defer res.Body.Close() + c.Key = newKey + return nil +} + +// AuthorizeOrder initiates the order-based application for certificate issuance, +// as opposed to pre-authorization in Authorize. +// It is only supported by CAs implementing RFC 8555. +// +// The caller then needs to fetch each authorization with GetAuthorization, +// identify those with StatusPending status and fulfill a challenge using Accept. +// Once all authorizations are satisfied, the caller will typically want to poll +// order status using WaitOrder until it's in StatusReady state. +// To finalize the order and obtain a certificate, the caller submits a CSR with CreateOrderCert. +func (c *Client) AuthorizeOrder(ctx context.Context, id []AuthzID, opt ...OrderOption) (*Order, error) { + dir, err := c.Discover(ctx) + if err != nil { + return nil, err + } + + req := struct { + Identifiers []wireAuthzID `json:"identifiers"` + NotBefore string `json:"notBefore,omitempty"` + NotAfter string `json:"notAfter,omitempty"` + Replaces string `json:"replaces,omitempty"` + }{} + for _, v := range id { + req.Identifiers = append(req.Identifiers, wireAuthzID{ + Type: v.Type, + Value: v.Value, + }) + } + for _, o := range opt { + switch o := o.(type) { + case orderNotBeforeOpt: + req.NotBefore = time.Time(o).Format(time.RFC3339) + case orderNotAfterOpt: + req.NotAfter = time.Time(o).Format(time.RFC3339) + case orderReplacesCert: + req.Replaces = certRenewalIdentifier(o.cert) + case orderReplacesCertDER: + cert, err := x509.ParseCertificate(o) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate being replaced: %w", err) + } + req.Replaces = certRenewalIdentifier(cert) + default: + // Package's fault if we let this happen. + panic(fmt.Sprintf("unsupported order option type %T", o)) + } + } + + res, err := c.post(ctx, nil, dir.OrderURL, req, wantStatus(http.StatusCreated)) + if err != nil { + return nil, err + } + defer res.Body.Close() + return responseOrder(res) +} + +// GetOrder retrives an order identified by the given URL. +// For orders created with AuthorizeOrder, the url value is Order.URI. +// +// If a caller needs to poll an order until its status is final, +// see the WaitOrder method. +func (c *Client) GetOrder(ctx context.Context, url string) (*Order, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + defer res.Body.Close() + return responseOrder(res) +} + +// WaitOrder polls an order from the given URL until it is in one of the final states, +// StatusReady, StatusValid or StatusInvalid, the CA responded with a non-retryable error +// or the context is done. +// +// It returns a non-nil Order only if its Status is StatusReady or StatusValid. +// In all other cases WaitOrder returns an error. +// If the Status is StatusInvalid, the returned error is of type *OrderError. +func (c *Client) WaitOrder(ctx context.Context, url string) (*Order, error) { + if _, err := c.Discover(ctx); err != nil { + return nil, err + } + for { + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + o, err := responseOrder(res) + res.Body.Close() + switch { + case err != nil: + // Skip and retry. + case o.Status == StatusInvalid: + return nil, &OrderError{OrderURL: o.URI, Status: o.Status} + case o.Status == StatusReady || o.Status == StatusValid: + return o, nil + } + + d := retryAfter(res.Header.Get("Retry-After")) + if d == 0 { + // Default retry-after. + // Same reasoning as in WaitAuthorization. + d = time.Second + } + t := time.NewTimer(d) + select { + case <-ctx.Done(): + t.Stop() + return nil, ctx.Err() + case <-t.C: + // Retry. + } + } +} + +func responseOrder(res *http.Response) (*Order, error) { + var v struct { + Status string + Expires time.Time + Identifiers []wireAuthzID + NotBefore time.Time + NotAfter time.Time + Error *wireError + Authorizations []string + Finalize string + Certificate string + } + if err := json.NewDecoder(res.Body).Decode(&v); err != nil { + return nil, fmt.Errorf("acme: error reading order: %v", err) + } + o := &Order{ + URI: res.Header.Get("Location"), + Status: v.Status, + Expires: v.Expires, + NotBefore: v.NotBefore, + NotAfter: v.NotAfter, + AuthzURLs: v.Authorizations, + FinalizeURL: v.Finalize, + CertURL: v.Certificate, + } + for _, id := range v.Identifiers { + o.Identifiers = append(o.Identifiers, AuthzID{Type: id.Type, Value: id.Value}) + } + if v.Error != nil { + o.Error = v.Error.error(nil /* headers */) + } + return o, nil +} + +// CreateOrderCert submits the CSR (Certificate Signing Request) to a CA at the specified URL. +// The URL is the FinalizeURL field of an Order created with AuthorizeOrder. +// +// If the bundle argument is true, the returned value also contain the CA (issuer) +// certificate chain. Otherwise, only a leaf certificate is returned. +// The returned URL can be used to re-fetch the certificate using FetchCert. +// +// This method is only supported by CAs implementing RFC 8555. See CreateCert for pre-RFC CAs. +// +// CreateOrderCert returns an error if the CA's response is unreasonably large. +// Callers are encouraged to parse the returned value to ensure the certificate is valid and has the expected features. +func (c *Client) CreateOrderCert(ctx context.Context, url string, csr []byte, bundle bool) (der [][]byte, certURL string, err error) { + if _, err := c.Discover(ctx); err != nil { // required by c.accountKID + return nil, "", err + } + + // RFC describes this as "finalize order" request. + req := struct { + CSR string `json:"csr"` + }{ + CSR: base64.RawURLEncoding.EncodeToString(csr), + } + res, err := c.post(ctx, nil, url, req, wantStatus(http.StatusOK)) + if err != nil { + return nil, "", err + } + defer res.Body.Close() + o, err := responseOrder(res) + if err != nil { + return nil, "", err + } + + // Wait for CA to issue the cert if they haven't. + if o.Status != StatusValid { + o, err = c.WaitOrder(ctx, o.URI) + } + if err != nil { + return nil, "", err + } + // The only acceptable status post finalize and WaitOrder is "valid". + if o.Status != StatusValid { + return nil, "", &OrderError{OrderURL: o.URI, Status: o.Status} + } + crt, err := c.fetchCertRFC(ctx, o.CertURL, bundle) + return crt, o.CertURL, err +} + +// fetchCertRFC downloads issued certificate from the given URL. +// It expects the CA to respond with PEM-encoded certificate chain. +// +// The URL argument is the CertURL field of Order. +func (c *Client) fetchCertRFC(ctx context.Context, url string, bundle bool) ([][]byte, error) { + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + defer res.Body.Close() + + // Get all the bytes up to a sane maximum. + // Account very roughly for base64 overhead. + const max = maxCertChainSize + maxCertChainSize/33 + b, err := io.ReadAll(io.LimitReader(res.Body, max+1)) + if err != nil { + return nil, fmt.Errorf("acme: fetch cert response stream: %v", err) + } + if len(b) > max { + return nil, errors.New("acme: certificate chain is too big") + } + + // Decode PEM chain. + var chain [][]byte + for { + var p *pem.Block + p, b = pem.Decode(b) + if p == nil { + break + } + if p.Type != "CERTIFICATE" { + return nil, fmt.Errorf("acme: invalid PEM cert type %q", p.Type) + } + + chain = append(chain, p.Bytes) + if !bundle { + return chain, nil + } + if len(chain) > maxChainLen { + return nil, errors.New("acme: certificate chain is too long") + } + } + if len(chain) == 0 { + return nil, errors.New("acme: certificate chain is empty") + } + return chain, nil +} + +// sends a cert revocation request in either JWK form when key is non-nil or KID form otherwise. +func (c *Client) revokeCertRFC(ctx context.Context, key crypto.Signer, cert []byte, reason CRLReasonCode) error { + req := &struct { + Cert string `json:"certificate"` + Reason int `json:"reason"` + }{ + Cert: base64.RawURLEncoding.EncodeToString(cert), + Reason: int(reason), + } + res, err := c.post(ctx, key, c.dir.RevokeURL, req, wantStatus(http.StatusOK)) + if err != nil { + if isAlreadyRevoked(err) { + // Assume it is not an error to revoke an already revoked cert. + return nil + } + return err + } + defer res.Body.Close() + return nil +} + +func isAlreadyRevoked(err error) bool { + e, ok := err.(*Error) + return ok && e.ProblemType == "urn:ietf:params:acme:error:alreadyRevoked" +} + +// ListCertAlternates retrieves any alternate certificate chain URLs for the +// given certificate chain URL. These alternate URLs can be passed to FetchCert +// in order to retrieve the alternate certificate chains. +// +// If there are no alternate issuer certificate chains, a nil slice will be +// returned. +func (c *Client) ListCertAlternates(ctx context.Context, url string) ([]string, error) { + if _, err := c.Discover(ctx); err != nil { // required by c.accountKID + return nil, err + } + + res, err := c.postAsGet(ctx, url, wantStatus(http.StatusOK)) + if err != nil { + return nil, err + } + defer res.Body.Close() + + // We don't need the body but we need to discard it so we don't end up + // preventing keep-alive + if _, err := io.Copy(io.Discard, res.Body); err != nil { + return nil, fmt.Errorf("acme: cert alternates response stream: %v", err) + } + alts := linkHeader(res.Header, "alternate") + return alts, nil +} diff --git a/tempfork/acme/rfc8555_test.go b/tempfork/acme/rfc8555_test.go new file mode 100644 index 0000000000000..ec51a7a5ed654 --- /dev/null +++ b/tempfork/acme/rfc8555_test.go @@ -0,0 +1,1024 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "io" + "math/big" + "net/http" + "net/http/httptest" + "reflect" + "strings" + "sync" + "testing" + "time" +) + +// While contents of this file is pertinent only to RFC8555, +// it is complementary to the tests in the other _test.go files +// many of which are valid for both pre- and RFC8555. +// This will make it easier to clean up the tests once non-RFC compliant +// code is removed. + +func TestRFC_Discover(t *testing.T) { + const ( + nonce = "https://example.com/acme/new-nonce" + reg = "https://example.com/acme/new-acct" + order = "https://example.com/acme/new-order" + authz = "https://example.com/acme/new-authz" + revoke = "https://example.com/acme/revoke-cert" + keychange = "https://example.com/acme/key-change" + metaTerms = "https://example.com/acme/terms/2017-5-30" + metaWebsite = "https://www.example.com/" + metaCAA = "example.com" + ) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{ + "newNonce": %q, + "newAccount": %q, + "newOrder": %q, + "newAuthz": %q, + "revokeCert": %q, + "keyChange": %q, + "meta": { + "termsOfService": %q, + "website": %q, + "caaIdentities": [%q], + "externalAccountRequired": true + } + }`, nonce, reg, order, authz, revoke, keychange, metaTerms, metaWebsite, metaCAA) + })) + defer ts.Close() + c := &Client{DirectoryURL: ts.URL} + dir, err := c.Discover(context.Background()) + if err != nil { + t.Fatal(err) + } + if dir.NonceURL != nonce { + t.Errorf("dir.NonceURL = %q; want %q", dir.NonceURL, nonce) + } + if dir.RegURL != reg { + t.Errorf("dir.RegURL = %q; want %q", dir.RegURL, reg) + } + if dir.OrderURL != order { + t.Errorf("dir.OrderURL = %q; want %q", dir.OrderURL, order) + } + if dir.AuthzURL != authz { + t.Errorf("dir.AuthzURL = %q; want %q", dir.AuthzURL, authz) + } + if dir.RevokeURL != revoke { + t.Errorf("dir.RevokeURL = %q; want %q", dir.RevokeURL, revoke) + } + if dir.KeyChangeURL != keychange { + t.Errorf("dir.KeyChangeURL = %q; want %q", dir.KeyChangeURL, keychange) + } + if dir.Terms != metaTerms { + t.Errorf("dir.Terms = %q; want %q", dir.Terms, metaTerms) + } + if dir.Website != metaWebsite { + t.Errorf("dir.Website = %q; want %q", dir.Website, metaWebsite) + } + if len(dir.CAA) == 0 || dir.CAA[0] != metaCAA { + t.Errorf("dir.CAA = %q; want [%q]", dir.CAA, metaCAA) + } + if !dir.ExternalAccountRequired { + t.Error("dir.Meta.ExternalAccountRequired is false") + } +} + +func TestRFC_popNonce(t *testing.T) { + var count int + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // The Client uses only Directory.NonceURL when specified. + // Expect no other URL paths. + if r.URL.Path != "/new-nonce" { + t.Errorf("r.URL.Path = %q; want /new-nonce", r.URL.Path) + } + if count > 0 { + w.WriteHeader(http.StatusTooManyRequests) + return + } + count++ + w.Header().Set("Replay-Nonce", "second") + })) + cl := &Client{ + DirectoryURL: ts.URL, + dir: &Directory{NonceURL: ts.URL + "/new-nonce"}, + } + cl.addNonce(http.Header{"Replay-Nonce": {"first"}}) + + for i, nonce := range []string{"first", "second"} { + v, err := cl.popNonce(context.Background(), "") + if err != nil { + t.Errorf("%d: cl.popNonce: %v", i, err) + } + if v != nonce { + t.Errorf("%d: cl.popNonce = %q; want %q", i, v, nonce) + } + } + // No more nonces and server replies with an error past first nonce fetch. + // Expected to fail. + if _, err := cl.popNonce(context.Background(), ""); err == nil { + t.Error("last cl.popNonce returned nil error") + } +} + +func TestRFC_postKID(t *testing.T) { + var ts *httptest.Server + ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/new-nonce": + w.Header().Set("Replay-Nonce", "nonce") + case "/new-account": + w.Header().Set("Location", "/account-1") + w.Write([]byte(`{"status":"valid"}`)) + case "/post": + b, _ := io.ReadAll(r.Body) // check err later in decodeJWSxxx + head, err := decodeJWSHead(bytes.NewReader(b)) + if err != nil { + t.Errorf("decodeJWSHead: %v", err) + return + } + if head.KID != "/account-1" { + t.Errorf("head.KID = %q; want /account-1", head.KID) + } + if len(head.JWK) != 0 { + t.Errorf("head.JWK = %q; want zero map", head.JWK) + } + if v := ts.URL + "/post"; head.URL != v { + t.Errorf("head.URL = %q; want %q", head.URL, v) + } + + var payload struct{ Msg string } + decodeJWSRequest(t, &payload, bytes.NewReader(b)) + if payload.Msg != "ping" { + t.Errorf("payload.Msg = %q; want ping", payload.Msg) + } + w.Write([]byte("pong")) + default: + t.Errorf("unhandled %s %s", r.Method, r.URL) + w.WriteHeader(http.StatusBadRequest) + } + })) + defer ts.Close() + + ctx := context.Background() + cl := &Client{ + Key: testKey, + DirectoryURL: ts.URL, + dir: &Directory{ + NonceURL: ts.URL + "/new-nonce", + RegURL: ts.URL + "/new-account", + OrderURL: "/force-rfc-mode", + }, + } + req := json.RawMessage(`{"msg":"ping"}`) + res, err := cl.post(ctx, nil /* use kid */, ts.URL+"/post", req, wantStatus(http.StatusOK)) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + b, _ := io.ReadAll(res.Body) // don't care about err - just checking b + if string(b) != "pong" { + t.Errorf("res.Body = %q; want pong", b) + } +} + +// acmeServer simulates a subset of RFC 8555 compliant CA. +// +// TODO: We also have x/crypto/acme/autocert/acmetest and startACMEServerStub in autocert_test.go. +// It feels like this acmeServer is a sweet spot between usefulness and added complexity. +// Also, acmetest and startACMEServerStub were both written for draft-02, no RFC support. +// The goal is to consolidate all into one ACME test server. +type acmeServer struct { + ts *httptest.Server + handler map[string]http.HandlerFunc // keyed by r.URL.Path + + mu sync.Mutex + nnonce int +} + +func newACMEServer() *acmeServer { + return &acmeServer{handler: make(map[string]http.HandlerFunc)} +} + +func (s *acmeServer) handle(path string, f func(http.ResponseWriter, *http.Request)) { + s.handler[path] = http.HandlerFunc(f) +} + +func (s *acmeServer) start() { + s.ts = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + + // Directory request. + if r.URL.Path == "/" { + fmt.Fprintf(w, `{ + "newNonce": %q, + "newAccount": %q, + "newOrder": %q, + "newAuthz": %q, + "revokeCert": %q, + "keyChange": %q, + "meta": {"termsOfService": %q} + }`, + s.url("/acme/new-nonce"), + s.url("/acme/new-account"), + s.url("/acme/new-order"), + s.url("/acme/new-authz"), + s.url("/acme/revoke-cert"), + s.url("/acme/key-change"), + s.url("/terms"), + ) + return + } + + // All other responses contain a nonce value unconditionally. + w.Header().Set("Replay-Nonce", s.nonce()) + if r.URL.Path == "/acme/new-nonce" { + return + } + + h := s.handler[r.URL.Path] + if h == nil { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, "Unhandled %s", r.URL.Path) + return + } + h.ServeHTTP(w, r) + })) +} + +func (s *acmeServer) close() { + s.ts.Close() +} + +func (s *acmeServer) url(path string) string { + return s.ts.URL + path +} + +func (s *acmeServer) nonce() string { + s.mu.Lock() + defer s.mu.Unlock() + s.nnonce++ + return fmt.Sprintf("nonce%d", s.nnonce) +} + +func (s *acmeServer) error(w http.ResponseWriter, e *wireError) { + w.WriteHeader(e.Status) + json.NewEncoder(w).Encode(e) +} + +func TestRFC_Register(t *testing.T) { + const email = "mailto:user@example.org" + + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusCreated) // 201 means new account created + fmt.Fprintf(w, `{ + "status": "valid", + "contact": [%q], + "orders": %q + }`, email, s.url("/accounts/1/orders")) + + b, _ := io.ReadAll(r.Body) // check err later in decodeJWSxxx + head, err := decodeJWSHead(bytes.NewReader(b)) + if err != nil { + t.Errorf("decodeJWSHead: %v", err) + return + } + if len(head.JWK) == 0 { + t.Error("head.JWK is empty") + } + + var req struct{ Contact []string } + decodeJWSRequest(t, &req, bytes.NewReader(b)) + if len(req.Contact) != 1 || req.Contact[0] != email { + t.Errorf("req.Contact = %q; want [%q]", req.Contact, email) + } + }) + s.start() + defer s.close() + + ctx := context.Background() + cl := &Client{ + Key: testKeyEC, + DirectoryURL: s.url("/"), + } + + var didPrompt bool + a := &Account{Contact: []string{email}} + acct, err := cl.Register(ctx, a, func(tos string) bool { + didPrompt = true + terms := s.url("/terms") + if tos != terms { + t.Errorf("tos = %q; want %q", tos, terms) + } + return true + }) + if err != nil { + t.Fatal(err) + } + okAccount := &Account{ + URI: s.url("/accounts/1"), + Status: StatusValid, + Contact: []string{email}, + OrdersURL: s.url("/accounts/1/orders"), + } + if !reflect.DeepEqual(acct, okAccount) { + t.Errorf("acct = %+v; want %+v", acct, okAccount) + } + if !didPrompt { + t.Error("tos prompt wasn't called") + } + if v := cl.accountKID(ctx); v != KeyID(okAccount.URI) { + t.Errorf("account kid = %q; want %q", v, okAccount.URI) + } +} + +func TestRFC_RegisterExternalAccountBinding(t *testing.T) { + eab := &ExternalAccountBinding{ + KID: "kid-1", + Key: []byte("secret"), + } + + type protected struct { + Algorithm string `json:"alg"` + KID string `json:"kid"` + URL string `json:"url"` + } + const email = "mailto:user@example.org" + + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + if r.Method != "POST" { + t.Errorf("r.Method = %q; want POST", r.Method) + } + + var j struct { + Protected string + Contact []string + TermsOfServiceAgreed bool + ExternalaccountBinding struct { + Protected string + Payload string + Signature string + } + } + decodeJWSRequest(t, &j, r.Body) + protData, err := base64.RawURLEncoding.DecodeString(j.ExternalaccountBinding.Protected) + if err != nil { + t.Fatal(err) + } + + var prot protected + err = json.Unmarshal(protData, &prot) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(j.Contact, []string{email}) { + t.Errorf("j.Contact = %v; want %v", j.Contact, []string{email}) + } + if !j.TermsOfServiceAgreed { + t.Error("j.TermsOfServiceAgreed = false; want true") + } + + // Ensure same KID. + if prot.KID != eab.KID { + t.Errorf("j.ExternalAccountBinding.KID = %s; want %s", prot.KID, eab.KID) + } + // Ensure expected Algorithm. + if prot.Algorithm != "HS256" { + t.Errorf("j.ExternalAccountBinding.Alg = %s; want %s", + prot.Algorithm, "HS256") + } + + // Ensure same URL as outer JWS. + url := fmt.Sprintf("http://%s/acme/new-account", r.Host) + if prot.URL != url { + t.Errorf("j.ExternalAccountBinding.URL = %s; want %s", + prot.URL, url) + } + + // Ensure payload is base64URL encoded string of JWK in outer JWS + jwk, err := jwkEncode(testKeyEC.Public()) + if err != nil { + t.Fatal(err) + } + decodedPayload, err := base64.RawURLEncoding.DecodeString(j.ExternalaccountBinding.Payload) + if err != nil { + t.Fatal(err) + } + if jwk != string(decodedPayload) { + t.Errorf("j.ExternalAccountBinding.Payload = %s; want %s", decodedPayload, jwk) + } + + // Check signature on inner external account binding JWS + hmac := hmac.New(sha256.New, []byte("secret")) + _, err = hmac.Write([]byte(j.ExternalaccountBinding.Protected + "." + j.ExternalaccountBinding.Payload)) + if err != nil { + t.Fatal(err) + } + mac := hmac.Sum(nil) + encodedMAC := base64.RawURLEncoding.EncodeToString(mac) + + if !bytes.Equal([]byte(encodedMAC), []byte(j.ExternalaccountBinding.Signature)) { + t.Errorf("j.ExternalAccountBinding.Signature = %v; want %v", + []byte(j.ExternalaccountBinding.Signature), encodedMAC) + } + + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusCreated) + b, _ := json.Marshal([]string{email}) + fmt.Fprintf(w, `{"status":"valid","orders":"%s","contact":%s}`, s.url("/accounts/1/orders"), b) + }) + s.start() + defer s.close() + + ctx := context.Background() + cl := &Client{ + Key: testKeyEC, + DirectoryURL: s.url("/"), + } + + var didPrompt bool + a := &Account{Contact: []string{email}, ExternalAccountBinding: eab} + acct, err := cl.Register(ctx, a, func(tos string) bool { + didPrompt = true + terms := s.url("/terms") + if tos != terms { + t.Errorf("tos = %q; want %q", tos, terms) + } + return true + }) + if err != nil { + t.Fatal(err) + } + okAccount := &Account{ + URI: s.url("/accounts/1"), + Status: StatusValid, + Contact: []string{email}, + OrdersURL: s.url("/accounts/1/orders"), + } + if !reflect.DeepEqual(acct, okAccount) { + t.Errorf("acct = %+v; want %+v", acct, okAccount) + } + if !didPrompt { + t.Error("tos prompt wasn't called") + } + if v := cl.accountKID(ctx); v != KeyID(okAccount.URI) { + t.Errorf("account kid = %q; want %q", v, okAccount.URI) + } +} + +func TestRFC_RegisterExisting(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) // 200 means account already exists + w.Write([]byte(`{"status": "valid"}`)) + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + _, err := cl.Register(context.Background(), &Account{}, AcceptTOS) + if err != ErrAccountAlreadyExists { + t.Errorf("err = %v; want %v", err, ErrAccountAlreadyExists) + } + kid := KeyID(s.url("/accounts/1")) + if v := cl.accountKID(context.Background()); v != kid { + t.Errorf("account kid = %q; want %q", v, kid) + } +} + +func TestRFC_UpdateReg(t *testing.T) { + const email = "mailto:user@example.org" + + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "valid"}`)) + }) + var didUpdate bool + s.handle("/accounts/1", func(w http.ResponseWriter, r *http.Request) { + didUpdate = true + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "valid"}`)) + + b, _ := io.ReadAll(r.Body) // check err later in decodeJWSxxx + head, err := decodeJWSHead(bytes.NewReader(b)) + if err != nil { + t.Errorf("decodeJWSHead: %v", err) + return + } + if len(head.JWK) != 0 { + t.Error("head.JWK is non-zero") + } + kid := s.url("/accounts/1") + if head.KID != kid { + t.Errorf("head.KID = %q; want %q", head.KID, kid) + } + + var req struct{ Contact []string } + decodeJWSRequest(t, &req, bytes.NewReader(b)) + if len(req.Contact) != 1 || req.Contact[0] != email { + t.Errorf("req.Contact = %q; want [%q]", req.Contact, email) + } + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + _, err := cl.UpdateReg(context.Background(), &Account{Contact: []string{email}}) + if err != nil { + t.Error(err) + } + if !didUpdate { + t.Error("UpdateReg didn't update the account") + } +} + +func TestRFC_GetReg(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "valid"}`)) + + head, err := decodeJWSHead(r.Body) + if err != nil { + t.Errorf("decodeJWSHead: %v", err) + return + } + if len(head.JWK) == 0 { + t.Error("head.JWK is empty") + } + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + acct, err := cl.GetReg(context.Background(), "") + if err != nil { + t.Fatal(err) + } + okAccount := &Account{ + URI: s.url("/accounts/1"), + Status: StatusValid, + } + if !reflect.DeepEqual(acct, okAccount) { + t.Errorf("acct = %+v; want %+v", acct, okAccount) + } +} + +func TestRFC_GetRegNoAccount(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + s.error(w, &wireError{ + Status: http.StatusBadRequest, + Type: "urn:ietf:params:acme:error:accountDoesNotExist", + }) + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + if _, err := cl.GetReg(context.Background(), ""); err != ErrNoAccount { + t.Errorf("err = %v; want %v", err, ErrNoAccount) + } +} + +func TestRFC_GetRegOtherError(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + if _, err := cl.GetReg(context.Background(), ""); err == nil || err == ErrNoAccount { + t.Errorf("GetReg: %v; want any other non-nil err", err) + } +} + +func TestRFC_AccountKeyRollover(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "valid"}`)) + }) + s.handle("/acme/key-change", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + if err := cl.AccountKeyRollover(context.Background(), testKeyEC384); err != nil { + t.Errorf("AccountKeyRollover: %v, wanted no error", err) + } else if cl.Key != testKeyEC384 { + t.Error("AccountKeyRollover did not rotate the client key") + } +} + +func TestRFC_DeactivateReg(t *testing.T) { + const email = "mailto:user@example.org" + curStatus := StatusValid + + type account struct { + Status string `json:"status"` + Contact []string `json:"contact"` + AcceptTOS bool `json:"termsOfServiceAgreed"` + Orders string `json:"orders"` + } + + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) // 200 means existing account + json.NewEncoder(w).Encode(account{ + Status: curStatus, + Contact: []string{email}, + AcceptTOS: true, + Orders: s.url("/accounts/1/orders"), + }) + + b, _ := io.ReadAll(r.Body) // check err later in decodeJWSxxx + head, err := decodeJWSHead(bytes.NewReader(b)) + if err != nil { + t.Errorf("decodeJWSHead: %v", err) + return + } + if len(head.JWK) == 0 { + t.Error("head.JWK is empty") + } + + var req struct { + Status string `json:"status"` + Contact []string `json:"contact"` + AcceptTOS bool `json:"termsOfServiceAgreed"` + OnlyExisting bool `json:"onlyReturnExisting"` + } + decodeJWSRequest(t, &req, bytes.NewReader(b)) + if !req.OnlyExisting { + t.Errorf("req.OnlyReturnExisting = %t; want = %t", req.OnlyExisting, true) + } + }) + s.handle("/accounts/1", func(w http.ResponseWriter, r *http.Request) { + if curStatus == StatusValid { + curStatus = StatusDeactivated + w.WriteHeader(http.StatusOK) + } else { + s.error(w, &wireError{ + Status: http.StatusUnauthorized, + Type: "urn:ietf:params:acme:error:unauthorized", + }) + } + var req account + b, _ := io.ReadAll(r.Body) // check err later in decodeJWSxxx + head, err := decodeJWSHead(bytes.NewReader(b)) + if err != nil { + t.Errorf("decodeJWSHead: %v", err) + return + } + if len(head.JWK) != 0 { + t.Error("head.JWK is not empty") + } + if !strings.HasSuffix(head.KID, "/accounts/1") { + t.Errorf("head.KID = %q; want suffix /accounts/1", head.KID) + } + + decodeJWSRequest(t, &req, bytes.NewReader(b)) + if req.Status != StatusDeactivated { + t.Errorf("req.Status = %q; want = %q", req.Status, StatusDeactivated) + } + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + if err := cl.DeactivateReg(context.Background()); err != nil { + t.Errorf("DeactivateReg: %v, wanted no error", err) + } + if err := cl.DeactivateReg(context.Background()); err == nil { + t.Errorf("DeactivateReg: %v, wanted error for unauthorized", err) + } +} + +func TestRF_DeactivateRegNoAccount(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + s.error(w, &wireError{ + Status: http.StatusBadRequest, + Type: "urn:ietf:params:acme:error:accountDoesNotExist", + }) + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + if err := cl.DeactivateReg(context.Background()); !errors.Is(err, ErrNoAccount) { + t.Errorf("DeactivateReg: %v, wanted ErrNoAccount", err) + } +} + +func TestRFC_AuthorizeOrder(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "valid"}`)) + }) + s.handle("/acme/new-order", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/orders/1")) + w.WriteHeader(http.StatusCreated) + fmt.Fprintf(w, `{ + "status": "pending", + "expires": "2019-09-01T00:00:00Z", + "notBefore": "2019-08-31T00:00:00Z", + "notAfter": "2019-09-02T00:00:00Z", + "identifiers": [{"type":"dns", "value":"example.org"}], + "authorizations": [%q] + }`, s.url("/authz/1")) + }) + s.start() + defer s.close() + + prevCertDER, _ := pem.Decode([]byte(leafPEM)) + prevCert, err := x509.ParseCertificate(prevCertDER.Bytes) + if err != nil { + t.Fatal(err) + } + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + o, err := cl.AuthorizeOrder(context.Background(), DomainIDs("example.org"), + WithOrderNotBefore(time.Date(2019, 8, 31, 0, 0, 0, 0, time.UTC)), + WithOrderNotAfter(time.Date(2019, 9, 2, 0, 0, 0, 0, time.UTC)), + WithOrderReplacesCert(prevCert), + ) + if err != nil { + t.Fatal(err) + } + okOrder := &Order{ + URI: s.url("/orders/1"), + Status: StatusPending, + Expires: time.Date(2019, 9, 1, 0, 0, 0, 0, time.UTC), + NotBefore: time.Date(2019, 8, 31, 0, 0, 0, 0, time.UTC), + NotAfter: time.Date(2019, 9, 2, 0, 0, 0, 0, time.UTC), + Identifiers: []AuthzID{AuthzID{Type: "dns", Value: "example.org"}}, + AuthzURLs: []string{s.url("/authz/1")}, + } + if !reflect.DeepEqual(o, okOrder) { + t.Errorf("AuthorizeOrder = %+v; want %+v", o, okOrder) + } +} + +func TestRFC_GetOrder(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "valid"}`)) + }) + s.handle("/orders/1", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/orders/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{ + "status": "invalid", + "expires": "2019-09-01T00:00:00Z", + "notBefore": "2019-08-31T00:00:00Z", + "notAfter": "2019-09-02T00:00:00Z", + "identifiers": [{"type":"dns", "value":"example.org"}], + "authorizations": ["/authz/1"], + "finalize": "/orders/1/fin", + "certificate": "/orders/1/cert", + "error": {"type": "badRequest"} + }`)) + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + o, err := cl.GetOrder(context.Background(), s.url("/orders/1")) + if err != nil { + t.Fatal(err) + } + okOrder := &Order{ + URI: s.url("/orders/1"), + Status: StatusInvalid, + Expires: time.Date(2019, 9, 1, 0, 0, 0, 0, time.UTC), + NotBefore: time.Date(2019, 8, 31, 0, 0, 0, 0, time.UTC), + NotAfter: time.Date(2019, 9, 2, 0, 0, 0, 0, time.UTC), + Identifiers: []AuthzID{AuthzID{Type: "dns", Value: "example.org"}}, + AuthzURLs: []string{"/authz/1"}, + FinalizeURL: "/orders/1/fin", + CertURL: "/orders/1/cert", + Error: &Error{ProblemType: "badRequest"}, + } + if !reflect.DeepEqual(o, okOrder) { + t.Errorf("GetOrder = %+v\nwant %+v", o, okOrder) + } +} + +func TestRFC_WaitOrder(t *testing.T) { + for _, st := range []string{StatusReady, StatusValid} { + t.Run(st, func(t *testing.T) { + testWaitOrderStatus(t, st) + }) + } +} + +func testWaitOrderStatus(t *testing.T, okStatus string) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "valid"}`)) + }) + var count int + s.handle("/orders/1", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/orders/1")) + w.WriteHeader(http.StatusOK) + s := StatusPending + if count > 0 { + s = okStatus + } + fmt.Fprintf(w, `{"status": %q}`, s) + count++ + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + order, err := cl.WaitOrder(context.Background(), s.url("/orders/1")) + if err != nil { + t.Fatalf("WaitOrder: %v", err) + } + if order.Status != okStatus { + t.Errorf("order.Status = %q; want %q", order.Status, okStatus) + } +} + +func TestRFC_WaitOrderError(t *testing.T) { + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status": "valid"}`)) + }) + var count int + s.handle("/orders/1", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/orders/1")) + w.WriteHeader(http.StatusOK) + s := StatusPending + if count > 0 { + s = StatusInvalid + } + fmt.Fprintf(w, `{"status": %q}`, s) + count++ + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + _, err := cl.WaitOrder(context.Background(), s.url("/orders/1")) + if err == nil { + t.Fatal("WaitOrder returned nil error") + } + e, ok := err.(*OrderError) + if !ok { + t.Fatalf("err = %v (%T); want OrderError", err, err) + } + if e.OrderURL != s.url("/orders/1") { + t.Errorf("e.OrderURL = %q; want %q", e.OrderURL, s.url("/orders/1")) + } + if e.Status != StatusInvalid { + t.Errorf("e.Status = %q; want %q", e.Status, StatusInvalid) + } +} + +func TestRFC_CreateOrderCert(t *testing.T) { + q := &x509.CertificateRequest{ + Subject: pkix.Name{CommonName: "example.org"}, + } + csr, err := x509.CreateCertificateRequest(rand.Reader, q, testKeyEC) + if err != nil { + t.Fatal(err) + } + + tmpl := &x509.Certificate{SerialNumber: big.NewInt(1)} + leaf, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &testKeyEC.PublicKey, testKeyEC) + if err != nil { + t.Fatal(err) + } + + s := newACMEServer() + s.handle("/acme/new-account", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/accounts/1")) + w.Write([]byte(`{"status": "valid"}`)) + }) + var count int + s.handle("/pleaseissue", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Location", s.url("/pleaseissue")) + st := StatusProcessing + if count > 0 { + st = StatusValid + } + fmt.Fprintf(w, `{"status":%q, "certificate":%q}`, st, s.url("/crt")) + count++ + }) + s.handle("/crt", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/pem-certificate-chain") + pem.Encode(w, &pem.Block{Type: "CERTIFICATE", Bytes: leaf}) + }) + s.start() + defer s.close() + ctx := context.Background() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + cert, curl, err := cl.CreateOrderCert(ctx, s.url("/pleaseissue"), csr, true) + if err != nil { + t.Fatalf("CreateOrderCert: %v", err) + } + if _, err := x509.ParseCertificate(cert[0]); err != nil { + t.Errorf("ParseCertificate: %v", err) + } + if !reflect.DeepEqual(cert[0], leaf) { + t.Errorf("cert and leaf bytes don't match") + } + if u := s.url("/crt"); curl != u { + t.Errorf("curl = %q; want %q", curl, u) + } +} + +func TestRFC_AlreadyRevokedCert(t *testing.T) { + s := newACMEServer() + s.handle("/acme/revoke-cert", func(w http.ResponseWriter, r *http.Request) { + s.error(w, &wireError{ + Status: http.StatusBadRequest, + Type: "urn:ietf:params:acme:error:alreadyRevoked", + }) + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + err := cl.RevokeCert(context.Background(), testKeyEC, []byte{0}, CRLReasonUnspecified) + if err != nil { + t.Fatalf("RevokeCert: %v", err) + } +} + +func TestRFC_ListCertAlternates(t *testing.T) { + s := newACMEServer() + s.handle("/crt", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/pem-certificate-chain") + w.Header().Add("Link", `;rel="alternate"`) + w.Header().Add("Link", `; rel="alternate"`) + w.Header().Add("Link", `; rel="index"`) + }) + s.handle("/crt2", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/pem-certificate-chain") + }) + s.start() + defer s.close() + + cl := &Client{Key: testKeyEC, DirectoryURL: s.url("/")} + crts, err := cl.ListCertAlternates(context.Background(), s.url("/crt")) + if err != nil { + t.Fatalf("ListCertAlternates: %v", err) + } + want := []string{"https://example.com/crt/2", "https://example.com/crt/3"} + if !reflect.DeepEqual(crts, want) { + t.Errorf("ListCertAlternates(/crt): %v; want %v", crts, want) + } + crts, err = cl.ListCertAlternates(context.Background(), s.url("/crt2")) + if err != nil { + t.Fatalf("ListCertAlternates: %v", err) + } + if crts != nil { + t.Errorf("ListCertAlternates(/crt2): %v; want nil", crts) + } +} diff --git a/tempfork/acme/sync_to_upstream_test.go b/tempfork/acme/sync_to_upstream_test.go new file mode 100644 index 0000000000000..e22c8c1a86d01 --- /dev/null +++ b/tempfork/acme/sync_to_upstream_test.go @@ -0,0 +1,70 @@ +package acme + +import ( + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + _ "github.com/tailscale/golang-x-crypto/acme" // so it's on disk for the test +) + +// Verify that the files tempfork/acme/*.go (other than this test file) match the +// files in "github.com/tailscale/golang-x-crypto/acme" which is where we develop +// our fork of golang.org/x/crypto/acme and merge with upstream, but then we vendor +// just its acme package into tailscale.com/tempfork/acme. +// +// Development workflow: +// +// - make a change in github.com/tailscale/golang-x-crypto/acme +// - merge it (ideally with golang.org/x/crypto/acme too) +// - rebase github.com/tailscale/golang-x-crypto/acme with upstream x/crypto/acme +// as needed +// - in the tailscale.com repo, run "go get github.com/tailscale/golang-x-crypto/acme@main" +// - run go test ./tempfork/acme to watch it fail; the failure includes +// a shell command you should run to copy the *.go files from tailscale/golang-x-crypto +// to tailscale.com. +// - watch tests pass. git add it all. +// - send PR to tailscale.com +func TestSyncedToUpstream(t *testing.T) { + const pkg = "github.com/tailscale/golang-x-crypto/acme" + out, err := exec.Command("go", "list", "-f", "{{.Dir}}", pkg).Output() + if err != nil { + t.Fatalf("failed to find %s's location o disk: %v", pkg, err) + } + xDir := strings.TrimSpace(string(out)) + + t.Logf("at %s", xDir) + scanDir := func(dir string) map[string]string { + m := map[string]string{} // filename => Go contents + ents, err := os.ReadDir(dir) + if err != nil { + t.Fatal(err) + } + for _, de := range ents { + name := de.Name() + if name == "sync_to_upstream_test.go" { + continue + } + if !strings.HasSuffix(name, ".go") { + continue + } + b, err := os.ReadFile(filepath.Join(dir, name)) + if err != nil { + t.Fatal(err) + } + m[name] = strings.ReplaceAll(string(b), "\r", "") + } + + return m + } + + want := scanDir(xDir) + got := scanDir(".") + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("files differ (-want +got):\n%s", diff) + t.Errorf("to fix, run from module root:\n\ncp %s/*.go ./tempfork/acme && ./tool/go mod tidy\n", xDir) + } +} diff --git a/tempfork/acme/types.go b/tempfork/acme/types.go new file mode 100644 index 0000000000000..0142469d8cb0d --- /dev/null +++ b/tempfork/acme/types.go @@ -0,0 +1,667 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "crypto" + "crypto/x509" + "encoding/json" + "errors" + "fmt" + "net/http" + "strings" + "time" +) + +// ACME status values of Account, Order, Authorization and Challenge objects. +// See https://tools.ietf.org/html/rfc8555#section-7.1.6 for details. +const ( + StatusDeactivated = "deactivated" + StatusExpired = "expired" + StatusInvalid = "invalid" + StatusPending = "pending" + StatusProcessing = "processing" + StatusReady = "ready" + StatusRevoked = "revoked" + StatusUnknown = "unknown" + StatusValid = "valid" +) + +// CRLReasonCode identifies the reason for a certificate revocation. +type CRLReasonCode int + +// CRL reason codes as defined in RFC 5280. +const ( + CRLReasonUnspecified CRLReasonCode = 0 + CRLReasonKeyCompromise CRLReasonCode = 1 + CRLReasonCACompromise CRLReasonCode = 2 + CRLReasonAffiliationChanged CRLReasonCode = 3 + CRLReasonSuperseded CRLReasonCode = 4 + CRLReasonCessationOfOperation CRLReasonCode = 5 + CRLReasonCertificateHold CRLReasonCode = 6 + CRLReasonRemoveFromCRL CRLReasonCode = 8 + CRLReasonPrivilegeWithdrawn CRLReasonCode = 9 + CRLReasonAACompromise CRLReasonCode = 10 +) + +var ( + // ErrUnsupportedKey is returned when an unsupported key type is encountered. + ErrUnsupportedKey = errors.New("acme: unknown key type; only RSA and ECDSA are supported") + + // ErrAccountAlreadyExists indicates that the Client's key has already been registered + // with the CA. It is returned by Register method. + ErrAccountAlreadyExists = errors.New("acme: account already exists") + + // ErrNoAccount indicates that the Client's key has not been registered with the CA. + ErrNoAccount = errors.New("acme: account does not exist") +) + +// A Subproblem describes an ACME subproblem as reported in an Error. +type Subproblem struct { + // Type is a URI reference that identifies the problem type, + // typically in a "urn:acme:error:xxx" form. + Type string + // Detail is a human-readable explanation specific to this occurrence of the problem. + Detail string + // Instance indicates a URL that the client should direct a human user to visit + // in order for instructions on how to agree to the updated Terms of Service. + // In such an event CA sets StatusCode to 403, Type to + // "urn:ietf:params:acme:error:userActionRequired", and adds a Link header with relation + // "terms-of-service" containing the latest TOS URL. + Instance string + // Identifier may contain the ACME identifier that the error is for. + Identifier *AuthzID +} + +func (sp Subproblem) String() string { + str := fmt.Sprintf("%s: ", sp.Type) + if sp.Identifier != nil { + str += fmt.Sprintf("[%s: %s] ", sp.Identifier.Type, sp.Identifier.Value) + } + str += sp.Detail + return str +} + +// Error is an ACME error, defined in Problem Details for HTTP APIs doc +// http://tools.ietf.org/html/draft-ietf-appsawg-http-problem. +type Error struct { + // StatusCode is The HTTP status code generated by the origin server. + StatusCode int + // ProblemType is a URI reference that identifies the problem type, + // typically in a "urn:acme:error:xxx" form. + ProblemType string + // Detail is a human-readable explanation specific to this occurrence of the problem. + Detail string + // Instance indicates a URL that the client should direct a human user to visit + // in order for instructions on how to agree to the updated Terms of Service. + // In such an event CA sets StatusCode to 403, ProblemType to + // "urn:ietf:params:acme:error:userActionRequired" and a Link header with relation + // "terms-of-service" containing the latest TOS URL. + Instance string + // Header is the original server error response headers. + // It may be nil. + Header http.Header + // Subproblems may contain more detailed information about the individual problems + // that caused the error. This field is only sent by RFC 8555 compatible ACME + // servers. Defined in RFC 8555 Section 6.7.1. + Subproblems []Subproblem +} + +func (e *Error) Error() string { + str := fmt.Sprintf("%d %s: %s", e.StatusCode, e.ProblemType, e.Detail) + if len(e.Subproblems) > 0 { + str += fmt.Sprintf("; subproblems:") + for _, sp := range e.Subproblems { + str += fmt.Sprintf("\n\t%s", sp) + } + } + return str +} + +// AuthorizationError indicates that an authorization for an identifier +// did not succeed. +// It contains all errors from Challenge items of the failed Authorization. +type AuthorizationError struct { + // URI uniquely identifies the failed Authorization. + URI string + + // Identifier is an AuthzID.Value of the failed Authorization. + Identifier string + + // Errors is a collection of non-nil error values of Challenge items + // of the failed Authorization. + Errors []error +} + +func (a *AuthorizationError) Error() string { + e := make([]string, len(a.Errors)) + for i, err := range a.Errors { + e[i] = err.Error() + } + + if a.Identifier != "" { + return fmt.Sprintf("acme: authorization error for %s: %s", a.Identifier, strings.Join(e, "; ")) + } + + return fmt.Sprintf("acme: authorization error: %s", strings.Join(e, "; ")) +} + +// OrderError is returned from Client's order related methods. +// It indicates the order is unusable and the clients should start over with +// AuthorizeOrder. +// +// The clients can still fetch the order object from CA using GetOrder +// to inspect its state. +type OrderError struct { + OrderURL string + Status string +} + +func (oe *OrderError) Error() string { + return fmt.Sprintf("acme: order %s status: %s", oe.OrderURL, oe.Status) +} + +// RateLimit reports whether err represents a rate limit error and +// any Retry-After duration returned by the server. +// +// See the following for more details on rate limiting: +// https://tools.ietf.org/html/draft-ietf-acme-acme-05#section-5.6 +func RateLimit(err error) (time.Duration, bool) { + e, ok := err.(*Error) + if !ok { + return 0, false + } + // Some CA implementations may return incorrect values. + // Use case-insensitive comparison. + if !strings.HasSuffix(strings.ToLower(e.ProblemType), ":ratelimited") { + return 0, false + } + if e.Header == nil { + return 0, true + } + return retryAfter(e.Header.Get("Retry-After")), true +} + +// Account is a user account. It is associated with a private key. +// Non-RFC 8555 fields are empty when interfacing with a compliant CA. +type Account struct { + // URI is the account unique ID, which is also a URL used to retrieve + // account data from the CA. + // When interfacing with RFC 8555-compliant CAs, URI is the "kid" field + // value in JWS signed requests. + URI string + + // Contact is a slice of contact info used during registration. + // See https://tools.ietf.org/html/rfc8555#section-7.3 for supported + // formats. + Contact []string + + // Status indicates current account status as returned by the CA. + // Possible values are StatusValid, StatusDeactivated, and StatusRevoked. + Status string + + // OrdersURL is a URL from which a list of orders submitted by this account + // can be fetched. + OrdersURL string + + // The terms user has agreed to. + // A value not matching CurrentTerms indicates that the user hasn't agreed + // to the actual Terms of Service of the CA. + // + // It is non-RFC 8555 compliant. Package users can store the ToS they agree to + // during Client's Register call in the prompt callback function. + AgreedTerms string + + // Actual terms of a CA. + // + // It is non-RFC 8555 compliant. Use Directory's Terms field. + // When a CA updates their terms and requires an account agreement, + // a URL at which instructions to do so is available in Error's Instance field. + CurrentTerms string + + // Authz is the authorization URL used to initiate a new authz flow. + // + // It is non-RFC 8555 compliant. Use Directory's AuthzURL or OrderURL. + Authz string + + // Authorizations is a URI from which a list of authorizations + // granted to this account can be fetched via a GET request. + // + // It is non-RFC 8555 compliant and is obsoleted by OrdersURL. + Authorizations string + + // Certificates is a URI from which a list of certificates + // issued for this account can be fetched via a GET request. + // + // It is non-RFC 8555 compliant and is obsoleted by OrdersURL. + Certificates string + + // ExternalAccountBinding represents an arbitrary binding to an account of + // the CA which the ACME server is tied to. + // See https://tools.ietf.org/html/rfc8555#section-7.3.4 for more details. + ExternalAccountBinding *ExternalAccountBinding +} + +// ExternalAccountBinding contains the data needed to form a request with +// an external account binding. +// See https://tools.ietf.org/html/rfc8555#section-7.3.4 for more details. +type ExternalAccountBinding struct { + // KID is the Key ID of the symmetric MAC key that the CA provides to + // identify an external account from ACME. + KID string + + // Key is the bytes of the symmetric key that the CA provides to identify + // the account. Key must correspond to the KID. + Key []byte +} + +func (e *ExternalAccountBinding) String() string { + return fmt.Sprintf("&{KID: %q, Key: redacted}", e.KID) +} + +// Directory is ACME server discovery data. +// See https://tools.ietf.org/html/rfc8555#section-7.1.1 for more details. +type Directory struct { + // NonceURL indicates an endpoint where to fetch fresh nonce values from. + NonceURL string + + // RegURL is an account endpoint URL, allowing for creating new accounts. + // Pre-RFC 8555 CAs also allow modifying existing accounts at this URL. + RegURL string + + // OrderURL is used to initiate the certificate issuance flow + // as described in RFC 8555. + OrderURL string + + // AuthzURL is used to initiate identifier pre-authorization flow. + // Empty string indicates the flow is unsupported by the CA. + AuthzURL string + + // CertURL is a new certificate issuance endpoint URL. + // It is non-RFC 8555 compliant and is obsoleted by OrderURL. + CertURL string + + // RevokeURL is used to initiate a certificate revocation flow. + RevokeURL string + + // KeyChangeURL allows to perform account key rollover flow. + KeyChangeURL string + + // RenewalInfoURL allows to perform certificate renewal using the ACME + // Renewal Information (ARI) Extension. + RenewalInfoURL string + + // Terms is a URI identifying the current terms of service. + Terms string + + // Website is an HTTP or HTTPS URL locating a website + // providing more information about the ACME server. + Website string + + // CAA consists of lowercase hostname elements, which the ACME server + // recognises as referring to itself for the purposes of CAA record validation + // as defined in RFC 6844. + CAA []string + + // ExternalAccountRequired indicates that the CA requires for all account-related + // requests to include external account binding information. + ExternalAccountRequired bool +} + +// Order represents a client's request for a certificate. +// It tracks the request flow progress through to issuance. +type Order struct { + // URI uniquely identifies an order. + URI string + + // Status represents the current status of the order. + // It indicates which action the client should take. + // + // Possible values are StatusPending, StatusReady, StatusProcessing, StatusValid and StatusInvalid. + // Pending means the CA does not believe that the client has fulfilled the requirements. + // Ready indicates that the client has fulfilled all the requirements and can submit a CSR + // to obtain a certificate. This is done with Client's CreateOrderCert. + // Processing means the certificate is being issued. + // Valid indicates the CA has issued the certificate. It can be downloaded + // from the Order's CertURL. This is done with Client's FetchCert. + // Invalid means the certificate will not be issued. Users should consider this order + // abandoned. + Status string + + // Expires is the timestamp after which CA considers this order invalid. + Expires time.Time + + // Identifiers contains all identifier objects which the order pertains to. + Identifiers []AuthzID + + // NotBefore is the requested value of the notBefore field in the certificate. + NotBefore time.Time + + // NotAfter is the requested value of the notAfter field in the certificate. + NotAfter time.Time + + // AuthzURLs represents authorizations to complete before a certificate + // for identifiers specified in the order can be issued. + // It also contains unexpired authorizations that the client has completed + // in the past. + // + // Authorization objects can be fetched using Client's GetAuthorization method. + // + // The required authorizations are dictated by CA policies. + // There may not be a 1:1 relationship between the identifiers and required authorizations. + // Required authorizations can be identified by their StatusPending status. + // + // For orders in the StatusValid or StatusInvalid state these are the authorizations + // which were completed. + AuthzURLs []string + + // FinalizeURL is the endpoint at which a CSR is submitted to obtain a certificate + // once all the authorizations are satisfied. + FinalizeURL string + + // CertURL points to the certificate that has been issued in response to this order. + CertURL string + + // The error that occurred while processing the order as received from a CA, if any. + Error *Error +} + +// OrderOption allows customizing Client.AuthorizeOrder call. +type OrderOption interface { + privateOrderOpt() +} + +// WithOrderNotBefore sets order's NotBefore field. +func WithOrderNotBefore(t time.Time) OrderOption { + return orderNotBeforeOpt(t) +} + +// WithOrderNotAfter sets order's NotAfter field. +func WithOrderNotAfter(t time.Time) OrderOption { + return orderNotAfterOpt(t) +} + +type orderNotBeforeOpt time.Time + +func (orderNotBeforeOpt) privateOrderOpt() {} + +type orderNotAfterOpt time.Time + +func (orderNotAfterOpt) privateOrderOpt() {} + +// WithOrderReplacesCert indicates that this Order is for a replacement of an +// existing certificate. +// See https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-5 +func WithOrderReplacesCert(cert *x509.Certificate) OrderOption { + return orderReplacesCert{cert} +} + +type orderReplacesCert struct { + cert *x509.Certificate +} + +func (orderReplacesCert) privateOrderOpt() {} + +// WithOrderReplacesCertDER indicates that this Order is for a replacement of +// an existing DER-encoded certificate. +// See https://datatracker.ietf.org/doc/html/draft-ietf-acme-ari-03#section-5 +func WithOrderReplacesCertDER(der []byte) OrderOption { + return orderReplacesCertDER(der) +} + +type orderReplacesCertDER []byte + +func (orderReplacesCertDER) privateOrderOpt() {} + +// Authorization encodes an authorization response. +type Authorization struct { + // URI uniquely identifies a authorization. + URI string + + // Status is the current status of an authorization. + // Possible values are StatusPending, StatusValid, StatusInvalid, StatusDeactivated, + // StatusExpired and StatusRevoked. + Status string + + // Identifier is what the account is authorized to represent. + Identifier AuthzID + + // The timestamp after which the CA considers the authorization invalid. + Expires time.Time + + // Wildcard is true for authorizations of a wildcard domain name. + Wildcard bool + + // Challenges that the client needs to fulfill in order to prove possession + // of the identifier (for pending authorizations). + // For valid authorizations, the challenge that was validated. + // For invalid authorizations, the challenge that was attempted and failed. + // + // RFC 8555 compatible CAs require users to fuflfill only one of the challenges. + Challenges []*Challenge + + // A collection of sets of challenges, each of which would be sufficient + // to prove possession of the identifier. + // Clients must complete a set of challenges that covers at least one set. + // Challenges are identified by their indices in the challenges array. + // If this field is empty, the client needs to complete all challenges. + // + // This field is unused in RFC 8555. + Combinations [][]int +} + +// AuthzID is an identifier that an account is authorized to represent. +type AuthzID struct { + Type string // The type of identifier, "dns" or "ip". + Value string // The identifier itself, e.g. "example.org". +} + +// DomainIDs creates a slice of AuthzID with "dns" identifier type. +func DomainIDs(names ...string) []AuthzID { + a := make([]AuthzID, len(names)) + for i, v := range names { + a[i] = AuthzID{Type: "dns", Value: v} + } + return a +} + +// IPIDs creates a slice of AuthzID with "ip" identifier type. +// Each element of addr is textual form of an address as defined +// in RFC 1123 Section 2.1 for IPv4 and in RFC 5952 Section 4 for IPv6. +func IPIDs(addr ...string) []AuthzID { + a := make([]AuthzID, len(addr)) + for i, v := range addr { + a[i] = AuthzID{Type: "ip", Value: v} + } + return a +} + +// wireAuthzID is ACME JSON representation of authorization identifier objects. +type wireAuthzID struct { + Type string `json:"type"` + Value string `json:"value"` +} + +// wireAuthz is ACME JSON representation of Authorization objects. +type wireAuthz struct { + Identifier wireAuthzID + Status string + Expires time.Time + Wildcard bool + Challenges []wireChallenge + Combinations [][]int + Error *wireError +} + +func (z *wireAuthz) authorization(uri string) *Authorization { + a := &Authorization{ + URI: uri, + Status: z.Status, + Identifier: AuthzID{Type: z.Identifier.Type, Value: z.Identifier.Value}, + Expires: z.Expires, + Wildcard: z.Wildcard, + Challenges: make([]*Challenge, len(z.Challenges)), + Combinations: z.Combinations, // shallow copy + } + for i, v := range z.Challenges { + a.Challenges[i] = v.challenge() + } + return a +} + +func (z *wireAuthz) error(uri string) *AuthorizationError { + err := &AuthorizationError{ + URI: uri, + Identifier: z.Identifier.Value, + } + + if z.Error != nil { + err.Errors = append(err.Errors, z.Error.error(nil)) + } + + for _, raw := range z.Challenges { + if raw.Error != nil { + err.Errors = append(err.Errors, raw.Error.error(nil)) + } + } + + return err +} + +// Challenge encodes a returned CA challenge. +// Its Error field may be non-nil if the challenge is part of an Authorization +// with StatusInvalid. +type Challenge struct { + // Type is the challenge type, e.g. "http-01", "tls-alpn-01", "dns-01". + Type string + + // URI is where a challenge response can be posted to. + URI string + + // Token is a random value that uniquely identifies the challenge. + Token string + + // Status identifies the status of this challenge. + // In RFC 8555, possible values are StatusPending, StatusProcessing, StatusValid, + // and StatusInvalid. + Status string + + // Validated is the time at which the CA validated this challenge. + // Always zero value in pre-RFC 8555. + Validated time.Time + + // Error indicates the reason for an authorization failure + // when this challenge was used. + // The type of a non-nil value is *Error. + Error error + + // Payload is the JSON-formatted payload that the client sends + // to the server to indicate it is ready to respond to the challenge. + // When unset, it defaults to an empty JSON object: {}. + // For most challenges, the client must not set Payload, + // see https://tools.ietf.org/html/rfc8555#section-7.5.1. + // Payload is used only for newer challenges (such as "device-attest-01") + // where the client must send additional data for the server to validate + // the challenge. + Payload json.RawMessage +} + +// wireChallenge is ACME JSON challenge representation. +type wireChallenge struct { + URL string `json:"url"` // RFC + URI string `json:"uri"` // pre-RFC + Type string + Token string + Status string + Validated time.Time + Error *wireError +} + +func (c *wireChallenge) challenge() *Challenge { + v := &Challenge{ + URI: c.URL, + Type: c.Type, + Token: c.Token, + Status: c.Status, + } + if v.URI == "" { + v.URI = c.URI // c.URL was empty; use legacy + } + if v.Status == "" { + v.Status = StatusPending + } + if c.Error != nil { + v.Error = c.Error.error(nil) + } + return v +} + +// wireError is a subset of fields of the Problem Details object +// as described in https://tools.ietf.org/html/rfc7807#section-3.1. +type wireError struct { + Status int + Type string + Detail string + Instance string + Subproblems []Subproblem +} + +func (e *wireError) error(h http.Header) *Error { + err := &Error{ + StatusCode: e.Status, + ProblemType: e.Type, + Detail: e.Detail, + Instance: e.Instance, + Header: h, + Subproblems: e.Subproblems, + } + return err +} + +// CertOption is an optional argument type for the TLS ChallengeCert methods for +// customizing a temporary certificate for TLS-based challenges. +type CertOption interface { + privateCertOpt() +} + +// WithKey creates an option holding a private/public key pair. +// The private part signs a certificate, and the public part represents the signee. +func WithKey(key crypto.Signer) CertOption { + return &certOptKey{key} +} + +type certOptKey struct { + key crypto.Signer +} + +func (*certOptKey) privateCertOpt() {} + +// WithTemplate creates an option for specifying a certificate template. +// See x509.CreateCertificate for template usage details. +// +// In TLS ChallengeCert methods, the template is also used as parent, +// resulting in a self-signed certificate. +// The DNSNames field of t is always overwritten for tls-sni challenge certs. +func WithTemplate(t *x509.Certificate) CertOption { + return (*certOptTemplate)(t) +} + +type certOptTemplate x509.Certificate + +func (*certOptTemplate) privateCertOpt() {} + +// RenewalInfoWindow describes the time frame during which the ACME client +// should attempt to renew, using the ACME Renewal Info Extension. +type RenewalInfoWindow struct { + Start time.Time `json:"start"` + End time.Time `json:"end"` +} + +// RenewalInfo describes the suggested renewal window for a given certificate, +// returned from an ACME server, using the ACME Renewal Info Extension. +type RenewalInfo struct { + SuggestedWindow RenewalInfoWindow `json:"suggestedWindow"` + ExplanationURL string `json:"explanationURL"` +} diff --git a/tempfork/acme/types_test.go b/tempfork/acme/types_test.go new file mode 100644 index 0000000000000..59ce7e7602ca3 --- /dev/null +++ b/tempfork/acme/types_test.go @@ -0,0 +1,219 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package acme + +import ( + "errors" + "net/http" + "reflect" + "testing" + "time" +) + +func TestExternalAccountBindingString(t *testing.T) { + eab := ExternalAccountBinding{ + KID: "kid", + Key: []byte("key"), + } + got := eab.String() + want := `&{KID: "kid", Key: redacted}` + if got != want { + t.Errorf("eab.String() = %q, want: %q", got, want) + } +} + +func TestRateLimit(t *testing.T) { + now := time.Date(2017, 04, 27, 10, 0, 0, 0, time.UTC) + f := timeNow + defer func() { timeNow = f }() + timeNow = func() time.Time { return now } + + h120, hTime := http.Header{}, http.Header{} + h120.Set("Retry-After", "120") + hTime.Set("Retry-After", "Tue Apr 27 11:00:00 2017") + + err1 := &Error{ + ProblemType: "urn:ietf:params:acme:error:nolimit", + Header: h120, + } + err2 := &Error{ + ProblemType: "urn:ietf:params:acme:error:rateLimited", + Header: h120, + } + err3 := &Error{ + ProblemType: "urn:ietf:params:acme:error:rateLimited", + Header: nil, + } + err4 := &Error{ + ProblemType: "urn:ietf:params:acme:error:rateLimited", + Header: hTime, + } + + tt := []struct { + err error + res time.Duration + ok bool + }{ + {nil, 0, false}, + {errors.New("dummy"), 0, false}, + {err1, 0, false}, + {err2, 2 * time.Minute, true}, + {err3, 0, true}, + {err4, time.Hour, true}, + } + for i, test := range tt { + res, ok := RateLimit(test.err) + if ok != test.ok { + t.Errorf("%d: RateLimit(%+v): ok = %v; want %v", i, test.err, ok, test.ok) + continue + } + if res != test.res { + t.Errorf("%d: RateLimit(%+v) = %v; want %v", i, test.err, res, test.res) + } + } +} + +func TestAuthorizationError(t *testing.T) { + tests := []struct { + desc string + err *AuthorizationError + msg string + }{ + { + desc: "when auth error identifier is set", + err: &AuthorizationError{ + Identifier: "domain.com", + Errors: []error{ + (&wireError{ + Status: 403, + Type: "urn:ietf:params:acme:error:caa", + Detail: "CAA record for domain.com prevents issuance", + }).error(nil), + }, + }, + msg: "acme: authorization error for domain.com: 403 urn:ietf:params:acme:error:caa: CAA record for domain.com prevents issuance", + }, + + { + desc: "when auth error identifier is unset", + err: &AuthorizationError{ + Errors: []error{ + (&wireError{ + Status: 403, + Type: "urn:ietf:params:acme:error:caa", + Detail: "CAA record for domain.com prevents issuance", + }).error(nil), + }, + }, + msg: "acme: authorization error: 403 urn:ietf:params:acme:error:caa: CAA record for domain.com prevents issuance", + }, + } + + for _, tt := range tests { + if tt.err.Error() != tt.msg { + t.Errorf("got: %s\nwant: %s", tt.err, tt.msg) + } + } +} + +func TestSubproblems(t *testing.T) { + tests := []struct { + wire wireError + expectedOut Error + }{ + { + wire: wireError{ + Status: 1, + Type: "urn:error", + Detail: "it's an error", + }, + expectedOut: Error{ + StatusCode: 1, + ProblemType: "urn:error", + Detail: "it's an error", + }, + }, + { + wire: wireError{ + Status: 1, + Type: "urn:error", + Detail: "it's an error", + Subproblems: []Subproblem{ + { + Type: "urn:error:sub", + Detail: "it's a subproblem", + }, + }, + }, + expectedOut: Error{ + StatusCode: 1, + ProblemType: "urn:error", + Detail: "it's an error", + Subproblems: []Subproblem{ + { + Type: "urn:error:sub", + Detail: "it's a subproblem", + }, + }, + }, + }, + { + wire: wireError{ + Status: 1, + Type: "urn:error", + Detail: "it's an error", + Subproblems: []Subproblem{ + { + Type: "urn:error:sub", + Detail: "it's a subproblem", + Identifier: &AuthzID{Type: "dns", Value: "example"}, + }, + }, + }, + expectedOut: Error{ + StatusCode: 1, + ProblemType: "urn:error", + Detail: "it's an error", + Subproblems: []Subproblem{ + { + Type: "urn:error:sub", + Detail: "it's a subproblem", + Identifier: &AuthzID{Type: "dns", Value: "example"}, + }, + }, + }, + }, + } + + for _, tc := range tests { + out := tc.wire.error(nil) + if !reflect.DeepEqual(*out, tc.expectedOut) { + t.Errorf("Unexpected error: wanted %v, got %v", tc.expectedOut, *out) + } + } +} + +func TestErrorStringerWithSubproblems(t *testing.T) { + err := Error{ + StatusCode: 1, + ProblemType: "urn:error", + Detail: "it's an error", + Subproblems: []Subproblem{ + { + Type: "urn:error:sub", + Detail: "it's a subproblem", + }, + { + Type: "urn:error:sub", + Detail: "it's a subproblem", + Identifier: &AuthzID{Type: "dns", Value: "example"}, + }, + }, + } + expectedStr := "1 urn:error: it's an error; subproblems:\n\turn:error:sub: it's a subproblem\n\turn:error:sub: [dns: example] it's a subproblem" + if err.Error() != expectedStr { + t.Errorf("Unexpected error string: wanted %q, got %q", expectedStr, err.Error()) + } +} diff --git a/tempfork/acme/version_go112.go b/tempfork/acme/version_go112.go new file mode 100644 index 0000000000000..cc5fab604b8d0 --- /dev/null +++ b/tempfork/acme/version_go112.go @@ -0,0 +1,27 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.12 + +package acme + +import "runtime/debug" + +func init() { + // Set packageVersion if the binary was built in modules mode and x/crypto + // was not replaced with a different module. + info, ok := debug.ReadBuildInfo() + if !ok { + return + } + for _, m := range info.Deps { + if m.Path != "golang.org/x/crypto" { + continue + } + if m.Replace == nil { + packageVersion = m.Version + } + break + } +} diff --git a/tempfork/gliderlabs/ssh/agent.go b/tempfork/gliderlabs/ssh/agent.go index 86a5bce7f8ebc..99e84c1e5c64c 100644 --- a/tempfork/gliderlabs/ssh/agent.go +++ b/tempfork/gliderlabs/ssh/agent.go @@ -7,7 +7,7 @@ import ( "path" "sync" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) const ( diff --git a/tempfork/gliderlabs/ssh/context.go b/tempfork/gliderlabs/ssh/context.go index d43de6f09c8a5..505a43dbf3ffe 100644 --- a/tempfork/gliderlabs/ssh/context.go +++ b/tempfork/gliderlabs/ssh/context.go @@ -6,7 +6,7 @@ import ( "net" "sync" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) // contextKey is a value for use with context.WithValue. It's used as @@ -55,8 +55,6 @@ var ( // ContextKeyPublicKey is a context key for use with Contexts in this package. // The associated value will be of type PublicKey. ContextKeyPublicKey = &contextKey{"public-key"} - - ContextKeySendAuthBanner = &contextKey{"send-auth-banner"} ) // Context is a package specific context interface. It exposes connection @@ -91,8 +89,6 @@ type Context interface { // SetValue allows you to easily write new values into the underlying context. SetValue(key, value interface{}) - - SendAuthBanner(banner string) error } type sshContext struct { @@ -121,7 +117,6 @@ func applyConnMetadata(ctx Context, conn gossh.ConnMetadata) { ctx.SetValue(ContextKeyUser, conn.User()) ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr()) ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr()) - ctx.SetValue(ContextKeySendAuthBanner, conn.SendAuthBanner) } func (ctx *sshContext) SetValue(key, value interface{}) { @@ -158,7 +153,3 @@ func (ctx *sshContext) LocalAddr() net.Addr { func (ctx *sshContext) Permissions() *Permissions { return ctx.Value(ContextKeyPermissions).(*Permissions) } - -func (ctx *sshContext) SendAuthBanner(msg string) error { - return ctx.Value(ContextKeySendAuthBanner).(func(string) error)(msg) -} diff --git a/tempfork/gliderlabs/ssh/options.go b/tempfork/gliderlabs/ssh/options.go index aa87a4f39db9e..29c8ef141842b 100644 --- a/tempfork/gliderlabs/ssh/options.go +++ b/tempfork/gliderlabs/ssh/options.go @@ -3,7 +3,7 @@ package ssh import ( "os" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) // PasswordAuth returns a functional option that sets PasswordHandler on the server. diff --git a/tempfork/gliderlabs/ssh/options_test.go b/tempfork/gliderlabs/ssh/options_test.go index 7cf6f376c6a88..47342b0f67923 100644 --- a/tempfork/gliderlabs/ssh/options_test.go +++ b/tempfork/gliderlabs/ssh/options_test.go @@ -8,7 +8,7 @@ import ( "sync/atomic" "testing" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, *gossh.Client, func()) { diff --git a/tempfork/gliderlabs/ssh/server.go b/tempfork/gliderlabs/ssh/server.go index 1086a72caf0e5..473e5fbd6fc8f 100644 --- a/tempfork/gliderlabs/ssh/server.go +++ b/tempfork/gliderlabs/ssh/server.go @@ -8,7 +8,7 @@ import ( "sync" "time" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) // ErrServerClosed is returned by the Server's Serve, ListenAndServe, diff --git a/tempfork/gliderlabs/ssh/session.go b/tempfork/gliderlabs/ssh/session.go index 0a4a21e534401..a7a9a3eebd96f 100644 --- a/tempfork/gliderlabs/ssh/session.go +++ b/tempfork/gliderlabs/ssh/session.go @@ -9,7 +9,7 @@ import ( "sync" "github.com/anmitsu/go-shlex" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) // Session provides access to information about an SSH session and methods diff --git a/tempfork/gliderlabs/ssh/session_test.go b/tempfork/gliderlabs/ssh/session_test.go index a60be5ec12d4e..fe61a9d96be9b 100644 --- a/tempfork/gliderlabs/ssh/session_test.go +++ b/tempfork/gliderlabs/ssh/session_test.go @@ -9,7 +9,7 @@ import ( "net" "testing" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) func (srv *Server) serveOnce(l net.Listener) error { diff --git a/tempfork/gliderlabs/ssh/ssh.go b/tempfork/gliderlabs/ssh/ssh.go index 644cb257d9afa..54bd31ec2fcb4 100644 --- a/tempfork/gliderlabs/ssh/ssh.go +++ b/tempfork/gliderlabs/ssh/ssh.go @@ -4,7 +4,7 @@ import ( "crypto/subtle" "net" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) type Signal string @@ -105,7 +105,7 @@ type Pty struct { // requested by the client as part of the pty-req. These are outlined as // part of https://datatracker.ietf.org/doc/html/rfc4254#section-8. // - // The opcodes are defined as constants in github.com/tailscale/golang-x-crypto/ssh (VINTR,VQUIT,etc.). + // The opcodes are defined as constants in golang.org/x/crypto/ssh (VINTR,VQUIT,etc.). // Boolean opcodes have values 0 or 1. Modes gossh.TerminalModes } diff --git a/tempfork/gliderlabs/ssh/tcpip.go b/tempfork/gliderlabs/ssh/tcpip.go index 056a0c7343daf..335fda65754ea 100644 --- a/tempfork/gliderlabs/ssh/tcpip.go +++ b/tempfork/gliderlabs/ssh/tcpip.go @@ -7,7 +7,7 @@ import ( "strconv" "sync" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) const ( diff --git a/tempfork/gliderlabs/ssh/tcpip_test.go b/tempfork/gliderlabs/ssh/tcpip_test.go index 118b5d53ac4a1..b3ba60a9bb6b8 100644 --- a/tempfork/gliderlabs/ssh/tcpip_test.go +++ b/tempfork/gliderlabs/ssh/tcpip_test.go @@ -10,7 +10,7 @@ import ( "strings" "testing" - gossh "github.com/tailscale/golang-x-crypto/ssh" + gossh "golang.org/x/crypto/ssh" ) var sampleServerResponse = []byte("Hello world") diff --git a/tempfork/gliderlabs/ssh/util.go b/tempfork/gliderlabs/ssh/util.go index e3b5716a3ab55..3bee06dcdef39 100644 --- a/tempfork/gliderlabs/ssh/util.go +++ b/tempfork/gliderlabs/ssh/util.go @@ -5,7 +5,7 @@ import ( "crypto/rsa" "encoding/binary" - "github.com/tailscale/golang-x-crypto/ssh" + "golang.org/x/crypto/ssh" ) func generateSigner() (ssh.Signer, error) { diff --git a/tempfork/gliderlabs/ssh/wrap.go b/tempfork/gliderlabs/ssh/wrap.go index 17867d7518dd1..d1f2b161e6932 100644 --- a/tempfork/gliderlabs/ssh/wrap.go +++ b/tempfork/gliderlabs/ssh/wrap.go @@ -1,6 +1,6 @@ package ssh -import gossh "github.com/tailscale/golang-x-crypto/ssh" +import gossh "golang.org/x/crypto/ssh" // PublicKey is an abstraction of different types of public keys. type PublicKey interface { diff --git a/tempfork/httprec/httprec.go b/tempfork/httprec/httprec.go new file mode 100644 index 0000000000000..13786aaf60e05 --- /dev/null +++ b/tempfork/httprec/httprec.go @@ -0,0 +1,258 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httprec is a copy of the Go standard library's httptest.ResponseRecorder +// type, which we want to use in non-test code without pulling in the rest of +// the httptest package and its test certs, etc. +package httprec + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/textproto" + "strconv" + "strings" + + "golang.org/x/net/http/httpguts" +) + +// ResponseRecorder is an implementation of [http.ResponseWriter] that +// records its mutations for later inspection in tests. +type ResponseRecorder struct { + // Code is the HTTP response code set by WriteHeader. + // + // Note that if a Handler never calls WriteHeader or Write, + // this might end up being 0, rather than the implicit + // http.StatusOK. To get the implicit value, use the Result + // method. + Code int + + // HeaderMap contains the headers explicitly set by the Handler. + // It is an internal detail. + // + // Deprecated: HeaderMap exists for historical compatibility + // and should not be used. To access the headers returned by a handler, + // use the Response.Header map as returned by the Result method. + HeaderMap http.Header + + // Body is the buffer to which the Handler's Write calls are sent. + // If nil, the Writes are silently discarded. + Body *bytes.Buffer + + // Flushed is whether the Handler called Flush. + Flushed bool + + result *http.Response // cache of Result's return value + snapHeader http.Header // snapshot of HeaderMap at first Write + wroteHeader bool +} + +// NewRecorder returns an initialized [ResponseRecorder]. +func NewRecorder() *ResponseRecorder { + return &ResponseRecorder{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + Code: 200, + } +} + +// DefaultRemoteAddr is the default remote address to return in RemoteAddr if +// an explicit DefaultRemoteAddr isn't set on [ResponseRecorder]. +const DefaultRemoteAddr = "1.2.3.4" + +// Header implements [http.ResponseWriter]. It returns the response +// headers to mutate within a handler. To test the headers that were +// written after a handler completes, use the [ResponseRecorder.Result] method and see +// the returned Response value's Header. +func (rw *ResponseRecorder) Header() http.Header { + m := rw.HeaderMap + if m == nil { + m = make(http.Header) + rw.HeaderMap = m + } + return m +} + +// writeHeader writes a header if it was not written yet and +// detects Content-Type if needed. +// +// bytes or str are the beginning of the response body. +// We pass both to avoid unnecessarily generate garbage +// in rw.WriteString which was created for performance reasons. +// Non-nil bytes win. +func (rw *ResponseRecorder) writeHeader(b []byte, str string) { + if rw.wroteHeader { + return + } + if len(str) > 512 { + str = str[:512] + } + + m := rw.Header() + + _, hasType := m["Content-Type"] + hasTE := m.Get("Transfer-Encoding") != "" + if !hasType && !hasTE { + if b == nil { + b = []byte(str) + } + m.Set("Content-Type", http.DetectContentType(b)) + } + + rw.WriteHeader(200) +} + +// Write implements http.ResponseWriter. The data in buf is written to +// rw.Body, if not nil. +func (rw *ResponseRecorder) Write(buf []byte) (int, error) { + rw.writeHeader(buf, "") + if rw.Body != nil { + rw.Body.Write(buf) + } + return len(buf), nil +} + +// WriteString implements [io.StringWriter]. The data in str is written +// to rw.Body, if not nil. +func (rw *ResponseRecorder) WriteString(str string) (int, error) { + rw.writeHeader(nil, str) + if rw.Body != nil { + rw.Body.WriteString(str) + } + return len(str), nil +} + +func checkWriteHeaderCode(code int) { + // Issue 22880: require valid WriteHeader status codes. + // For now we only enforce that it's three digits. + // In the future we might block things over 599 (600 and above aren't defined + // at https://httpwg.org/specs/rfc7231.html#status.codes) + // and we might block under 200 (once we have more mature 1xx support). + // But for now any three digits. + // + // We used to send "HTTP/1.1 000 0" on the wire in responses but there's + // no equivalent bogus thing we can realistically send in HTTP/2, + // so we'll consistently panic instead and help people find their bugs + // early. (We can't return an error from WriteHeader even if we wanted to.) + if code < 100 || code > 999 { + panic(fmt.Sprintf("invalid WriteHeader code %v", code)) + } +} + +// WriteHeader implements [http.ResponseWriter]. +func (rw *ResponseRecorder) WriteHeader(code int) { + if rw.wroteHeader { + return + } + + checkWriteHeaderCode(code) + rw.Code = code + rw.wroteHeader = true + if rw.HeaderMap == nil { + rw.HeaderMap = make(http.Header) + } + rw.snapHeader = rw.HeaderMap.Clone() +} + +// Flush implements [http.Flusher]. To test whether Flush was +// called, see rw.Flushed. +func (rw *ResponseRecorder) Flush() { + if !rw.wroteHeader { + rw.WriteHeader(200) + } + rw.Flushed = true +} + +// Result returns the response generated by the handler. +// +// The returned Response will have at least its StatusCode, +// Header, Body, and optionally Trailer populated. +// More fields may be populated in the future, so callers should +// not DeepEqual the result in tests. +// +// The Response.Header is a snapshot of the headers at the time of the +// first write call, or at the time of this call, if the handler never +// did a write. +// +// The Response.Body is guaranteed to be non-nil and Body.Read call is +// guaranteed to not return any error other than [io.EOF]. +// +// Result must only be called after the handler has finished running. +func (rw *ResponseRecorder) Result() *http.Response { + if rw.result != nil { + return rw.result + } + if rw.snapHeader == nil { + rw.snapHeader = rw.HeaderMap.Clone() + } + res := &http.Response{ + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + StatusCode: rw.Code, + Header: rw.snapHeader, + } + rw.result = res + if res.StatusCode == 0 { + res.StatusCode = 200 + } + res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode)) + if rw.Body != nil { + res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes())) + } else { + res.Body = http.NoBody + } + res.ContentLength = parseContentLength(res.Header.Get("Content-Length")) + + if trailers, ok := rw.snapHeader["Trailer"]; ok { + res.Trailer = make(http.Header, len(trailers)) + for _, k := range trailers { + for _, k := range strings.Split(k, ",") { + k = http.CanonicalHeaderKey(textproto.TrimString(k)) + if !httpguts.ValidTrailerHeader(k) { + // Ignore since forbidden by RFC 7230, section 4.1.2. + continue + } + vv, ok := rw.HeaderMap[k] + if !ok { + continue + } + vv2 := make([]string, len(vv)) + copy(vv2, vv) + res.Trailer[k] = vv2 + } + } + } + for k, vv := range rw.HeaderMap { + if !strings.HasPrefix(k, http.TrailerPrefix) { + continue + } + if res.Trailer == nil { + res.Trailer = make(http.Header) + } + for _, v := range vv { + res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v) + } + } + return res +} + +// parseContentLength trims whitespace from s and returns -1 if no value +// is set, or the value if it's >= 0. +// +// This a modified version of same function found in net/http/transfer.go. This +// one just ignores an invalid header. +func parseContentLength(cl string) int64 { + cl = textproto.TrimString(cl) + if cl == "" { + return -1 + } + n, err := strconv.ParseUint(cl, 10, 63) + if err != nil { + return -1 + } + return int64(n) +} diff --git a/tempfork/sshtest/README.md b/tempfork/sshtest/README.md new file mode 100644 index 0000000000000..30c74f5254844 --- /dev/null +++ b/tempfork/sshtest/README.md @@ -0,0 +1,9 @@ +# sshtest + +This contains packages that are forked & locally hacked up for use +in tests. + +Notably, `golang.org/x/crypto/ssh` was copied to +`tailscale.com/tempfork/sshtest/ssh` to permit adding behaviors specific +to testing (for testing Tailscale SSH) that aren't necessarily desirable +to have upstream. diff --git a/tempfork/sshtest/ssh/benchmark_test.go b/tempfork/sshtest/ssh/benchmark_test.go new file mode 100644 index 0000000000000..b356330b469d2 --- /dev/null +++ b/tempfork/sshtest/ssh/benchmark_test.go @@ -0,0 +1,127 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "errors" + "fmt" + "io" + "net" + "testing" +) + +type server struct { + *ServerConn + chans <-chan NewChannel +} + +func newServer(c net.Conn, conf *ServerConfig) (*server, error) { + sconn, chans, reqs, err := NewServerConn(c, conf) + if err != nil { + return nil, err + } + go DiscardRequests(reqs) + return &server{sconn, chans}, nil +} + +func (s *server) Accept() (NewChannel, error) { + n, ok := <-s.chans + if !ok { + return nil, io.EOF + } + return n, nil +} + +func sshPipe() (Conn, *server, error) { + c1, c2, err := netPipe() + if err != nil { + return nil, nil, err + } + + clientConf := ClientConfig{ + User: "user", + HostKeyCallback: InsecureIgnoreHostKey(), + } + serverConf := ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["ecdsa"]) + done := make(chan *server, 1) + go func() { + server, err := newServer(c2, &serverConf) + if err != nil { + done <- nil + } + done <- server + }() + + client, _, reqs, err := NewClientConn(c1, "", &clientConf) + if err != nil { + return nil, nil, err + } + + server := <-done + if server == nil { + return nil, nil, errors.New("server handshake failed.") + } + go DiscardRequests(reqs) + + return client, server, nil +} + +func BenchmarkEndToEnd(b *testing.B) { + b.StopTimer() + + client, server, err := sshPipe() + if err != nil { + b.Fatalf("sshPipe: %v", err) + } + + defer client.Close() + defer server.Close() + + size := (1 << 20) + input := make([]byte, size) + output := make([]byte, size) + b.SetBytes(int64(size)) + done := make(chan int, 1) + + go func() { + newCh, err := server.Accept() + if err != nil { + panic(fmt.Sprintf("Client: %v", err)) + } + ch, incoming, err := newCh.Accept() + if err != nil { + panic(fmt.Sprintf("Accept: %v", err)) + } + go DiscardRequests(incoming) + for i := 0; i < b.N; i++ { + if _, err := io.ReadFull(ch, output); err != nil { + panic(fmt.Sprintf("ReadFull: %v", err)) + } + } + ch.Close() + done <- 1 + }() + + ch, in, err := client.OpenChannel("speed", nil) + if err != nil { + b.Fatalf("OpenChannel: %v", err) + } + go DiscardRequests(in) + + b.ResetTimer() + b.StartTimer() + for i := 0; i < b.N; i++ { + if _, err := ch.Write(input); err != nil { + b.Fatalf("WriteFull: %v", err) + } + } + ch.Close() + b.StopTimer() + + <-done +} diff --git a/tempfork/sshtest/ssh/buffer.go b/tempfork/sshtest/ssh/buffer.go new file mode 100644 index 0000000000000..1ab07d078db16 --- /dev/null +++ b/tempfork/sshtest/ssh/buffer.go @@ -0,0 +1,97 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "sync" +) + +// buffer provides a linked list buffer for data exchange +// between producer and consumer. Theoretically the buffer is +// of unlimited capacity as it does no allocation of its own. +type buffer struct { + // protects concurrent access to head, tail and closed + *sync.Cond + + head *element // the buffer that will be read first + tail *element // the buffer that will be read last + + closed bool +} + +// An element represents a single link in a linked list. +type element struct { + buf []byte + next *element +} + +// newBuffer returns an empty buffer that is not closed. +func newBuffer() *buffer { + e := new(element) + b := &buffer{ + Cond: newCond(), + head: e, + tail: e, + } + return b +} + +// write makes buf available for Read to receive. +// buf must not be modified after the call to write. +func (b *buffer) write(buf []byte) { + b.Cond.L.Lock() + e := &element{buf: buf} + b.tail.next = e + b.tail = e + b.Cond.Signal() + b.Cond.L.Unlock() +} + +// eof closes the buffer. Reads from the buffer once all +// the data has been consumed will receive io.EOF. +func (b *buffer) eof() { + b.Cond.L.Lock() + b.closed = true + b.Cond.Signal() + b.Cond.L.Unlock() +} + +// Read reads data from the internal buffer in buf. Reads will block +// if no data is available, or until the buffer is closed. +func (b *buffer) Read(buf []byte) (n int, err error) { + b.Cond.L.Lock() + defer b.Cond.L.Unlock() + + for len(buf) > 0 { + // if there is data in b.head, copy it + if len(b.head.buf) > 0 { + r := copy(buf, b.head.buf) + buf, b.head.buf = buf[r:], b.head.buf[r:] + n += r + continue + } + // if there is a next buffer, make it the head + if len(b.head.buf) == 0 && b.head != b.tail { + b.head = b.head.next + continue + } + + // if at least one byte has been copied, return + if n > 0 { + break + } + + // if nothing was read, and there is nothing outstanding + // check to see if the buffer is closed. + if b.closed { + err = io.EOF + break + } + // out of buffers, wait for producer + b.Cond.Wait() + } + return +} diff --git a/tempfork/sshtest/ssh/buffer_test.go b/tempfork/sshtest/ssh/buffer_test.go new file mode 100644 index 0000000000000..d5781cb3da997 --- /dev/null +++ b/tempfork/sshtest/ssh/buffer_test.go @@ -0,0 +1,87 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "testing" +) + +var alphabet = []byte("abcdefghijklmnopqrstuvwxyz") + +func TestBufferReadwrite(t *testing.T) { + b := newBuffer() + b.write(alphabet[:10]) + r, _ := b.Read(make([]byte, 10)) + if r != 10 { + t.Fatalf("Expected written == read == 10, written: 10, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:5]) + r, _ = b.Read(make([]byte, 10)) + if r != 5 { + t.Fatalf("Expected written == read == 5, written: 5, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:10]) + r, _ = b.Read(make([]byte, 5)) + if r != 5 { + t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r) + } + + b = newBuffer() + b.write(alphabet[:5]) + b.write(alphabet[5:15]) + r, _ = b.Read(make([]byte, 10)) + r2, _ := b.Read(make([]byte, 10)) + if r != 10 || r2 != 5 || 15 != r+r2 { + t.Fatal("Expected written == read == 15") + } +} + +func TestBufferClose(t *testing.T) { + b := newBuffer() + b.write(alphabet[:10]) + b.eof() + _, err := b.Read(make([]byte, 5)) + if err != nil { + t.Fatal("expected read of 5 to not return EOF") + } + b = newBuffer() + b.write(alphabet[:10]) + b.eof() + r, err := b.Read(make([]byte, 5)) + r2, err2 := b.Read(make([]byte, 10)) + if r != 5 || r2 != 5 || err != nil || err2 != nil { + t.Fatal("expected reads of 5 and 5") + } + + b = newBuffer() + b.write(alphabet[:10]) + b.eof() + r, err = b.Read(make([]byte, 5)) + r2, err2 = b.Read(make([]byte, 10)) + r3, err3 := b.Read(make([]byte, 10)) + if r != 5 || r2 != 5 || r3 != 0 || err != nil || err2 != nil || err3 != io.EOF { + t.Fatal("expected reads of 5 and 5 and 0, with EOF") + } + + b = newBuffer() + b.write(make([]byte, 5)) + b.write(make([]byte, 10)) + b.eof() + r, err = b.Read(make([]byte, 9)) + r2, err2 = b.Read(make([]byte, 3)) + r3, err3 = b.Read(make([]byte, 3)) + r4, err4 := b.Read(make([]byte, 10)) + if err != nil || err2 != nil || err3 != nil || err4 != io.EOF { + t.Fatalf("Expected EOF on forth read only, err=%v, err2=%v, err3=%v, err4=%v", err, err2, err3, err4) + } + if r != 9 || r2 != 3 || r3 != 3 || r4 != 0 { + t.Fatal("Expected written == read == 15", r, r2, r3, r4) + } +} diff --git a/tempfork/sshtest/ssh/certs.go b/tempfork/sshtest/ssh/certs.go new file mode 100644 index 0000000000000..27d0e14aa99c7 --- /dev/null +++ b/tempfork/sshtest/ssh/certs.go @@ -0,0 +1,611 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "sort" + "time" +) + +// Certificate algorithm names from [PROTOCOL.certkeys]. These values can appear +// in Certificate.Type, PublicKey.Type, and ClientConfig.HostKeyAlgorithms. +// Unlike key algorithm names, these are not passed to AlgorithmSigner nor +// returned by MultiAlgorithmSigner and don't appear in the Signature.Format +// field. +const ( + CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com" + CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com" + CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com" + CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com" + CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com" + CertAlgoSKECDSA256v01 = "sk-ecdsa-sha2-nistp256-cert-v01@openssh.com" + CertAlgoED25519v01 = "ssh-ed25519-cert-v01@openssh.com" + CertAlgoSKED25519v01 = "sk-ssh-ed25519-cert-v01@openssh.com" + + // CertAlgoRSASHA256v01 and CertAlgoRSASHA512v01 can't appear as a + // Certificate.Type (or PublicKey.Type), but only in + // ClientConfig.HostKeyAlgorithms. + CertAlgoRSASHA256v01 = "rsa-sha2-256-cert-v01@openssh.com" + CertAlgoRSASHA512v01 = "rsa-sha2-512-cert-v01@openssh.com" +) + +const ( + // Deprecated: use CertAlgoRSAv01. + CertSigAlgoRSAv01 = CertAlgoRSAv01 + // Deprecated: use CertAlgoRSASHA256v01. + CertSigAlgoRSASHA2256v01 = CertAlgoRSASHA256v01 + // Deprecated: use CertAlgoRSASHA512v01. + CertSigAlgoRSASHA2512v01 = CertAlgoRSASHA512v01 +) + +// Certificate types distinguish between host and user +// certificates. The values can be set in the CertType field of +// Certificate. +const ( + UserCert = 1 + HostCert = 2 +) + +// Signature represents a cryptographic signature. +type Signature struct { + Format string + Blob []byte + Rest []byte `ssh:"rest"` +} + +// CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that +// a certificate does not expire. +const CertTimeInfinity = 1<<64 - 1 + +// An Certificate represents an OpenSSH certificate as defined in +// [PROTOCOL.certkeys]?rev=1.8. The Certificate type implements the +// PublicKey interface, so it can be unmarshaled using +// ParsePublicKey. +type Certificate struct { + Nonce []byte + Key PublicKey + Serial uint64 + CertType uint32 + KeyId string + ValidPrincipals []string + ValidAfter uint64 + ValidBefore uint64 + Permissions + Reserved []byte + SignatureKey PublicKey + Signature *Signature +} + +// genericCertData holds the key-independent part of the certificate data. +// Overall, certificates contain an nonce, public key fields and +// key-independent fields. +type genericCertData struct { + Serial uint64 + CertType uint32 + KeyId string + ValidPrincipals []byte + ValidAfter uint64 + ValidBefore uint64 + CriticalOptions []byte + Extensions []byte + Reserved []byte + SignatureKey []byte + Signature []byte +} + +func marshalStringList(namelist []string) []byte { + var to []byte + for _, name := range namelist { + s := struct{ N string }{name} + to = append(to, Marshal(&s)...) + } + return to +} + +type optionsTuple struct { + Key string + Value []byte +} + +type optionsTupleValue struct { + Value string +} + +// serialize a map of critical options or extensions +// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation, +// we need two length prefixes for a non-empty string value +func marshalTuples(tups map[string]string) []byte { + keys := make([]string, 0, len(tups)) + for key := range tups { + keys = append(keys, key) + } + sort.Strings(keys) + + var ret []byte + for _, key := range keys { + s := optionsTuple{Key: key} + if value := tups[key]; len(value) > 0 { + s.Value = Marshal(&optionsTupleValue{value}) + } + ret = append(ret, Marshal(&s)...) + } + return ret +} + +// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation, +// we need two length prefixes for a non-empty option value +func parseTuples(in []byte) (map[string]string, error) { + tups := map[string]string{} + var lastKey string + var haveLastKey bool + + for len(in) > 0 { + var key, val, extra []byte + var ok bool + + if key, in, ok = parseString(in); !ok { + return nil, errShortRead + } + keyStr := string(key) + // according to [PROTOCOL.certkeys], the names must be in + // lexical order. + if haveLastKey && keyStr <= lastKey { + return nil, fmt.Errorf("ssh: certificate options are not in lexical order") + } + lastKey, haveLastKey = keyStr, true + // the next field is a data field, which if non-empty has a string embedded + if val, in, ok = parseString(in); !ok { + return nil, errShortRead + } + if len(val) > 0 { + val, extra, ok = parseString(val) + if !ok { + return nil, errShortRead + } + if len(extra) > 0 { + return nil, fmt.Errorf("ssh: unexpected trailing data after certificate option value") + } + tups[keyStr] = string(val) + } else { + tups[keyStr] = "" + } + } + return tups, nil +} + +func parseCert(in []byte, privAlgo string) (*Certificate, error) { + nonce, rest, ok := parseString(in) + if !ok { + return nil, errShortRead + } + + key, rest, err := parsePubKey(rest, privAlgo) + if err != nil { + return nil, err + } + + var g genericCertData + if err := Unmarshal(rest, &g); err != nil { + return nil, err + } + + c := &Certificate{ + Nonce: nonce, + Key: key, + Serial: g.Serial, + CertType: g.CertType, + KeyId: g.KeyId, + ValidAfter: g.ValidAfter, + ValidBefore: g.ValidBefore, + } + + for principals := g.ValidPrincipals; len(principals) > 0; { + principal, rest, ok := parseString(principals) + if !ok { + return nil, errShortRead + } + c.ValidPrincipals = append(c.ValidPrincipals, string(principal)) + principals = rest + } + + c.CriticalOptions, err = parseTuples(g.CriticalOptions) + if err != nil { + return nil, err + } + c.Extensions, err = parseTuples(g.Extensions) + if err != nil { + return nil, err + } + c.Reserved = g.Reserved + k, err := ParsePublicKey(g.SignatureKey) + if err != nil { + return nil, err + } + + c.SignatureKey = k + c.Signature, rest, ok = parseSignatureBody(g.Signature) + if !ok || len(rest) > 0 { + return nil, errors.New("ssh: signature parse error") + } + + return c, nil +} + +type openSSHCertSigner struct { + pub *Certificate + signer Signer +} + +type algorithmOpenSSHCertSigner struct { + *openSSHCertSigner + algorithmSigner AlgorithmSigner +} + +// NewCertSigner returns a Signer that signs with the given Certificate, whose +// private key is held by signer. It returns an error if the public key in cert +// doesn't match the key used by signer. +func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) { + if !bytes.Equal(cert.Key.Marshal(), signer.PublicKey().Marshal()) { + return nil, errors.New("ssh: signer and cert have different public key") + } + + switch s := signer.(type) { + case MultiAlgorithmSigner: + return &multiAlgorithmSigner{ + AlgorithmSigner: &algorithmOpenSSHCertSigner{ + &openSSHCertSigner{cert, signer}, s}, + supportedAlgorithms: s.Algorithms(), + }, nil + case AlgorithmSigner: + return &algorithmOpenSSHCertSigner{ + &openSSHCertSigner{cert, signer}, s}, nil + default: + return &openSSHCertSigner{cert, signer}, nil + } +} + +func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + return s.signer.Sign(rand, data) +} + +func (s *openSSHCertSigner) PublicKey() PublicKey { + return s.pub +} + +func (s *algorithmOpenSSHCertSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + return s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm) +} + +const sourceAddressCriticalOption = "source-address" + +// CertChecker does the work of verifying a certificate. Its methods +// can be plugged into ClientConfig.HostKeyCallback and +// ServerConfig.PublicKeyCallback. For the CertChecker to work, +// minimally, the IsAuthority callback should be set. +type CertChecker struct { + // SupportedCriticalOptions lists the CriticalOptions that the + // server application layer understands. These are only used + // for user certificates. + SupportedCriticalOptions []string + + // IsUserAuthority should return true if the key is recognized as an + // authority for the given user certificate. This allows for + // certificates to be signed by other certificates. This must be set + // if this CertChecker will be checking user certificates. + IsUserAuthority func(auth PublicKey) bool + + // IsHostAuthority should report whether the key is recognized as + // an authority for this host. This allows for certificates to be + // signed by other keys, and for those other keys to only be valid + // signers for particular hostnames. This must be set if this + // CertChecker will be checking host certificates. + IsHostAuthority func(auth PublicKey, address string) bool + + // Clock is used for verifying time stamps. If nil, time.Now + // is used. + Clock func() time.Time + + // UserKeyFallback is called when CertChecker.Authenticate encounters a + // public key that is not a certificate. It must implement validation + // of user keys or else, if nil, all such keys are rejected. + UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // HostKeyFallback is called when CertChecker.CheckHostKey encounters a + // public key that is not a certificate. It must implement host key + // validation or else, if nil, all such keys are rejected. + HostKeyFallback HostKeyCallback + + // IsRevoked is called for each certificate so that revocation checking + // can be implemented. It should return true if the given certificate + // is revoked and false otherwise. If nil, no certificates are + // considered to have been revoked. + IsRevoked func(cert *Certificate) bool +} + +// CheckHostKey checks a host key certificate. This method can be +// plugged into ClientConfig.HostKeyCallback. +func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error { + cert, ok := key.(*Certificate) + if !ok { + if c.HostKeyFallback != nil { + return c.HostKeyFallback(addr, remote, key) + } + return errors.New("ssh: non-certificate host key") + } + if cert.CertType != HostCert { + return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) + } + if !c.IsHostAuthority(cert.SignatureKey, addr) { + return fmt.Errorf("ssh: no authorities for hostname: %v", addr) + } + + hostname, _, err := net.SplitHostPort(addr) + if err != nil { + return err + } + + // Pass hostname only as principal for host certificates (consistent with OpenSSH) + return c.CheckCert(hostname, cert) +} + +// Authenticate checks a user certificate. Authenticate can be used as +// a value for ServerConfig.PublicKeyCallback. +func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) { + cert, ok := pubKey.(*Certificate) + if !ok { + if c.UserKeyFallback != nil { + return c.UserKeyFallback(conn, pubKey) + } + return nil, errors.New("ssh: normal key pairs not accepted") + } + + if cert.CertType != UserCert { + return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) + } + if !c.IsUserAuthority(cert.SignatureKey) { + return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority") + } + + if err := c.CheckCert(conn.User(), cert); err != nil { + return nil, err + } + + return &cert.Permissions, nil +} + +// CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and +// the signature of the certificate. +func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { + if c.IsRevoked != nil && c.IsRevoked(cert) { + return fmt.Errorf("ssh: certificate serial %d revoked", cert.Serial) + } + + for opt := range cert.CriticalOptions { + // sourceAddressCriticalOption will be enforced by + // serverAuthenticate + if opt == sourceAddressCriticalOption { + continue + } + + found := false + for _, supp := range c.SupportedCriticalOptions { + if supp == opt { + found = true + break + } + } + if !found { + return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt) + } + } + + if len(cert.ValidPrincipals) > 0 { + // By default, certs are valid for all users/hosts. + found := false + for _, p := range cert.ValidPrincipals { + if p == principal { + found = true + break + } + } + if !found { + return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals) + } + } + + clock := c.Clock + if clock == nil { + clock = time.Now + } + + unixNow := clock().Unix() + if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { + return fmt.Errorf("ssh: cert is not yet valid") + } + if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) { + return fmt.Errorf("ssh: cert has expired") + } + if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil { + return fmt.Errorf("ssh: certificate signature does not verify") + } + + return nil +} + +// SignCert signs the certificate with an authority, setting the Nonce, +// SignatureKey, and Signature fields. If the authority implements the +// MultiAlgorithmSigner interface the first algorithm in the list is used. This +// is useful if you want to sign with a specific algorithm. +func (c *Certificate) SignCert(rand io.Reader, authority Signer) error { + c.Nonce = make([]byte, 32) + if _, err := io.ReadFull(rand, c.Nonce); err != nil { + return err + } + c.SignatureKey = authority.PublicKey() + + if v, ok := authority.(MultiAlgorithmSigner); ok { + if len(v.Algorithms()) == 0 { + return errors.New("the provided authority has no signature algorithm") + } + // Use the first algorithm in the list. + sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), v.Algorithms()[0]) + if err != nil { + return err + } + c.Signature = sig + return nil + } else if v, ok := authority.(AlgorithmSigner); ok && v.PublicKey().Type() == KeyAlgoRSA { + // Default to KeyAlgoRSASHA512 for ssh-rsa signers. + // TODO: consider using KeyAlgoRSASHA256 as default. + sig, err := v.SignWithAlgorithm(rand, c.bytesForSigning(), KeyAlgoRSASHA512) + if err != nil { + return err + } + c.Signature = sig + return nil + } + + sig, err := authority.Sign(rand, c.bytesForSigning()) + if err != nil { + return err + } + c.Signature = sig + return nil +} + +// certKeyAlgoNames is a mapping from known certificate algorithm names to the +// corresponding public key signature algorithm. +// +// This map must be kept in sync with the one in agent/client.go. +var certKeyAlgoNames = map[string]string{ + CertAlgoRSAv01: KeyAlgoRSA, + CertAlgoRSASHA256v01: KeyAlgoRSASHA256, + CertAlgoRSASHA512v01: KeyAlgoRSASHA512, + CertAlgoDSAv01: KeyAlgoDSA, + CertAlgoECDSA256v01: KeyAlgoECDSA256, + CertAlgoECDSA384v01: KeyAlgoECDSA384, + CertAlgoECDSA521v01: KeyAlgoECDSA521, + CertAlgoSKECDSA256v01: KeyAlgoSKECDSA256, + CertAlgoED25519v01: KeyAlgoED25519, + CertAlgoSKED25519v01: KeyAlgoSKED25519, +} + +// underlyingAlgo returns the signature algorithm associated with algo (which is +// an advertised or negotiated public key or host key algorithm). These are +// usually the same, except for certificate algorithms. +func underlyingAlgo(algo string) string { + if a, ok := certKeyAlgoNames[algo]; ok { + return a + } + return algo +} + +// certificateAlgo returns the certificate algorithms that uses the provided +// underlying signature algorithm. +func certificateAlgo(algo string) (certAlgo string, ok bool) { + for certName, algoName := range certKeyAlgoNames { + if algoName == algo { + return certName, true + } + } + return "", false +} + +func (cert *Certificate) bytesForSigning() []byte { + c2 := *cert + c2.Signature = nil + out := c2.Marshal() + // Drop trailing signature length. + return out[:len(out)-4] +} + +// Marshal serializes c into OpenSSH's wire format. It is part of the +// PublicKey interface. +func (c *Certificate) Marshal() []byte { + generic := genericCertData{ + Serial: c.Serial, + CertType: c.CertType, + KeyId: c.KeyId, + ValidPrincipals: marshalStringList(c.ValidPrincipals), + ValidAfter: uint64(c.ValidAfter), + ValidBefore: uint64(c.ValidBefore), + CriticalOptions: marshalTuples(c.CriticalOptions), + Extensions: marshalTuples(c.Extensions), + Reserved: c.Reserved, + SignatureKey: c.SignatureKey.Marshal(), + } + if c.Signature != nil { + generic.Signature = Marshal(c.Signature) + } + genericBytes := Marshal(&generic) + keyBytes := c.Key.Marshal() + _, keyBytes, _ = parseString(keyBytes) + prefix := Marshal(&struct { + Name string + Nonce []byte + Key []byte `ssh:"rest"` + }{c.Type(), c.Nonce, keyBytes}) + + result := make([]byte, 0, len(prefix)+len(genericBytes)) + result = append(result, prefix...) + result = append(result, genericBytes...) + return result +} + +// Type returns the certificate algorithm name. It is part of the PublicKey interface. +func (c *Certificate) Type() string { + certName, ok := certificateAlgo(c.Key.Type()) + if !ok { + panic("unknown certificate type for key type " + c.Key.Type()) + } + return certName +} + +// Verify verifies a signature against the certificate's public +// key. It is part of the PublicKey interface. +func (c *Certificate) Verify(data []byte, sig *Signature) error { + return c.Key.Verify(data, sig) +} + +func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) { + format, in, ok := parseString(in) + if !ok { + return + } + + out = &Signature{ + Format: string(format), + } + + if out.Blob, in, ok = parseString(in); !ok { + return + } + + switch out.Format { + case KeyAlgoSKECDSA256, CertAlgoSKECDSA256v01, KeyAlgoSKED25519, CertAlgoSKED25519v01: + out.Rest = in + return out, nil, ok + } + + return out, in, ok +} + +func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) { + sigBytes, rest, ok := parseString(in) + if !ok { + return + } + + out, trailing, ok := parseSignatureBody(sigBytes) + if !ok || len(trailing) > 0 { + return nil, nil, false + } + return +} diff --git a/tempfork/sshtest/ssh/certs_test.go b/tempfork/sshtest/ssh/certs_test.go new file mode 100644 index 0000000000000..6208bb37a97e2 --- /dev/null +++ b/tempfork/sshtest/ssh/certs_test.go @@ -0,0 +1,406 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "fmt" + "io" + "net" + "reflect" + "testing" + "time" + + "golang.org/x/crypto/ssh/testdata" +) + +func TestParseCert(t *testing.T) { + authKeyBytes := bytes.TrimSuffix(testdata.SSHCertificates["rsa"], []byte(" host.example.com\n")) + + key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + if len(rest) > 0 { + t.Errorf("rest: got %q, want empty", rest) + } + + if _, ok := key.(*Certificate); !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + + marshaled := MarshalAuthorizedKey(key) + // Before comparison, remove the trailing newline that + // MarshalAuthorizedKey adds. + marshaled = marshaled[:len(marshaled)-1] + if !bytes.Equal(authKeyBytes, marshaled) { + t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) + } +} + +// Cert generated by ssh-keygen OpenSSH_6.8p1 OS X 10.10.3 +// % ssh-keygen -s ca -I testcert -O source-address=192.168.1.0/24 -O force-command=/bin/sleep user.pub +// user.pub key: ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMN +// Critical Options: +// +// force-command /bin/sleep +// source-address 192.168.1.0/24 +// +// Extensions: +// +// permit-X11-forwarding +// permit-agent-forwarding +// permit-port-forwarding +// permit-pty +// permit-user-rc +const exampleSSHCertWithOptions = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgDyysCJY0XrO1n03EeRRoITnTPdjENFmWDs9X58PP3VUAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMNAAAAAAAAAAAAAAABAAAACHRlc3RjZXJ0AAAAAAAAAAAAAAAA//////////8AAABLAAAADWZvcmNlLWNvbW1hbmQAAAAOAAAACi9iaW4vc2xlZXAAAAAOc291cmNlLWFkZHJlc3MAAAASAAAADjE5Mi4xNjguMS4wLzI0AAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEAwU+c5ui5A8+J/CFpjW8wCa52bEODA808WWQDCSuTG/eMXNf59v9Y8Pk0F1E9dGCosSNyVcB/hacUrc6He+i97+HJCyKavBsE6GDxrjRyxYqAlfcOXi/IVmaUGiO8OQ39d4GHrjToInKvExSUeleQyH4Y4/e27T/pILAqPFL3fyrvMLT5qU9QyIt6zIpa7GBP5+urouNavMprV3zsfIqNBbWypinOQAw823a5wN+zwXnhZrgQiHZ/USG09Y6k98y1dTVz8YHlQVR4D3lpTAsKDKJ5hCH9WU4fdf+lU8OyNGaJ/vz0XNqxcToe1l4numLTnaoSuH89pHryjqurB7lJKwAAAQ8AAAAHc3NoLXJzYQAAAQCaHvUIoPL1zWUHIXLvu96/HU1s/i4CAW2IIEuGgxCUCiFj6vyTyYtgxQxcmbfZf6eaITlS6XJZa7Qq4iaFZh75C1DXTX8labXhRSD4E2t//AIP9MC1rtQC5xo6FmbQ+BoKcDskr+mNACcbRSxs3IL3bwCfWDnIw2WbVox9ZdcthJKk4UoCW4ix4QwdHw7zlddlz++fGEEVhmTbll1SUkycGApPFBsAYRTMupUJcYPIeReBI/m8XfkoMk99bV8ZJQTAd7OekHY2/48Ff53jLmyDjP7kNw1F8OaPtkFs6dGJXta4krmaekPy87j+35In5hFj7yoOqvSbmYUkeX70/GGQ` + +func TestParseCertWithOptions(t *testing.T) { + opts := map[string]string{ + "source-address": "192.168.1.0/24", + "force-command": "/bin/sleep", + } + exts := map[string]string{ + "permit-X11-forwarding": "", + "permit-agent-forwarding": "", + "permit-port-forwarding": "", + "permit-pty": "", + "permit-user-rc": "", + } + authKeyBytes := []byte(exampleSSHCertWithOptions) + + key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + if len(rest) > 0 { + t.Errorf("rest: got %q, want empty", rest) + } + cert, ok := key.(*Certificate) + if !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + if !reflect.DeepEqual(cert.CriticalOptions, opts) { + t.Errorf("unexpected critical options - got %v, want %v", cert.CriticalOptions, opts) + } + if !reflect.DeepEqual(cert.Extensions, exts) { + t.Errorf("unexpected Extensions - got %v, want %v", cert.Extensions, exts) + } + marshaled := MarshalAuthorizedKey(key) + // Before comparison, remove the trailing newline that + // MarshalAuthorizedKey adds. + marshaled = marshaled[:len(marshaled)-1] + if !bytes.Equal(authKeyBytes, marshaled) { + t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) + } +} + +func TestValidateCert(t *testing.T) { + key, _, _, _, err := ParseAuthorizedKey(testdata.SSHCertificates["rsa-user-testcertificate"]) + if err != nil { + t.Fatalf("ParseAuthorizedKey: %v", err) + } + validCert, ok := key.(*Certificate) + if !ok { + t.Fatalf("got %v (%T), want *Certificate", key, key) + } + checker := CertChecker{} + checker.IsUserAuthority = func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal()) + } + + if err := checker.CheckCert("testcertificate", validCert); err != nil { + t.Errorf("Unable to validate certificate: %v", err) + } + invalidCert := &Certificate{ + Key: testPublicKeys["rsa"], + SignatureKey: testPublicKeys["ecdsa"], + ValidBefore: CertTimeInfinity, + Signature: &Signature{}, + } + if err := checker.CheckCert("testcertificate", invalidCert); err == nil { + t.Error("Invalid cert signature passed validation") + } +} + +func TestValidateCertTime(t *testing.T) { + cert := Certificate{ + ValidPrincipals: []string{"user"}, + Key: testPublicKeys["rsa"], + ValidAfter: 50, + ValidBefore: 100, + } + + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + for ts, ok := range map[int64]bool{ + 25: false, + 50: true, + 99: true, + 100: false, + 125: false, + } { + checker := CertChecker{ + Clock: func() time.Time { return time.Unix(ts, 0) }, + } + checker.IsUserAuthority = func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), + testPublicKeys["ecdsa"].Marshal()) + } + + if v := checker.CheckCert("user", &cert); (v == nil) != ok { + t.Errorf("Authenticate(%d): %v", ts, v) + } + } +} + +// TODO(hanwen): tests for +// +// host keys: +// * fallbacks + +func TestHostKeyCert(t *testing.T) { + cert := &Certificate{ + ValidPrincipals: []string{"hostname", "hostname.domain", "otherhost"}, + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: HostCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + + checker := &CertChecker{ + IsHostAuthority: func(p PublicKey, addr string) bool { + return addr == "hostname:22" && bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal()) + }, + } + + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Errorf("NewCertSigner: %v", err) + } + + for _, test := range []struct { + addr string + succeed bool + certSignerAlgorithms []string // Empty means no algorithm restrictions. + clientHostKeyAlgorithms []string + }{ + {addr: "hostname:22", succeed: true}, + { + addr: "hostname:22", + succeed: true, + certSignerAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}, + clientHostKeyAlgorithms: []string{CertAlgoRSASHA512v01}, + }, + { + addr: "hostname:22", + succeed: false, + certSignerAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}, + clientHostKeyAlgorithms: []string{CertAlgoRSAv01}, + }, + { + addr: "hostname:22", + succeed: false, + certSignerAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}, + clientHostKeyAlgorithms: []string{KeyAlgoRSASHA512}, // Not a certificate algorithm. + }, + {addr: "otherhost:22", succeed: false}, // The certificate is valid for 'otherhost' as hostname, but we only recognize the authority of the signer for the address 'hostname:22' + {addr: "lasthost:22", succeed: false}, + } { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + errc := make(chan error) + + go func() { + conf := ServerConfig{ + NoClientAuth: true, + } + if len(test.certSignerAlgorithms) > 0 { + mas, err := NewSignerWithAlgorithms(certSigner.(AlgorithmSigner), test.certSignerAlgorithms) + if err != nil { + errc <- err + return + } + conf.AddHostKey(mas) + } else { + conf.AddHostKey(certSigner) + } + _, _, _, err := NewServerConn(c1, &conf) + errc <- err + }() + + config := &ClientConfig{ + User: "user", + HostKeyCallback: checker.CheckHostKey, + HostKeyAlgorithms: test.clientHostKeyAlgorithms, + } + _, _, _, err = NewClientConn(c2, test.addr, config) + + if (err == nil) != test.succeed { + t.Errorf("NewClientConn(%q): %v", test.addr, err) + } + + err = <-errc + if (err == nil) != test.succeed { + t.Errorf("NewServerConn(%q): %v", test.addr, err) + } + } +} + +type legacyRSASigner struct { + Signer +} + +func (s *legacyRSASigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + v, ok := s.Signer.(AlgorithmSigner) + if !ok { + return nil, fmt.Errorf("invalid signer") + } + return v.SignWithAlgorithm(rand, data, KeyAlgoRSA) +} + +func TestCertTypes(t *testing.T) { + algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner) + if !ok { + t.Fatal("rsa test signer does not implement the AlgorithmSigner interface") + } + multiAlgoSignerSHA256, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256}) + if err != nil { + t.Fatalf("unable to create multi algorithm signer SHA256: %v", err) + } + // Algorithms are in order of preference, we expect rsa-sha2-512 to be used. + multiAlgoSignerSHA512, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256}) + if err != nil { + t.Fatalf("unable to create multi algorithm signer SHA512: %v", err) + } + + var testVars = []struct { + name string + signer Signer + algo string + }{ + {CertAlgoECDSA256v01, testSigners["ecdsap256"], ""}, + {CertAlgoECDSA384v01, testSigners["ecdsap384"], ""}, + {CertAlgoECDSA521v01, testSigners["ecdsap521"], ""}, + {CertAlgoED25519v01, testSigners["ed25519"], ""}, + {CertAlgoRSAv01, testSigners["rsa"], KeyAlgoRSASHA256}, + {"legacyRSASigner", &legacyRSASigner{testSigners["rsa"]}, KeyAlgoRSA}, + {"multiAlgoRSASignerSHA256", multiAlgoSignerSHA256, KeyAlgoRSASHA256}, + {"multiAlgoRSASignerSHA512", multiAlgoSignerSHA512, KeyAlgoRSASHA512}, + {CertAlgoDSAv01, testSigners["dsa"], ""}, + } + + k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("error generating host key: %v", err) + } + + signer, err := NewSignerFromKey(k) + if err != nil { + t.Fatalf("error generating signer for ssh listener: %v", err) + } + + conf := &ServerConfig{ + PublicKeyCallback: func(c ConnMetadata, k PublicKey) (*Permissions, error) { + return new(Permissions), nil + }, + } + conf.AddHostKey(signer) + + for _, m := range testVars { + t.Run(m.name, func(t *testing.T) { + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewServerConn(c1, conf) + + priv := m.signer + if err != nil { + t.Fatalf("error generating ssh pubkey: %v", err) + } + + cert := &Certificate{ + CertType: UserCert, + Key: priv.PublicKey(), + } + cert.SignCert(rand.Reader, priv) + + certSigner, err := NewCertSigner(cert, priv) + if err != nil { + t.Fatalf("error generating cert signer: %v", err) + } + + if m.algo != "" && cert.Signature.Format != m.algo { + t.Errorf("expected %q signature format, got %q", m.algo, cert.Signature.Format) + } + + config := &ClientConfig{ + User: "user", + HostKeyCallback: func(h string, r net.Addr, k PublicKey) error { return nil }, + Auth: []AuthMethod{PublicKeys(certSigner)}, + } + + _, _, _, err = NewClientConn(c2, "", config) + if err != nil { + t.Fatalf("error connecting: %v", err) + } + }) + } +} + +func TestCertSignWithMultiAlgorithmSigner(t *testing.T) { + type testcase struct { + sigAlgo string + algorithms []string + } + cases := []testcase{ + { + sigAlgo: KeyAlgoRSA, + algorithms: []string{KeyAlgoRSA, KeyAlgoRSASHA512}, + }, + { + sigAlgo: KeyAlgoRSASHA256, + algorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSA, KeyAlgoRSASHA512}, + }, + { + sigAlgo: KeyAlgoRSASHA512, + algorithms: []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256}, + }, + } + + cert := &Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: UserCert, + } + + for _, c := range cases { + t.Run(c.sigAlgo, func(t *testing.T) { + signer, err := NewSignerWithAlgorithms(testSigners["rsa"].(AlgorithmSigner), c.algorithms) + if err != nil { + t.Fatalf("NewSignerWithAlgorithms error: %v", err) + } + if err := cert.SignCert(rand.Reader, signer); err != nil { + t.Fatalf("SignCert error: %v", err) + } + if cert.Signature.Format != c.sigAlgo { + t.Fatalf("got signature format %q, want %q", cert.Signature.Format, c.sigAlgo) + } + }) + } +} diff --git a/tempfork/sshtest/ssh/channel.go b/tempfork/sshtest/ssh/channel.go new file mode 100644 index 0000000000000..cc0bb7ab64814 --- /dev/null +++ b/tempfork/sshtest/ssh/channel.go @@ -0,0 +1,645 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "log" + "sync" +) + +const ( + minPacketLength = 9 + // channelMaxPacket contains the maximum number of bytes that will be + // sent in a single packet. As per RFC 4253, section 6.1, 32k is also + // the minimum. + channelMaxPacket = 1 << 15 + // We follow OpenSSH here. + channelWindowSize = 64 * channelMaxPacket +) + +// NewChannel represents an incoming request to a channel. It must either be +// accepted for use by calling Accept, or rejected by calling Reject. +type NewChannel interface { + // Accept accepts the channel creation request. It returns the Channel + // and a Go channel containing SSH requests. The Go channel must be + // serviced otherwise the Channel will hang. + Accept() (Channel, <-chan *Request, error) + + // Reject rejects the channel creation request. After calling + // this, no other methods on the Channel may be called. + Reject(reason RejectionReason, message string) error + + // ChannelType returns the type of the channel, as supplied by the + // client. + ChannelType() string + + // ExtraData returns the arbitrary payload for this channel, as supplied + // by the client. This data is specific to the channel type. + ExtraData() []byte +} + +// A Channel is an ordered, reliable, flow-controlled, duplex stream +// that is multiplexed over an SSH connection. +type Channel interface { + // Read reads up to len(data) bytes from the channel. + Read(data []byte) (int, error) + + // Write writes len(data) bytes to the channel. + Write(data []byte) (int, error) + + // Close signals end of channel use. No data may be sent after this + // call. + Close() error + + // CloseWrite signals the end of sending in-band + // data. Requests may still be sent, and the other side may + // still send data + CloseWrite() error + + // SendRequest sends a channel request. If wantReply is true, + // it will wait for a reply and return the result as a + // boolean, otherwise the return value will be false. Channel + // requests are out-of-band messages so they may be sent even + // if the data stream is closed or blocked by flow control. + // If the channel is closed before a reply is returned, io.EOF + // is returned. + SendRequest(name string, wantReply bool, payload []byte) (bool, error) + + // Stderr returns an io.ReadWriter that writes to this channel + // with the extended data type set to stderr. Stderr may + // safely be read and written from a different goroutine than + // Read and Write respectively. + Stderr() io.ReadWriter +} + +// Request is a request sent outside of the normal stream of +// data. Requests can either be specific to an SSH channel, or they +// can be global. +type Request struct { + Type string + WantReply bool + Payload []byte + + ch *channel + mux *mux +} + +// Reply sends a response to a request. It must be called for all requests +// where WantReply is true and is a no-op otherwise. The payload argument is +// ignored for replies to channel-specific requests. +func (r *Request) Reply(ok bool, payload []byte) error { + if !r.WantReply { + return nil + } + + if r.ch == nil { + return r.mux.ackRequest(ok, payload) + } + + return r.ch.ackRequest(ok) +} + +// RejectionReason is an enumeration used when rejecting channel creation +// requests. See RFC 4254, section 5.1. +type RejectionReason uint32 + +const ( + Prohibited RejectionReason = iota + 1 + ConnectionFailed + UnknownChannelType + ResourceShortage +) + +// String converts the rejection reason to human readable form. +func (r RejectionReason) String() string { + switch r { + case Prohibited: + return "administratively prohibited" + case ConnectionFailed: + return "connect failed" + case UnknownChannelType: + return "unknown channel type" + case ResourceShortage: + return "resource shortage" + } + return fmt.Sprintf("unknown reason %d", int(r)) +} + +func min(a uint32, b int) uint32 { + if a < uint32(b) { + return a + } + return uint32(b) +} + +type channelDirection uint8 + +const ( + channelInbound channelDirection = iota + channelOutbound +) + +// channel is an implementation of the Channel interface that works +// with the mux class. +type channel struct { + // R/O after creation + chanType string + extraData []byte + localId, remoteId uint32 + + // maxIncomingPayload and maxRemotePayload are the maximum + // payload sizes of normal and extended data packets for + // receiving and sending, respectively. The wire packet will + // be 9 or 13 bytes larger (excluding encryption overhead). + maxIncomingPayload uint32 + maxRemotePayload uint32 + + mux *mux + + // decided is set to true if an accept or reject message has been sent + // (for outbound channels) or received (for inbound channels). + decided bool + + // direction contains either channelOutbound, for channels created + // locally, or channelInbound, for channels created by the peer. + direction channelDirection + + // Pending internal channel messages. + msg chan interface{} + + // Since requests have no ID, there can be only one request + // with WantReply=true outstanding. This lock is held by a + // goroutine that has such an outgoing request pending. + sentRequestMu sync.Mutex + + incomingRequests chan *Request + + sentEOF bool + + // thread-safe data + remoteWin window + pending *buffer + extPending *buffer + + // windowMu protects myWindow, the flow-control window, and myConsumed, + // the number of bytes consumed since we last increased myWindow + windowMu sync.Mutex + myWindow uint32 + myConsumed uint32 + + // writeMu serializes calls to mux.conn.writePacket() and + // protects sentClose and packetPool. This mutex must be + // different from windowMu, as writePacket can block if there + // is a key exchange pending. + writeMu sync.Mutex + sentClose bool + + // packetPool has a buffer for each extended channel ID to + // save allocations during writes. + packetPool map[uint32][]byte +} + +// writePacket sends a packet. If the packet is a channel close, it updates +// sentClose. This method takes the lock c.writeMu. +func (ch *channel) writePacket(packet []byte) error { + ch.writeMu.Lock() + if ch.sentClose { + ch.writeMu.Unlock() + return io.EOF + } + ch.sentClose = (packet[0] == msgChannelClose) + err := ch.mux.conn.writePacket(packet) + ch.writeMu.Unlock() + return err +} + +func (ch *channel) sendMessage(msg interface{}) error { + if debugMux { + log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg) + } + + p := Marshal(msg) + binary.BigEndian.PutUint32(p[1:], ch.remoteId) + return ch.writePacket(p) +} + +// WriteExtended writes data to a specific extended stream. These streams are +// used, for example, for stderr. +func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { + if ch.sentEOF { + return 0, io.EOF + } + // 1 byte message type, 4 bytes remoteId, 4 bytes data length + opCode := byte(msgChannelData) + headerLength := uint32(9) + if extendedCode > 0 { + headerLength += 4 + opCode = msgChannelExtendedData + } + + ch.writeMu.Lock() + packet := ch.packetPool[extendedCode] + // We don't remove the buffer from packetPool, so + // WriteExtended calls from different goroutines will be + // flagged as errors by the race detector. + ch.writeMu.Unlock() + + for len(data) > 0 { + space := min(ch.maxRemotePayload, len(data)) + if space, err = ch.remoteWin.reserve(space); err != nil { + return n, err + } + if want := headerLength + space; uint32(cap(packet)) < want { + packet = make([]byte, want) + } else { + packet = packet[:want] + } + + todo := data[:space] + + packet[0] = opCode + binary.BigEndian.PutUint32(packet[1:], ch.remoteId) + if extendedCode > 0 { + binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) + } + binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) + copy(packet[headerLength:], todo) + if err = ch.writePacket(packet); err != nil { + return n, err + } + + n += len(todo) + data = data[len(todo):] + } + + ch.writeMu.Lock() + ch.packetPool[extendedCode] = packet + ch.writeMu.Unlock() + + return n, err +} + +func (ch *channel) handleData(packet []byte) error { + headerLen := 9 + isExtendedData := packet[0] == msgChannelExtendedData + if isExtendedData { + headerLen = 13 + } + if len(packet) < headerLen { + // malformed data packet + return parseError(packet[0]) + } + + var extended uint32 + if isExtendedData { + extended = binary.BigEndian.Uint32(packet[5:]) + } + + length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen]) + if length == 0 { + return nil + } + if length > ch.maxIncomingPayload { + // TODO(hanwen): should send Disconnect? + return errors.New("ssh: incoming packet exceeds maximum payload size") + } + + data := packet[headerLen:] + if length != uint32(len(data)) { + return errors.New("ssh: wrong packet length") + } + + ch.windowMu.Lock() + if ch.myWindow < length { + ch.windowMu.Unlock() + // TODO(hanwen): should send Disconnect with reason? + return errors.New("ssh: remote side wrote too much") + } + ch.myWindow -= length + ch.windowMu.Unlock() + + if extended == 1 { + ch.extPending.write(data) + } else if extended > 0 { + // discard other extended data. + } else { + ch.pending.write(data) + } + return nil +} + +func (c *channel) adjustWindow(adj uint32) error { + c.windowMu.Lock() + // Since myConsumed and myWindow are managed on our side, and can never + // exceed the initial window setting, we don't worry about overflow. + c.myConsumed += adj + var sendAdj uint32 + if (channelWindowSize-c.myWindow > 3*c.maxIncomingPayload) || + (c.myWindow < channelWindowSize/2) { + sendAdj = c.myConsumed + c.myConsumed = 0 + c.myWindow += sendAdj + } + c.windowMu.Unlock() + if sendAdj == 0 { + return nil + } + return c.sendMessage(windowAdjustMsg{ + AdditionalBytes: sendAdj, + }) +} + +func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) { + switch extended { + case 1: + n, err = c.extPending.Read(data) + case 0: + n, err = c.pending.Read(data) + default: + return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended) + } + + if n > 0 { + err = c.adjustWindow(uint32(n)) + // sendWindowAdjust can return io.EOF if the remote + // peer has closed the connection, however we want to + // defer forwarding io.EOF to the caller of Read until + // the buffer has been drained. + if n > 0 && err == io.EOF { + err = nil + } + } + + return n, err +} + +func (c *channel) close() { + c.pending.eof() + c.extPending.eof() + close(c.msg) + close(c.incomingRequests) + c.writeMu.Lock() + // This is not necessary for a normal channel teardown, but if + // there was another error, it is. + c.sentClose = true + c.writeMu.Unlock() + // Unblock writers. + c.remoteWin.close() +} + +// responseMessageReceived is called when a success or failure message is +// received on a channel to check that such a message is reasonable for the +// given channel. +func (ch *channel) responseMessageReceived() error { + if ch.direction == channelInbound { + return errors.New("ssh: channel response message received on inbound channel") + } + if ch.decided { + return errors.New("ssh: duplicate response received for channel") + } + ch.decided = true + return nil +} + +func (ch *channel) handlePacket(packet []byte) error { + switch packet[0] { + case msgChannelData, msgChannelExtendedData: + return ch.handleData(packet) + case msgChannelClose: + ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId}) + ch.mux.chanList.remove(ch.localId) + ch.close() + return nil + case msgChannelEOF: + // RFC 4254 is mute on how EOF affects dataExt messages but + // it is logical to signal EOF at the same time. + ch.extPending.eof() + ch.pending.eof() + return nil + } + + decoded, err := decode(packet) + if err != nil { + return err + } + + switch msg := decoded.(type) { + case *channelOpenFailureMsg: + if err := ch.responseMessageReceived(); err != nil { + return err + } + ch.mux.chanList.remove(msg.PeersID) + ch.msg <- msg + case *channelOpenConfirmMsg: + if err := ch.responseMessageReceived(); err != nil { + return err + } + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { + return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) + } + ch.remoteId = msg.MyID + ch.maxRemotePayload = msg.MaxPacketSize + ch.remoteWin.add(msg.MyWindow) + ch.msg <- msg + case *windowAdjustMsg: + if !ch.remoteWin.add(msg.AdditionalBytes) { + return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) + } + case *channelRequestMsg: + req := Request{ + Type: msg.Request, + WantReply: msg.WantReply, + Payload: msg.RequestSpecificData, + ch: ch, + } + + ch.incomingRequests <- &req + default: + ch.msg <- msg + } + return nil +} + +func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel { + ch := &channel{ + remoteWin: window{Cond: newCond()}, + myWindow: channelWindowSize, + pending: newBuffer(), + extPending: newBuffer(), + direction: direction, + incomingRequests: make(chan *Request, chanSize), + msg: make(chan interface{}, chanSize), + chanType: chanType, + extraData: extraData, + mux: m, + packetPool: make(map[uint32][]byte), + } + ch.localId = m.chanList.add(ch) + return ch +} + +var errUndecided = errors.New("ssh: must Accept or Reject channel") +var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once") + +type extChannel struct { + code uint32 + ch *channel +} + +func (e *extChannel) Write(data []byte) (n int, err error) { + return e.ch.WriteExtended(data, e.code) +} + +func (e *extChannel) Read(data []byte) (n int, err error) { + return e.ch.ReadExtended(data, e.code) +} + +func (ch *channel) Accept() (Channel, <-chan *Request, error) { + if ch.decided { + return nil, nil, errDecidedAlready + } + ch.maxIncomingPayload = channelMaxPacket + confirm := channelOpenConfirmMsg{ + PeersID: ch.remoteId, + MyID: ch.localId, + MyWindow: ch.myWindow, + MaxPacketSize: ch.maxIncomingPayload, + } + ch.decided = true + if err := ch.sendMessage(confirm); err != nil { + return nil, nil, err + } + + return ch, ch.incomingRequests, nil +} + +func (ch *channel) Reject(reason RejectionReason, message string) error { + if ch.decided { + return errDecidedAlready + } + reject := channelOpenFailureMsg{ + PeersID: ch.remoteId, + Reason: reason, + Message: message, + Language: "en", + } + ch.decided = true + return ch.sendMessage(reject) +} + +func (ch *channel) Read(data []byte) (int, error) { + if !ch.decided { + return 0, errUndecided + } + return ch.ReadExtended(data, 0) +} + +func (ch *channel) Write(data []byte) (int, error) { + if !ch.decided { + return 0, errUndecided + } + return ch.WriteExtended(data, 0) +} + +func (ch *channel) CloseWrite() error { + if !ch.decided { + return errUndecided + } + ch.sentEOF = true + return ch.sendMessage(channelEOFMsg{ + PeersID: ch.remoteId}) +} + +func (ch *channel) Close() error { + if !ch.decided { + return errUndecided + } + + return ch.sendMessage(channelCloseMsg{ + PeersID: ch.remoteId}) +} + +// Extended returns an io.ReadWriter that sends and receives data on the given, +// SSH extended stream. Such streams are used, for example, for stderr. +func (ch *channel) Extended(code uint32) io.ReadWriter { + if !ch.decided { + return nil + } + return &extChannel{code, ch} +} + +func (ch *channel) Stderr() io.ReadWriter { + return ch.Extended(1) +} + +func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + if !ch.decided { + return false, errUndecided + } + + if wantReply { + ch.sentRequestMu.Lock() + defer ch.sentRequestMu.Unlock() + } + + msg := channelRequestMsg{ + PeersID: ch.remoteId, + Request: name, + WantReply: wantReply, + RequestSpecificData: payload, + } + + if err := ch.sendMessage(msg); err != nil { + return false, err + } + + if wantReply { + m, ok := (<-ch.msg) + if !ok { + return false, io.EOF + } + switch m.(type) { + case *channelRequestFailureMsg: + return false, nil + case *channelRequestSuccessMsg: + return true, nil + default: + return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m) + } + } + + return false, nil +} + +// ackRequest either sends an ack or nack to the channel request. +func (ch *channel) ackRequest(ok bool) error { + if !ch.decided { + return errUndecided + } + + var msg interface{} + if !ok { + msg = channelRequestFailureMsg{ + PeersID: ch.remoteId, + } + } else { + msg = channelRequestSuccessMsg{ + PeersID: ch.remoteId, + } + } + return ch.sendMessage(msg) +} + +func (ch *channel) ChannelType() string { + return ch.chanType +} + +func (ch *channel) ExtraData() []byte { + return ch.extraData +} diff --git a/tempfork/sshtest/ssh/cipher.go b/tempfork/sshtest/ssh/cipher.go new file mode 100644 index 0000000000000..0533786f4bcc5 --- /dev/null +++ b/tempfork/sshtest/ssh/cipher.go @@ -0,0 +1,789 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/rc4" + "crypto/subtle" + "encoding/binary" + "errors" + "fmt" + "hash" + "io" + + "golang.org/x/crypto/chacha20" + "golang.org/x/crypto/poly1305" +) + +const ( + packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher. + + // RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations + // MUST be able to process (plus a few more kilobytes for padding and mac). The RFC + // indicates implementations SHOULD be able to handle larger packet sizes, but then + // waffles on about reasonable limits. + // + // OpenSSH caps their maxPacket at 256kB so we choose to do + // the same. maxPacket is also used to ensure that uint32 + // length fields do not overflow, so it should remain well + // below 4G. + maxPacket = 256 * 1024 +) + +// noneCipher implements cipher.Stream and provides no encryption. It is used +// by the transport before the first key-exchange. +type noneCipher struct{} + +func (c noneCipher) XORKeyStream(dst, src []byte) { + copy(dst, src) +} + +func newAESCTR(key, iv []byte) (cipher.Stream, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return cipher.NewCTR(c, iv), nil +} + +func newRC4(key, iv []byte) (cipher.Stream, error) { + return rc4.NewCipher(key) +} + +type cipherMode struct { + keySize int + ivSize int + create func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) +} + +func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + return func(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + stream, err := createFunc(key, iv) + if err != nil { + return nil, err + } + + var streamDump []byte + if skip > 0 { + streamDump = make([]byte, 512) + } + + for remainingToDump := skip; remainingToDump > 0; { + dumpThisTime := remainingToDump + if dumpThisTime > len(streamDump) { + dumpThisTime = len(streamDump) + } + stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime]) + remainingToDump -= dumpThisTime + } + + mac := macModes[algs.MAC].new(macKey) + return &streamPacketCipher{ + mac: mac, + etm: macModes[algs.MAC].etm, + macResult: make([]byte, mac.Size()), + cipher: stream, + }, nil + } +} + +// cipherModes documents properties of supported ciphers. Ciphers not included +// are not supported and will not be negotiated, even if explicitly requested in +// ClientConfig.Crypto.Ciphers. +var cipherModes = map[string]*cipherMode{ + // Ciphers from RFC 4344, which introduced many CTR-based ciphers. Algorithms + // are defined in the order specified in the RFC. + "aes128-ctr": {16, aes.BlockSize, streamCipherMode(0, newAESCTR)}, + "aes192-ctr": {24, aes.BlockSize, streamCipherMode(0, newAESCTR)}, + "aes256-ctr": {32, aes.BlockSize, streamCipherMode(0, newAESCTR)}, + + // Ciphers from RFC 4345, which introduces security-improved arcfour ciphers. + // They are defined in the order specified in the RFC. + "arcfour128": {16, 0, streamCipherMode(1536, newRC4)}, + "arcfour256": {32, 0, streamCipherMode(1536, newRC4)}, + + // Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol. + // Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and + // RC4) has problems with weak keys, and should be used with caution." + // RFC 4345 introduces improved versions of Arcfour. + "arcfour": {16, 0, streamCipherMode(0, newRC4)}, + + // AEAD ciphers + gcm128CipherID: {16, 12, newGCMCipher}, + gcm256CipherID: {32, 12, newGCMCipher}, + chacha20Poly1305ID: {64, 0, newChaCha20Cipher}, + + // CBC mode is insecure and so is not included in the default config. + // (See https://www.ieee-security.org/TC/SP2013/papers/4977a526.pdf). If absolutely + // needed, it's possible to specify a custom Config to enable it. + // You should expect that an active attacker can recover plaintext if + // you do. + aes128cbcID: {16, aes.BlockSize, newAESCBCCipher}, + + // 3des-cbc is insecure and is not included in the default + // config. + tripledescbcID: {24, des.BlockSize, newTripleDESCBCCipher}, +} + +// prefixLen is the length of the packet prefix that contains the packet length +// and number of padding bytes. +const prefixLen = 5 + +// streamPacketCipher is a packetCipher using a stream cipher. +type streamPacketCipher struct { + mac hash.Hash + cipher cipher.Stream + etm bool + + // The following members are to avoid per-packet allocations. + prefix [prefixLen]byte + seqNumBytes [4]byte + padding [2 * packetSizeMultiple]byte + packetData []byte + macResult []byte +} + +// readCipherPacket reads and decrypt a single packet from the reader argument. +func (s *streamPacketCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) { + if _, err := io.ReadFull(r, s.prefix[:]); err != nil { + return nil, err + } + + var encryptedPaddingLength [1]byte + if s.mac != nil && s.etm { + copy(encryptedPaddingLength[:], s.prefix[4:5]) + s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5]) + } else { + s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) + } + + length := binary.BigEndian.Uint32(s.prefix[0:4]) + paddingLength := uint32(s.prefix[4]) + + var macSize uint32 + if s.mac != nil { + s.mac.Reset() + binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) + s.mac.Write(s.seqNumBytes[:]) + if s.etm { + s.mac.Write(s.prefix[:4]) + s.mac.Write(encryptedPaddingLength[:]) + } else { + s.mac.Write(s.prefix[:]) + } + macSize = uint32(s.mac.Size()) + } + + if length <= paddingLength+1 { + return nil, errors.New("ssh: invalid packet length, packet too small") + } + + if length > maxPacket { + return nil, errors.New("ssh: invalid packet length, packet too large") + } + + // the maxPacket check above ensures that length-1+macSize + // does not overflow. + if uint32(cap(s.packetData)) < length-1+macSize { + s.packetData = make([]byte, length-1+macSize) + } else { + s.packetData = s.packetData[:length-1+macSize] + } + + if _, err := io.ReadFull(r, s.packetData); err != nil { + return nil, err + } + mac := s.packetData[length-1:] + data := s.packetData[:length-1] + + if s.mac != nil && s.etm { + s.mac.Write(data) + } + + s.cipher.XORKeyStream(data, data) + + if s.mac != nil { + if !s.etm { + s.mac.Write(data) + } + s.macResult = s.mac.Sum(s.macResult[:0]) + if subtle.ConstantTimeCompare(s.macResult, mac) != 1 { + return nil, errors.New("ssh: MAC failure") + } + } + + return s.packetData[:length-paddingLength-1], nil +} + +// writeCipherPacket encrypts and sends a packet of data to the writer argument +func (s *streamPacketCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + if len(packet) > maxPacket { + return errors.New("ssh: packet too large") + } + + aadlen := 0 + if s.mac != nil && s.etm { + // packet length is not encrypted for EtM modes + aadlen = 4 + } + + paddingLength := packetSizeMultiple - (prefixLen+len(packet)-aadlen)%packetSizeMultiple + if paddingLength < 4 { + paddingLength += packetSizeMultiple + } + + length := len(packet) + 1 + paddingLength + binary.BigEndian.PutUint32(s.prefix[:], uint32(length)) + s.prefix[4] = byte(paddingLength) + padding := s.padding[:paddingLength] + if _, err := io.ReadFull(rand, padding); err != nil { + return err + } + + if s.mac != nil { + s.mac.Reset() + binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) + s.mac.Write(s.seqNumBytes[:]) + + if s.etm { + // For EtM algorithms, the packet length must stay unencrypted, + // but the following data (padding length) must be encrypted + s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5]) + } + + s.mac.Write(s.prefix[:]) + + if !s.etm { + // For non-EtM algorithms, the algorithm is applied on unencrypted data + s.mac.Write(packet) + s.mac.Write(padding) + } + } + + if !(s.mac != nil && s.etm) { + // For EtM algorithms, the padding length has already been encrypted + // and the packet length must remain unencrypted + s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) + } + + s.cipher.XORKeyStream(packet, packet) + s.cipher.XORKeyStream(padding, padding) + + if s.mac != nil && s.etm { + // For EtM algorithms, packet and padding must be encrypted + s.mac.Write(packet) + s.mac.Write(padding) + } + + if _, err := w.Write(s.prefix[:]); err != nil { + return err + } + if _, err := w.Write(packet); err != nil { + return err + } + if _, err := w.Write(padding); err != nil { + return err + } + + if s.mac != nil { + s.macResult = s.mac.Sum(s.macResult[:0]) + if _, err := w.Write(s.macResult); err != nil { + return err + } + } + + return nil +} + +type gcmCipher struct { + aead cipher.AEAD + prefix [4]byte + iv []byte + buf []byte +} + +func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + aead, err := cipher.NewGCM(c) + if err != nil { + return nil, err + } + + return &gcmCipher{ + aead: aead, + iv: iv, + }, nil +} + +const gcmTagSize = 16 + +func (c *gcmCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + // Pad out to multiple of 16 bytes. This is different from the + // stream cipher because that encrypts the length too. + padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple) + if padding < 4 { + padding += packetSizeMultiple + } + + length := uint32(len(packet) + int(padding) + 1) + binary.BigEndian.PutUint32(c.prefix[:], length) + if _, err := w.Write(c.prefix[:]); err != nil { + return err + } + + if cap(c.buf) < int(length) { + c.buf = make([]byte, length) + } else { + c.buf = c.buf[:length] + } + + c.buf[0] = padding + copy(c.buf[1:], packet) + if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil { + return err + } + c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:]) + if _, err := w.Write(c.buf); err != nil { + return err + } + c.incIV() + + return nil +} + +func (c *gcmCipher) incIV() { + for i := 4 + 7; i >= 4; i-- { + c.iv[i]++ + if c.iv[i] != 0 { + break + } + } +} + +func (c *gcmCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) { + if _, err := io.ReadFull(r, c.prefix[:]); err != nil { + return nil, err + } + length := binary.BigEndian.Uint32(c.prefix[:]) + if length > maxPacket { + return nil, errors.New("ssh: max packet length exceeded") + } + + if cap(c.buf) < int(length+gcmTagSize) { + c.buf = make([]byte, length+gcmTagSize) + } else { + c.buf = c.buf[:length+gcmTagSize] + } + + if _, err := io.ReadFull(r, c.buf); err != nil { + return nil, err + } + + plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:]) + if err != nil { + return nil, err + } + c.incIV() + + if len(plain) == 0 { + return nil, errors.New("ssh: empty packet") + } + + padding := plain[0] + if padding < 4 { + // padding is a byte, so it automatically satisfies + // the maximum size, which is 255. + return nil, fmt.Errorf("ssh: illegal padding %d", padding) + } + + if int(padding+1) >= len(plain) { + return nil, fmt.Errorf("ssh: padding %d too large", padding) + } + plain = plain[1 : length-uint32(padding)] + return plain, nil +} + +// cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1 +type cbcCipher struct { + mac hash.Hash + macSize uint32 + decrypter cipher.BlockMode + encrypter cipher.BlockMode + + // The following members are to avoid per-packet allocations. + seqNumBytes [4]byte + packetData []byte + macResult []byte + + // Amount of data we should still read to hide which + // verification error triggered. + oracleCamouflage uint32 +} + +func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + cbc := &cbcCipher{ + mac: macModes[algs.MAC].new(macKey), + decrypter: cipher.NewCBCDecrypter(c, iv), + encrypter: cipher.NewCBCEncrypter(c, iv), + packetData: make([]byte, 1024), + } + if cbc.mac != nil { + cbc.macSize = uint32(cbc.mac.Size()) + } + + return cbc, nil +} + +func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + c, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + cbc, err := newCBCCipher(c, key, iv, macKey, algs) + if err != nil { + return nil, err + } + + return cbc, nil +} + +func newTripleDESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) { + c, err := des.NewTripleDESCipher(key) + if err != nil { + return nil, err + } + + cbc, err := newCBCCipher(c, key, iv, macKey, algs) + if err != nil { + return nil, err + } + + return cbc, nil +} + +func maxUInt32(a, b int) uint32 { + if a > b { + return uint32(a) + } + return uint32(b) +} + +const ( + cbcMinPacketSizeMultiple = 8 + cbcMinPacketSize = 16 + cbcMinPaddingSize = 4 +) + +// cbcError represents a verification error that may leak information. +type cbcError string + +func (e cbcError) Error() string { return string(e) } + +func (c *cbcCipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) { + p, err := c.readCipherPacketLeaky(seqNum, r) + if err != nil { + if _, ok := err.(cbcError); ok { + // Verification error: read a fixed amount of + // data, to make distinguishing between + // failing MAC and failing length check more + // difficult. + io.CopyN(io.Discard, r, int64(c.oracleCamouflage)) + } + } + return p, err +} + +func (c *cbcCipher) readCipherPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) { + blockSize := c.decrypter.BlockSize() + + // Read the header, which will include some of the subsequent data in the + // case of block ciphers - this is copied back to the payload later. + // How many bytes of payload/padding will be read with this first read. + firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize) + firstBlock := c.packetData[:firstBlockLength] + if _, err := io.ReadFull(r, firstBlock); err != nil { + return nil, err + } + + c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength + + c.decrypter.CryptBlocks(firstBlock, firstBlock) + length := binary.BigEndian.Uint32(firstBlock[:4]) + if length > maxPacket { + return nil, cbcError("ssh: packet too large") + } + if length+4 < maxUInt32(cbcMinPacketSize, blockSize) { + // The minimum size of a packet is 16 (or the cipher block size, whichever + // is larger) bytes. + return nil, cbcError("ssh: packet too small") + } + // The length of the packet (including the length field but not the MAC) must + // be a multiple of the block size or 8, whichever is larger. + if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 { + return nil, cbcError("ssh: invalid packet length multiple") + } + + paddingLength := uint32(firstBlock[4]) + if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 { + return nil, cbcError("ssh: invalid packet length") + } + + // Positions within the c.packetData buffer: + macStart := 4 + length + paddingStart := macStart - paddingLength + + // Entire packet size, starting before length, ending at end of mac. + entirePacketSize := macStart + c.macSize + + // Ensure c.packetData is large enough for the entire packet data. + if uint32(cap(c.packetData)) < entirePacketSize { + // Still need to upsize and copy, but this should be rare at runtime, only + // on upsizing the packetData buffer. + c.packetData = make([]byte, entirePacketSize) + copy(c.packetData, firstBlock) + } else { + c.packetData = c.packetData[:entirePacketSize] + } + + n, err := io.ReadFull(r, c.packetData[firstBlockLength:]) + if err != nil { + return nil, err + } + c.oracleCamouflage -= uint32(n) + + remainingCrypted := c.packetData[firstBlockLength:macStart] + c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted) + + mac := c.packetData[macStart:] + if c.mac != nil { + c.mac.Reset() + binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) + c.mac.Write(c.seqNumBytes[:]) + c.mac.Write(c.packetData[:macStart]) + c.macResult = c.mac.Sum(c.macResult[:0]) + if subtle.ConstantTimeCompare(c.macResult, mac) != 1 { + return nil, cbcError("ssh: MAC failure") + } + } + + return c.packetData[prefixLen:paddingStart], nil +} + +func (c *cbcCipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { + effectiveBlockSize := maxUInt32(cbcMinPacketSizeMultiple, c.encrypter.BlockSize()) + + // Length of encrypted portion of the packet (header, payload, padding). + // Enforce minimum padding and packet size. + encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize) + // Enforce block size. + encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize + + length := encLength - 4 + paddingLength := int(length) - (1 + len(packet)) + + // Overall buffer contains: header, payload, padding, mac. + // Space for the MAC is reserved in the capacity but not the slice length. + bufferSize := encLength + c.macSize + if uint32(cap(c.packetData)) < bufferSize { + c.packetData = make([]byte, encLength, bufferSize) + } else { + c.packetData = c.packetData[:encLength] + } + + p := c.packetData + + // Packet header. + binary.BigEndian.PutUint32(p, length) + p = p[4:] + p[0] = byte(paddingLength) + + // Payload. + p = p[1:] + copy(p, packet) + + // Padding. + p = p[len(packet):] + if _, err := io.ReadFull(rand, p); err != nil { + return err + } + + if c.mac != nil { + c.mac.Reset() + binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) + c.mac.Write(c.seqNumBytes[:]) + c.mac.Write(c.packetData) + // The MAC is now appended into the capacity reserved for it earlier. + c.packetData = c.mac.Sum(c.packetData) + } + + c.encrypter.CryptBlocks(c.packetData[:encLength], c.packetData[:encLength]) + + if _, err := w.Write(c.packetData); err != nil { + return err + } + + return nil +} + +const chacha20Poly1305ID = "chacha20-poly1305@openssh.com" + +// chacha20Poly1305Cipher implements the chacha20-poly1305@openssh.com +// AEAD, which is described here: +// +// https://tools.ietf.org/html/draft-josefsson-ssh-chacha20-poly1305-openssh-00 +// +// the methods here also implement padding, which RFC 4253 Section 6 +// also requires of stream ciphers. +type chacha20Poly1305Cipher struct { + lengthKey [32]byte + contentKey [32]byte + buf []byte +} + +func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) { + if len(key) != 64 { + panic(len(key)) + } + + c := &chacha20Poly1305Cipher{ + buf: make([]byte, 256), + } + + copy(c.contentKey[:], key[:32]) + copy(c.lengthKey[:], key[32:]) + return c, nil +} + +func (c *chacha20Poly1305Cipher) readCipherPacket(seqNum uint32, r io.Reader) ([]byte, error) { + nonce := make([]byte, 12) + binary.BigEndian.PutUint32(nonce[8:], seqNum) + s, err := chacha20.NewUnauthenticatedCipher(c.contentKey[:], nonce) + if err != nil { + return nil, err + } + var polyKey, discardBuf [32]byte + s.XORKeyStream(polyKey[:], polyKey[:]) + s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes + + encryptedLength := c.buf[:4] + if _, err := io.ReadFull(r, encryptedLength); err != nil { + return nil, err + } + + var lenBytes [4]byte + ls, err := chacha20.NewUnauthenticatedCipher(c.lengthKey[:], nonce) + if err != nil { + return nil, err + } + ls.XORKeyStream(lenBytes[:], encryptedLength) + + length := binary.BigEndian.Uint32(lenBytes[:]) + if length > maxPacket { + return nil, errors.New("ssh: invalid packet length, packet too large") + } + + contentEnd := 4 + length + packetEnd := contentEnd + poly1305.TagSize + if uint32(cap(c.buf)) < packetEnd { + c.buf = make([]byte, packetEnd) + copy(c.buf[:], encryptedLength) + } else { + c.buf = c.buf[:packetEnd] + } + + if _, err := io.ReadFull(r, c.buf[4:packetEnd]); err != nil { + return nil, err + } + + var mac [poly1305.TagSize]byte + copy(mac[:], c.buf[contentEnd:packetEnd]) + if !poly1305.Verify(&mac, c.buf[:contentEnd], &polyKey) { + return nil, errors.New("ssh: MAC failure") + } + + plain := c.buf[4:contentEnd] + s.XORKeyStream(plain, plain) + + if len(plain) == 0 { + return nil, errors.New("ssh: empty packet") + } + + padding := plain[0] + if padding < 4 { + // padding is a byte, so it automatically satisfies + // the maximum size, which is 255. + return nil, fmt.Errorf("ssh: illegal padding %d", padding) + } + + if int(padding)+1 >= len(plain) { + return nil, fmt.Errorf("ssh: padding %d too large", padding) + } + + plain = plain[1 : len(plain)-int(padding)] + + return plain, nil +} + +func (c *chacha20Poly1305Cipher) writeCipherPacket(seqNum uint32, w io.Writer, rand io.Reader, payload []byte) error { + nonce := make([]byte, 12) + binary.BigEndian.PutUint32(nonce[8:], seqNum) + s, err := chacha20.NewUnauthenticatedCipher(c.contentKey[:], nonce) + if err != nil { + return err + } + var polyKey, discardBuf [32]byte + s.XORKeyStream(polyKey[:], polyKey[:]) + s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes + + // There is no blocksize, so fall back to multiple of 8 byte + // padding, as described in RFC 4253, Sec 6. + const packetSizeMultiple = 8 + + padding := packetSizeMultiple - (1+len(payload))%packetSizeMultiple + if padding < 4 { + padding += packetSizeMultiple + } + + // size (4 bytes), padding (1), payload, padding, tag. + totalLength := 4 + 1 + len(payload) + padding + poly1305.TagSize + if cap(c.buf) < totalLength { + c.buf = make([]byte, totalLength) + } else { + c.buf = c.buf[:totalLength] + } + + binary.BigEndian.PutUint32(c.buf, uint32(1+len(payload)+padding)) + ls, err := chacha20.NewUnauthenticatedCipher(c.lengthKey[:], nonce) + if err != nil { + return err + } + ls.XORKeyStream(c.buf, c.buf[:4]) + c.buf[4] = byte(padding) + copy(c.buf[5:], payload) + packetEnd := 5 + len(payload) + padding + if _, err := io.ReadFull(rand, c.buf[5+len(payload):packetEnd]); err != nil { + return err + } + + s.XORKeyStream(c.buf[4:], c.buf[4:packetEnd]) + + var mac [poly1305.TagSize]byte + poly1305.Sum(&mac, c.buf[:packetEnd], &polyKey) + + copy(c.buf[packetEnd:], mac[:]) + + if _, err := w.Write(c.buf); err != nil { + return err + } + return nil +} diff --git a/tempfork/sshtest/ssh/cipher_test.go b/tempfork/sshtest/ssh/cipher_test.go new file mode 100644 index 0000000000000..fe339862c5a9d --- /dev/null +++ b/tempfork/sshtest/ssh/cipher_test.go @@ -0,0 +1,231 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto" + "crypto/rand" + "encoding/binary" + "io" + "testing" + + "golang.org/x/crypto/chacha20" + "golang.org/x/crypto/poly1305" +) + +func TestDefaultCiphersExist(t *testing.T) { + for _, cipherAlgo := range supportedCiphers { + if _, ok := cipherModes[cipherAlgo]; !ok { + t.Errorf("supported cipher %q is unknown", cipherAlgo) + } + } + for _, cipherAlgo := range preferredCiphers { + if _, ok := cipherModes[cipherAlgo]; !ok { + t.Errorf("preferred cipher %q is unknown", cipherAlgo) + } + } +} + +func TestPacketCiphers(t *testing.T) { + defaultMac := "hmac-sha2-256" + defaultCipher := "aes128-ctr" + for cipher := range cipherModes { + t.Run("cipher="+cipher, + func(t *testing.T) { testPacketCipher(t, cipher, defaultMac) }) + } + for mac := range macModes { + t.Run("mac="+mac, + func(t *testing.T) { testPacketCipher(t, defaultCipher, mac) }) + } +} + +func testPacketCipher(t *testing.T, cipher, mac string) { + kr := &kexResult{Hash: crypto.SHA1} + algs := directionAlgorithms{ + Cipher: cipher, + MAC: mac, + Compression: "none", + } + client, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err) + } + server, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err) + } + + want := "bla bla" + input := []byte(want) + buf := &bytes.Buffer{} + if err := client.writeCipherPacket(0, buf, rand.Reader, input); err != nil { + t.Fatalf("writeCipherPacket(%q, %q): %v", cipher, mac, err) + } + + packet, err := server.readCipherPacket(0, buf) + if err != nil { + t.Fatalf("readCipherPacket(%q, %q): %v", cipher, mac, err) + } + + if string(packet) != want { + t.Errorf("roundtrip(%q, %q): got %q, want %q", cipher, mac, packet, want) + } +} + +func TestCBCOracleCounterMeasure(t *testing.T) { + kr := &kexResult{Hash: crypto.SHA1} + algs := directionAlgorithms{ + Cipher: aes128cbcID, + MAC: "hmac-sha1", + Compression: "none", + } + client, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client): %v", err) + } + + want := "bla bla" + input := []byte(want) + buf := &bytes.Buffer{} + if err := client.writeCipherPacket(0, buf, rand.Reader, input); err != nil { + t.Errorf("writeCipherPacket: %v", err) + } + + packetSize := buf.Len() + buf.Write(make([]byte, 2*maxPacket)) + + // We corrupt each byte, but this usually will only test the + // 'packet too large' or 'MAC failure' cases. + lastRead := -1 + for i := 0; i < packetSize; i++ { + server, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client): %v", err) + } + + fresh := &bytes.Buffer{} + fresh.Write(buf.Bytes()) + fresh.Bytes()[i] ^= 0x01 + + before := fresh.Len() + _, err = server.readCipherPacket(0, fresh) + if err == nil { + t.Errorf("corrupt byte %d: readCipherPacket succeeded ", i) + continue + } + if _, ok := err.(cbcError); !ok { + t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err) + continue + } + + after := fresh.Len() + bytesRead := before - after + if bytesRead < maxPacket { + t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket) + continue + } + + if i > 0 && bytesRead != lastRead { + t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead) + } + lastRead = bytesRead + } +} + +func TestCVE202143565(t *testing.T) { + tests := []struct { + cipher string + constructPacket func(packetCipher) io.Reader + }{ + { + cipher: gcm128CipherID, + constructPacket: func(client packetCipher) io.Reader { + internalCipher := client.(*gcmCipher) + b := &bytes.Buffer{} + prefix := [4]byte{} + if _, err := b.Write(prefix[:]); err != nil { + t.Fatal(err) + } + internalCipher.buf = internalCipher.aead.Seal(internalCipher.buf[:0], internalCipher.iv, []byte{}, prefix[:]) + if _, err := b.Write(internalCipher.buf); err != nil { + t.Fatal(err) + } + internalCipher.incIV() + + return b + }, + }, + { + cipher: chacha20Poly1305ID, + constructPacket: func(client packetCipher) io.Reader { + internalCipher := client.(*chacha20Poly1305Cipher) + b := &bytes.Buffer{} + + nonce := make([]byte, 12) + s, err := chacha20.NewUnauthenticatedCipher(internalCipher.contentKey[:], nonce) + if err != nil { + t.Fatal(err) + } + var polyKey, discardBuf [32]byte + s.XORKeyStream(polyKey[:], polyKey[:]) + s.XORKeyStream(discardBuf[:], discardBuf[:]) // skip the next 32 bytes + + internalCipher.buf = make([]byte, 4+poly1305.TagSize) + binary.BigEndian.PutUint32(internalCipher.buf, 0) + ls, err := chacha20.NewUnauthenticatedCipher(internalCipher.lengthKey[:], nonce) + if err != nil { + t.Fatal(err) + } + ls.XORKeyStream(internalCipher.buf, internalCipher.buf[:4]) + if _, err := io.ReadFull(rand.Reader, internalCipher.buf[4:4]); err != nil { + t.Fatal(err) + } + + s.XORKeyStream(internalCipher.buf[4:], internalCipher.buf[4:4]) + + var tag [poly1305.TagSize]byte + poly1305.Sum(&tag, internalCipher.buf[:4], &polyKey) + + copy(internalCipher.buf[4:], tag[:]) + + if _, err := b.Write(internalCipher.buf); err != nil { + t.Fatal(err) + } + + return b + }, + }, + } + + for _, tc := range tests { + mac := "hmac-sha2-256" + + kr := &kexResult{Hash: crypto.SHA1} + algs := directionAlgorithms{ + Cipher: tc.cipher, + MAC: mac, + Compression: "none", + } + client, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client, %q, %q): %v", tc.cipher, mac, err) + } + server, err := newPacketCipher(clientKeys, algs, kr) + if err != nil { + t.Fatalf("newPacketCipher(client, %q, %q): %v", tc.cipher, mac, err) + } + + b := tc.constructPacket(client) + + wantErr := "ssh: empty packet" + _, err = server.readCipherPacket(0, b) + if err == nil { + t.Fatalf("readCipherPacket(%q, %q): didn't fail with empty packet", tc.cipher, mac) + } else if err.Error() != wantErr { + t.Fatalf("readCipherPacket(%q, %q): unexpected error, got %q, want %q", tc.cipher, mac, err, wantErr) + } + } +} diff --git a/tempfork/sshtest/ssh/client.go b/tempfork/sshtest/ssh/client.go new file mode 100644 index 0000000000000..5876e6421d23c --- /dev/null +++ b/tempfork/sshtest/ssh/client.go @@ -0,0 +1,290 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "net" + "os" + "sync" + "time" +) + +// Client implements a traditional SSH client that supports shells, +// subprocesses, TCP port/streamlocal forwarding and tunneled dialing. +type Client struct { + Conn + + handleForwardsOnce sync.Once // guards calling (*Client).handleForwards + + forwards forwardList // forwarded tcpip connections from the remote side + mu sync.Mutex + channelHandlers map[string]chan NewChannel +} + +// HandleChannelOpen returns a channel on which NewChannel requests +// for the given type are sent. If the type already is being handled, +// nil is returned. The channel is closed when the connection is closed. +func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel { + c.mu.Lock() + defer c.mu.Unlock() + if c.channelHandlers == nil { + // The SSH channel has been closed. + c := make(chan NewChannel) + close(c) + return c + } + + ch := c.channelHandlers[channelType] + if ch != nil { + return nil + } + + ch = make(chan NewChannel, chanSize) + c.channelHandlers[channelType] = ch + return ch +} + +// NewClient creates a Client on top of the given connection. +func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { + conn := &Client{ + Conn: c, + channelHandlers: make(map[string]chan NewChannel, 1), + } + + go conn.handleGlobalRequests(reqs) + go conn.handleChannelOpens(chans) + go func() { + conn.Wait() + conn.forwards.closeAll() + }() + return conn +} + +// NewClientConn establishes an authenticated SSH connection using c +// as the underlying transport. The Request and NewChannel channels +// must be serviced or the connection will hang. +func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { + fullConf := *config + fullConf.SetDefaults() + if fullConf.HostKeyCallback == nil { + c.Close() + return nil, nil, nil, errors.New("ssh: must specify HostKeyCallback") + } + + conn := &connection{ + sshConn: sshConn{conn: c, user: fullConf.User}, + } + + if err := conn.clientHandshake(addr, &fullConf); err != nil { + c.Close() + return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %w", err) + } + conn.mux = newMux(conn.transport) + return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil +} + +// clientHandshake performs the client side key exchange. See RFC 4253 Section +// 7. +func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error { + if config.ClientVersion != "" { + c.clientVersion = []byte(config.ClientVersion) + } else { + c.clientVersion = []byte(packageVersion) + } + var err error + c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) + if err != nil { + return err + } + + c.transport = newClientTransport( + newTransport(c.sshConn.conn, config.Rand, true /* is client */), + c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr()) + if err := c.transport.waitSession(); err != nil { + return err + } + + c.sessionID = c.transport.getSessionID() + return c.clientAuthenticate(config) +} + +// verifyHostKeySignature verifies the host key obtained in the key exchange. +// algo is the negotiated algorithm, and may be a certificate type. +func verifyHostKeySignature(hostKey PublicKey, algo string, result *kexResult) error { + sig, rest, ok := parseSignatureBody(result.Signature) + if len(rest) > 0 || !ok { + return errors.New("ssh: signature parse error") + } + + if a := underlyingAlgo(algo); sig.Format != a { + return fmt.Errorf("ssh: invalid signature algorithm %q, expected %q", sig.Format, a) + } + + return hostKey.Verify(result.H, sig) +} + +// NewSession opens a new Session for this client. (A session is a remote +// execution of a program.) +func (c *Client) NewSession() (*Session, error) { + ch, in, err := c.OpenChannel("session", nil) + if err != nil { + return nil, err + } + return newSession(ch, in) +} + +func (c *Client) handleGlobalRequests(incoming <-chan *Request) { + for r := range incoming { + // This handles keepalive messages and matches + // the behaviour of OpenSSH. + r.Reply(false, nil) + } +} + +// handleChannelOpens channel open messages from the remote side. +func (c *Client) handleChannelOpens(in <-chan NewChannel) { + for ch := range in { + c.mu.Lock() + handler := c.channelHandlers[ch.ChannelType()] + c.mu.Unlock() + + if handler != nil { + handler <- ch + } else { + ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType())) + } + } + + c.mu.Lock() + for _, ch := range c.channelHandlers { + close(ch) + } + c.channelHandlers = nil + c.mu.Unlock() +} + +// Dial starts a client connection to the given SSH server. It is a +// convenience function that connects to the given network address, +// initiates the SSH handshake, and then sets up a Client. For access +// to incoming channels and requests, use net.Dial with NewClientConn +// instead. +func Dial(network, addr string, config *ClientConfig) (*Client, error) { + conn, err := net.DialTimeout(network, addr, config.Timeout) + if err != nil { + return nil, err + } + c, chans, reqs, err := NewClientConn(conn, addr, config) + if err != nil { + return nil, err + } + return NewClient(c, chans, reqs), nil +} + +// HostKeyCallback is the function type used for verifying server +// keys. A HostKeyCallback must return nil if the host key is OK, or +// an error to reject it. It receives the hostname as passed to Dial +// or NewClientConn. The remote address is the RemoteAddr of the +// net.Conn underlying the SSH connection. +type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error + +// BannerCallback is the function type used for treat the banner sent by +// the server. A BannerCallback receives the message sent by the remote server. +type BannerCallback func(message string) error + +// A ClientConfig structure is used to configure a Client. It must not be +// modified after having been passed to an SSH function. +type ClientConfig struct { + // Config contains configuration that is shared between clients and + // servers. + Config + + // User contains the username to authenticate as. + User string + + // Auth contains possible authentication methods to use with the + // server. Only the first instance of a particular RFC 4252 method will + // be used during authentication. + Auth []AuthMethod + + // HostKeyCallback is called during the cryptographic + // handshake to validate the server's host key. The client + // configuration must supply this callback for the connection + // to succeed. The functions InsecureIgnoreHostKey or + // FixedHostKey can be used for simplistic host key checks. + HostKeyCallback HostKeyCallback + + // BannerCallback is called during the SSH dance to display a custom + // server's message. The client configuration can supply this callback to + // handle it as wished. The function BannerDisplayStderr can be used for + // simplistic display on Stderr. + BannerCallback BannerCallback + + // ClientVersion contains the version identification string that will + // be used for the connection. If empty, a reasonable default is used. + ClientVersion string + + // HostKeyAlgorithms lists the public key algorithms that the client will + // accept from the server for host key authentication, in order of + // preference. If empty, a reasonable default is used. Any + // string returned from a PublicKey.Type method may be used, or + // any of the CertAlgo and KeyAlgo constants. + HostKeyAlgorithms []string + + // Timeout is the maximum amount of time for the TCP connection to establish. + // + // A Timeout of zero means no timeout. + Timeout time.Duration + + // SkipNoneAuth allows skipping the initial "none" auth request. This is unusual + // behavior, but it is allowed by [RFC4252 5.2](https://datatracker.ietf.org/doc/html/rfc4252#section-5.2), + // and some clients in the wild behave like this. One such client is the paramiko Python + // library, which is used in pgadmin4 via the sshtunnel library. + // When SkipNoneAuth is true, the client will attempt all configured + // [AuthMethod]s until one works, or it runs out. + SkipNoneAuth bool +} + +// InsecureIgnoreHostKey returns a function that can be used for +// ClientConfig.HostKeyCallback to accept any host key. It should +// not be used for production code. +func InsecureIgnoreHostKey() HostKeyCallback { + return func(hostname string, remote net.Addr, key PublicKey) error { + return nil + } +} + +type fixedHostKey struct { + key PublicKey +} + +func (f *fixedHostKey) check(hostname string, remote net.Addr, key PublicKey) error { + if f.key == nil { + return fmt.Errorf("ssh: required host key was nil") + } + if !bytes.Equal(key.Marshal(), f.key.Marshal()) { + return fmt.Errorf("ssh: host key mismatch") + } + return nil +} + +// FixedHostKey returns a function for use in +// ClientConfig.HostKeyCallback to accept only a specific host key. +func FixedHostKey(key PublicKey) HostKeyCallback { + hk := &fixedHostKey{key} + return hk.check +} + +// BannerDisplayStderr returns a function that can be used for +// ClientConfig.BannerCallback to display banners on os.Stderr. +func BannerDisplayStderr() BannerCallback { + return func(banner string) error { + _, err := os.Stderr.WriteString(banner) + + return err + } +} diff --git a/tempfork/sshtest/ssh/client_auth.go b/tempfork/sshtest/ssh/client_auth.go new file mode 100644 index 0000000000000..af25a4f01b9e7 --- /dev/null +++ b/tempfork/sshtest/ssh/client_auth.go @@ -0,0 +1,805 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "strings" +) + +type authResult int + +const ( + authFailure authResult = iota + authPartialSuccess + authSuccess +) + +// clientAuthenticate authenticates with the remote server. See RFC 4252. +func (c *connection) clientAuthenticate(config *ClientConfig) error { + // initiate user auth session + if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { + return err + } + packet, err := c.transport.readPacket() + if err != nil { + return err + } + // The server may choose to send a SSH_MSG_EXT_INFO at this point (if we + // advertised willingness to receive one, which we always do) or not. See + // RFC 8308, Section 2.4. + extensions := make(map[string][]byte) + if len(packet) > 0 && packet[0] == msgExtInfo { + var extInfo extInfoMsg + if err := Unmarshal(packet, &extInfo); err != nil { + return err + } + payload := extInfo.Payload + for i := uint32(0); i < extInfo.NumExtensions; i++ { + name, rest, ok := parseString(payload) + if !ok { + return parseError(msgExtInfo) + } + value, rest, ok := parseString(rest) + if !ok { + return parseError(msgExtInfo) + } + extensions[string(name)] = value + payload = rest + } + packet, err = c.transport.readPacket() + if err != nil { + return err + } + } + var serviceAccept serviceAcceptMsg + if err := Unmarshal(packet, &serviceAccept); err != nil { + return err + } + + // during the authentication phase the client first attempts the "none" method + // then any untried methods suggested by the server. + var tried []string + var lastMethods []string + + sessionID := c.transport.getSessionID() + var auth AuthMethod + if !config.SkipNoneAuth { + auth = AuthMethod(new(noneAuth)) + } else if len(config.Auth) > 0 { + auth = config.Auth[0] + for _, a := range config.Auth { + lastMethods = append(lastMethods, a.method()) + } + } + for auth != nil { + ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand, extensions) + if err != nil { + // On disconnect, return error immediately + if _, ok := err.(*disconnectMsg); ok { + return err + } + // We return the error later if there is no other method left to + // try. + ok = authFailure + } + if ok == authSuccess { + // success + return nil + } else if ok == authFailure { + if m := auth.method(); !contains(tried, m) { + tried = append(tried, m) + } + } + if methods == nil { + methods = lastMethods + } + lastMethods = methods + + auth = nil + + findNext: + for _, a := range config.Auth { + candidateMethod := a.method() + if contains(tried, candidateMethod) { + continue + } + for _, meth := range methods { + if meth == candidateMethod { + auth = a + break findNext + } + } + } + + if auth == nil && err != nil { + // We have an error and there are no other authentication methods to + // try, so we return it. + return err + } + } + return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", tried) +} + +func contains(list []string, e string) bool { + for _, s := range list { + if s == e { + return true + } + } + return false +} + +// An AuthMethod represents an instance of an RFC 4252 authentication method. +type AuthMethod interface { + // auth authenticates user over transport t. + // Returns true if authentication is successful. + // If authentication is not successful, a []string of alternative + // method names is returned. If the slice is nil, it will be ignored + // and the previous set of possible methods will be reused. + auth(session []byte, user string, p packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) + + // method returns the RFC 4252 method name. + method() string +} + +// "none" authentication, RFC 4252 section 5.2. +type noneAuth int + +func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) { + if err := c.writePacket(Marshal(&userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: "none", + })); err != nil { + return authFailure, nil, err + } + + return handleAuthResponse(c) +} + +func (n *noneAuth) method() string { + return "none" +} + +// passwordCallback is an AuthMethod that fetches the password through +// a function call, e.g. by prompting the user. +type passwordCallback func() (password string, err error) + +func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) { + type passwordAuthMsg struct { + User string `sshtype:"50"` + Service string + Method string + Reply bool + Password string + } + + pw, err := cb() + // REVIEW NOTE: is there a need to support skipping a password attempt? + // The program may only find out that the user doesn't have a password + // when prompting. + if err != nil { + return authFailure, nil, err + } + + if err := c.writePacket(Marshal(&passwordAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + Reply: false, + Password: pw, + })); err != nil { + return authFailure, nil, err + } + + return handleAuthResponse(c) +} + +func (cb passwordCallback) method() string { + return "password" +} + +// Password returns an AuthMethod using the given password. +func Password(secret string) AuthMethod { + return passwordCallback(func() (string, error) { return secret, nil }) +} + +// PasswordCallback returns an AuthMethod that uses a callback for +// fetching a password. +func PasswordCallback(prompt func() (secret string, err error)) AuthMethod { + return passwordCallback(prompt) +} + +type publickeyAuthMsg struct { + User string `sshtype:"50"` + Service string + Method string + // HasSig indicates to the receiver packet that the auth request is signed and + // should be used for authentication of the request. + HasSig bool + Algoname string + PubKey []byte + // Sig is tagged with "rest" so Marshal will exclude it during + // validateKey + Sig []byte `ssh:"rest"` +} + +// publicKeyCallback is an AuthMethod that uses a set of key +// pairs for authentication. +type publicKeyCallback func() ([]Signer, error) + +func (cb publicKeyCallback) method() string { + return "publickey" +} + +func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (MultiAlgorithmSigner, string, error) { + var as MultiAlgorithmSigner + keyFormat := signer.PublicKey().Type() + + // If the signer implements MultiAlgorithmSigner we use the algorithms it + // support, if it implements AlgorithmSigner we assume it supports all + // algorithms, otherwise only the key format one. + switch s := signer.(type) { + case MultiAlgorithmSigner: + as = s + case AlgorithmSigner: + as = &multiAlgorithmSigner{ + AlgorithmSigner: s, + supportedAlgorithms: algorithmsForKeyFormat(underlyingAlgo(keyFormat)), + } + default: + as = &multiAlgorithmSigner{ + AlgorithmSigner: algorithmSignerWrapper{signer}, + supportedAlgorithms: []string{underlyingAlgo(keyFormat)}, + } + } + + getFallbackAlgo := func() (string, error) { + // Fallback to use if there is no "server-sig-algs" extension or a + // common algorithm cannot be found. We use the public key format if the + // MultiAlgorithmSigner supports it, otherwise we return an error. + if !contains(as.Algorithms(), underlyingAlgo(keyFormat)) { + return "", fmt.Errorf("ssh: no common public key signature algorithm, server only supports %q for key type %q, signer only supports %v", + underlyingAlgo(keyFormat), keyFormat, as.Algorithms()) + } + return keyFormat, nil + } + + extPayload, ok := extensions["server-sig-algs"] + if !ok { + // If there is no "server-sig-algs" extension use the fallback + // algorithm. + algo, err := getFallbackAlgo() + return as, algo, err + } + + // The server-sig-algs extension only carries underlying signature + // algorithm, but we are trying to select a protocol-level public key + // algorithm, which might be a certificate type. Extend the list of server + // supported algorithms to include the corresponding certificate algorithms. + serverAlgos := strings.Split(string(extPayload), ",") + for _, algo := range serverAlgos { + if certAlgo, ok := certificateAlgo(algo); ok { + serverAlgos = append(serverAlgos, certAlgo) + } + } + + // Filter algorithms based on those supported by MultiAlgorithmSigner. + var keyAlgos []string + for _, algo := range algorithmsForKeyFormat(keyFormat) { + if contains(as.Algorithms(), underlyingAlgo(algo)) { + keyAlgos = append(keyAlgos, algo) + } + } + + algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos) + if err != nil { + // If there is no overlap, return the fallback algorithm to support + // servers that fail to list all supported algorithms. + algo, err := getFallbackAlgo() + return as, algo, err + } + return as, algo, nil +} + +func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) { + // Authentication is performed by sending an enquiry to test if a key is + // acceptable to the remote. If the key is acceptable, the client will + // attempt to authenticate with the valid key. If not the client will repeat + // the process with the remaining keys. + + signers, err := cb() + if err != nil { + return authFailure, nil, err + } + var methods []string + var errSigAlgo error + + origSignersLen := len(signers) + for idx := 0; idx < len(signers); idx++ { + signer := signers[idx] + pub := signer.PublicKey() + as, algo, err := pickSignatureAlgorithm(signer, extensions) + if err != nil && errSigAlgo == nil { + // If we cannot negotiate a signature algorithm store the first + // error so we can return it to provide a more meaningful message if + // no other signers work. + errSigAlgo = err + continue + } + ok, err := validateKey(pub, algo, user, c) + if err != nil { + return authFailure, nil, err + } + // OpenSSH 7.2-7.7 advertises support for rsa-sha2-256 and rsa-sha2-512 + // in the "server-sig-algs" extension but doesn't support these + // algorithms for certificate authentication, so if the server rejects + // the key try to use the obtained algorithm as if "server-sig-algs" had + // not been implemented if supported from the algorithm signer. + if !ok && idx < origSignersLen && isRSACert(algo) && algo != CertAlgoRSAv01 { + if contains(as.Algorithms(), KeyAlgoRSA) { + // We retry using the compat algorithm after all signers have + // been tried normally. + signers = append(signers, &multiAlgorithmSigner{ + AlgorithmSigner: as, + supportedAlgorithms: []string{KeyAlgoRSA}, + }) + } + } + if !ok { + continue + } + + pubKey := pub.Marshal() + data := buildDataSignedForAuth(session, userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + }, algo, pubKey) + sign, err := as.SignWithAlgorithm(rand, data, underlyingAlgo(algo)) + if err != nil { + return authFailure, nil, err + } + + // manually wrap the serialized signature in a string + s := Marshal(sign) + sig := make([]byte, stringLength(len(s))) + marshalString(sig, s) + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + HasSig: true, + Algoname: algo, + PubKey: pubKey, + Sig: sig, + } + p := Marshal(&msg) + if err := c.writePacket(p); err != nil { + return authFailure, nil, err + } + var success authResult + success, methods, err = handleAuthResponse(c) + if err != nil { + return authFailure, nil, err + } + + // If authentication succeeds or the list of available methods does not + // contain the "publickey" method, do not attempt to authenticate with any + // other keys. According to RFC 4252 Section 7, the latter can occur when + // additional authentication methods are required. + if success == authSuccess || !contains(methods, cb.method()) { + return success, methods, err + } + } + + return authFailure, methods, errSigAlgo +} + +// validateKey validates the key provided is acceptable to the server. +func validateKey(key PublicKey, algo string, user string, c packetConn) (bool, error) { + pubKey := key.Marshal() + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: "publickey", + HasSig: false, + Algoname: algo, + PubKey: pubKey, + } + if err := c.writePacket(Marshal(&msg)); err != nil { + return false, err + } + + return confirmKeyAck(key, c) +} + +func confirmKeyAck(key PublicKey, c packetConn) (bool, error) { + pubKey := key.Marshal() + + for { + packet, err := c.readPacket() + if err != nil { + return false, err + } + switch packet[0] { + case msgUserAuthBanner: + if err := handleBannerResponse(c, packet); err != nil { + return false, err + } + case msgUserAuthPubKeyOk: + var msg userAuthPubKeyOkMsg + if err := Unmarshal(packet, &msg); err != nil { + return false, err + } + // According to RFC 4252 Section 7 the algorithm in + // SSH_MSG_USERAUTH_PK_OK should match that of the request but some + // servers send the key type instead. OpenSSH allows any algorithm + // that matches the public key, so we do the same. + // https://github.com/openssh/openssh-portable/blob/86bdd385/sshconnect2.c#L709 + if !contains(algorithmsForKeyFormat(key.Type()), msg.Algo) { + return false, nil + } + if !bytes.Equal(msg.PubKey, pubKey) { + return false, nil + } + return true, nil + case msgUserAuthFailure: + return false, nil + default: + return false, unexpectedMessageError(msgUserAuthPubKeyOk, packet[0]) + } + } +} + +// PublicKeys returns an AuthMethod that uses the given key +// pairs. +func PublicKeys(signers ...Signer) AuthMethod { + return publicKeyCallback(func() ([]Signer, error) { return signers, nil }) +} + +// PublicKeysCallback returns an AuthMethod that runs the given +// function to obtain a list of key pairs. +func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod { + return publicKeyCallback(getSigners) +} + +// handleAuthResponse returns whether the preceding authentication request succeeded +// along with a list of remaining authentication methods to try next and +// an error if an unexpected response was received. +func handleAuthResponse(c packetConn) (authResult, []string, error) { + gotMsgExtInfo := false + for { + packet, err := c.readPacket() + if err != nil { + return authFailure, nil, err + } + + switch packet[0] { + case msgUserAuthBanner: + if err := handleBannerResponse(c, packet); err != nil { + return authFailure, nil, err + } + case msgExtInfo: + // Ignore post-authentication RFC 8308 extensions, once. + if gotMsgExtInfo { + return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) + } + gotMsgExtInfo = true + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return authFailure, nil, err + } + if msg.PartialSuccess { + return authPartialSuccess, msg.Methods, nil + } + return authFailure, msg.Methods, nil + case msgUserAuthSuccess: + return authSuccess, nil, nil + default: + return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) + } + } +} + +func handleBannerResponse(c packetConn, packet []byte) error { + var msg userAuthBannerMsg + if err := Unmarshal(packet, &msg); err != nil { + return err + } + + transport, ok := c.(*handshakeTransport) + if !ok { + return nil + } + + if transport.bannerCallback != nil { + return transport.bannerCallback(msg.Message) + } + + return nil +} + +// KeyboardInteractiveChallenge should print questions, optionally +// disabling echoing (e.g. for passwords), and return all the answers. +// Challenge may be called multiple times in a single session. After +// successful authentication, the server may send a challenge with no +// questions, for which the name and instruction messages should be +// printed. RFC 4256 section 3.3 details how the UI should behave for +// both CLI and GUI environments. +type KeyboardInteractiveChallenge func(name, instruction string, questions []string, echos []bool) (answers []string, err error) + +// KeyboardInteractive returns an AuthMethod using a prompt/response +// sequence controlled by the server. +func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod { + return challenge +} + +func (cb KeyboardInteractiveChallenge) method() string { + return "keyboard-interactive" +} + +func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) { + type initiateMsg struct { + User string `sshtype:"50"` + Service string + Method string + Language string + Submethods string + } + + if err := c.writePacket(Marshal(&initiateMsg{ + User: user, + Service: serviceSSH, + Method: "keyboard-interactive", + })); err != nil { + return authFailure, nil, err + } + + gotMsgExtInfo := false + gotUserAuthInfoRequest := false + for { + packet, err := c.readPacket() + if err != nil { + return authFailure, nil, err + } + + // like handleAuthResponse, but with less options. + switch packet[0] { + case msgUserAuthBanner: + if err := handleBannerResponse(c, packet); err != nil { + return authFailure, nil, err + } + continue + case msgExtInfo: + // Ignore post-authentication RFC 8308 extensions, once. + if gotMsgExtInfo { + return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) + } + gotMsgExtInfo = true + continue + case msgUserAuthInfoRequest: + // OK + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return authFailure, nil, err + } + if msg.PartialSuccess { + return authPartialSuccess, msg.Methods, nil + } + if !gotUserAuthInfoRequest { + return authFailure, msg.Methods, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) + } + return authFailure, msg.Methods, nil + case msgUserAuthSuccess: + return authSuccess, nil, nil + default: + return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) + } + + var msg userAuthInfoRequestMsg + if err := Unmarshal(packet, &msg); err != nil { + return authFailure, nil, err + } + gotUserAuthInfoRequest = true + + // Manually unpack the prompt/echo pairs. + rest := msg.Prompts + var prompts []string + var echos []bool + for i := 0; i < int(msg.NumPrompts); i++ { + prompt, r, ok := parseString(rest) + if !ok || len(r) == 0 { + return authFailure, nil, errors.New("ssh: prompt format error") + } + prompts = append(prompts, string(prompt)) + echos = append(echos, r[0] != 0) + rest = r[1:] + } + + if len(rest) != 0 { + return authFailure, nil, errors.New("ssh: extra data following keyboard-interactive pairs") + } + + answers, err := cb(msg.Name, msg.Instruction, prompts, echos) + if err != nil { + return authFailure, nil, err + } + + if len(answers) != len(prompts) { + return authFailure, nil, fmt.Errorf("ssh: incorrect number of answers from keyboard-interactive callback %d (expected %d)", len(answers), len(prompts)) + } + responseLength := 1 + 4 + for _, a := range answers { + responseLength += stringLength(len(a)) + } + serialized := make([]byte, responseLength) + p := serialized + p[0] = msgUserAuthInfoResponse + p = p[1:] + p = marshalUint32(p, uint32(len(answers))) + for _, a := range answers { + p = marshalString(p, []byte(a)) + } + + if err := c.writePacket(serialized); err != nil { + return authFailure, nil, err + } + } +} + +type retryableAuthMethod struct { + authMethod AuthMethod + maxTries int +} + +func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (ok authResult, methods []string, err error) { + for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ { + ok, methods, err = r.authMethod.auth(session, user, c, rand, extensions) + if ok != authFailure || err != nil { // either success, partial success or error terminate + return ok, methods, err + } + } + return ok, methods, err +} + +func (r *retryableAuthMethod) method() string { + return r.authMethod.method() +} + +// RetryableAuthMethod is a decorator for other auth methods enabling them to +// be retried up to maxTries before considering that AuthMethod itself failed. +// If maxTries is <= 0, will retry indefinitely +// +// This is useful for interactive clients using challenge/response type +// authentication (e.g. Keyboard-Interactive, Password, etc) where the user +// could mistype their response resulting in the server issuing a +// SSH_MSG_USERAUTH_FAILURE (rfc4252 #8 [password] and rfc4256 #3.4 +// [keyboard-interactive]); Without this decorator, the non-retryable +// AuthMethod would be removed from future consideration, and never tried again +// (and so the user would never be able to retry their entry). +func RetryableAuthMethod(auth AuthMethod, maxTries int) AuthMethod { + return &retryableAuthMethod{authMethod: auth, maxTries: maxTries} +} + +// GSSAPIWithMICAuthMethod is an AuthMethod with "gssapi-with-mic" authentication. +// See RFC 4462 section 3 +// gssAPIClient is implementation of the GSSAPIClient interface, see the definition of the interface for details. +// target is the server host you want to log in to. +func GSSAPIWithMICAuthMethod(gssAPIClient GSSAPIClient, target string) AuthMethod { + if gssAPIClient == nil { + panic("gss-api client must be not nil with enable gssapi-with-mic") + } + return &gssAPIWithMICCallback{gssAPIClient: gssAPIClient, target: target} +} + +type gssAPIWithMICCallback struct { + gssAPIClient GSSAPIClient + target string +} + +func (g *gssAPIWithMICCallback) auth(session []byte, user string, c packetConn, rand io.Reader, _ map[string][]byte) (authResult, []string, error) { + m := &userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: g.method(), + } + // The GSS-API authentication method is initiated when the client sends an SSH_MSG_USERAUTH_REQUEST. + // See RFC 4462 section 3.2. + m.Payload = appendU32(m.Payload, 1) + m.Payload = appendString(m.Payload, string(krb5OID)) + if err := c.writePacket(Marshal(m)); err != nil { + return authFailure, nil, err + } + // The server responds to the SSH_MSG_USERAUTH_REQUEST with either an + // SSH_MSG_USERAUTH_FAILURE if none of the mechanisms are supported or + // with an SSH_MSG_USERAUTH_GSSAPI_RESPONSE. + // See RFC 4462 section 3.3. + // OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication,so I don't want to check + // selected mech if it is valid. + packet, err := c.readPacket() + if err != nil { + return authFailure, nil, err + } + userAuthGSSAPIResp := &userAuthGSSAPIResponse{} + if err := Unmarshal(packet, userAuthGSSAPIResp); err != nil { + return authFailure, nil, err + } + // Start the loop into the exchange token. + // See RFC 4462 section 3.4. + var token []byte + defer g.gssAPIClient.DeleteSecContext() + for { + // Initiates the establishment of a security context between the application and a remote peer. + nextToken, needContinue, err := g.gssAPIClient.InitSecContext("host@"+g.target, token, false) + if err != nil { + return authFailure, nil, err + } + if len(nextToken) > 0 { + if err := c.writePacket(Marshal(&userAuthGSSAPIToken{ + Token: nextToken, + })); err != nil { + return authFailure, nil, err + } + } + if !needContinue { + break + } + packet, err = c.readPacket() + if err != nil { + return authFailure, nil, err + } + switch packet[0] { + case msgUserAuthFailure: + var msg userAuthFailureMsg + if err := Unmarshal(packet, &msg); err != nil { + return authFailure, nil, err + } + if msg.PartialSuccess { + return authPartialSuccess, msg.Methods, nil + } + return authFailure, msg.Methods, nil + case msgUserAuthGSSAPIError: + userAuthGSSAPIErrorResp := &userAuthGSSAPIError{} + if err := Unmarshal(packet, userAuthGSSAPIErrorResp); err != nil { + return authFailure, nil, err + } + return authFailure, nil, fmt.Errorf("GSS-API Error:\n"+ + "Major Status: %d\n"+ + "Minor Status: %d\n"+ + "Error Message: %s\n", userAuthGSSAPIErrorResp.MajorStatus, userAuthGSSAPIErrorResp.MinorStatus, + userAuthGSSAPIErrorResp.Message) + case msgUserAuthGSSAPIToken: + userAuthGSSAPITokenReq := &userAuthGSSAPIToken{} + if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil { + return authFailure, nil, err + } + token = userAuthGSSAPITokenReq.Token + } + } + // Binding Encryption Keys. + // See RFC 4462 section 3.5. + micField := buildMIC(string(session), user, "ssh-connection", "gssapi-with-mic") + micToken, err := g.gssAPIClient.GetMIC(micField) + if err != nil { + return authFailure, nil, err + } + if err := c.writePacket(Marshal(&userAuthGSSAPIMIC{ + MIC: micToken, + })); err != nil { + return authFailure, nil, err + } + return handleAuthResponse(c) +} + +func (g *gssAPIWithMICCallback) method() string { + return "gssapi-with-mic" +} diff --git a/tempfork/sshtest/ssh/client_auth_test.go b/tempfork/sshtest/ssh/client_auth_test.go new file mode 100644 index 0000000000000..ec27133a39b28 --- /dev/null +++ b/tempfork/sshtest/ssh/client_auth_test.go @@ -0,0 +1,1384 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "runtime" + "strings" + "testing" +) + +type keyboardInteractive map[string]string + +func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) { + var answers []string + for _, q := range questions { + answers = append(answers, cr[q]) + } + return answers, nil +} + +// reused internally by tests +var clientPassword = "tiger" + +// tryAuth runs a handshake with a given config against an SSH server +// with config serverConfig. Returns both client and server side errors. +func tryAuth(t *testing.T, config *ClientConfig) error { + err, _ := tryAuthBothSides(t, config, nil) + return err +} + +// tryAuthWithGSSAPIWithMICConfig runs a handshake with a given config against an SSH server +// with a given GSSAPIWithMICConfig and config serverConfig. Returns both client and server side errors. +func tryAuthWithGSSAPIWithMICConfig(t *testing.T, clientConfig *ClientConfig, gssAPIWithMICConfig *GSSAPIWithMICConfig) error { + err, _ := tryAuthBothSides(t, clientConfig, gssAPIWithMICConfig) + return err +} + +// tryAuthBothSides runs the handshake and returns the resulting errors from both sides of the connection. +func tryAuthBothSides(t *testing.T, config *ClientConfig, gssAPIWithMICConfig *GSSAPIWithMICConfig) (clientError error, serverAuthErrors []error) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + certChecker := CertChecker{ + IsUserAuthority: func(k PublicKey) bool { + return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal()) + }, + UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + return nil, nil + } + + return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) + }, + IsRevoked: func(c *Certificate) bool { + return c.Serial == 666 + }, + } + serverConfig := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { + if conn.User() == "testuser" && string(pass) == clientPassword { + return nil, nil + } + return nil, errors.New("password auth failed") + }, + PublicKeyCallback: certChecker.Authenticate, + KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) { + ans, err := challenge("user", + "instruction", + []string{"question1", "question2"}, + []bool{true, true}) + if err != nil { + return nil, err + } + ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2" + if ok { + challenge("user", "motd", nil, nil) + return nil, nil + } + return nil, errors.New("keyboard-interactive failed") + }, + GSSAPIWithMICConfig: gssAPIWithMICConfig, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) { + serverAuthErrors = append(serverAuthErrors, err) + } + + go newServer(c1, serverConfig) + _, _, _, err = NewClientConn(c2, "", config) + return err, serverAuthErrors +} + +type loggingAlgorithmSigner struct { + used []string + AlgorithmSigner +} + +func (l *loggingAlgorithmSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + l.used = append(l.used, "[Sign]") + return l.AlgorithmSigner.Sign(rand, data) +} + +func (l *loggingAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + l.used = append(l.used, algorithm) + return l.AlgorithmSigner.SignWithAlgorithm(rand, data, algorithm) +} + +func TestClientAuthPublicKey(t *testing.T) { + signer := &loggingAlgorithmSigner{AlgorithmSigner: testSigners["rsa"].(AlgorithmSigner)} + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(signer), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + if len(signer.used) != 1 || signer.used[0] != KeyAlgoRSASHA256 { + t.Errorf("unexpected Sign/SignWithAlgorithm calls: %q", signer.used) + } +} + +// TestClientAuthNoSHA2 tests a ssh-rsa Signer that doesn't implement AlgorithmSigner. +func TestClientAuthNoSHA2(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(&legacyRSASigner{testSigners["rsa"]}), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +// TestClientAuthThirdKey checks that the third configured can succeed. If we +// were to do three attempts for each key (rsa-sha2-256, rsa-sha2-512, ssh-rsa), +// we'd hit the six maximum attempts before reaching it. +func TestClientAuthThirdKey(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa-openssh-format"], + testSigners["rsa-openssh-format"], testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodPassword(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodFallback(t *testing.T) { + var passwordCalled bool + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + PasswordCallback( + func() (string, error) { + passwordCalled = true + return "WRONG", nil + }), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + + if passwordCalled { + t.Errorf("password auth tried before public-key auth.") + } +} + +func TestAuthMethodWrongPassword(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + Password("wrong"), + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodKeyboardInteractive(t *testing.T) { + answers := keyboardInteractive(map[string]string{ + "question1": "answer1", + "question2": "answer2", + }) + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + KeyboardInteractive(answers.Challenge), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestAuthMethodWrongKeyboardInteractive(t *testing.T) { + answers := keyboardInteractive(map[string]string{ + "question1": "answer1", + "question2": "WRONG", + }) + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + KeyboardInteractive(answers.Challenge), + }, + } + + if err := tryAuth(t, config); err == nil { + t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive") + } +} + +// the mock server will only authenticate ssh-rsa keys +func TestAuthMethodInvalidPublicKey(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["dsa"]), + }, + } + + if err := tryAuth(t, config); err == nil { + t.Fatalf("dsa private key should not have authenticated with rsa public key") + } +} + +// the client should authenticate with the second key +func TestAuthMethodRSAandDSA(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["dsa"], testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("client could not authenticate with rsa key: %v", err) + } +} + +type invalidAlgSigner struct { + Signer +} + +func (s *invalidAlgSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + sig, err := s.Signer.Sign(rand, data) + if sig != nil { + sig.Format = "invalid" + } + return sig, err +} + +func TestMethodInvalidAlgorithm(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(&invalidAlgSigner{testSigners["rsa"]}), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + err, serverErrors := tryAuthBothSides(t, config, nil) + if err == nil { + t.Fatalf("login succeeded") + } + + found := false + want := "algorithm \"invalid\"" + + var errStrings []string + for _, err := range serverErrors { + found = found || (err != nil && strings.Contains(err.Error(), want)) + errStrings = append(errStrings, err.Error()) + } + if !found { + t.Errorf("server got error %q, want substring %q", errStrings, want) + } +} + +func TestClientHMAC(t *testing.T) { + for _, mac := range supportedMACs { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + Config: Config{ + MACs: []string{mac}, + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err != nil { + t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err) + } + } +} + +// issue 4285. +func TestClientUnsupportedCipher(t *testing.T) { + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(), + }, + Config: Config{ + Ciphers: []string{"aes128-cbc"}, // not currently supported + }, + } + if err := tryAuth(t, config); err == nil { + t.Errorf("expected no ciphers in common") + } +} + +func TestClientUnsupportedKex(t *testing.T) { + if os.Getenv("GO_BUILDER_NAME") != "" { + t.Skip("skipping known-flaky test on the Go build dashboard; see golang.org/issue/15198") + } + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(), + }, + Config: Config{ + KeyExchanges: []string{"non-existent-kex"}, + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") { + t.Errorf("got %v, expected 'common algorithm'", err) + } +} + +func TestClientLoginCert(t *testing.T) { + cert := &Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + clientConfig := &ClientConfig{ + User: "user", + HostKeyCallback: InsecureIgnoreHostKey(), + } + clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner)) + + // should succeed + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login failed: %v", err) + } + + // corrupted signature + cert.Signature.Blob[0]++ + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with corrupted sig") + } + + // revoked + cert.Serial = 666 + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("revoked cert login succeeded") + } + cert.Serial = 1 + + // sign with wrong key + cert.SignCert(rand.Reader, testSigners["dsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with non-authoritative key") + } + + // host cert + cert.CertType = HostCert + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with wrong type") + } + cert.CertType = UserCert + + // principal specified + cert.ValidPrincipals = []string{"user"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login failed: %v", err) + } + + // wrong principal specified + cert.ValidPrincipals = []string{"fred"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with wrong principal") + } + cert.ValidPrincipals = nil + + // added critical option + cert.CriticalOptions = map[string]string{"root-access": "yes"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login passed with unrecognized critical option") + } + + // allowed source address + cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24,::42/120"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err != nil { + t.Errorf("cert login with source-address failed: %v", err) + } + + // disallowed source address + cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42,::42"} + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + if err := tryAuth(t, clientConfig); err == nil { + t.Errorf("cert login with source-address succeeded") + } +} + +func testPermissionsPassing(withPermissions bool, t *testing.T) { + serverConfig := &ServerConfig{ + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if conn.User() == "nopermissions" { + return nil, nil + } + return &Permissions{}, nil + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + clientConfig := &ClientConfig{ + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if withPermissions { + clientConfig.User = "permissions" + } else { + clientConfig.User = "nopermissions" + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewClientConn(c2, "", clientConfig) + serverConn, err := newServer(c1, serverConfig) + if err != nil { + t.Fatal(err) + } + if p := serverConn.Permissions; (p != nil) != withPermissions { + t.Fatalf("withPermissions is %t, but Permissions object is %#v", withPermissions, p) + } +} + +func TestPermissionsPassing(t *testing.T) { + testPermissionsPassing(true, t) +} + +func TestNoPermissionsPassing(t *testing.T) { + testPermissionsPassing(false, t) +} + +func TestRetryableAuth(t *testing.T) { + n := 0 + passwords := []string{"WRONG1", "WRONG2"} + + config := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + RetryableAuthMethod(PasswordCallback(func() (string, error) { + p := passwords[n] + n++ + return p, nil + }), 2), + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if err := tryAuth(t, config); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + if n != 2 { + t.Fatalf("Did not try all passwords") + } +} + +func ExampleRetryableAuthMethod() { + user := "testuser" + NumberOfPrompts := 3 + + // Normally this would be a callback that prompts the user to answer the + // provided questions + Cb := func(user, instruction string, questions []string, echos []bool) (answers []string, err error) { + return []string{"answer1", "answer2"}, nil + } + + config := &ClientConfig{ + HostKeyCallback: InsecureIgnoreHostKey(), + User: user, + Auth: []AuthMethod{ + RetryableAuthMethod(KeyboardInteractiveChallenge(Cb), NumberOfPrompts), + }, + } + + host := "mysshserver" + netConn, err := net.Dial("tcp", host) + if err != nil { + log.Fatal(err) + } + + sshConn, _, _, err := NewClientConn(netConn, host, config) + if err != nil { + log.Fatal(err) + } + _ = sshConn +} + +// Test if username is received on server side when NoClientAuth is used +func TestClientAuthNone(t *testing.T) { + user := "testuser" + serverConfig := &ServerConfig{ + NoClientAuth: true, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + clientConfig := &ClientConfig{ + User: user, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewClientConn(c2, "", clientConfig) + serverConn, err := newServer(c1, serverConfig) + if err != nil { + t.Fatalf("newServer: %v", err) + } + if serverConn.User() != user { + t.Fatalf("server: got %q, want %q", serverConn.User(), user) + } +} + +// Test if authentication attempts are limited on server when MaxAuthTries is set +func TestClientAuthMaxAuthTries(t *testing.T) { + user := "testuser" + + serverConfig := &ServerConfig{ + MaxAuthTries: 2, + PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { + if conn.User() == "testuser" && string(pass) == "right" { + return nil, nil + } + return nil, errors.New("password auth failed") + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{ + Reason: 2, + Message: "too many authentication failures", + }) + + for tries := 2; tries < 4; tries++ { + n := tries + clientConfig := &ClientConfig{ + User: user, + Auth: []AuthMethod{ + RetryableAuthMethod(PasswordCallback(func() (string, error) { + n-- + if n == 0 { + return "right", nil + } + return "wrong", nil + }), tries), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + errCh := make(chan error, 1) + + go func() { + _, err := newServer(c1, serverConfig) + errCh <- err + }() + _, _, _, cliErr := NewClientConn(c2, "", clientConfig) + srvErr := <-errCh + + if tries > serverConfig.MaxAuthTries { + if cliErr == nil { + t.Fatalf("client: got no error, want %s", expectedErr) + } else if cliErr.Error() != expectedErr.Error() { + t.Fatalf("client: got %s, want %s", err, expectedErr) + } + var authErr *ServerAuthError + if !errors.As(srvErr, &authErr) { + t.Errorf("expected ServerAuthError, got: %v", srvErr) + } + } else { + if cliErr != nil { + t.Fatalf("client: got %s, want no error", cliErr) + } + } + } +} + +// Test if authentication attempts are correctly limited on server +// when more public keys are provided then MaxAuthTries +func TestClientAuthMaxAuthTriesPublicKey(t *testing.T) { + signers := []Signer{} + for i := 0; i < 6; i++ { + signers = append(signers, testSigners["dsa"]) + } + + validConfig := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(append([]Signer{testSigners["rsa"]}, signers...)...), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, validConfig); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } + + expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{ + Reason: 2, + Message: "too many authentication failures", + }) + invalidConfig := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + PublicKeys(append(signers, testSigners["rsa"])...), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + if err := tryAuth(t, invalidConfig); err == nil { + t.Fatalf("client: got no error, want %s", expectedErr) + } else if err.Error() != expectedErr.Error() { + // On Windows we can see a WSAECONNABORTED error + // if the client writes another authentication request + // before the client goroutine reads the disconnection + // message. See issue 50805. + if runtime.GOOS == "windows" && strings.Contains(err.Error(), "wsarecv: An established connection was aborted") { + // OK. + } else { + t.Fatalf("client: got %s, want %s", err, expectedErr) + } + } +} + +// Test whether authentication errors are being properly logged if all +// authentication methods have been exhausted +func TestClientAuthErrorList(t *testing.T) { + publicKeyErr := errors.New("This is an error from PublicKeyCallback") + + clientConfig := &ClientConfig{ + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + serverConfig := &ServerConfig{ + PublicKeyCallback: func(_ ConnMetadata, _ PublicKey) (*Permissions, error) { + return nil, publicKeyErr + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + go NewClientConn(c2, "", clientConfig) + _, err = newServer(c1, serverConfig) + if err == nil { + t.Fatal("newServer: got nil, expected errors") + } + + authErrs, ok := err.(*ServerAuthError) + if !ok { + t.Fatalf("errors: got %T, want *ssh.ServerAuthError", err) + } + for i, e := range authErrs.Errors { + switch i { + case 0: + if e != ErrNoAuth { + t.Fatalf("errors: got error %v, want ErrNoAuth", e) + } + case 1: + if e != publicKeyErr { + t.Fatalf("errors: got %v, want %v", e, publicKeyErr) + } + default: + t.Fatalf("errors: got %v, expected 2 errors", authErrs.Errors) + } + } +} + +func TestAuthMethodGSSAPIWithMIC(t *testing.T) { + type testcase struct { + config *ClientConfig + gssConfig *GSSAPIWithMICConfig + clientWantErr string + serverWantErr string + } + testcases := []*testcase{ + { + config: &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + GSSAPIWithMICAuthMethod( + &FakeClient{ + exchanges: []*exchange{ + { + outToken: "client-valid-token-1", + }, + { + expectedToken: "server-valid-token-1", + }, + }, + mic: []byte("valid-mic"), + maxRound: 2, + }, "testtarget", + ), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + }, + gssConfig: &GSSAPIWithMICConfig{ + AllowLogin: func(conn ConnMetadata, srcName string) (*Permissions, error) { + if srcName != conn.User()+"@DOMAIN" { + return nil, fmt.Errorf("srcName is %s, conn user is %s", srcName, conn.User()) + } + return nil, nil + }, + Server: &FakeServer{ + exchanges: []*exchange{ + { + outToken: "server-valid-token-1", + expectedToken: "client-valid-token-1", + }, + }, + maxRound: 1, + expectedMIC: []byte("valid-mic"), + srcName: "testuser@DOMAIN", + }, + }, + }, + { + config: &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + GSSAPIWithMICAuthMethod( + &FakeClient{ + exchanges: []*exchange{ + { + outToken: "client-valid-token-1", + }, + { + expectedToken: "server-valid-token-1", + }, + }, + mic: []byte("valid-mic"), + maxRound: 2, + }, "testtarget", + ), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + }, + gssConfig: &GSSAPIWithMICConfig{ + AllowLogin: func(conn ConnMetadata, srcName string) (*Permissions, error) { + return nil, fmt.Errorf("user is not allowed to login") + }, + Server: &FakeServer{ + exchanges: []*exchange{ + { + outToken: "server-valid-token-1", + expectedToken: "client-valid-token-1", + }, + }, + maxRound: 1, + expectedMIC: []byte("valid-mic"), + srcName: "testuser@DOMAIN", + }, + }, + serverWantErr: "user is not allowed to login", + clientWantErr: "ssh: handshake failed: ssh: unable to authenticate", + }, + { + config: &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + GSSAPIWithMICAuthMethod( + &FakeClient{ + exchanges: []*exchange{ + { + outToken: "client-valid-token-1", + }, + { + expectedToken: "server-valid-token-1", + }, + }, + mic: []byte("valid-mic"), + maxRound: 2, + }, "testtarget", + ), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + }, + gssConfig: &GSSAPIWithMICConfig{ + AllowLogin: func(conn ConnMetadata, srcName string) (*Permissions, error) { + if srcName != conn.User() { + return nil, fmt.Errorf("srcName is %s, conn user is %s", srcName, conn.User()) + } + return nil, nil + }, + Server: &FakeServer{ + exchanges: []*exchange{ + { + outToken: "server-invalid-token-1", + expectedToken: "client-valid-token-1", + }, + }, + maxRound: 1, + expectedMIC: []byte("valid-mic"), + srcName: "testuser@DOMAIN", + }, + }, + clientWantErr: "ssh: handshake failed: got \"server-invalid-token-1\", want token \"server-valid-token-1\"", + }, + { + config: &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + GSSAPIWithMICAuthMethod( + &FakeClient{ + exchanges: []*exchange{ + { + outToken: "client-valid-token-1", + }, + { + expectedToken: "server-valid-token-1", + }, + }, + mic: []byte("invalid-mic"), + maxRound: 2, + }, "testtarget", + ), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + }, + gssConfig: &GSSAPIWithMICConfig{ + AllowLogin: func(conn ConnMetadata, srcName string) (*Permissions, error) { + if srcName != conn.User() { + return nil, fmt.Errorf("srcName is %s, conn user is %s", srcName, conn.User()) + } + return nil, nil + }, + Server: &FakeServer{ + exchanges: []*exchange{ + { + outToken: "server-valid-token-1", + expectedToken: "client-valid-token-1", + }, + }, + maxRound: 1, + expectedMIC: []byte("valid-mic"), + srcName: "testuser@DOMAIN", + }, + }, + serverWantErr: "got MICToken \"invalid-mic\", want \"valid-mic\"", + clientWantErr: "ssh: handshake failed: ssh: unable to authenticate", + }, + } + + for i, c := range testcases { + clientErr, serverErrs := tryAuthBothSides(t, c.config, c.gssConfig) + if (c.clientWantErr == "") != (clientErr == nil) { + t.Fatalf("client got %v, want %s, case %d", clientErr, c.clientWantErr, i) + } + if (c.serverWantErr == "") != (len(serverErrs) == 2 && serverErrs[1] == nil || len(serverErrs) == 1) { + t.Fatalf("server got err %v, want %s", serverErrs, c.serverWantErr) + } + if c.clientWantErr != "" { + if clientErr != nil && !strings.Contains(clientErr.Error(), c.clientWantErr) { + t.Fatalf("client got %v, want %s, case %d", clientErr, c.clientWantErr, i) + } + } + found := false + var errStrings []string + if c.serverWantErr != "" { + for _, err := range serverErrs { + found = found || (err != nil && strings.Contains(err.Error(), c.serverWantErr)) + errStrings = append(errStrings, err.Error()) + } + if !found { + t.Errorf("server got error %q, want substring %q, case %d", errStrings, c.serverWantErr, i) + } + } + } +} + +func TestCompatibleAlgoAndSignatures(t *testing.T) { + type testcase struct { + algo string + sigFormat string + compatible bool + } + testcases := []*testcase{ + { + KeyAlgoRSA, + KeyAlgoRSA, + true, + }, + { + KeyAlgoRSA, + KeyAlgoRSASHA256, + true, + }, + { + KeyAlgoRSA, + KeyAlgoRSASHA512, + true, + }, + { + KeyAlgoRSASHA256, + KeyAlgoRSA, + true, + }, + { + KeyAlgoRSASHA512, + KeyAlgoRSA, + true, + }, + { + KeyAlgoRSASHA512, + KeyAlgoRSASHA256, + true, + }, + { + KeyAlgoRSASHA256, + KeyAlgoRSASHA512, + true, + }, + { + KeyAlgoRSASHA512, + KeyAlgoRSASHA512, + true, + }, + { + CertAlgoRSAv01, + KeyAlgoRSA, + true, + }, + { + CertAlgoRSAv01, + KeyAlgoRSASHA256, + true, + }, + { + CertAlgoRSAv01, + KeyAlgoRSASHA512, + true, + }, + { + CertAlgoRSASHA256v01, + KeyAlgoRSASHA512, + true, + }, + { + CertAlgoRSASHA512v01, + KeyAlgoRSASHA512, + true, + }, + { + CertAlgoRSASHA512v01, + KeyAlgoRSASHA256, + true, + }, + { + CertAlgoRSASHA256v01, + CertAlgoRSAv01, + true, + }, + { + CertAlgoRSAv01, + CertAlgoRSASHA512v01, + true, + }, + { + KeyAlgoECDSA256, + KeyAlgoRSA, + false, + }, + { + KeyAlgoECDSA256, + KeyAlgoECDSA521, + false, + }, + { + KeyAlgoECDSA256, + KeyAlgoECDSA256, + true, + }, + { + KeyAlgoECDSA256, + KeyAlgoED25519, + false, + }, + { + KeyAlgoED25519, + KeyAlgoED25519, + true, + }, + } + + for _, c := range testcases { + if isAlgoCompatible(c.algo, c.sigFormat) != c.compatible { + t.Errorf("algorithm %q, signature format %q, expected compatible to be %t", c.algo, c.sigFormat, c.compatible) + } + } +} + +func TestPickSignatureAlgorithm(t *testing.T) { + type testcase struct { + name string + extensions map[string][]byte + } + cases := []testcase{ + { + name: "server with empty server-sig-algs", + extensions: map[string][]byte{ + "server-sig-algs": []byte(``), + }, + }, + { + name: "server with no server-sig-algs", + extensions: nil, + }, + } + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + signer, ok := testSigners["rsa"].(MultiAlgorithmSigner) + if !ok { + t.Fatalf("rsa test signer does not implement the MultiAlgorithmSigner interface") + } + // The signer supports the public key algorithm which is then returned. + _, algo, err := pickSignatureAlgorithm(signer, c.extensions) + if err != nil { + t.Fatalf("got %v, want no error", err) + } + if algo != signer.PublicKey().Type() { + t.Fatalf("got algo %q, want %q", algo, signer.PublicKey().Type()) + } + // Test a signer that uses a certificate algorithm as the public key + // type. + cert := &Certificate{ + CertType: UserCert, + Key: signer.PublicKey(), + } + cert.SignCert(rand.Reader, signer) + + certSigner, err := NewCertSigner(cert, signer) + if err != nil { + t.Fatalf("error generating cert signer: %v", err) + } + // The signer supports the public key algorithm and the + // public key format is a certificate type so the cerificate + // algorithm matching the key format must be returned + _, algo, err = pickSignatureAlgorithm(certSigner, c.extensions) + if err != nil { + t.Fatalf("got %v, want no error", err) + } + if algo != certSigner.PublicKey().Type() { + t.Fatalf("got algo %q, want %q", algo, certSigner.PublicKey().Type()) + } + signer, err = NewSignerWithAlgorithms(signer.(AlgorithmSigner), []string{KeyAlgoRSASHA512, KeyAlgoRSASHA256}) + if err != nil { + t.Fatalf("unable to create signer with algorithms: %v", err) + } + // The signer does not support the public key algorithm so an error + // is returned. + _, _, err = pickSignatureAlgorithm(signer, c.extensions) + if err == nil { + t.Fatal("got no error, no common public key signature algorithm error expected") + } + }) + } +} + +// configurablePublicKeyCallback is a public key callback that allows to +// configure the signature algorithm and format. This way we can emulate the +// behavior of buggy clients. +type configurablePublicKeyCallback struct { + signer AlgorithmSigner + signatureAlgo string + signatureFormat string +} + +func (cb configurablePublicKeyCallback) method() string { + return "publickey" +} + +func (cb configurablePublicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader, extensions map[string][]byte) (authResult, []string, error) { + pub := cb.signer.PublicKey() + + ok, err := validateKey(pub, cb.signatureAlgo, user, c) + if err != nil { + return authFailure, nil, err + } + if !ok { + return authFailure, nil, fmt.Errorf("invalid public key") + } + + pubKey := pub.Marshal() + data := buildDataSignedForAuth(session, userAuthRequestMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + }, cb.signatureAlgo, pubKey) + sign, err := cb.signer.SignWithAlgorithm(rand, data, underlyingAlgo(cb.signatureFormat)) + if err != nil { + return authFailure, nil, err + } + + s := Marshal(sign) + sig := make([]byte, stringLength(len(s))) + marshalString(sig, s) + msg := publickeyAuthMsg{ + User: user, + Service: serviceSSH, + Method: cb.method(), + HasSig: true, + Algoname: cb.signatureAlgo, + PubKey: pubKey, + Sig: sig, + } + p := Marshal(&msg) + if err := c.writePacket(p); err != nil { + return authFailure, nil, err + } + var success authResult + success, methods, err := handleAuthResponse(c) + if err != nil { + return authFailure, nil, err + } + if success == authSuccess || !contains(methods, cb.method()) { + return success, methods, err + } + + return authFailure, methods, nil +} + +func TestPublicKeyAndAlgoCompatibility(t *testing.T) { + cert := &Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + clientConfig := &ClientConfig{ + User: "user", + HostKeyCallback: InsecureIgnoreHostKey(), + Auth: []AuthMethod{ + configurablePublicKeyCallback{ + signer: certSigner.(AlgorithmSigner), + signatureAlgo: KeyAlgoRSASHA256, + signatureFormat: KeyAlgoRSASHA256, + }, + }, + } + if err := tryAuth(t, clientConfig); err == nil { + t.Error("cert login passed with incompatible public key type and algorithm") + } +} + +func TestClientAuthGPGAgentCompat(t *testing.T) { + clientConfig := &ClientConfig{ + User: "testuser", + HostKeyCallback: InsecureIgnoreHostKey(), + Auth: []AuthMethod{ + // algorithm rsa-sha2-512 and signature format ssh-rsa. + configurablePublicKeyCallback{ + signer: testSigners["rsa"].(AlgorithmSigner), + signatureAlgo: KeyAlgoRSASHA512, + signatureFormat: KeyAlgoRSA, + }, + }, + } + if err := tryAuth(t, clientConfig); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestCertAuthOpenSSHCompat(t *testing.T) { + cert := &Certificate{ + Key: testPublicKeys["rsa"], + ValidBefore: CertTimeInfinity, + CertType: UserCert, + } + cert.SignCert(rand.Reader, testSigners["ecdsa"]) + certSigner, err := NewCertSigner(cert, testSigners["rsa"]) + if err != nil { + t.Fatalf("NewCertSigner: %v", err) + } + + clientConfig := &ClientConfig{ + User: "user", + HostKeyCallback: InsecureIgnoreHostKey(), + Auth: []AuthMethod{ + // algorithm ssh-rsa-cert-v01@openssh.com and signature format + // rsa-sha2-256. + configurablePublicKeyCallback{ + signer: certSigner.(AlgorithmSigner), + signatureAlgo: CertAlgoRSAv01, + signatureFormat: KeyAlgoRSASHA256, + }, + }, + } + if err := tryAuth(t, clientConfig); err != nil { + t.Fatalf("unable to dial remote side: %s", err) + } +} + +func TestKeyboardInteractiveAuthEarlyFail(t *testing.T) { + const maxAuthTries = 2 + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + // Start testserver + serverConfig := &ServerConfig{ + MaxAuthTries: maxAuthTries, + KeyboardInteractiveCallback: func(c ConnMetadata, + client KeyboardInteractiveChallenge) (*Permissions, error) { + // Fail keyboard-interactive authentication early before + // any prompt is sent to client. + return nil, errors.New("keyboard-interactive auth failed") + }, + PasswordCallback: func(c ConnMetadata, + pass []byte) (*Permissions, error) { + if string(pass) == clientPassword { + return nil, nil + } + return nil, errors.New("password auth failed") + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + serverDone := make(chan struct{}) + go func() { + defer func() { serverDone <- struct{}{} }() + conn, chans, reqs, err := NewServerConn(c2, serverConfig) + if err != nil { + return + } + _ = conn.Close() + + discarderDone := make(chan struct{}) + go func() { + defer func() { discarderDone <- struct{}{} }() + DiscardRequests(reqs) + }() + for newChannel := range chans { + newChannel.Reject(Prohibited, + "testserver not accepting requests") + } + + <-discarderDone + }() + + // Connect to testserver, expect KeyboardInteractive() to be not called, + // PasswordCallback() to be called and connection to succeed. + passwordCallbackCalled := false + clientConfig := &ClientConfig{ + User: "testuser", + Auth: []AuthMethod{ + RetryableAuthMethod(KeyboardInteractive(func(name, + instruction string, questions []string, + echos []bool) ([]string, error) { + t.Errorf("unexpected call to KeyboardInteractive()") + return []string{clientPassword}, nil + }), maxAuthTries), + RetryableAuthMethod(PasswordCallback(func() (secret string, + err error) { + t.Logf("PasswordCallback()") + passwordCallbackCalled = true + return clientPassword, nil + }), maxAuthTries), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + conn, _, _, err := NewClientConn(c1, "", clientConfig) + if err != nil { + t.Errorf("unexpected NewClientConn() error: %v", err) + } + if conn != nil { + conn.Close() + } + + // Wait for server to finish. + <-serverDone + + if !passwordCallbackCalled { + t.Errorf("expected PasswordCallback() to be called") + } +} diff --git a/tempfork/sshtest/ssh/client_test.go b/tempfork/sshtest/ssh/client_test.go new file mode 100644 index 0000000000000..2621f0ea5276b --- /dev/null +++ b/tempfork/sshtest/ssh/client_test.go @@ -0,0 +1,367 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "net" + "strings" + "testing" +) + +func TestClientVersion(t *testing.T) { + for _, tt := range []struct { + name string + version string + multiLine string + wantErr bool + }{ + { + name: "default version", + version: packageVersion, + }, + { + name: "custom version", + version: "SSH-2.0-CustomClientVersionString", + }, + { + name: "good multi line version", + version: packageVersion, + multiLine: strings.Repeat("ignored\r\n", 20), + }, + { + name: "bad multi line version", + version: packageVersion, + multiLine: "bad multi line version", + wantErr: true, + }, + { + name: "long multi line version", + version: packageVersion, + multiLine: strings.Repeat("long multi line version\r\n", 50)[:256], + wantErr: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + go func() { + if tt.multiLine != "" { + c1.Write([]byte(tt.multiLine)) + } + NewClientConn(c1, "", &ClientConfig{ + ClientVersion: tt.version, + HostKeyCallback: InsecureIgnoreHostKey(), + }) + c1.Close() + }() + conf := &ServerConfig{NoClientAuth: true} + conf.AddHostKey(testSigners["rsa"]) + conn, _, _, err := NewServerConn(c2, conf) + if err == nil == tt.wantErr { + t.Fatalf("got err %v; wantErr %t", err, tt.wantErr) + } + if tt.wantErr { + // Don't verify the version on an expected error. + return + } + if got := string(conn.ClientVersion()); got != tt.version { + t.Fatalf("got %q; want %q", got, tt.version) + } + }) + } +} + +func TestHostKeyCheck(t *testing.T) { + for _, tt := range []struct { + name string + wantError string + key PublicKey + }{ + {"no callback", "must specify HostKeyCallback", nil}, + {"correct key", "", testSigners["rsa"].PublicKey()}, + {"mismatch", "mismatch", testSigners["ecdsa"].PublicKey()}, + } { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + serverConf := &ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["rsa"]) + + go NewServerConn(c1, serverConf) + clientConf := ClientConfig{ + User: "user", + } + if tt.key != nil { + clientConf.HostKeyCallback = FixedHostKey(tt.key) + } + + _, _, _, err = NewClientConn(c2, "", &clientConf) + if err != nil { + if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("%s: got error %q, missing %q", tt.name, err.Error(), tt.wantError) + } + } else if tt.wantError != "" { + t.Errorf("%s: succeeded, but want error string %q", tt.name, tt.wantError) + } + } +} + +func TestVerifyHostKeySignature(t *testing.T) { + for _, tt := range []struct { + key string + signAlgo string + verifyAlgo string + wantError string + }{ + {"rsa", KeyAlgoRSA, KeyAlgoRSA, ""}, + {"rsa", KeyAlgoRSASHA256, KeyAlgoRSASHA256, ""}, + {"rsa", KeyAlgoRSA, KeyAlgoRSASHA512, `ssh: invalid signature algorithm "ssh-rsa", expected "rsa-sha2-512"`}, + {"ed25519", KeyAlgoED25519, KeyAlgoED25519, ""}, + } { + key := testSigners[tt.key].PublicKey() + s, ok := testSigners[tt.key].(AlgorithmSigner) + if !ok { + t.Fatalf("needed an AlgorithmSigner") + } + sig, err := s.SignWithAlgorithm(rand.Reader, []byte("test"), tt.signAlgo) + if err != nil { + t.Fatalf("couldn't sign: %q", err) + } + + b := bytes.Buffer{} + writeString(&b, []byte(sig.Format)) + writeString(&b, sig.Blob) + + result := kexResult{Signature: b.Bytes(), H: []byte("test")} + + err = verifyHostKeySignature(key, tt.verifyAlgo, &result) + if err != nil { + if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("got error %q, expecting %q", err.Error(), tt.wantError) + } + } else if tt.wantError != "" { + t.Errorf("succeeded, but want error string %q", tt.wantError) + } + } +} + +func TestBannerCallback(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConf := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + BannerCallback: func(conn ConnMetadata) string { + return "Hello World" + }, + } + serverConf.AddHostKey(testSigners["rsa"]) + go NewServerConn(c1, serverConf) + + var receivedBanner string + var bannerCount int + clientConf := ClientConfig{ + Auth: []AuthMethod{ + Password("123"), + }, + User: "user", + HostKeyCallback: InsecureIgnoreHostKey(), + BannerCallback: func(message string) error { + bannerCount++ + receivedBanner = message + return nil + }, + } + + _, _, _, err = NewClientConn(c2, "", &clientConf) + if err != nil { + t.Fatal(err) + } + + if bannerCount != 1 { + t.Errorf("got %d banners; want 1", bannerCount) + } + + expected := "Hello World" + if receivedBanner != expected { + t.Fatalf("got %s; want %s", receivedBanner, expected) + } +} + +func TestNewClientConn(t *testing.T) { + errHostKeyMismatch := errors.New("host key mismatch") + + for _, tt := range []struct { + name string + user string + simulateHostKeyMismatch HostKeyCallback + }{ + { + name: "good user field for ConnMetadata", + user: "testuser", + }, + { + name: "empty user field for ConnMetadata", + user: "", + }, + { + name: "host key mismatch", + user: "testuser", + simulateHostKeyMismatch: func(hostname string, remote net.Addr, key PublicKey) error { + return fmt.Errorf("%w: %s", errHostKeyMismatch, bytes.TrimSpace(MarshalAuthorizedKey(key))) + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConf := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + } + serverConf.AddHostKey(testSigners["rsa"]) + go NewServerConn(c1, serverConf) + + clientConf := &ClientConfig{ + User: tt.user, + Auth: []AuthMethod{ + Password("testpw"), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + if tt.simulateHostKeyMismatch != nil { + clientConf.HostKeyCallback = tt.simulateHostKeyMismatch + } + + clientConn, _, _, err := NewClientConn(c2, "", clientConf) + if err != nil { + if tt.simulateHostKeyMismatch != nil && errors.Is(err, errHostKeyMismatch) { + return + } + t.Fatal(err) + } + + if userGot := clientConn.User(); userGot != tt.user { + t.Errorf("got user %q; want user %q", userGot, tt.user) + } + }) + } +} + +func TestUnsupportedAlgorithm(t *testing.T) { + for _, tt := range []struct { + name string + config Config + wantError string + }{ + { + "unsupported KEX", + Config{ + KeyExchanges: []string{"unsupported"}, + }, + "no common algorithm", + }, + { + "unsupported and supported KEXs", + Config{ + KeyExchanges: []string{"unsupported", kexAlgoCurve25519SHA256}, + }, + "", + }, + { + "unsupported cipher", + Config{ + Ciphers: []string{"unsupported"}, + }, + "no common algorithm", + }, + { + "unsupported and supported ciphers", + Config{ + Ciphers: []string{"unsupported", chacha20Poly1305ID}, + }, + "", + }, + { + "unsupported MAC", + Config{ + MACs: []string{"unsupported"}, + // MAC is used for non AAED ciphers. + Ciphers: []string{"aes256-ctr"}, + }, + "no common algorithm", + }, + { + "unsupported and supported MACs", + Config{ + MACs: []string{"unsupported", "hmac-sha2-256-etm@openssh.com"}, + // MAC is used for non AAED ciphers. + Ciphers: []string{"aes256-ctr"}, + }, + "", + }, + } { + t.Run(tt.name, func(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConf := &ServerConfig{ + Config: tt.config, + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + } + serverConf.AddHostKey(testSigners["rsa"]) + go NewServerConn(c1, serverConf) + + clientConf := &ClientConfig{ + User: "testuser", + Config: tt.config, + Auth: []AuthMethod{ + Password("testpw"), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + _, _, _, err = NewClientConn(c2, "", clientConf) + if err != nil { + if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) { + t.Errorf("%s: got error %q, missing %q", tt.name, err.Error(), tt.wantError) + } + } else if tt.wantError != "" { + t.Errorf("%s: succeeded, but want error string %q", tt.name, tt.wantError) + } + }) + } +} diff --git a/tempfork/sshtest/ssh/common.go b/tempfork/sshtest/ssh/common.go new file mode 100644 index 0000000000000..7e9c2cbc647e2 --- /dev/null +++ b/tempfork/sshtest/ssh/common.go @@ -0,0 +1,476 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto" + "crypto/rand" + "fmt" + "io" + "math" + "sync" + + _ "crypto/sha1" + _ "crypto/sha256" + _ "crypto/sha512" +) + +// These are string constants in the SSH protocol. +const ( + compressionNone = "none" + serviceUserAuth = "ssh-userauth" + serviceSSH = "ssh-connection" +) + +// supportedCiphers lists ciphers we support but might not recommend. +var supportedCiphers = []string{ + "aes128-ctr", "aes192-ctr", "aes256-ctr", + "aes128-gcm@openssh.com", gcm256CipherID, + chacha20Poly1305ID, + "arcfour256", "arcfour128", "arcfour", + aes128cbcID, + tripledescbcID, +} + +// preferredCiphers specifies the default preference for ciphers. +var preferredCiphers = []string{ + "aes128-gcm@openssh.com", gcm256CipherID, + chacha20Poly1305ID, + "aes128-ctr", "aes192-ctr", "aes256-ctr", +} + +// supportedKexAlgos specifies the supported key-exchange algorithms in +// preference order. +var supportedKexAlgos = []string{ + kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH, + // P384 and P521 are not constant-time yet, but since we don't + // reuse ephemeral keys, using them for ECDH should be OK. + kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, + kexAlgoDH14SHA256, kexAlgoDH16SHA512, kexAlgoDH14SHA1, + kexAlgoDH1SHA1, +} + +// serverForbiddenKexAlgos contains key exchange algorithms, that are forbidden +// for the server half. +var serverForbiddenKexAlgos = map[string]struct{}{ + kexAlgoDHGEXSHA1: {}, // server half implementation is only minimal to satisfy the automated tests + kexAlgoDHGEXSHA256: {}, // server half implementation is only minimal to satisfy the automated tests +} + +// preferredKexAlgos specifies the default preference for key-exchange +// algorithms in preference order. The diffie-hellman-group16-sha512 algorithm +// is disabled by default because it is a bit slower than the others. +var preferredKexAlgos = []string{ + kexAlgoCurve25519SHA256, kexAlgoCurve25519SHA256LibSSH, + kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, + kexAlgoDH14SHA256, kexAlgoDH14SHA1, +} + +// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods +// of authenticating servers) in preference order. +var supportedHostKeyAlgos = []string{ + CertAlgoRSASHA256v01, CertAlgoRSASHA512v01, + CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, + CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01, + + KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, + KeyAlgoRSASHA256, KeyAlgoRSASHA512, + KeyAlgoRSA, KeyAlgoDSA, + + KeyAlgoED25519, +} + +// supportedMACs specifies a default set of MAC algorithms in preference order. +// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed +// because they have reached the end of their useful life. +var supportedMACs = []string{ + "hmac-sha2-256-etm@openssh.com", "hmac-sha2-512-etm@openssh.com", "hmac-sha2-256", "hmac-sha2-512", "hmac-sha1", "hmac-sha1-96", +} + +var supportedCompressions = []string{compressionNone} + +// hashFuncs keeps the mapping of supported signature algorithms to their +// respective hashes needed for signing and verification. +var hashFuncs = map[string]crypto.Hash{ + KeyAlgoRSA: crypto.SHA1, + KeyAlgoRSASHA256: crypto.SHA256, + KeyAlgoRSASHA512: crypto.SHA512, + KeyAlgoDSA: crypto.SHA1, + KeyAlgoECDSA256: crypto.SHA256, + KeyAlgoECDSA384: crypto.SHA384, + KeyAlgoECDSA521: crypto.SHA512, + // KeyAlgoED25519 doesn't pre-hash. + KeyAlgoSKECDSA256: crypto.SHA256, + KeyAlgoSKED25519: crypto.SHA256, +} + +// algorithmsForKeyFormat returns the supported signature algorithms for a given +// public key format (PublicKey.Type), in order of preference. See RFC 8332, +// Section 2. See also the note in sendKexInit on backwards compatibility. +func algorithmsForKeyFormat(keyFormat string) []string { + switch keyFormat { + case KeyAlgoRSA: + return []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA} + case CertAlgoRSAv01: + return []string{CertAlgoRSASHA256v01, CertAlgoRSASHA512v01, CertAlgoRSAv01} + default: + return []string{keyFormat} + } +} + +// isRSA returns whether algo is a supported RSA algorithm, including certificate +// algorithms. +func isRSA(algo string) bool { + algos := algorithmsForKeyFormat(KeyAlgoRSA) + return contains(algos, underlyingAlgo(algo)) +} + +func isRSACert(algo string) bool { + _, ok := certKeyAlgoNames[algo] + if !ok { + return false + } + return isRSA(algo) +} + +// supportedPubKeyAuthAlgos specifies the supported client public key +// authentication algorithms. Note that this doesn't include certificate types +// since those use the underlying algorithm. This list is sent to the client if +// it supports the server-sig-algs extension. Order is irrelevant. +var supportedPubKeyAuthAlgos = []string{ + KeyAlgoED25519, + KeyAlgoSKED25519, KeyAlgoSKECDSA256, + KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, + KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA, + KeyAlgoDSA, +} + +// unexpectedMessageError results when the SSH message that we received didn't +// match what we wanted. +func unexpectedMessageError(expected, got uint8) error { + return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected) +} + +// parseError results from a malformed SSH message. +func parseError(tag uint8) error { + return fmt.Errorf("ssh: parse error in message type %d", tag) +} + +func findCommon(what string, client []string, server []string) (common string, err error) { + for _, c := range client { + for _, s := range server { + if c == s { + return c, nil + } + } + } + return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server) +} + +// directionAlgorithms records algorithm choices in one direction (either read or write) +type directionAlgorithms struct { + Cipher string + MAC string + Compression string +} + +// rekeyBytes returns a rekeying intervals in bytes. +func (a *directionAlgorithms) rekeyBytes() int64 { + // According to RFC 4344 block ciphers should rekey after + // 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is + // 128. + switch a.Cipher { + case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcm128CipherID, gcm256CipherID, aes128cbcID: + return 16 * (1 << 32) + + } + + // For others, stick with RFC 4253 recommendation to rekey after 1 Gb of data. + return 1 << 30 +} + +var aeadCiphers = map[string]bool{ + gcm128CipherID: true, + gcm256CipherID: true, + chacha20Poly1305ID: true, +} + +type algorithms struct { + kex string + hostKey string + w directionAlgorithms + r directionAlgorithms +} + +func findAgreedAlgorithms(isClient bool, clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) { + result := &algorithms{} + + result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) + if err != nil { + return + } + + result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) + if err != nil { + return + } + + stoc, ctos := &result.w, &result.r + if isClient { + ctos, stoc = stoc, ctos + } + + ctos.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer) + if err != nil { + return + } + + stoc.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient) + if err != nil { + return + } + + if !aeadCiphers[ctos.Cipher] { + ctos.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer) + if err != nil { + return + } + } + + if !aeadCiphers[stoc.Cipher] { + stoc.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient) + if err != nil { + return + } + } + + ctos.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) + if err != nil { + return + } + + stoc.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) + if err != nil { + return + } + + return result, nil +} + +// If rekeythreshold is too small, we can't make any progress sending +// stuff. +const minRekeyThreshold uint64 = 256 + +// Config contains configuration data common to both ServerConfig and +// ClientConfig. +type Config struct { + // Rand provides the source of entropy for cryptographic + // primitives. If Rand is nil, the cryptographic random reader + // in package crypto/rand will be used. + Rand io.Reader + + // The maximum number of bytes sent or received after which a + // new key is negotiated. It must be at least 256. If + // unspecified, a size suitable for the chosen cipher is used. + RekeyThreshold uint64 + + // The allowed key exchanges algorithms. If unspecified then a default set + // of algorithms is used. Unsupported values are silently ignored. + KeyExchanges []string + + // The allowed cipher algorithms. If unspecified then a sensible default is + // used. Unsupported values are silently ignored. + Ciphers []string + + // The allowed MAC algorithms. If unspecified then a sensible default is + // used. Unsupported values are silently ignored. + MACs []string +} + +// SetDefaults sets sensible values for unset fields in config. This is +// exported for testing: Configs passed to SSH functions are copied and have +// default values set automatically. +func (c *Config) SetDefaults() { + if c.Rand == nil { + c.Rand = rand.Reader + } + if c.Ciphers == nil { + c.Ciphers = preferredCiphers + } + var ciphers []string + for _, c := range c.Ciphers { + if cipherModes[c] != nil { + // Ignore the cipher if we have no cipherModes definition. + ciphers = append(ciphers, c) + } + } + c.Ciphers = ciphers + + if c.KeyExchanges == nil { + c.KeyExchanges = preferredKexAlgos + } + var kexs []string + for _, k := range c.KeyExchanges { + if kexAlgoMap[k] != nil { + // Ignore the KEX if we have no kexAlgoMap definition. + kexs = append(kexs, k) + } + } + c.KeyExchanges = kexs + + if c.MACs == nil { + c.MACs = supportedMACs + } + var macs []string + for _, m := range c.MACs { + if macModes[m] != nil { + // Ignore the MAC if we have no macModes definition. + macs = append(macs, m) + } + } + c.MACs = macs + + if c.RekeyThreshold == 0 { + // cipher specific default + } else if c.RekeyThreshold < minRekeyThreshold { + c.RekeyThreshold = minRekeyThreshold + } else if c.RekeyThreshold >= math.MaxInt64 { + // Avoid weirdness if somebody uses -1 as a threshold. + c.RekeyThreshold = math.MaxInt64 + } +} + +// buildDataSignedForAuth returns the data that is signed in order to prove +// possession of a private key. See RFC 4252, section 7. algo is the advertised +// algorithm, and may be a certificate type. +func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo string, pubKey []byte) []byte { + data := struct { + Session []byte + Type byte + User string + Service string + Method string + Sign bool + Algo string + PubKey []byte + }{ + sessionID, + msgUserAuthRequest, + req.User, + req.Service, + req.Method, + true, + algo, + pubKey, + } + return Marshal(data) +} + +func appendU16(buf []byte, n uint16) []byte { + return append(buf, byte(n>>8), byte(n)) +} + +func appendU32(buf []byte, n uint32) []byte { + return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) +} + +func appendU64(buf []byte, n uint64) []byte { + return append(buf, + byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32), + byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) +} + +func appendInt(buf []byte, n int) []byte { + return appendU32(buf, uint32(n)) +} + +func appendString(buf []byte, s string) []byte { + buf = appendU32(buf, uint32(len(s))) + buf = append(buf, s...) + return buf +} + +func appendBool(buf []byte, b bool) []byte { + if b { + return append(buf, 1) + } + return append(buf, 0) +} + +// newCond is a helper to hide the fact that there is no usable zero +// value for sync.Cond. +func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) } + +// window represents the buffer available to clients +// wishing to write to a channel. +type window struct { + *sync.Cond + win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1 + writeWaiters int + closed bool +} + +// add adds win to the amount of window available +// for consumers. +func (w *window) add(win uint32) bool { + // a zero sized window adjust is a noop. + if win == 0 { + return true + } + w.L.Lock() + if w.win+win < win { + w.L.Unlock() + return false + } + w.win += win + // It is unusual that multiple goroutines would be attempting to reserve + // window space, but not guaranteed. Use broadcast to notify all waiters + // that additional window is available. + w.Broadcast() + w.L.Unlock() + return true +} + +// close sets the window to closed, so all reservations fail +// immediately. +func (w *window) close() { + w.L.Lock() + w.closed = true + w.Broadcast() + w.L.Unlock() +} + +// reserve reserves win from the available window capacity. +// If no capacity remains, reserve will block. reserve may +// return less than requested. +func (w *window) reserve(win uint32) (uint32, error) { + var err error + w.L.Lock() + w.writeWaiters++ + w.Broadcast() + for w.win == 0 && !w.closed { + w.Wait() + } + w.writeWaiters-- + if w.win < win { + win = w.win + } + w.win -= win + if w.closed { + err = io.EOF + } + w.L.Unlock() + return win, err +} + +// waitWriterBlocked waits until some goroutine is blocked for further +// writes. It is used in tests only. +func (w *window) waitWriterBlocked() { + w.Cond.L.Lock() + for w.writeWaiters == 0 { + w.Cond.Wait() + } + w.Cond.L.Unlock() +} diff --git a/tempfork/sshtest/ssh/common_test.go b/tempfork/sshtest/ssh/common_test.go new file mode 100644 index 0000000000000..a7beee8e884e7 --- /dev/null +++ b/tempfork/sshtest/ssh/common_test.go @@ -0,0 +1,176 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "reflect" + "testing" +) + +func TestFindAgreedAlgorithms(t *testing.T) { + initKex := func(k *kexInitMsg) { + if k.KexAlgos == nil { + k.KexAlgos = []string{"kex1"} + } + if k.ServerHostKeyAlgos == nil { + k.ServerHostKeyAlgos = []string{"hostkey1"} + } + if k.CiphersClientServer == nil { + k.CiphersClientServer = []string{"cipher1"} + + } + if k.CiphersServerClient == nil { + k.CiphersServerClient = []string{"cipher1"} + + } + if k.MACsClientServer == nil { + k.MACsClientServer = []string{"mac1"} + + } + if k.MACsServerClient == nil { + k.MACsServerClient = []string{"mac1"} + + } + if k.CompressionClientServer == nil { + k.CompressionClientServer = []string{"compression1"} + + } + if k.CompressionServerClient == nil { + k.CompressionServerClient = []string{"compression1"} + + } + if k.LanguagesClientServer == nil { + k.LanguagesClientServer = []string{"language1"} + + } + if k.LanguagesServerClient == nil { + k.LanguagesServerClient = []string{"language1"} + + } + } + + initDirAlgs := func(a *directionAlgorithms) { + if a.Cipher == "" { + a.Cipher = "cipher1" + } + if a.MAC == "" { + a.MAC = "mac1" + } + if a.Compression == "" { + a.Compression = "compression1" + } + } + + initAlgs := func(a *algorithms) { + if a.kex == "" { + a.kex = "kex1" + } + if a.hostKey == "" { + a.hostKey = "hostkey1" + } + initDirAlgs(&a.r) + initDirAlgs(&a.w) + } + + type testcase struct { + name string + clientIn, serverIn kexInitMsg + wantClient, wantServer algorithms + wantErr bool + } + + cases := []testcase{ + { + name: "standard", + }, + + { + name: "no common hostkey", + serverIn: kexInitMsg{ + ServerHostKeyAlgos: []string{"hostkey2"}, + }, + wantErr: true, + }, + + { + name: "no common kex", + serverIn: kexInitMsg{ + KexAlgos: []string{"kex2"}, + }, + wantErr: true, + }, + + { + name: "no common cipher", + serverIn: kexInitMsg{ + CiphersClientServer: []string{"cipher2"}, + }, + wantErr: true, + }, + + { + name: "client decides cipher", + serverIn: kexInitMsg{ + CiphersClientServer: []string{"cipher1", "cipher2"}, + CiphersServerClient: []string{"cipher2", "cipher3"}, + }, + clientIn: kexInitMsg{ + CiphersClientServer: []string{"cipher2", "cipher1"}, + CiphersServerClient: []string{"cipher3", "cipher2"}, + }, + wantClient: algorithms{ + r: directionAlgorithms{ + Cipher: "cipher3", + }, + w: directionAlgorithms{ + Cipher: "cipher2", + }, + }, + wantServer: algorithms{ + w: directionAlgorithms{ + Cipher: "cipher3", + }, + r: directionAlgorithms{ + Cipher: "cipher2", + }, + }, + }, + + // TODO(hanwen): fix and add tests for AEAD ignoring + // the MACs field + } + + for i := range cases { + initKex(&cases[i].clientIn) + initKex(&cases[i].serverIn) + initAlgs(&cases[i].wantClient) + initAlgs(&cases[i].wantServer) + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + serverAlgs, serverErr := findAgreedAlgorithms(false, &c.clientIn, &c.serverIn) + clientAlgs, clientErr := findAgreedAlgorithms(true, &c.clientIn, &c.serverIn) + + serverHasErr := serverErr != nil + clientHasErr := clientErr != nil + if c.wantErr != serverHasErr || c.wantErr != clientHasErr { + t.Fatalf("got client/server error (%v, %v), want hasError %v", + clientErr, serverErr, c.wantErr) + + } + if c.wantErr { + return + } + + if !reflect.DeepEqual(serverAlgs, &c.wantServer) { + t.Errorf("server: got algs %#v, want %#v", serverAlgs, &c.wantServer) + } + if !reflect.DeepEqual(clientAlgs, &c.wantClient) { + t.Errorf("server: got algs %#v, want %#v", clientAlgs, &c.wantClient) + } + }) + } +} diff --git a/tempfork/sshtest/ssh/connection.go b/tempfork/sshtest/ssh/connection.go new file mode 100644 index 0000000000000..8f345ee924e43 --- /dev/null +++ b/tempfork/sshtest/ssh/connection.go @@ -0,0 +1,143 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "fmt" + "net" +) + +// OpenChannelError is returned if the other side rejects an +// OpenChannel request. +type OpenChannelError struct { + Reason RejectionReason + Message string +} + +func (e *OpenChannelError) Error() string { + return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) +} + +// ConnMetadata holds metadata for the connection. +type ConnMetadata interface { + // User returns the user ID for this connection. + User() string + + // SessionID returns the session hash, also denoted by H. + SessionID() []byte + + // ClientVersion returns the client's version string as hashed + // into the session ID. + ClientVersion() []byte + + // ServerVersion returns the server's version string as hashed + // into the session ID. + ServerVersion() []byte + + // RemoteAddr returns the remote address for this connection. + RemoteAddr() net.Addr + + // LocalAddr returns the local address for this connection. + LocalAddr() net.Addr +} + +// Conn represents an SSH connection for both server and client roles. +// Conn is the basis for implementing an application layer, such +// as ClientConn, which implements the traditional shell access for +// clients. +type Conn interface { + ConnMetadata + + // SendRequest sends a global request, and returns the + // reply. If wantReply is true, it returns the response status + // and payload. See also RFC 4254, section 4. + SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) + + // OpenChannel tries to open an channel. If the request is + // rejected, it returns *OpenChannelError. On success it returns + // the SSH Channel and a Go channel for incoming, out-of-band + // requests. The Go channel must be serviced, or the + // connection will hang. + OpenChannel(name string, data []byte) (Channel, <-chan *Request, error) + + // Close closes the underlying network connection + Close() error + + // Wait blocks until the connection has shut down, and returns the + // error causing the shutdown. + Wait() error + + // TODO(hanwen): consider exposing: + // RequestKeyChange + // Disconnect +} + +// DiscardRequests consumes and rejects all requests from the +// passed-in channel. +func DiscardRequests(in <-chan *Request) { + for req := range in { + if req.WantReply { + req.Reply(false, nil) + } + } +} + +// A connection represents an incoming connection. +type connection struct { + transport *handshakeTransport + sshConn + + // The connection protocol. + *mux +} + +func (c *connection) Close() error { + return c.sshConn.conn.Close() +} + +// sshConn provides net.Conn metadata, but disallows direct reads and +// writes. +type sshConn struct { + conn net.Conn + + user string + sessionID []byte + clientVersion []byte + serverVersion []byte +} + +func dup(src []byte) []byte { + dst := make([]byte, len(src)) + copy(dst, src) + return dst +} + +func (c *sshConn) User() string { + return c.user +} + +func (c *sshConn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +func (c *sshConn) Close() error { + return c.conn.Close() +} + +func (c *sshConn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *sshConn) SessionID() []byte { + return dup(c.sessionID) +} + +func (c *sshConn) ClientVersion() []byte { + return dup(c.clientVersion) +} + +func (c *sshConn) ServerVersion() []byte { + return dup(c.serverVersion) +} diff --git a/tempfork/sshtest/ssh/doc.go b/tempfork/sshtest/ssh/doc.go new file mode 100644 index 0000000000000..f5d352fe3a0b0 --- /dev/null +++ b/tempfork/sshtest/ssh/doc.go @@ -0,0 +1,23 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package ssh implements an SSH client and server. + +SSH is a transport security protocol, an authentication protocol and a +family of application protocols. The most typical application level +protocol is a remote shell and this is specifically implemented. However, +the multiplexed nature of SSH is exposed to users that wish to support +others. + +References: + + [PROTOCOL]: https://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL?rev=HEAD + [PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD + [SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1 + +This package does not fall under the stability promise of the Go language itself, +so its API may be changed when pressing needs arise. +*/ +package ssh diff --git a/tempfork/sshtest/ssh/example_test.go b/tempfork/sshtest/ssh/example_test.go new file mode 100644 index 0000000000000..97b3b6aba6f8e --- /dev/null +++ b/tempfork/sshtest/ssh/example_test.go @@ -0,0 +1,400 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh_test + +import ( + "bufio" + "bytes" + "crypto/rand" + "crypto/rsa" + "fmt" + "log" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + + "golang.org/x/crypto/ssh" + "golang.org/x/crypto/ssh/terminal" +) + +func ExampleNewServerConn() { + // Public key authentication is done by comparing + // the public key of a received connection + // with the entries in the authorized_keys file. + authorizedKeysBytes, err := os.ReadFile("authorized_keys") + if err != nil { + log.Fatalf("Failed to load authorized_keys, err: %v", err) + } + + authorizedKeysMap := map[string]bool{} + for len(authorizedKeysBytes) > 0 { + pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes) + if err != nil { + log.Fatal(err) + } + + authorizedKeysMap[string(pubKey.Marshal())] = true + authorizedKeysBytes = rest + } + + // An SSH server is represented by a ServerConfig, which holds + // certificate details and handles authentication of ServerConns. + config := &ssh.ServerConfig{ + // Remove to disable password auth. + PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + // Should use constant-time compare (or better, salt+hash) in + // a production setting. + if c.User() == "testuser" && string(pass) == "tiger" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) + }, + + // Remove to disable public key auth. + PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { + if authorizedKeysMap[string(pubKey.Marshal())] { + return &ssh.Permissions{ + // Record the public key used for authentication. + Extensions: map[string]string{ + "pubkey-fp": ssh.FingerprintSHA256(pubKey), + }, + }, nil + } + return nil, fmt.Errorf("unknown public key for %q", c.User()) + }, + } + + privateBytes, err := os.ReadFile("id_rsa") + if err != nil { + log.Fatal("Failed to load private key: ", err) + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key: ", err) + } + config.AddHostKey(private) + + // Once a ServerConfig has been configured, connections can be + // accepted. + listener, err := net.Listen("tcp", "0.0.0.0:2022") + if err != nil { + log.Fatal("failed to listen for connection: ", err) + } + nConn, err := listener.Accept() + if err != nil { + log.Fatal("failed to accept incoming connection: ", err) + } + + // Before use, a handshake must be performed on the incoming + // net.Conn. + conn, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + log.Fatal("failed to handshake: ", err) + } + log.Printf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"]) + + var wg sync.WaitGroup + defer wg.Wait() + + // The incoming Request channel must be serviced. + wg.Add(1) + go func() { + ssh.DiscardRequests(reqs) + wg.Done() + }() + + // Service the incoming Channel channel. + for newChannel := range chans { + // Channels have a type, depending on the application level + // protocol intended. In the case of a shell, the type is + // "session" and ServerShell may be used to present a simple + // terminal interface. + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + log.Fatalf("Could not accept channel: %v", err) + } + + // Sessions have out-of-band requests such as "shell", + // "pty-req" and "env". Here we handle only the + // "shell" request. + wg.Add(1) + go func(in <-chan *ssh.Request) { + for req := range in { + req.Reply(req.Type == "shell", nil) + } + wg.Done() + }(requests) + + term := terminal.NewTerminal(channel, "> ") + + wg.Add(1) + go func() { + defer func() { + channel.Close() + wg.Done() + }() + for { + line, err := term.ReadLine() + if err != nil { + break + } + fmt.Println(line) + } + }() + } +} + +func ExampleServerConfig_AddHostKey() { + // Minimal ServerConfig supporting only password authentication. + config := &ssh.ServerConfig{ + PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + // Should use constant-time compare (or better, salt+hash) in + // a production setting. + if c.User() == "testuser" && string(pass) == "tiger" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) + }, + } + + privateBytes, err := os.ReadFile("id_rsa") + if err != nil { + log.Fatal("Failed to load private key: ", err) + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key: ", err) + } + // Restrict host key algorithms to disable ssh-rsa. + signer, err := ssh.NewSignerWithAlgorithms(private.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSASHA256, ssh.KeyAlgoRSASHA512}) + if err != nil { + log.Fatal("Failed to create private key with restricted algorithms: ", err) + } + config.AddHostKey(signer) +} + +func ExampleClientConfig_HostKeyCallback() { + // Every client must provide a host key check. Here is a + // simple-minded parse of OpenSSH's known_hosts file + host := "hostname" + file, err := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts")) + if err != nil { + log.Fatal(err) + } + defer file.Close() + + scanner := bufio.NewScanner(file) + var hostKey ssh.PublicKey + for scanner.Scan() { + fields := strings.Split(scanner.Text(), " ") + if len(fields) != 3 { + continue + } + if strings.Contains(fields[0], host) { + var err error + hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes()) + if err != nil { + log.Fatalf("error parsing %q: %v", fields[2], err) + } + break + } + } + + if hostKey == nil { + log.Fatalf("no hostkey for %s", host) + } + + config := ssh.ClientConfig{ + User: os.Getenv("USER"), + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + + _, err = ssh.Dial("tcp", host+":22", &config) + log.Println(err) +} + +func ExampleDial() { + var hostKey ssh.PublicKey + // An SSH client is represented with a ClientConn. + // + // To authenticate with the remote server you must pass at least one + // implementation of AuthMethod via the Auth field in ClientConfig, + // and provide a HostKeyCallback. + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("yourpassword"), + }, + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + client, err := ssh.Dial("tcp", "yourserver.com:22", config) + if err != nil { + log.Fatal("Failed to dial: ", err) + } + defer client.Close() + + // Each ClientConn can support multiple interactive sessions, + // represented by a Session. + session, err := client.NewSession() + if err != nil { + log.Fatal("Failed to create session: ", err) + } + defer session.Close() + + // Once a Session is created, you can execute a single command on + // the remote side using the Run method. + var b bytes.Buffer + session.Stdout = &b + if err := session.Run("/usr/bin/whoami"); err != nil { + log.Fatal("Failed to run: " + err.Error()) + } + fmt.Println(b.String()) +} + +func ExamplePublicKeys() { + var hostKey ssh.PublicKey + // A public key may be used to authenticate against the remote + // server by using an unencrypted PEM-encoded private key file. + // + // If you have an encrypted private key, the crypto/x509 package + // can be used to decrypt it. + key, err := os.ReadFile("/home/user/.ssh/id_rsa") + if err != nil { + log.Fatalf("unable to read private key: %v", err) + } + + // Create the Signer for this private key. + signer, err := ssh.ParsePrivateKey(key) + if err != nil { + log.Fatalf("unable to parse private key: %v", err) + } + + config := &ssh.ClientConfig{ + User: "user", + Auth: []ssh.AuthMethod{ + // Use the PublicKeys method for remote authentication. + ssh.PublicKeys(signer), + }, + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + + // Connect to the remote server and perform the SSH handshake. + client, err := ssh.Dial("tcp", "host.com:22", config) + if err != nil { + log.Fatalf("unable to connect: %v", err) + } + defer client.Close() +} + +func ExampleClient_Listen() { + var hostKey ssh.PublicKey + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("password"), + }, + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + // Dial your ssh server. + conn, err := ssh.Dial("tcp", "localhost:22", config) + if err != nil { + log.Fatal("unable to connect: ", err) + } + defer conn.Close() + + // Request the remote side to open port 8080 on all interfaces. + l, err := conn.Listen("tcp", "0.0.0.0:8080") + if err != nil { + log.Fatal("unable to register tcp forward: ", err) + } + defer l.Close() + + // Serve HTTP with your SSH server acting as a reverse proxy. + http.Serve(l, http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + fmt.Fprintf(resp, "Hello world!\n") + })) +} + +func ExampleSession_RequestPty() { + var hostKey ssh.PublicKey + // Create client config + config := &ssh.ClientConfig{ + User: "username", + Auth: []ssh.AuthMethod{ + ssh.Password("password"), + }, + HostKeyCallback: ssh.FixedHostKey(hostKey), + } + // Connect to ssh server + conn, err := ssh.Dial("tcp", "localhost:22", config) + if err != nil { + log.Fatal("unable to connect: ", err) + } + defer conn.Close() + // Create a session + session, err := conn.NewSession() + if err != nil { + log.Fatal("unable to create session: ", err) + } + defer session.Close() + // Set up terminal modes + modes := ssh.TerminalModes{ + ssh.ECHO: 0, // disable echoing + ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud + ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud + } + // Request pseudo terminal + if err := session.RequestPty("xterm", 40, 80, modes); err != nil { + log.Fatal("request for pseudo terminal failed: ", err) + } + // Start remote shell + if err := session.Shell(); err != nil { + log.Fatal("failed to start shell: ", err) + } +} + +func ExampleCertificate_SignCert() { + // Sign a certificate with a specific algorithm. + privateKey, err := rsa.GenerateKey(rand.Reader, 3072) + if err != nil { + log.Fatal("unable to generate RSA key: ", err) + } + publicKey, err := ssh.NewPublicKey(&privateKey.PublicKey) + if err != nil { + log.Fatal("unable to get RSA public key: ", err) + } + caKey, err := rsa.GenerateKey(rand.Reader, 3072) + if err != nil { + log.Fatal("unable to generate CA key: ", err) + } + signer, err := ssh.NewSignerFromKey(caKey) + if err != nil { + log.Fatal("unable to generate signer from key: ", err) + } + mas, err := ssh.NewSignerWithAlgorithms(signer.(ssh.AlgorithmSigner), []string{ssh.KeyAlgoRSASHA256}) + if err != nil { + log.Fatal("unable to create signer with algorithms: ", err) + } + certificate := ssh.Certificate{ + Key: publicKey, + CertType: ssh.UserCert, + } + if err := certificate.SignCert(rand.Reader, mas); err != nil { + log.Fatal("unable to sign certificate: ", err) + } + // Save the public key to a file and check that rsa-sha-256 is used for + // signing: + // ssh-keygen -L -f + fmt.Println(string(ssh.MarshalAuthorizedKey(&certificate))) +} diff --git a/tempfork/sshtest/ssh/handshake.go b/tempfork/sshtest/ssh/handshake.go new file mode 100644 index 0000000000000..2b42b2b597773 --- /dev/null +++ b/tempfork/sshtest/ssh/handshake.go @@ -0,0 +1,817 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto/rand" + "errors" + "fmt" + "io" + "log" + "net" + "strings" + "sync" +) + +// debugHandshake, if set, prints messages sent and received. Key +// exchange messages are printed as if DH were used, so the debug +// messages are wrong when using ECDH. +const debugHandshake = false + +// chanSize sets the amount of buffering SSH connections. This is +// primarily for testing: setting chanSize=0 uncovers deadlocks more +// quickly. +const chanSize = 16 + +// keyingTransport is a packet based transport that supports key +// changes. It need not be thread-safe. It should pass through +// msgNewKeys in both directions. +type keyingTransport interface { + packetConn + + // prepareKeyChange sets up a key change. The key change for a + // direction will be effected if a msgNewKeys message is sent + // or received. + prepareKeyChange(*algorithms, *kexResult) error + + // setStrictMode sets the strict KEX mode, notably triggering + // sequence number resets on sending or receiving msgNewKeys. + // If the sequence number is already > 1 when setStrictMode + // is called, an error is returned. + setStrictMode() error + + // setInitialKEXDone indicates to the transport that the initial key exchange + // was completed + setInitialKEXDone() +} + +// handshakeTransport implements rekeying on top of a keyingTransport +// and offers a thread-safe writePacket() interface. +type handshakeTransport struct { + conn keyingTransport + config *Config + + serverVersion []byte + clientVersion []byte + + // hostKeys is non-empty if we are the server. In that case, + // it contains all host keys that can be used to sign the + // connection. + hostKeys []Signer + + // publicKeyAuthAlgorithms is non-empty if we are the server. In that case, + // it contains the supported client public key authentication algorithms. + publicKeyAuthAlgorithms []string + + // hostKeyAlgorithms is non-empty if we are the client. In that case, + // we accept these key types from the server as host key. + hostKeyAlgorithms []string + + // On read error, incoming is closed, and readError is set. + incoming chan []byte + readError error + + mu sync.Mutex + writeError error + sentInitPacket []byte + sentInitMsg *kexInitMsg + pendingPackets [][]byte // Used when a key exchange is in progress. + writePacketsLeft uint32 + writeBytesLeft int64 + userAuthComplete bool // whether the user authentication phase is complete + + // If the read loop wants to schedule a kex, it pings this + // channel, and the write loop will send out a kex + // message. + requestKex chan struct{} + + // If the other side requests or confirms a kex, its kexInit + // packet is sent here for the write loop to find it. + startKex chan *pendingKex + kexLoopDone chan struct{} // closed (with writeError non-nil) when kexLoop exits + + // data for host key checking + hostKeyCallback HostKeyCallback + dialAddress string + remoteAddr net.Addr + + // bannerCallback is non-empty if we are the client and it has been set in + // ClientConfig. In that case it is called during the user authentication + // dance to handle a custom server's message. + bannerCallback BannerCallback + + // Algorithms agreed in the last key exchange. + algorithms *algorithms + + // Counters exclusively owned by readLoop. + readPacketsLeft uint32 + readBytesLeft int64 + + // The session ID or nil if first kex did not complete yet. + sessionID []byte + + // strictMode indicates if the other side of the handshake indicated + // that we should be following the strict KEX protocol restrictions. + strictMode bool +} + +type pendingKex struct { + otherInit []byte + done chan error +} + +func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { + t := &handshakeTransport{ + conn: conn, + serverVersion: serverVersion, + clientVersion: clientVersion, + incoming: make(chan []byte, chanSize), + requestKex: make(chan struct{}, 1), + startKex: make(chan *pendingKex), + kexLoopDone: make(chan struct{}), + + config: config, + } + t.resetReadThresholds() + t.resetWriteThresholds() + + // We always start with a mandatory key exchange. + t.requestKex <- struct{}{} + return t +} + +func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { + t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) + t.dialAddress = dialAddr + t.remoteAddr = addr + t.hostKeyCallback = config.HostKeyCallback + t.bannerCallback = config.BannerCallback + if config.HostKeyAlgorithms != nil { + t.hostKeyAlgorithms = config.HostKeyAlgorithms + } else { + t.hostKeyAlgorithms = supportedHostKeyAlgos + } + go t.readLoop() + go t.kexLoop() + return t +} + +func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport { + t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) + t.hostKeys = config.hostKeys + t.publicKeyAuthAlgorithms = config.PublicKeyAuthAlgorithms + go t.readLoop() + go t.kexLoop() + return t +} + +func (t *handshakeTransport) getSessionID() []byte { + return t.sessionID +} + +// waitSession waits for the session to be established. This should be +// the first thing to call after instantiating handshakeTransport. +func (t *handshakeTransport) waitSession() error { + p, err := t.readPacket() + if err != nil { + return err + } + if p[0] != msgNewKeys { + return fmt.Errorf("ssh: first packet should be msgNewKeys") + } + + return nil +} + +func (t *handshakeTransport) id() string { + if len(t.hostKeys) > 0 { + return "server" + } + return "client" +} + +func (t *handshakeTransport) printPacket(p []byte, write bool) { + action := "got" + if write { + action = "sent" + } + + if p[0] == msgChannelData || p[0] == msgChannelExtendedData { + log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p)) + } else { + msg, err := decode(p) + log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err) + } +} + +func (t *handshakeTransport) readPacket() ([]byte, error) { + p, ok := <-t.incoming + if !ok { + return nil, t.readError + } + return p, nil +} + +func (t *handshakeTransport) readLoop() { + first := true + for { + p, err := t.readOnePacket(first) + first = false + if err != nil { + t.readError = err + close(t.incoming) + break + } + // If this is the first kex, and strict KEX mode is enabled, + // we don't ignore any messages, as they may be used to manipulate + // the packet sequence numbers. + if !(t.sessionID == nil && t.strictMode) && (p[0] == msgIgnore || p[0] == msgDebug) { + continue + } + t.incoming <- p + } + + // Stop writers too. + t.recordWriteError(t.readError) + + // Unblock the writer should it wait for this. + close(t.startKex) + + // Don't close t.requestKex; it's also written to from writePacket. +} + +func (t *handshakeTransport) pushPacket(p []byte) error { + if debugHandshake { + t.printPacket(p, true) + } + return t.conn.writePacket(p) +} + +func (t *handshakeTransport) getWriteError() error { + t.mu.Lock() + defer t.mu.Unlock() + return t.writeError +} + +func (t *handshakeTransport) recordWriteError(err error) { + t.mu.Lock() + defer t.mu.Unlock() + if t.writeError == nil && err != nil { + t.writeError = err + } +} + +func (t *handshakeTransport) requestKeyExchange() { + select { + case t.requestKex <- struct{}{}: + default: + // something already requested a kex, so do nothing. + } +} + +func (t *handshakeTransport) resetWriteThresholds() { + t.writePacketsLeft = packetRekeyThreshold + if t.config.RekeyThreshold > 0 { + t.writeBytesLeft = int64(t.config.RekeyThreshold) + } else if t.algorithms != nil { + t.writeBytesLeft = t.algorithms.w.rekeyBytes() + } else { + t.writeBytesLeft = 1 << 30 + } +} + +func (t *handshakeTransport) kexLoop() { + +write: + for t.getWriteError() == nil { + var request *pendingKex + var sent bool + + for request == nil || !sent { + var ok bool + select { + case request, ok = <-t.startKex: + if !ok { + break write + } + case <-t.requestKex: + // nolint:revive + break + } + + if !sent { + if err := t.sendKexInit(); err != nil { + t.recordWriteError(err) + break + } + sent = true + } + } + + if err := t.getWriteError(); err != nil { + if request != nil { + request.done <- err + } + break + } + + // We're not servicing t.requestKex, but that is OK: + // we never block on sending to t.requestKex. + + // We're not servicing t.startKex, but the remote end + // has just sent us a kexInitMsg, so it can't send + // another key change request, until we close the done + // channel on the pendingKex request. + + err := t.enterKeyExchange(request.otherInit) + + t.mu.Lock() + t.writeError = err + t.sentInitPacket = nil + t.sentInitMsg = nil + + t.resetWriteThresholds() + + // we have completed the key exchange. Since the + // reader is still blocked, it is safe to clear out + // the requestKex channel. This avoids the situation + // where: 1) we consumed our own request for the + // initial kex, and 2) the kex from the remote side + // caused another send on the requestKex channel, + clear: + for { + select { + case <-t.requestKex: + // + default: + break clear + } + } + + request.done <- t.writeError + + // kex finished. Push packets that we received while + // the kex was in progress. Don't look at t.startKex + // and don't increment writtenSinceKex: if we trigger + // another kex while we are still busy with the last + // one, things will become very confusing. + for _, p := range t.pendingPackets { + t.writeError = t.pushPacket(p) + if t.writeError != nil { + break + } + } + t.pendingPackets = t.pendingPackets[:0] + t.mu.Unlock() + } + + // Unblock reader. + t.conn.Close() + + // drain startKex channel. We don't service t.requestKex + // because nobody does blocking sends there. + for request := range t.startKex { + request.done <- t.getWriteError() + } + + // Mark that the loop is done so that Close can return. + close(t.kexLoopDone) +} + +// The protocol uses uint32 for packet counters, so we can't let them +// reach 1<<32. We will actually read and write more packets than +// this, though: the other side may send more packets, and after we +// hit this limit on writing we will send a few more packets for the +// key exchange itself. +const packetRekeyThreshold = (1 << 31) + +func (t *handshakeTransport) resetReadThresholds() { + t.readPacketsLeft = packetRekeyThreshold + if t.config.RekeyThreshold > 0 { + t.readBytesLeft = int64(t.config.RekeyThreshold) + } else if t.algorithms != nil { + t.readBytesLeft = t.algorithms.r.rekeyBytes() + } else { + t.readBytesLeft = 1 << 30 + } +} + +func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) { + p, err := t.conn.readPacket() + if err != nil { + return nil, err + } + + if t.readPacketsLeft > 0 { + t.readPacketsLeft-- + } else { + t.requestKeyExchange() + } + + if t.readBytesLeft > 0 { + t.readBytesLeft -= int64(len(p)) + } else { + t.requestKeyExchange() + } + + if debugHandshake { + t.printPacket(p, false) + } + + if first && p[0] != msgKexInit { + return nil, fmt.Errorf("ssh: first packet should be msgKexInit") + } + + if p[0] != msgKexInit { + return p, nil + } + + firstKex := t.sessionID == nil + + kex := pendingKex{ + done: make(chan error, 1), + otherInit: p, + } + t.startKex <- &kex + err = <-kex.done + + if debugHandshake { + log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err) + } + + if err != nil { + return nil, err + } + + t.resetReadThresholds() + + // By default, a key exchange is hidden from higher layers by + // translating it into msgIgnore. + successPacket := []byte{msgIgnore} + if firstKex { + // sendKexInit() for the first kex waits for + // msgNewKeys so the authentication process is + // guaranteed to happen over an encrypted transport. + successPacket = []byte{msgNewKeys} + } + + return successPacket, nil +} + +const ( + kexStrictClient = "kex-strict-c-v00@openssh.com" + kexStrictServer = "kex-strict-s-v00@openssh.com" +) + +// sendKexInit sends a key change message. +func (t *handshakeTransport) sendKexInit() error { + t.mu.Lock() + defer t.mu.Unlock() + if t.sentInitMsg != nil { + // kexInits may be sent either in response to the other side, + // or because our side wants to initiate a key change, so we + // may have already sent a kexInit. In that case, don't send a + // second kexInit. + return nil + } + + msg := &kexInitMsg{ + CiphersClientServer: t.config.Ciphers, + CiphersServerClient: t.config.Ciphers, + MACsClientServer: t.config.MACs, + MACsServerClient: t.config.MACs, + CompressionClientServer: supportedCompressions, + CompressionServerClient: supportedCompressions, + } + io.ReadFull(rand.Reader, msg.Cookie[:]) + + // We mutate the KexAlgos slice, in order to add the kex-strict extension algorithm, + // and possibly to add the ext-info extension algorithm. Since the slice may be the + // user owned KeyExchanges, we create our own slice in order to avoid using user + // owned memory by mistake. + msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+2) // room for kex-strict and ext-info + msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...) + + isServer := len(t.hostKeys) > 0 + if isServer { + for _, k := range t.hostKeys { + // If k is a MultiAlgorithmSigner, we restrict the signature + // algorithms. If k is a AlgorithmSigner, presume it supports all + // signature algorithms associated with the key format. If k is not + // an AlgorithmSigner, we can only assume it only supports the + // algorithms that matches the key format. (This means that Sign + // can't pick a different default). + keyFormat := k.PublicKey().Type() + + switch s := k.(type) { + case MultiAlgorithmSigner: + for _, algo := range algorithmsForKeyFormat(keyFormat) { + if contains(s.Algorithms(), underlyingAlgo(algo)) { + msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algo) + } + } + case AlgorithmSigner: + msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, algorithmsForKeyFormat(keyFormat)...) + default: + msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat) + } + } + + if t.sessionID == nil { + msg.KexAlgos = append(msg.KexAlgos, kexStrictServer) + } + } else { + msg.ServerHostKeyAlgos = t.hostKeyAlgorithms + + // As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what + // algorithms the server supports for public key authentication. See RFC + // 8308, Section 2.1. + // + // We also send the strict KEX mode extension algorithm, in order to opt + // into the strict KEX mode. + if firstKeyExchange := t.sessionID == nil; firstKeyExchange { + msg.KexAlgos = append(msg.KexAlgos, "ext-info-c") + msg.KexAlgos = append(msg.KexAlgos, kexStrictClient) + } + + } + + packet := Marshal(msg) + + // writePacket destroys the contents, so save a copy. + packetCopy := make([]byte, len(packet)) + copy(packetCopy, packet) + + if err := t.pushPacket(packetCopy); err != nil { + return err + } + + t.sentInitMsg = msg + t.sentInitPacket = packet + + return nil +} + +var errSendBannerPhase = errors.New("ssh: SendAuthBanner outside of authentication phase") + +func (t *handshakeTransport) writePacket(p []byte) error { + t.mu.Lock() + defer t.mu.Unlock() + + switch p[0] { + case msgKexInit: + return errors.New("ssh: only handshakeTransport can send kexInit") + case msgNewKeys: + return errors.New("ssh: only handshakeTransport can send newKeys") + case msgUserAuthBanner: + if t.userAuthComplete { + return errSendBannerPhase + } + case msgUserAuthSuccess: + t.userAuthComplete = true + } + + if t.writeError != nil { + return t.writeError + } + + if t.sentInitMsg != nil { + // Copy the packet so the writer can reuse the buffer. + cp := make([]byte, len(p)) + copy(cp, p) + t.pendingPackets = append(t.pendingPackets, cp) + return nil + } + + if t.writeBytesLeft > 0 { + t.writeBytesLeft -= int64(len(p)) + } else { + t.requestKeyExchange() + } + + if t.writePacketsLeft > 0 { + t.writePacketsLeft-- + } else { + t.requestKeyExchange() + } + + if err := t.pushPacket(p); err != nil { + t.writeError = err + } + + return nil +} + +func (t *handshakeTransport) Close() error { + // Close the connection. This should cause the readLoop goroutine to wake up + // and close t.startKex, which will shut down kexLoop if running. + err := t.conn.Close() + + // Wait for the kexLoop goroutine to complete. + // At that point we know that the readLoop goroutine is complete too, + // because kexLoop itself waits for readLoop to close the startKex channel. + <-t.kexLoopDone + + return err +} + +func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { + if debugHandshake { + log.Printf("%s entered key exchange", t.id()) + } + + otherInit := &kexInitMsg{} + if err := Unmarshal(otherInitPacket, otherInit); err != nil { + return err + } + + magics := handshakeMagics{ + clientVersion: t.clientVersion, + serverVersion: t.serverVersion, + clientKexInit: otherInitPacket, + serverKexInit: t.sentInitPacket, + } + + clientInit := otherInit + serverInit := t.sentInitMsg + isClient := len(t.hostKeys) == 0 + if isClient { + clientInit, serverInit = serverInit, clientInit + + magics.clientKexInit = t.sentInitPacket + magics.serverKexInit = otherInitPacket + } + + var err error + t.algorithms, err = findAgreedAlgorithms(isClient, clientInit, serverInit) + if err != nil { + return err + } + + if t.sessionID == nil && ((isClient && contains(serverInit.KexAlgos, kexStrictServer)) || (!isClient && contains(clientInit.KexAlgos, kexStrictClient))) { + t.strictMode = true + if err := t.conn.setStrictMode(); err != nil { + return err + } + } + + // We don't send FirstKexFollows, but we handle receiving it. + // + // RFC 4253 section 7 defines the kex and the agreement method for + // first_kex_packet_follows. It states that the guessed packet + // should be ignored if the "kex algorithm and/or the host + // key algorithm is guessed wrong (server and client have + // different preferred algorithm), or if any of the other + // algorithms cannot be agreed upon". The other algorithms have + // already been checked above so the kex algorithm and host key + // algorithm are checked here. + if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) { + // other side sent a kex message for the wrong algorithm, + // which we have to ignore. + if _, err := t.conn.readPacket(); err != nil { + return err + } + } + + kex, ok := kexAlgoMap[t.algorithms.kex] + if !ok { + return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex) + } + + var result *kexResult + if len(t.hostKeys) > 0 { + result, err = t.server(kex, &magics) + } else { + result, err = t.client(kex, &magics) + } + + if err != nil { + return err + } + + firstKeyExchange := t.sessionID == nil + if firstKeyExchange { + t.sessionID = result.H + } + result.SessionID = t.sessionID + + if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil { + return err + } + if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { + return err + } + + // On the server side, after the first SSH_MSG_NEWKEYS, send a SSH_MSG_EXT_INFO + // message with the server-sig-algs extension if the client supports it. See + // RFC 8308, Sections 2.4 and 3.1, and [PROTOCOL], Section 1.9. + if !isClient && firstKeyExchange && contains(clientInit.KexAlgos, "ext-info-c") { + supportedPubKeyAuthAlgosList := strings.Join(t.publicKeyAuthAlgorithms, ",") + extInfo := &extInfoMsg{ + NumExtensions: 2, + Payload: make([]byte, 0, 4+15+4+len(supportedPubKeyAuthAlgosList)+4+16+4+1), + } + extInfo.Payload = appendInt(extInfo.Payload, len("server-sig-algs")) + extInfo.Payload = append(extInfo.Payload, "server-sig-algs"...) + extInfo.Payload = appendInt(extInfo.Payload, len(supportedPubKeyAuthAlgosList)) + extInfo.Payload = append(extInfo.Payload, supportedPubKeyAuthAlgosList...) + extInfo.Payload = appendInt(extInfo.Payload, len("ping@openssh.com")) + extInfo.Payload = append(extInfo.Payload, "ping@openssh.com"...) + extInfo.Payload = appendInt(extInfo.Payload, 1) + extInfo.Payload = append(extInfo.Payload, "0"...) + if err := t.conn.writePacket(Marshal(extInfo)); err != nil { + return err + } + } + + if packet, err := t.conn.readPacket(); err != nil { + return err + } else if packet[0] != msgNewKeys { + return unexpectedMessageError(msgNewKeys, packet[0]) + } + + if firstKeyExchange { + // Indicates to the transport that the first key exchange is completed + // after receiving SSH_MSG_NEWKEYS. + t.conn.setInitialKEXDone() + } + + return nil +} + +// algorithmSignerWrapper is an AlgorithmSigner that only supports the default +// key format algorithm. +// +// This is technically a violation of the AlgorithmSigner interface, but it +// should be unreachable given where we use this. Anyway, at least it returns an +// error instead of panicing or producing an incorrect signature. +type algorithmSignerWrapper struct { + Signer +} + +func (a algorithmSignerWrapper) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + if algorithm != underlyingAlgo(a.PublicKey().Type()) { + return nil, errors.New("ssh: internal error: algorithmSignerWrapper invoked with non-default algorithm") + } + return a.Sign(rand, data) +} + +func pickHostKey(hostKeys []Signer, algo string) AlgorithmSigner { + for _, k := range hostKeys { + if s, ok := k.(MultiAlgorithmSigner); ok { + if !contains(s.Algorithms(), underlyingAlgo(algo)) { + continue + } + } + + if algo == k.PublicKey().Type() { + return algorithmSignerWrapper{k} + } + + k, ok := k.(AlgorithmSigner) + if !ok { + continue + } + for _, a := range algorithmsForKeyFormat(k.PublicKey().Type()) { + if algo == a { + return k + } + } + } + return nil +} + +func (t *handshakeTransport) server(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) { + hostKey := pickHostKey(t.hostKeys, t.algorithms.hostKey) + if hostKey == nil { + return nil, errors.New("ssh: internal error: negotiated unsupported signature type") + } + + r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey, t.algorithms.hostKey) + return r, err +} + +func (t *handshakeTransport) client(kex kexAlgorithm, magics *handshakeMagics) (*kexResult, error) { + result, err := kex.Client(t.conn, t.config.Rand, magics) + if err != nil { + return nil, err + } + + hostKey, err := ParsePublicKey(result.HostKey) + if err != nil { + return nil, err + } + + if err := verifyHostKeySignature(hostKey, t.algorithms.hostKey, result); err != nil { + return nil, err + } + + err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) + if err != nil { + return nil, err + } + + return result, nil +} diff --git a/tempfork/sshtest/ssh/handshake_test.go b/tempfork/sshtest/ssh/handshake_test.go new file mode 100644 index 0000000000000..48628128dbadf --- /dev/null +++ b/tempfork/sshtest/ssh/handshake_test.go @@ -0,0 +1,1022 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "io" + "net" + "reflect" + "runtime" + "strings" + "sync" + "testing" +) + +type testChecker struct { + calls []string +} + +func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { + if dialAddr == "bad" { + return fmt.Errorf("dialAddr is bad") + } + + if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil { + return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr) + } + + t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal())) + + return nil +} + +// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and +// therefore is buffered (net.Pipe deadlocks if both sides start with +// a write.) +func netPipe() (net.Conn, net.Conn, error) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + listener, err = net.Listen("tcp", "[::1]:0") + if err != nil { + return nil, nil, err + } + } + defer listener.Close() + c1, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + return nil, nil, err + } + + c2, err := listener.Accept() + if err != nil { + c1.Close() + return nil, nil, err + } + + return c1, c2, nil +} + +// noiseTransport inserts ignore messages to check that the read loop +// and the key exchange filters out these messages. +type noiseTransport struct { + keyingTransport +} + +func (t *noiseTransport) writePacket(p []byte) error { + ignore := []byte{msgIgnore} + if err := t.keyingTransport.writePacket(ignore); err != nil { + return err + } + debug := []byte{msgDebug, 1, 2, 3} + if err := t.keyingTransport.writePacket(debug); err != nil { + return err + } + + return t.keyingTransport.writePacket(p) +} + +func addNoiseTransport(t keyingTransport) keyingTransport { + return &noiseTransport{t} +} + +// handshakePair creates two handshakeTransports connected with each +// other. If the noise argument is true, both transports will try to +// confuse the other side by sending ignore and debug messages. +func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) { + a, b, err := netPipe() + if err != nil { + return nil, nil, err + } + + var trC, trS keyingTransport + + trC = newTransport(a, rand.Reader, true) + trS = newTransport(b, rand.Reader, false) + if noise { + trC = addNoiseTransport(trC) + trS = addNoiseTransport(trS) + } + clientConf.SetDefaults() + + v := []byte("version") + client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr()) + + serverConf := &ServerConfig{} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.SetDefaults() + server = newServerTransport(trS, v, v, serverConf) + + if err := server.waitSession(); err != nil { + return nil, nil, fmt.Errorf("server.waitSession: %v", err) + } + if err := client.waitSession(); err != nil { + return nil, nil, fmt.Errorf("client.waitSession: %v", err) + } + + return client, server, nil +} + +func TestHandshakeBasic(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + checker := &syncChecker{ + waitCall: make(chan int, 10), + called: make(chan int, 10), + } + + checker.waitCall <- 1 + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + // Let first kex complete normally. + <-checker.called + + clientDone := make(chan int, 0) + gotHalf := make(chan int, 0) + const N = 20 + errorCh := make(chan error, 1) + + go func() { + defer close(clientDone) + // Client writes a bunch of stuff, and does a key + // change in the middle. This should not confuse the + // handshake in progress. We do this twice, so we test + // that the packet buffer is reset correctly. + for i := 0; i < N; i++ { + p := []byte{msgRequestSuccess, byte(i)} + if err := trC.writePacket(p); err != nil { + errorCh <- err + trC.Close() + return + } + if (i % 10) == 5 { + <-gotHalf + // halfway through, we request a key change. + trC.requestKeyExchange() + + // Wait until we can be sure the key + // change has really started before we + // write more. + <-checker.called + } + if (i % 10) == 7 { + // write some packets until the kex + // completes, to test buffering of + // packets. + checker.waitCall <- 1 + } + } + errorCh <- nil + }() + + // Server checks that client messages come in cleanly + i := 0 + for ; i < N; i++ { + p, err := trS.readPacket() + if err != nil && err != io.EOF { + t.Fatalf("server error: %v", err) + } + if (i % 10) == 5 { + gotHalf <- 1 + } + + want := []byte{msgRequestSuccess, byte(i)} + if bytes.Compare(p, want) != 0 { + t.Errorf("message %d: got %v, want %v", i, p, want) + } + } + <-clientDone + if err := <-errorCh; err != nil { + t.Fatalf("sendPacket: %v", err) + } + if i != N { + t.Errorf("received %d messages, want 10.", i) + } + + close(checker.called) + if _, ok := <-checker.called; ok { + // If all went well, we registered exactly 2 key changes: one + // that establishes the session, and one that we requested + // additionally. + t.Fatalf("got another host key checks after 2 handshakes") + } +} + +func TestForceFirstKex(t *testing.T) { + // like handshakePair, but must access the keyingTransport. + checker := &testChecker{} + clientConf := &ClientConfig{HostKeyCallback: checker.Check} + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + var trC, trS keyingTransport + + trC = newTransport(a, rand.Reader, true) + + // This is the disallowed packet: + trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})) + + // Rest of the setup. + trS = newTransport(b, rand.Reader, false) + clientConf.SetDefaults() + + v := []byte("version") + client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr()) + + serverConf := &ServerConfig{} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.SetDefaults() + server := newServerTransport(trS, v, v, serverConf) + + defer client.Close() + defer server.Close() + + // We setup the initial key exchange, but the remote side + // tries to send serviceRequestMsg in cleartext, which is + // disallowed. + + if err := server.waitSession(); err == nil { + t.Errorf("server first kex init should reject unexpected packet") + } +} + +func TestHandshakeAutoRekeyWrite(t *testing.T) { + checker := &syncChecker{ + called: make(chan int, 10), + waitCall: nil, + } + clientConf := &ClientConfig{HostKeyCallback: checker.Check} + clientConf.RekeyThreshold = 500 + trC, trS, err := handshakePair(clientConf, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + input := make([]byte, 251) + input[0] = msgRequestSuccess + + done := make(chan int, 1) + const numPacket = 5 + go func() { + defer close(done) + j := 0 + for ; j < numPacket; j++ { + if p, err := trS.readPacket(); err != nil { + break + } else if !bytes.Equal(input, p) { + t.Errorf("got packet type %d, want %d", p[0], input[0]) + } + } + + if j != numPacket { + t.Errorf("got %d, want 5 messages", j) + } + }() + + <-checker.called + + for i := 0; i < numPacket; i++ { + p := make([]byte, len(input)) + copy(p, input) + if err := trC.writePacket(p); err != nil { + t.Errorf("writePacket: %v", err) + } + if i == 2 { + // Make sure the kex is in progress. + <-checker.called + } + + } + <-done +} + +type syncChecker struct { + waitCall chan int + called chan int +} + +func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { + c.called <- 1 + if c.waitCall != nil { + <-c.waitCall + } + return nil +} + +func TestHandshakeAutoRekeyRead(t *testing.T) { + sync := &syncChecker{ + called: make(chan int, 2), + waitCall: nil, + } + clientConf := &ClientConfig{ + HostKeyCallback: sync.Check, + } + clientConf.RekeyThreshold = 500 + + trC, trS, err := handshakePair(clientConf, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + packet := make([]byte, 501) + packet[0] = msgRequestSuccess + if err := trS.writePacket(packet); err != nil { + t.Fatalf("writePacket: %v", err) + } + + // While we read out the packet, a key change will be + // initiated. + errorCh := make(chan error, 1) + go func() { + _, err := trC.readPacket() + errorCh <- err + }() + + if err := <-errorCh; err != nil { + t.Fatalf("readPacket(client): %v", err) + } + + <-sync.called +} + +// errorKeyingTransport generates errors after a given number of +// read/write operations. +type errorKeyingTransport struct { + packetConn + readLeft, writeLeft int +} + +func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error { + return nil +} + +func (n *errorKeyingTransport) getSessionID() []byte { + return nil +} + +func (n *errorKeyingTransport) writePacket(packet []byte) error { + if n.writeLeft == 0 { + n.Close() + return errors.New("barf") + } + + n.writeLeft-- + return n.packetConn.writePacket(packet) +} + +func (n *errorKeyingTransport) readPacket() ([]byte, error) { + if n.readLeft == 0 { + n.Close() + return nil, errors.New("barf") + } + + n.readLeft-- + return n.packetConn.readPacket() +} + +func (n *errorKeyingTransport) setStrictMode() error { return nil } + +func (n *errorKeyingTransport) setInitialKEXDone() {} + +func TestHandshakeErrorHandlingRead(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, i, -1, false) + } +} + +func TestHandshakeErrorHandlingWrite(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, -1, i, false) + } +} + +func TestHandshakeErrorHandlingReadCoupled(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, i, -1, true) + } +} + +func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) { + for i := 0; i < 20; i++ { + testHandshakeErrorHandlingN(t, -1, i, true) + } +} + +// testHandshakeErrorHandlingN runs handshakes, injecting errors. If +// handshakeTransport deadlocks, the go runtime will detect it and +// panic. +func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) { + if (runtime.GOOS == "js" || runtime.GOOS == "wasip1") && runtime.GOARCH == "wasm" { + t.Skipf("skipping on %s/wasm; see golang.org/issue/32840", runtime.GOOS) + } + msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)}) + + a, b := memPipe() + defer a.Close() + defer b.Close() + + key := testSigners["ecdsa"] + serverConf := Config{RekeyThreshold: minRekeyThreshold} + serverConf.SetDefaults() + serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'}) + serverConn.hostKeys = []Signer{key} + go serverConn.readLoop() + go serverConn.kexLoop() + + clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold} + clientConf.SetDefaults() + clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) + clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} + clientConn.hostKeyCallback = InsecureIgnoreHostKey() + go clientConn.readLoop() + go clientConn.kexLoop() + + var wg sync.WaitGroup + + for _, hs := range []packetConn{serverConn, clientConn} { + if !coupled { + wg.Add(2) + go func(c packetConn) { + for i := 0; ; i++ { + str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8) + err := c.writePacket(Marshal(&serviceRequestMsg{str})) + if err != nil { + break + } + } + wg.Done() + c.Close() + }(hs) + go func(c packetConn) { + for { + _, err := c.readPacket() + if err != nil { + break + } + } + wg.Done() + }(hs) + } else { + wg.Add(1) + go func(c packetConn) { + for { + _, err := c.readPacket() + if err != nil { + break + } + if err := c.writePacket(msg); err != nil { + break + } + + } + wg.Done() + }(hs) + } + } + wg.Wait() +} + +func TestDisconnect(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + checker := &testChecker{} + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + + defer trC.Close() + defer trS.Close() + + trC.writePacket([]byte{msgRequestSuccess, 0, 0}) + errMsg := &disconnectMsg{ + Reason: 42, + Message: "such is life", + } + trC.writePacket(Marshal(errMsg)) + trC.writePacket([]byte{msgRequestSuccess, 0, 0}) + + packet, err := trS.readPacket() + if err != nil { + t.Fatalf("readPacket 1: %v", err) + } + if packet[0] != msgRequestSuccess { + t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess) + } + + _, err = trS.readPacket() + if err == nil { + t.Errorf("readPacket 2 succeeded") + } else if !reflect.DeepEqual(err, errMsg) { + t.Errorf("got error %#v, want %#v", err, errMsg) + } + + _, err = trS.readPacket() + if err == nil { + t.Errorf("readPacket 3 succeeded") + } +} + +func TestHandshakeRekeyDefault(t *testing.T) { + clientConf := &ClientConfig{ + Config: Config{ + Ciphers: []string{"aes128-ctr"}, + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + trC, trS, err := handshakePair(clientConf, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + trC.writePacket([]byte{msgRequestSuccess, 0, 0}) + trC.Close() + + rgb := (1024 + trC.readBytesLeft) >> 30 + wgb := (1024 + trC.writeBytesLeft) >> 30 + + if rgb != 64 { + t.Errorf("got rekey after %dG read, want 64G", rgb) + } + if wgb != 64 { + t.Errorf("got rekey after %dG write, want 64G", wgb) + } +} + +func TestHandshakeAEADCipherNoMAC(t *testing.T) { + for _, cipher := range []string{chacha20Poly1305ID, gcm128CipherID} { + checker := &syncChecker{ + called: make(chan int, 1), + } + clientConf := &ClientConfig{ + Config: Config{ + Ciphers: []string{cipher}, + MACs: []string{}, + }, + HostKeyCallback: checker.Check, + } + trC, trS, err := handshakePair(clientConf, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + defer trC.Close() + defer trS.Close() + + <-checker.called + } +} + +// TestNoSHA2Support tests a host key Signer that is not an AlgorithmSigner and +// therefore can't do SHA-2 signatures. Ensures the server does not advertise +// support for them in this case. +func TestNoSHA2Support(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConf := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + } + serverConf.AddHostKey(&legacyRSASigner{testSigners["rsa"]}) + go func() { + _, _, _, err := NewServerConn(c1, serverConf) + if err != nil { + t.Error(err) + } + }() + + clientConf := &ClientConfig{ + User: "test", + Auth: []AuthMethod{Password("testpw")}, + HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()), + } + + if _, _, _, err := NewClientConn(c2, "", clientConf); err != nil { + t.Fatal(err) + } +} + +func TestMultiAlgoSignerHandshake(t *testing.T) { + algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner) + if !ok { + t.Fatal("rsa test signer does not implement the AlgorithmSigner interface") + } + multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}) + if err != nil { + t.Fatalf("unable to create multi algorithm signer: %v", err) + } + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverConf := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + } + serverConf.AddHostKey(multiAlgoSigner) + go NewServerConn(c1, serverConf) + + clientConf := &ClientConfig{ + User: "test", + Auth: []AuthMethod{Password("testpw")}, + HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()), + HostKeyAlgorithms: []string{KeyAlgoRSASHA512}, + } + + if _, _, _, err := NewClientConn(c2, "", clientConf); err != nil { + t.Fatal(err) + } +} + +func TestMultiAlgoSignerNoCommonHostKeyAlgo(t *testing.T) { + algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner) + if !ok { + t.Fatal("rsa test signer does not implement the AlgorithmSigner interface") + } + multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}) + if err != nil { + t.Fatalf("unable to create multi algorithm signer: %v", err) + } + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + // ssh-rsa is disabled server side + serverConf := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + } + serverConf.AddHostKey(multiAlgoSigner) + go NewServerConn(c1, serverConf) + + // the client only supports ssh-rsa + clientConf := &ClientConfig{ + User: "test", + Auth: []AuthMethod{Password("testpw")}, + HostKeyCallback: FixedHostKey(testSigners["rsa"].PublicKey()), + HostKeyAlgorithms: []string{KeyAlgoRSA}, + } + + _, _, _, err = NewClientConn(c2, "", clientConf) + if err == nil { + t.Fatal("succeeded connecting with no common hostkey algorithm") + } +} + +func TestPickIncompatibleHostKeyAlgo(t *testing.T) { + algorithmSigner, ok := testSigners["rsa"].(AlgorithmSigner) + if !ok { + t.Fatal("rsa test signer does not implement the AlgorithmSigner interface") + } + multiAlgoSigner, err := NewSignerWithAlgorithms(algorithmSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}) + if err != nil { + t.Fatalf("unable to create multi algorithm signer: %v", err) + } + signer := pickHostKey([]Signer{multiAlgoSigner}, KeyAlgoRSA) + if signer != nil { + t.Fatal("incompatible signer returned") + } +} + +func TestStrictKEXResetSeqFirstKEX(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + checker := &syncChecker{ + waitCall: make(chan int, 10), + called: make(chan int, 10), + } + + checker.waitCall <- 1 + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + <-checker.called + + t.Cleanup(func() { + trC.Close() + trS.Close() + }) + + // Throw away the msgExtInfo packet sent during the handshake by the server + _, err = trC.readPacket() + if err != nil { + t.Fatalf("readPacket failed: %s", err) + } + + // close the handshake transports before checking the sequence number to + // avoid races. + trC.Close() + trS.Close() + + // check that the sequence number counters. We reset after msgNewKeys, but + // then the server immediately writes msgExtInfo, and we close the + // transports so we expect read 2, write 0 on the client and read 1, write 1 + // on the server. + if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 || + trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 { + t.Errorf( + "unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)", + trC.conn.(*transport).reader.seqNum, + trC.conn.(*transport).writer.seqNum, + trS.conn.(*transport).reader.seqNum, + trS.conn.(*transport).writer.seqNum, + ) + } +} + +func TestStrictKEXResetSeqSuccessiveKEX(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + checker := &syncChecker{ + waitCall: make(chan int, 10), + called: make(chan int, 10), + } + + checker.waitCall <- 1 + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + <-checker.called + + t.Cleanup(func() { + trC.Close() + trS.Close() + }) + + // Throw away the msgExtInfo packet sent during the handshake by the server + _, err = trC.readPacket() + if err != nil { + t.Fatalf("readPacket failed: %s", err) + } + + // write and read five packets on either side to bump the sequence numbers + for i := 0; i < 5; i++ { + if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if _, err := trS.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } + if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if _, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } + } + + // Request a key exchange, which should cause the sequence numbers to reset + checker.waitCall <- 1 + trC.requestKeyExchange() + <-checker.called + + // write a packet on the client, and then read it, to verify the key change has actually happened, since + // the HostKeyCallback is called _during_ the handshake, so isn't actually indicative of the handshake + // finishing. + dummyPacket := []byte{99} + if err := trS.writePacket(dummyPacket); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if p, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } else if !bytes.Equal(p, dummyPacket) { + t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket) + } + + // close the handshake transports before checking the sequence number to + // avoid races. + trC.Close() + trS.Close() + + if trC.conn.(*transport).reader.seqNum != 2 || trC.conn.(*transport).writer.seqNum != 0 || + trS.conn.(*transport).reader.seqNum != 1 || trS.conn.(*transport).writer.seqNum != 1 { + t.Errorf( + "unexpected sequence counters:\nclient: reader %d (expected 2), writer %d (expected 0)\nserver: reader %d (expected 1), writer %d (expected 1)", + trC.conn.(*transport).reader.seqNum, + trC.conn.(*transport).writer.seqNum, + trS.conn.(*transport).reader.seqNum, + trS.conn.(*transport).writer.seqNum, + ) + } +} + +func TestSeqNumIncrease(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + checker := &syncChecker{ + waitCall: make(chan int, 10), + called: make(chan int, 10), + } + + checker.waitCall <- 1 + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) + if err != nil { + t.Fatalf("handshakePair: %v", err) + } + <-checker.called + + t.Cleanup(func() { + trC.Close() + trS.Close() + }) + + // Throw away the msgExtInfo packet sent during the handshake by the server + _, err = trC.readPacket() + if err != nil { + t.Fatalf("readPacket failed: %s", err) + } + + // write and read five packets on either side to bump the sequence numbers + for i := 0; i < 5; i++ { + if err := trC.writePacket([]byte{msgRequestSuccess}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if _, err := trS.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } + if err := trS.writePacket([]byte{msgRequestSuccess}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if _, err := trC.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } + } + + // close the handshake transports before checking the sequence number to + // avoid races. + trC.Close() + trS.Close() + + if trC.conn.(*transport).reader.seqNum != 7 || trC.conn.(*transport).writer.seqNum != 5 || + trS.conn.(*transport).reader.seqNum != 6 || trS.conn.(*transport).writer.seqNum != 6 { + t.Errorf( + "unexpected sequence counters:\nclient: reader %d (expected 7), writer %d (expected 5)\nserver: reader %d (expected 6), writer %d (expected 6)", + trC.conn.(*transport).reader.seqNum, + trC.conn.(*transport).writer.seqNum, + trS.conn.(*transport).reader.seqNum, + trS.conn.(*transport).writer.seqNum, + ) + } +} + +func TestStrictKEXUnexpectedMsg(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("see golang.org/issue/7237") + } + + // Check that unexpected messages during the handshake cause failure + _, _, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", true) + if err == nil { + t.Fatal("handshake should fail when there are unexpected messages during the handshake") + } + + trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }}, "addr", false) + if err != nil { + t.Fatalf("handshake failed: %s", err) + } + + // Check that ignore/debug pacekts are still ignored outside of the handshake + if err := trC.writePacket([]byte{msgIgnore}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + if err := trC.writePacket([]byte{msgDebug}); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + dummyPacket := []byte{99} + if err := trC.writePacket(dummyPacket); err != nil { + t.Fatalf("writePacket failed: %s", err) + } + + if p, err := trS.readPacket(); err != nil { + t.Fatalf("readPacket failed: %s", err) + } else if !bytes.Equal(p, dummyPacket) { + t.Fatalf("unexpected packet: got %x, want %x", p, dummyPacket) + } +} + +func TestStrictKEXMixed(t *testing.T) { + // Test that we still support a mixed connection, where one side sends kex-strict but the other + // side doesn't. + + a, b, err := netPipe() + if err != nil { + t.Fatalf("netPipe failed: %s", err) + } + + var trC, trS keyingTransport + + trC = newTransport(a, rand.Reader, true) + trS = newTransport(b, rand.Reader, false) + trS = addNoiseTransport(trS) + + clientConf := &ClientConfig{HostKeyCallback: func(hostname string, remote net.Addr, key PublicKey) error { return nil }} + clientConf.SetDefaults() + + v := []byte("version") + client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr()) + + serverConf := &ServerConfig{} + serverConf.AddHostKey(testSigners["ecdsa"]) + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.SetDefaults() + + transport := newHandshakeTransport(trS, &serverConf.Config, []byte("version"), []byte("version")) + transport.hostKeys = serverConf.hostKeys + transport.publicKeyAuthAlgorithms = serverConf.PublicKeyAuthAlgorithms + + readOneFailure := make(chan error, 1) + go func() { + if _, err := transport.readOnePacket(true); err != nil { + readOneFailure <- err + } + }() + + // Basically sendKexInit, but without the kex-strict extension algorithm + msg := &kexInitMsg{ + KexAlgos: transport.config.KeyExchanges, + CiphersClientServer: transport.config.Ciphers, + CiphersServerClient: transport.config.Ciphers, + MACsClientServer: transport.config.MACs, + MACsServerClient: transport.config.MACs, + CompressionClientServer: supportedCompressions, + CompressionServerClient: supportedCompressions, + ServerHostKeyAlgos: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512, KeyAlgoRSA}, + } + packet := Marshal(msg) + // writePacket destroys the contents, so save a copy. + packetCopy := make([]byte, len(packet)) + copy(packetCopy, packet) + if err := transport.pushPacket(packetCopy); err != nil { + t.Fatalf("pushPacket: %s", err) + } + transport.sentInitMsg = msg + transport.sentInitPacket = packet + + if err := transport.getWriteError(); err != nil { + t.Fatalf("getWriteError failed: %s", err) + } + var request *pendingKex + select { + case err = <-readOneFailure: + t.Fatalf("server readOnePacket failed: %s", err) + case request = <-transport.startKex: + // nolint:revive + break + } + + // We expect the following calls to fail if the side which does not support + // kex-strict sends unexpected/ignored packets during the handshake, even if + // the other side does support kex-strict. + + if err := transport.enterKeyExchange(request.otherInit); err != nil { + t.Fatalf("enterKeyExchange failed: %s", err) + } + if err := client.waitSession(); err != nil { + t.Fatalf("client.waitSession: %v", err) + } +} diff --git a/tempfork/sshtest/ssh/kex.go b/tempfork/sshtest/ssh/kex.go new file mode 100644 index 0000000000000..8a05f79902c09 --- /dev/null +++ b/tempfork/sshtest/ssh/kex.go @@ -0,0 +1,786 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/subtle" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + + "golang.org/x/crypto/curve25519" +) + +const ( + kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1" + kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1" + kexAlgoDH14SHA256 = "diffie-hellman-group14-sha256" + kexAlgoDH16SHA512 = "diffie-hellman-group16-sha512" + kexAlgoECDH256 = "ecdh-sha2-nistp256" + kexAlgoECDH384 = "ecdh-sha2-nistp384" + kexAlgoECDH521 = "ecdh-sha2-nistp521" + kexAlgoCurve25519SHA256LibSSH = "curve25519-sha256@libssh.org" + kexAlgoCurve25519SHA256 = "curve25519-sha256" + + // For the following kex only the client half contains a production + // ready implementation. The server half only consists of a minimal + // implementation to satisfy the automated tests. + kexAlgoDHGEXSHA1 = "diffie-hellman-group-exchange-sha1" + kexAlgoDHGEXSHA256 = "diffie-hellman-group-exchange-sha256" +) + +// kexResult captures the outcome of a key exchange. +type kexResult struct { + // Session hash. See also RFC 4253, section 8. + H []byte + + // Shared secret. See also RFC 4253, section 8. + K []byte + + // Host key as hashed into H. + HostKey []byte + + // Signature of H. + Signature []byte + + // A cryptographic hash function that matches the security + // level of the key exchange algorithm. It is used for + // calculating H, and for deriving keys from H and K. + Hash crypto.Hash + + // The session ID, which is the first H computed. This is used + // to derive key material inside the transport. + SessionID []byte +} + +// handshakeMagics contains data that is always included in the +// session hash. +type handshakeMagics struct { + clientVersion, serverVersion []byte + clientKexInit, serverKexInit []byte +} + +func (m *handshakeMagics) write(w io.Writer) { + writeString(w, m.clientVersion) + writeString(w, m.serverVersion) + writeString(w, m.clientKexInit) + writeString(w, m.serverKexInit) +} + +// kexAlgorithm abstracts different key exchange algorithms. +type kexAlgorithm interface { + // Server runs server-side key agreement, signing the result + // with a hostkey. algo is the negotiated algorithm, and may + // be a certificate type. + Server(p packetConn, rand io.Reader, magics *handshakeMagics, s AlgorithmSigner, algo string) (*kexResult, error) + + // Client runs the client-side key agreement. Caller is + // responsible for verifying the host key signature. + Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) +} + +// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement. +type dhGroup struct { + g, p, pMinus1 *big.Int + hashFunc crypto.Hash +} + +func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) { + if theirPublic.Cmp(bigOne) <= 0 || theirPublic.Cmp(group.pMinus1) >= 0 { + return nil, errors.New("ssh: DH parameter out of bounds") + } + return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil +} + +func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { + var x *big.Int + for { + var err error + if x, err = rand.Int(randSource, group.pMinus1); err != nil { + return nil, err + } + if x.Sign() > 0 { + break + } + } + + X := new(big.Int).Exp(group.g, x, group.p) + kexDHInit := kexDHInitMsg{ + X: X, + } + if err := c.writePacket(Marshal(&kexDHInit)); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var kexDHReply kexDHReplyMsg + if err = Unmarshal(packet, &kexDHReply); err != nil { + return nil, err + } + + ki, err := group.diffieHellman(kexDHReply.Y, x) + if err != nil { + return nil, err + } + + h := group.hashFunc.New() + magics.write(h) + writeString(h, kexDHReply.HostKey) + writeInt(h, X) + writeInt(h, kexDHReply.Y) + K := make([]byte, intLength(ki)) + marshalInt(K, ki) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: kexDHReply.HostKey, + Signature: kexDHReply.Signature, + Hash: group.hashFunc, + }, nil +} + +func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) { + packet, err := c.readPacket() + if err != nil { + return + } + var kexDHInit kexDHInitMsg + if err = Unmarshal(packet, &kexDHInit); err != nil { + return + } + + var y *big.Int + for { + if y, err = rand.Int(randSource, group.pMinus1); err != nil { + return + } + if y.Sign() > 0 { + break + } + } + + Y := new(big.Int).Exp(group.g, y, group.p) + ki, err := group.diffieHellman(kexDHInit.X, y) + if err != nil { + return nil, err + } + + hostKeyBytes := priv.PublicKey().Marshal() + + h := group.hashFunc.New() + magics.write(h) + writeString(h, hostKeyBytes) + writeInt(h, kexDHInit.X) + writeInt(h, Y) + + K := make([]byte, intLength(ki)) + marshalInt(K, ki) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, randSource, H, algo) + if err != nil { + return nil, err + } + + kexDHReply := kexDHReplyMsg{ + HostKey: hostKeyBytes, + Y: Y, + Signature: sig, + } + packet = Marshal(&kexDHReply) + + err = c.writePacket(packet) + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: group.hashFunc, + }, err +} + +// ecdh performs Elliptic Curve Diffie-Hellman key exchange as +// described in RFC 5656, section 4. +type ecdh struct { + curve elliptic.Curve +} + +func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { + ephKey, err := ecdsa.GenerateKey(kex.curve, rand) + if err != nil { + return nil, err + } + + kexInit := kexECDHInitMsg{ + ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y), + } + + serialized := Marshal(&kexInit) + if err := c.writePacket(serialized); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var reply kexECDHReplyMsg + if err = Unmarshal(packet, &reply); err != nil { + return nil, err + } + + x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey) + if err != nil { + return nil, err + } + + // generate shared secret + secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes()) + + h := ecHash(kex.curve).New() + magics.write(h) + writeString(h, reply.HostKey) + writeString(h, kexInit.ClientPubKey) + writeString(h, reply.EphemeralPubKey) + K := make([]byte, intLength(secret)) + marshalInt(K, secret) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: reply.HostKey, + Signature: reply.Signature, + Hash: ecHash(kex.curve), + }, nil +} + +// unmarshalECKey parses and checks an EC key. +func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) { + x, y = elliptic.Unmarshal(curve, pubkey) + if x == nil { + return nil, nil, errors.New("ssh: elliptic.Unmarshal failure") + } + if !validateECPublicKey(curve, x, y) { + return nil, nil, errors.New("ssh: public key not on curve") + } + return x, y, nil +} + +// validateECPublicKey checks that the point is a valid public key for +// the given curve. See [SEC1], 3.2.2 +func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool { + if x.Sign() == 0 && y.Sign() == 0 { + return false + } + + if x.Cmp(curve.Params().P) >= 0 { + return false + } + + if y.Cmp(curve.Params().P) >= 0 { + return false + } + + if !curve.IsOnCurve(x, y) { + return false + } + + // We don't check if N * PubKey == 0, since + // + // - the NIST curves have cofactor = 1, so this is implicit. + // (We don't foresee an implementation that supports non NIST + // curves) + // + // - for ephemeral keys, we don't need to worry about small + // subgroup attacks. + return true +} + +func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) { + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var kexECDHInit kexECDHInitMsg + if err = Unmarshal(packet, &kexECDHInit); err != nil { + return nil, err + } + + clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey) + if err != nil { + return nil, err + } + + // We could cache this key across multiple users/multiple + // connection attempts, but the benefit is small. OpenSSH + // generates a new key for each incoming connection. + ephKey, err := ecdsa.GenerateKey(kex.curve, rand) + if err != nil { + return nil, err + } + + hostKeyBytes := priv.PublicKey().Marshal() + + serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y) + + // generate shared secret + secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes()) + + h := ecHash(kex.curve).New() + magics.write(h) + writeString(h, hostKeyBytes) + writeString(h, kexECDHInit.ClientPubKey) + writeString(h, serializedEphKey) + + K := make([]byte, intLength(secret)) + marshalInt(K, secret) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, rand, H, algo) + if err != nil { + return nil, err + } + + reply := kexECDHReplyMsg{ + EphemeralPubKey: serializedEphKey, + HostKey: hostKeyBytes, + Signature: sig, + } + + serialized := Marshal(&reply) + if err := c.writePacket(serialized); err != nil { + return nil, err + } + + return &kexResult{ + H: H, + K: K, + HostKey: reply.HostKey, + Signature: sig, + Hash: ecHash(kex.curve), + }, nil +} + +// ecHash returns the hash to match the given elliptic curve, see RFC +// 5656, section 6.2.1 +func ecHash(curve elliptic.Curve) crypto.Hash { + bitSize := curve.Params().BitSize + switch { + case bitSize <= 256: + return crypto.SHA256 + case bitSize <= 384: + return crypto.SHA384 + } + return crypto.SHA512 +} + +var kexAlgoMap = map[string]kexAlgorithm{} + +func init() { + // This is the group called diffie-hellman-group1-sha1 in + // RFC 4253 and Oakley Group 2 in RFC 2409. + p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16) + kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + pMinus1: new(big.Int).Sub(p, bigOne), + hashFunc: crypto.SHA1, + } + + // This are the groups called diffie-hellman-group14-sha1 and + // diffie-hellman-group14-sha256 in RFC 4253 and RFC 8268, + // and Oakley Group 14 in RFC 3526. + p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) + group14 := &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + pMinus1: new(big.Int).Sub(p, bigOne), + } + + kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{ + g: group14.g, p: group14.p, pMinus1: group14.pMinus1, + hashFunc: crypto.SHA1, + } + kexAlgoMap[kexAlgoDH14SHA256] = &dhGroup{ + g: group14.g, p: group14.p, pMinus1: group14.pMinus1, + hashFunc: crypto.SHA256, + } + + // This is the group called diffie-hellman-group16-sha512 in RFC + // 8268 and Oakley Group 16 in RFC 3526. + p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AAAC42DAD33170D04507A33A85521ABDF1CBA64ECFB850458DBEF0A8AEA71575D060C7DB3970F85A6E1E4C7ABF5AE8CDB0933D71E8C94E04A25619DCEE3D2261AD2EE6BF12FFA06D98A0864D87602733EC86A64521F2B18177B200CBBE117577A615D6C770988C0BAD946E208E24FA074E5AB3143DB5BFCE0FD108E4B82D120A92108011A723C12A787E6D788719A10BDBA5B2699C327186AF4E23C1A946834B6150BDA2583E9CA2AD44CE8DBBBC2DB04DE8EF92E8EFC141FBECAA6287C59474E6BC05D99B2964FA090C3A2233BA186515BE7ED1F612970CEE2D7AFB81BDD762170481CD0069127D5B05AA993B4EA988D8FDDC186FFB7DC90A6C08F4DF435C934063199FFFFFFFFFFFFFFFF", 16) + + kexAlgoMap[kexAlgoDH16SHA512] = &dhGroup{ + g: new(big.Int).SetInt64(2), + p: p, + pMinus1: new(big.Int).Sub(p, bigOne), + hashFunc: crypto.SHA512, + } + + kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()} + kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()} + kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()} + kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{} + kexAlgoMap[kexAlgoCurve25519SHA256LibSSH] = &curve25519sha256{} + kexAlgoMap[kexAlgoDHGEXSHA1] = &dhGEXSHA{hashFunc: crypto.SHA1} + kexAlgoMap[kexAlgoDHGEXSHA256] = &dhGEXSHA{hashFunc: crypto.SHA256} +} + +// curve25519sha256 implements the curve25519-sha256 (formerly known as +// curve25519-sha256@libssh.org) key exchange method, as described in RFC 8731. +type curve25519sha256 struct{} + +type curve25519KeyPair struct { + priv [32]byte + pub [32]byte +} + +func (kp *curve25519KeyPair) generate(rand io.Reader) error { + if _, err := io.ReadFull(rand, kp.priv[:]); err != nil { + return err + } + curve25519.ScalarBaseMult(&kp.pub, &kp.priv) + return nil +} + +// curve25519Zeros is just an array of 32 zero bytes so that we have something +// convenient to compare against in order to reject curve25519 points with the +// wrong order. +var curve25519Zeros [32]byte + +func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { + var kp curve25519KeyPair + if err := kp.generate(rand); err != nil { + return nil, err + } + if err := c.writePacket(Marshal(&kexECDHInitMsg{kp.pub[:]})); err != nil { + return nil, err + } + + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var reply kexECDHReplyMsg + if err = Unmarshal(packet, &reply); err != nil { + return nil, err + } + if len(reply.EphemeralPubKey) != 32 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong length") + } + + var servPub, secret [32]byte + copy(servPub[:], reply.EphemeralPubKey) + curve25519.ScalarMult(&secret, &kp.priv, &servPub) + if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong order") + } + + h := crypto.SHA256.New() + magics.write(h) + writeString(h, reply.HostKey) + writeString(h, kp.pub[:]) + writeString(h, reply.EphemeralPubKey) + + ki := new(big.Int).SetBytes(secret[:]) + K := make([]byte, intLength(ki)) + marshalInt(K, ki) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: reply.HostKey, + Signature: reply.Signature, + Hash: crypto.SHA256, + }, nil +} + +func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) { + packet, err := c.readPacket() + if err != nil { + return + } + var kexInit kexECDHInitMsg + if err = Unmarshal(packet, &kexInit); err != nil { + return + } + + if len(kexInit.ClientPubKey) != 32 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong length") + } + + var kp curve25519KeyPair + if err := kp.generate(rand); err != nil { + return nil, err + } + + var clientPub, secret [32]byte + copy(clientPub[:], kexInit.ClientPubKey) + curve25519.ScalarMult(&secret, &kp.priv, &clientPub) + if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { + return nil, errors.New("ssh: peer's curve25519 public value has wrong order") + } + + hostKeyBytes := priv.PublicKey().Marshal() + + h := crypto.SHA256.New() + magics.write(h) + writeString(h, hostKeyBytes) + writeString(h, kexInit.ClientPubKey) + writeString(h, kp.pub[:]) + + ki := new(big.Int).SetBytes(secret[:]) + K := make([]byte, intLength(ki)) + marshalInt(K, ki) + h.Write(K) + + H := h.Sum(nil) + + sig, err := signAndMarshal(priv, rand, H, algo) + if err != nil { + return nil, err + } + + reply := kexECDHReplyMsg{ + EphemeralPubKey: kp.pub[:], + HostKey: hostKeyBytes, + Signature: sig, + } + if err := c.writePacket(Marshal(&reply)); err != nil { + return nil, err + } + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: crypto.SHA256, + }, nil +} + +// dhGEXSHA implements the diffie-hellman-group-exchange-sha1 and +// diffie-hellman-group-exchange-sha256 key agreement protocols, +// as described in RFC 4419 +type dhGEXSHA struct { + hashFunc crypto.Hash +} + +const ( + dhGroupExchangeMinimumBits = 2048 + dhGroupExchangePreferredBits = 2048 + dhGroupExchangeMaximumBits = 8192 +) + +func (gex *dhGEXSHA) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { + // Send GexRequest + kexDHGexRequest := kexDHGexRequestMsg{ + MinBits: dhGroupExchangeMinimumBits, + PreferedBits: dhGroupExchangePreferredBits, + MaxBits: dhGroupExchangeMaximumBits, + } + if err := c.writePacket(Marshal(&kexDHGexRequest)); err != nil { + return nil, err + } + + // Receive GexGroup + packet, err := c.readPacket() + if err != nil { + return nil, err + } + + var msg kexDHGexGroupMsg + if err = Unmarshal(packet, &msg); err != nil { + return nil, err + } + + // reject if p's bit length < dhGroupExchangeMinimumBits or > dhGroupExchangeMaximumBits + if msg.P.BitLen() < dhGroupExchangeMinimumBits || msg.P.BitLen() > dhGroupExchangeMaximumBits { + return nil, fmt.Errorf("ssh: server-generated gex p is out of range (%d bits)", msg.P.BitLen()) + } + + // Check if g is safe by verifying that 1 < g < p-1 + pMinusOne := new(big.Int).Sub(msg.P, bigOne) + if msg.G.Cmp(bigOne) <= 0 || msg.G.Cmp(pMinusOne) >= 0 { + return nil, fmt.Errorf("ssh: server provided gex g is not safe") + } + + // Send GexInit + pHalf := new(big.Int).Rsh(msg.P, 1) + x, err := rand.Int(randSource, pHalf) + if err != nil { + return nil, err + } + X := new(big.Int).Exp(msg.G, x, msg.P) + kexDHGexInit := kexDHGexInitMsg{ + X: X, + } + if err := c.writePacket(Marshal(&kexDHGexInit)); err != nil { + return nil, err + } + + // Receive GexReply + packet, err = c.readPacket() + if err != nil { + return nil, err + } + + var kexDHGexReply kexDHGexReplyMsg + if err = Unmarshal(packet, &kexDHGexReply); err != nil { + return nil, err + } + + if kexDHGexReply.Y.Cmp(bigOne) <= 0 || kexDHGexReply.Y.Cmp(pMinusOne) >= 0 { + return nil, errors.New("ssh: DH parameter out of bounds") + } + kInt := new(big.Int).Exp(kexDHGexReply.Y, x, msg.P) + + // Check if k is safe by verifying that k > 1 and k < p - 1 + if kInt.Cmp(bigOne) <= 0 || kInt.Cmp(pMinusOne) >= 0 { + return nil, fmt.Errorf("ssh: derived k is not safe") + } + + h := gex.hashFunc.New() + magics.write(h) + writeString(h, kexDHGexReply.HostKey) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits)) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits)) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits)) + writeInt(h, msg.P) + writeInt(h, msg.G) + writeInt(h, X) + writeInt(h, kexDHGexReply.Y) + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + return &kexResult{ + H: h.Sum(nil), + K: K, + HostKey: kexDHGexReply.HostKey, + Signature: kexDHGexReply.Signature, + Hash: gex.hashFunc, + }, nil +} + +// Server half implementation of the Diffie Hellman Key Exchange with SHA1 and SHA256. +// +// This is a minimal implementation to satisfy the automated tests. +func (gex dhGEXSHA) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv AlgorithmSigner, algo string) (result *kexResult, err error) { + // Receive GexRequest + packet, err := c.readPacket() + if err != nil { + return + } + var kexDHGexRequest kexDHGexRequestMsg + if err = Unmarshal(packet, &kexDHGexRequest); err != nil { + return + } + + // Send GexGroup + // This is the group called diffie-hellman-group14-sha1 in RFC + // 4253 and Oakley Group 14 in RFC 3526. + p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) + g := big.NewInt(2) + + msg := &kexDHGexGroupMsg{ + P: p, + G: g, + } + if err := c.writePacket(Marshal(msg)); err != nil { + return nil, err + } + + // Receive GexInit + packet, err = c.readPacket() + if err != nil { + return + } + var kexDHGexInit kexDHGexInitMsg + if err = Unmarshal(packet, &kexDHGexInit); err != nil { + return + } + + pHalf := new(big.Int).Rsh(p, 1) + + y, err := rand.Int(randSource, pHalf) + if err != nil { + return + } + Y := new(big.Int).Exp(g, y, p) + + pMinusOne := new(big.Int).Sub(p, bigOne) + if kexDHGexInit.X.Cmp(bigOne) <= 0 || kexDHGexInit.X.Cmp(pMinusOne) >= 0 { + return nil, errors.New("ssh: DH parameter out of bounds") + } + kInt := new(big.Int).Exp(kexDHGexInit.X, y, p) + + hostKeyBytes := priv.PublicKey().Marshal() + + h := gex.hashFunc.New() + magics.write(h) + writeString(h, hostKeyBytes) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMinimumBits)) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangePreferredBits)) + binary.Write(h, binary.BigEndian, uint32(dhGroupExchangeMaximumBits)) + writeInt(h, p) + writeInt(h, g) + writeInt(h, kexDHGexInit.X) + writeInt(h, Y) + + K := make([]byte, intLength(kInt)) + marshalInt(K, kInt) + h.Write(K) + + H := h.Sum(nil) + + // H is already a hash, but the hostkey signing will apply its + // own key-specific hash algorithm. + sig, err := signAndMarshal(priv, randSource, H, algo) + if err != nil { + return nil, err + } + + kexDHGexReply := kexDHGexReplyMsg{ + HostKey: hostKeyBytes, + Y: Y, + Signature: sig, + } + packet = Marshal(&kexDHGexReply) + + err = c.writePacket(packet) + + return &kexResult{ + H: H, + K: K, + HostKey: hostKeyBytes, + Signature: sig, + Hash: gex.hashFunc, + }, err +} diff --git a/tempfork/sshtest/ssh/kex_test.go b/tempfork/sshtest/ssh/kex_test.go new file mode 100644 index 0000000000000..cb7f66a5092f6 --- /dev/null +++ b/tempfork/sshtest/ssh/kex_test.go @@ -0,0 +1,106 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Key exchange tests. + +import ( + "crypto/rand" + "fmt" + "reflect" + "sync" + "testing" +) + +// Runs multiple key exchanges concurrent to detect potential data races with +// kex obtained from the global kexAlgoMap. +// This test needs to be executed using the race detector in order to detect +// race conditions. +func TestKexes(t *testing.T) { + type kexResultErr struct { + result *kexResult + err error + } + + for name, kex := range kexAlgoMap { + t.Run(name, func(t *testing.T) { + wg := sync.WaitGroup{} + for i := 0; i < 3; i++ { + wg.Add(1) + go func() { + defer wg.Done() + a, b := memPipe() + + s := make(chan kexResultErr, 1) + c := make(chan kexResultErr, 1) + var magics handshakeMagics + go func() { + r, e := kex.Client(a, rand.Reader, &magics) + a.Close() + c <- kexResultErr{r, e} + }() + go func() { + r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"].(AlgorithmSigner), testSigners["ecdsa"].PublicKey().Type()) + b.Close() + s <- kexResultErr{r, e} + }() + + clientRes := <-c + serverRes := <-s + if clientRes.err != nil { + t.Errorf("client: %v", clientRes.err) + } + if serverRes.err != nil { + t.Errorf("server: %v", serverRes.err) + } + if !reflect.DeepEqual(clientRes.result, serverRes.result) { + t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result) + } + }() + } + wg.Wait() + }) + } +} + +func BenchmarkKexes(b *testing.B) { + type kexResultErr struct { + result *kexResult + err error + } + + for name, kex := range kexAlgoMap { + b.Run(name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + t1, t2 := memPipe() + + s := make(chan kexResultErr, 1) + c := make(chan kexResultErr, 1) + var magics handshakeMagics + + go func() { + r, e := kex.Client(t1, rand.Reader, &magics) + t1.Close() + c <- kexResultErr{r, e} + }() + go func() { + r, e := kex.Server(t2, rand.Reader, &magics, testSigners["ecdsa"].(AlgorithmSigner), testSigners["ecdsa"].PublicKey().Type()) + t2.Close() + s <- kexResultErr{r, e} + }() + + clientRes := <-c + serverRes := <-s + + if clientRes.err != nil { + panic(fmt.Sprintf("client: %v", clientRes.err)) + } + if serverRes.err != nil { + panic(fmt.Sprintf("server: %v", serverRes.err)) + } + } + }) + } +} diff --git a/tempfork/sshtest/ssh/keys.go b/tempfork/sshtest/ssh/keys.go new file mode 100644 index 0000000000000..4a3d769d98447 --- /dev/null +++ b/tempfork/sshtest/ssh/keys.go @@ -0,0 +1,1626 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto" + "crypto/dsa" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/md5" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/asn1" + "encoding/base64" + "encoding/binary" + "encoding/hex" + "encoding/pem" + "errors" + "fmt" + "io" + "math/big" + "strings" +) + +// Public key algorithms names. These values can appear in PublicKey.Type, +// ClientConfig.HostKeyAlgorithms, Signature.Format, or as AlgorithmSigner +// arguments. +const ( + KeyAlgoRSA = "ssh-rsa" + KeyAlgoDSA = "ssh-dss" + KeyAlgoECDSA256 = "ecdsa-sha2-nistp256" + KeyAlgoSKECDSA256 = "sk-ecdsa-sha2-nistp256@openssh.com" + KeyAlgoECDSA384 = "ecdsa-sha2-nistp384" + KeyAlgoECDSA521 = "ecdsa-sha2-nistp521" + KeyAlgoED25519 = "ssh-ed25519" + KeyAlgoSKED25519 = "sk-ssh-ed25519@openssh.com" + + // KeyAlgoRSASHA256 and KeyAlgoRSASHA512 are only public key algorithms, not + // public key formats, so they can't appear as a PublicKey.Type. The + // corresponding PublicKey.Type is KeyAlgoRSA. See RFC 8332, Section 2. + KeyAlgoRSASHA256 = "rsa-sha2-256" + KeyAlgoRSASHA512 = "rsa-sha2-512" +) + +const ( + // Deprecated: use KeyAlgoRSA. + SigAlgoRSA = KeyAlgoRSA + // Deprecated: use KeyAlgoRSASHA256. + SigAlgoRSASHA2256 = KeyAlgoRSASHA256 + // Deprecated: use KeyAlgoRSASHA512. + SigAlgoRSASHA2512 = KeyAlgoRSASHA512 +) + +// parsePubKey parses a public key of the given algorithm. +// Use ParsePublicKey for keys with prepended algorithm. +func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) { + switch algo { + case KeyAlgoRSA: + return parseRSA(in) + case KeyAlgoDSA: + return parseDSA(in) + case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: + return parseECDSA(in) + case KeyAlgoSKECDSA256: + return parseSKECDSA(in) + case KeyAlgoED25519: + return parseED25519(in) + case KeyAlgoSKED25519: + return parseSKEd25519(in) + case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoSKECDSA256v01, CertAlgoED25519v01, CertAlgoSKED25519v01: + cert, err := parseCert(in, certKeyAlgoNames[algo]) + if err != nil { + return nil, nil, err + } + return cert, nil, nil + } + return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", algo) +} + +// parseAuthorizedKey parses a public key in OpenSSH authorized_keys format +// (see sshd(8) manual page) once the options and key type fields have been +// removed. +func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) { + in = bytes.TrimSpace(in) + + i := bytes.IndexAny(in, " \t") + if i == -1 { + i = len(in) + } + base64Key := in[:i] + + key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key))) + n, err := base64.StdEncoding.Decode(key, base64Key) + if err != nil { + return nil, "", err + } + key = key[:n] + out, err = ParsePublicKey(key) + if err != nil { + return nil, "", err + } + comment = string(bytes.TrimSpace(in[i:])) + return out, comment, nil +} + +// ParseKnownHosts parses an entry in the format of the known_hosts file. +// +// The known_hosts format is documented in the sshd(8) manual page. This +// function will parse a single entry from in. On successful return, marker +// will contain the optional marker value (i.e. "cert-authority" or "revoked") +// or else be empty, hosts will contain the hosts that this entry matches, +// pubKey will contain the public key and comment will contain any trailing +// comment at the end of the line. See the sshd(8) manual page for the various +// forms that a host string can take. +// +// The unparsed remainder of the input will be returned in rest. This function +// can be called repeatedly to parse multiple entries. +// +// If no entries were found in the input then err will be io.EOF. Otherwise a +// non-nil err value indicates a parse error. +func ParseKnownHosts(in []byte) (marker string, hosts []string, pubKey PublicKey, comment string, rest []byte, err error) { + for len(in) > 0 { + end := bytes.IndexByte(in, '\n') + if end != -1 { + rest = in[end+1:] + in = in[:end] + } else { + rest = nil + } + + end = bytes.IndexByte(in, '\r') + if end != -1 { + in = in[:end] + } + + in = bytes.TrimSpace(in) + if len(in) == 0 || in[0] == '#' { + in = rest + continue + } + + i := bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + // Strip out the beginning of the known_host key. + // This is either an optional marker or a (set of) hostname(s). + keyFields := bytes.Fields(in) + if len(keyFields) < 3 || len(keyFields) > 5 { + return "", nil, nil, "", nil, errors.New("ssh: invalid entry in known_hosts data") + } + + // keyFields[0] is either "@cert-authority", "@revoked" or a comma separated + // list of hosts + marker := "" + if keyFields[0][0] == '@' { + marker = string(keyFields[0][1:]) + keyFields = keyFields[1:] + } + + hosts := string(keyFields[0]) + // keyFields[1] contains the key type (e.g. “ssh-rsa”). + // However, that information is duplicated inside the + // base64-encoded key and so is ignored here. + + key := bytes.Join(keyFields[2:], []byte(" ")) + if pubKey, comment, err = parseAuthorizedKey(key); err != nil { + return "", nil, nil, "", nil, err + } + + return marker, strings.Split(hosts, ","), pubKey, comment, rest, nil + } + + return "", nil, nil, "", nil, io.EOF +} + +// ParseAuthorizedKey parses a public key from an authorized_keys +// file used in OpenSSH according to the sshd(8) manual page. +func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { + for len(in) > 0 { + end := bytes.IndexByte(in, '\n') + if end != -1 { + rest = in[end+1:] + in = in[:end] + } else { + rest = nil + } + + end = bytes.IndexByte(in, '\r') + if end != -1 { + in = in[:end] + } + + in = bytes.TrimSpace(in) + if len(in) == 0 || in[0] == '#' { + in = rest + continue + } + + i := bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { + return out, comment, options, rest, nil + } + + // No key type recognised. Maybe there's an options field at + // the beginning. + var b byte + inQuote := false + var candidateOptions []string + optionStart := 0 + for i, b = range in { + isEnd := !inQuote && (b == ' ' || b == '\t') + if (b == ',' && !inQuote) || isEnd { + if i-optionStart > 0 { + candidateOptions = append(candidateOptions, string(in[optionStart:i])) + } + optionStart = i + 1 + } + if isEnd { + break + } + if b == '"' && (i == 0 || (i > 0 && in[i-1] != '\\')) { + inQuote = !inQuote + } + } + for i < len(in) && (in[i] == ' ' || in[i] == '\t') { + i++ + } + if i == len(in) { + // Invalid line: unmatched quote + in = rest + continue + } + + in = in[i:] + i = bytes.IndexAny(in, " \t") + if i == -1 { + in = rest + continue + } + + if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { + options = candidateOptions + return out, comment, options, rest, nil + } + + in = rest + continue + } + + return nil, "", nil, nil, errors.New("ssh: no key found") +} + +// ParsePublicKey parses an SSH public key formatted for use in +// the SSH wire protocol according to RFC 4253, section 6.6. +func ParsePublicKey(in []byte) (out PublicKey, err error) { + algo, in, ok := parseString(in) + if !ok { + return nil, errShortRead + } + var rest []byte + out, rest, err = parsePubKey(in, string(algo)) + if len(rest) > 0 { + return nil, errors.New("ssh: trailing junk in public key") + } + + return out, err +} + +// MarshalAuthorizedKey serializes key for inclusion in an OpenSSH +// authorized_keys file. The return value ends with newline. +func MarshalAuthorizedKey(key PublicKey) []byte { + b := &bytes.Buffer{} + b.WriteString(key.Type()) + b.WriteByte(' ') + e := base64.NewEncoder(base64.StdEncoding, b) + e.Write(key.Marshal()) + e.Close() + b.WriteByte('\n') + return b.Bytes() +} + +// MarshalPrivateKey returns a PEM block with the private key serialized in the +// OpenSSH format. +func MarshalPrivateKey(key crypto.PrivateKey, comment string) (*pem.Block, error) { + return marshalOpenSSHPrivateKey(key, comment, unencryptedOpenSSHMarshaler) +} + +// PublicKey represents a public key using an unspecified algorithm. +// +// Some PublicKeys provided by this package also implement CryptoPublicKey. +type PublicKey interface { + // Type returns the key format name, e.g. "ssh-rsa". + Type() string + + // Marshal returns the serialized key data in SSH wire format, with the name + // prefix. To unmarshal the returned data, use the ParsePublicKey function. + Marshal() []byte + + // Verify that sig is a signature on the given data using this key. This + // method will hash the data appropriately first. sig.Format is allowed to + // be any signature algorithm compatible with the key type, the caller + // should check if it has more stringent requirements. + Verify(data []byte, sig *Signature) error +} + +// CryptoPublicKey, if implemented by a PublicKey, +// returns the underlying crypto.PublicKey form of the key. +type CryptoPublicKey interface { + CryptoPublicKey() crypto.PublicKey +} + +// A Signer can create signatures that verify against a public key. +// +// Some Signers provided by this package also implement MultiAlgorithmSigner. +type Signer interface { + // PublicKey returns the associated PublicKey. + PublicKey() PublicKey + + // Sign returns a signature for the given data. This method will hash the + // data appropriately first. The signature algorithm is expected to match + // the key format returned by the PublicKey.Type method (and not to be any + // alternative algorithm supported by the key format). + Sign(rand io.Reader, data []byte) (*Signature, error) +} + +// An AlgorithmSigner is a Signer that also supports specifying an algorithm to +// use for signing. +// +// An AlgorithmSigner can't advertise the algorithms it supports, unless it also +// implements MultiAlgorithmSigner, so it should be prepared to be invoked with +// every algorithm supported by the public key format. +type AlgorithmSigner interface { + Signer + + // SignWithAlgorithm is like Signer.Sign, but allows specifying a desired + // signing algorithm. Callers may pass an empty string for the algorithm in + // which case the AlgorithmSigner will use a default algorithm. This default + // doesn't currently control any behavior in this package. + SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) +} + +// MultiAlgorithmSigner is an AlgorithmSigner that also reports the algorithms +// supported by that signer. +type MultiAlgorithmSigner interface { + AlgorithmSigner + + // Algorithms returns the available algorithms in preference order. The list + // must not be empty, and it must not include certificate types. + Algorithms() []string +} + +// NewSignerWithAlgorithms returns a signer restricted to the specified +// algorithms. The algorithms must be set in preference order. The list must not +// be empty, and it must not include certificate types. An error is returned if +// the specified algorithms are incompatible with the public key type. +func NewSignerWithAlgorithms(signer AlgorithmSigner, algorithms []string) (MultiAlgorithmSigner, error) { + if len(algorithms) == 0 { + return nil, errors.New("ssh: please specify at least one valid signing algorithm") + } + var signerAlgos []string + supportedAlgos := algorithmsForKeyFormat(underlyingAlgo(signer.PublicKey().Type())) + if s, ok := signer.(*multiAlgorithmSigner); ok { + signerAlgos = s.Algorithms() + } else { + signerAlgos = supportedAlgos + } + + for _, algo := range algorithms { + if !contains(supportedAlgos, algo) { + return nil, fmt.Errorf("ssh: algorithm %q is not supported for key type %q", + algo, signer.PublicKey().Type()) + } + if !contains(signerAlgos, algo) { + return nil, fmt.Errorf("ssh: algorithm %q is restricted for the provided signer", algo) + } + } + return &multiAlgorithmSigner{ + AlgorithmSigner: signer, + supportedAlgorithms: algorithms, + }, nil +} + +type multiAlgorithmSigner struct { + AlgorithmSigner + supportedAlgorithms []string +} + +func (s *multiAlgorithmSigner) Algorithms() []string { + return s.supportedAlgorithms +} + +func (s *multiAlgorithmSigner) isAlgorithmSupported(algorithm string) bool { + if algorithm == "" { + algorithm = underlyingAlgo(s.PublicKey().Type()) + } + for _, algo := range s.supportedAlgorithms { + if algorithm == algo { + return true + } + } + return false +} + +func (s *multiAlgorithmSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + if !s.isAlgorithmSupported(algorithm) { + return nil, fmt.Errorf("ssh: algorithm %q is not supported: %v", algorithm, s.supportedAlgorithms) + } + return s.AlgorithmSigner.SignWithAlgorithm(rand, data, algorithm) +} + +type rsaPublicKey rsa.PublicKey + +func (r *rsaPublicKey) Type() string { + return "ssh-rsa" +} + +// parseRSA parses an RSA key according to RFC 4253, section 6.6. +func parseRSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + E *big.Int + N *big.Int + Rest []byte `ssh:"rest"` + } + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + if w.E.BitLen() > 24 { + return nil, nil, errors.New("ssh: exponent too large") + } + e := w.E.Int64() + if e < 3 || e&1 == 0 { + return nil, nil, errors.New("ssh: incorrect exponent") + } + + var key rsa.PublicKey + key.E = int(e) + key.N = w.N + return (*rsaPublicKey)(&key), w.Rest, nil +} + +func (r *rsaPublicKey) Marshal() []byte { + e := new(big.Int).SetInt64(int64(r.E)) + // RSA publickey struct layout should match the struct used by + // parseRSACert in the x/crypto/ssh/agent package. + wirekey := struct { + Name string + E *big.Int + N *big.Int + }{ + KeyAlgoRSA, + e, + r.N, + } + return Marshal(&wirekey) +} + +func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error { + supportedAlgos := algorithmsForKeyFormat(r.Type()) + if !contains(supportedAlgos, sig.Format) { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type()) + } + hash := hashFuncs[sig.Format] + h := hash.New() + h.Write(data) + digest := h.Sum(nil) + + // Signatures in PKCS1v15 must match the key's modulus in + // length. However with SSH, some signers provide RSA + // signatures which are missing the MSB 0's of the bignum + // represented. With ssh-rsa signatures, this is encouraged by + // the spec (even though e.g. OpenSSH will give the full + // length unconditionally). With rsa-sha2-* signatures, the + // verifier is allowed to support these, even though they are + // out of spec. See RFC 4253 Section 6.6 for ssh-rsa and RFC + // 8332 Section 3 for rsa-sha2-* details. + // + // In practice: + // * OpenSSH always allows "short" signatures: + // https://github.com/openssh/openssh-portable/blob/V_9_8_P1/ssh-rsa.c#L526 + // but always generates padded signatures: + // https://github.com/openssh/openssh-portable/blob/V_9_8_P1/ssh-rsa.c#L439 + // + // * PuTTY versions 0.81 and earlier will generate short + // signatures for all RSA signature variants. Note that + // PuTTY is embedded in other software, such as WinSCP and + // FileZilla. At the time of writing, a patch has been + // applied to PuTTY to generate padded signatures for + // rsa-sha2-*, but not yet released: + // https://git.tartarus.org/?p=simon/putty.git;a=commitdiff;h=a5bcf3d384e1bf15a51a6923c3724cbbee022d8e + // + // * SSH.NET versions 2024.0.0 and earlier will generate short + // signatures for all RSA signature variants, fixed in 2024.1.0: + // https://github.com/sshnet/SSH.NET/releases/tag/2024.1.0 + // + // As a result, we pad these up to the key size by inserting + // leading 0's. + // + // Note that support for short signatures with rsa-sha2-* may + // be removed in the future due to such signatures not being + // allowed by the spec. + blob := sig.Blob + keySize := (*rsa.PublicKey)(r).Size() + if len(blob) < keySize { + padded := make([]byte, keySize) + copy(padded[keySize-len(blob):], blob) + blob = padded + } + return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), hash, digest, blob) +} + +func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*rsa.PublicKey)(r) +} + +type dsaPublicKey dsa.PublicKey + +func (k *dsaPublicKey) Type() string { + return "ssh-dss" +} + +func checkDSAParams(param *dsa.Parameters) error { + // SSH specifies FIPS 186-2, which only provided a single size + // (1024 bits) DSA key. FIPS 186-3 allows for larger key + // sizes, which would confuse SSH. + if l := param.P.BitLen(); l != 1024 { + return fmt.Errorf("ssh: unsupported DSA key size %d", l) + } + + return nil +} + +// parseDSA parses an DSA key according to RFC 4253, section 6.6. +func parseDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + P, Q, G, Y *big.Int + Rest []byte `ssh:"rest"` + } + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + param := dsa.Parameters{ + P: w.P, + Q: w.Q, + G: w.G, + } + if err := checkDSAParams(¶m); err != nil { + return nil, nil, err + } + + key := &dsaPublicKey{ + Parameters: param, + Y: w.Y, + } + return key, w.Rest, nil +} + +func (k *dsaPublicKey) Marshal() []byte { + // DSA publickey struct layout should match the struct used by + // parseDSACert in the x/crypto/ssh/agent package. + w := struct { + Name string + P, Q, G, Y *big.Int + }{ + k.Type(), + k.P, + k.Q, + k.G, + k.Y, + } + + return Marshal(&w) +} + +func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + h := hashFuncs[sig.Format].New() + h.Write(data) + digest := h.Sum(nil) + + // Per RFC 4253, section 6.6, + // The value for 'dss_signature_blob' is encoded as a string containing + // r, followed by s (which are 160-bit integers, without lengths or + // padding, unsigned, and in network byte order). + // For DSS purposes, sig.Blob should be exactly 40 bytes in length. + if len(sig.Blob) != 40 { + return errors.New("ssh: DSA signature parse error") + } + r := new(big.Int).SetBytes(sig.Blob[:20]) + s := new(big.Int).SetBytes(sig.Blob[20:]) + if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +func (k *dsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*dsa.PublicKey)(k) +} + +type dsaPrivateKey struct { + *dsa.PrivateKey +} + +func (k *dsaPrivateKey) PublicKey() PublicKey { + return (*dsaPublicKey)(&k.PrivateKey.PublicKey) +} + +func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { + return k.SignWithAlgorithm(rand, data, k.PublicKey().Type()) +} + +func (k *dsaPrivateKey) Algorithms() []string { + return []string{k.PublicKey().Type()} +} + +func (k *dsaPrivateKey) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + if algorithm != "" && algorithm != k.PublicKey().Type() { + return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm) + } + + h := hashFuncs[k.PublicKey().Type()].New() + h.Write(data) + digest := h.Sum(nil) + r, s, err := dsa.Sign(rand, k.PrivateKey, digest) + if err != nil { + return nil, err + } + + sig := make([]byte, 40) + rb := r.Bytes() + sb := s.Bytes() + + copy(sig[20-len(rb):20], rb) + copy(sig[40-len(sb):], sb) + + return &Signature{ + Format: k.PublicKey().Type(), + Blob: sig, + }, nil +} + +type ecdsaPublicKey ecdsa.PublicKey + +func (k *ecdsaPublicKey) Type() string { + return "ecdsa-sha2-" + k.nistID() +} + +func (k *ecdsaPublicKey) nistID() string { + switch k.Params().BitSize { + case 256: + return "nistp256" + case 384: + return "nistp384" + case 521: + return "nistp521" + } + panic("ssh: unsupported ecdsa key size") +} + +type ed25519PublicKey ed25519.PublicKey + +func (k ed25519PublicKey) Type() string { + return KeyAlgoED25519 +} + +func parseED25519(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + KeyBytes []byte + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + if l := len(w.KeyBytes); l != ed25519.PublicKeySize { + return nil, nil, fmt.Errorf("invalid size %d for Ed25519 public key", l) + } + + return ed25519PublicKey(w.KeyBytes), w.Rest, nil +} + +func (k ed25519PublicKey) Marshal() []byte { + w := struct { + Name string + KeyBytes []byte + }{ + KeyAlgoED25519, + []byte(k), + } + return Marshal(&w) +} + +func (k ed25519PublicKey) Verify(b []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + if l := len(k); l != ed25519.PublicKeySize { + return fmt.Errorf("ssh: invalid size %d for Ed25519 public key", l) + } + + if ok := ed25519.Verify(ed25519.PublicKey(k), b, sig.Blob); !ok { + return errors.New("ssh: signature did not verify") + } + + return nil +} + +func (k ed25519PublicKey) CryptoPublicKey() crypto.PublicKey { + return ed25519.PublicKey(k) +} + +func supportedEllipticCurve(curve elliptic.Curve) bool { + return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521() +} + +// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1. +func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + Curve string + KeyBytes []byte + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + key := new(ecdsa.PublicKey) + + switch w.Curve { + case "nistp256": + key.Curve = elliptic.P256() + case "nistp384": + key.Curve = elliptic.P384() + case "nistp521": + key.Curve = elliptic.P521() + default: + return nil, nil, errors.New("ssh: unsupported curve") + } + + key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes) + if key.X == nil || key.Y == nil { + return nil, nil, errors.New("ssh: invalid curve point") + } + return (*ecdsaPublicKey)(key), w.Rest, nil +} + +func (k *ecdsaPublicKey) Marshal() []byte { + // See RFC 5656, section 3.1. + keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y) + // ECDSA publickey struct layout should match the struct used by + // parseECDSACert in the x/crypto/ssh/agent package. + w := struct { + Name string + ID string + Key []byte + }{ + k.Type(), + k.nistID(), + keyBytes, + } + + return Marshal(&w) +} + +func (k *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + + h := hashFuncs[sig.Format].New() + h.Write(data) + digest := h.Sum(nil) + + // Per RFC 5656, section 3.1.2, + // The ecdsa_signature_blob value has the following specific encoding: + // mpint r + // mpint s + var ecSig struct { + R *big.Int + S *big.Int + } + + if err := Unmarshal(sig.Blob, &ecSig); err != nil { + return err + } + + if ecdsa.Verify((*ecdsa.PublicKey)(k), digest, ecSig.R, ecSig.S) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +func (k *ecdsaPublicKey) CryptoPublicKey() crypto.PublicKey { + return (*ecdsa.PublicKey)(k) +} + +// skFields holds the additional fields present in U2F/FIDO2 signatures. +// See openssh/PROTOCOL.u2f 'SSH U2F Signatures' for details. +type skFields struct { + // Flags contains U2F/FIDO2 flags such as 'user present' + Flags byte + // Counter is a monotonic signature counter which can be + // used to detect concurrent use of a private key, should + // it be extracted from hardware. + Counter uint32 +} + +type skECDSAPublicKey struct { + // application is a URL-like string, typically "ssh:" for SSH. + // see openssh/PROTOCOL.u2f for details. + application string + ecdsa.PublicKey +} + +func (k *skECDSAPublicKey) Type() string { + return KeyAlgoSKECDSA256 +} + +func (k *skECDSAPublicKey) nistID() string { + return "nistp256" +} + +func parseSKECDSA(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + Curve string + KeyBytes []byte + Application string + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + key := new(skECDSAPublicKey) + key.application = w.Application + + if w.Curve != "nistp256" { + return nil, nil, errors.New("ssh: unsupported curve") + } + key.Curve = elliptic.P256() + + key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes) + if key.X == nil || key.Y == nil { + return nil, nil, errors.New("ssh: invalid curve point") + } + + return key, w.Rest, nil +} + +func (k *skECDSAPublicKey) Marshal() []byte { + // See RFC 5656, section 3.1. + keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y) + w := struct { + Name string + ID string + Key []byte + Application string + }{ + k.Type(), + k.nistID(), + keyBytes, + k.application, + } + + return Marshal(&w) +} + +func (k *skECDSAPublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + + h := hashFuncs[sig.Format].New() + h.Write([]byte(k.application)) + appDigest := h.Sum(nil) + + h.Reset() + h.Write(data) + dataDigest := h.Sum(nil) + + var ecSig struct { + R *big.Int + S *big.Int + } + if err := Unmarshal(sig.Blob, &ecSig); err != nil { + return err + } + + var skf skFields + if err := Unmarshal(sig.Rest, &skf); err != nil { + return err + } + + blob := struct { + ApplicationDigest []byte `ssh:"rest"` + Flags byte + Counter uint32 + MessageDigest []byte `ssh:"rest"` + }{ + appDigest, + skf.Flags, + skf.Counter, + dataDigest, + } + + original := Marshal(blob) + + h.Reset() + h.Write(original) + digest := h.Sum(nil) + + if ecdsa.Verify((*ecdsa.PublicKey)(&k.PublicKey), digest, ecSig.R, ecSig.S) { + return nil + } + return errors.New("ssh: signature did not verify") +} + +func (k *skECDSAPublicKey) CryptoPublicKey() crypto.PublicKey { + return &k.PublicKey +} + +type skEd25519PublicKey struct { + // application is a URL-like string, typically "ssh:" for SSH. + // see openssh/PROTOCOL.u2f for details. + application string + ed25519.PublicKey +} + +func (k *skEd25519PublicKey) Type() string { + return KeyAlgoSKED25519 +} + +func parseSKEd25519(in []byte) (out PublicKey, rest []byte, err error) { + var w struct { + KeyBytes []byte + Application string + Rest []byte `ssh:"rest"` + } + + if err := Unmarshal(in, &w); err != nil { + return nil, nil, err + } + + if l := len(w.KeyBytes); l != ed25519.PublicKeySize { + return nil, nil, fmt.Errorf("invalid size %d for Ed25519 public key", l) + } + + key := new(skEd25519PublicKey) + key.application = w.Application + key.PublicKey = ed25519.PublicKey(w.KeyBytes) + + return key, w.Rest, nil +} + +func (k *skEd25519PublicKey) Marshal() []byte { + w := struct { + Name string + KeyBytes []byte + Application string + }{ + KeyAlgoSKED25519, + []byte(k.PublicKey), + k.application, + } + return Marshal(&w) +} + +func (k *skEd25519PublicKey) Verify(data []byte, sig *Signature) error { + if sig.Format != k.Type() { + return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) + } + if l := len(k.PublicKey); l != ed25519.PublicKeySize { + return fmt.Errorf("invalid size %d for Ed25519 public key", l) + } + + h := hashFuncs[sig.Format].New() + h.Write([]byte(k.application)) + appDigest := h.Sum(nil) + + h.Reset() + h.Write(data) + dataDigest := h.Sum(nil) + + var edSig struct { + Signature []byte `ssh:"rest"` + } + + if err := Unmarshal(sig.Blob, &edSig); err != nil { + return err + } + + var skf skFields + if err := Unmarshal(sig.Rest, &skf); err != nil { + return err + } + + blob := struct { + ApplicationDigest []byte `ssh:"rest"` + Flags byte + Counter uint32 + MessageDigest []byte `ssh:"rest"` + }{ + appDigest, + skf.Flags, + skf.Counter, + dataDigest, + } + + original := Marshal(blob) + + if ok := ed25519.Verify(k.PublicKey, original, edSig.Signature); !ok { + return errors.New("ssh: signature did not verify") + } + + return nil +} + +func (k *skEd25519PublicKey) CryptoPublicKey() crypto.PublicKey { + return k.PublicKey +} + +// NewSignerFromKey takes an *rsa.PrivateKey, *dsa.PrivateKey, +// *ecdsa.PrivateKey or any other crypto.Signer and returns a +// corresponding Signer instance. ECDSA keys must use P-256, P-384 or +// P-521. DSA keys must use parameter size L1024N160. +func NewSignerFromKey(key interface{}) (Signer, error) { + switch key := key.(type) { + case crypto.Signer: + return NewSignerFromSigner(key) + case *dsa.PrivateKey: + return newDSAPrivateKey(key) + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", key) + } +} + +func newDSAPrivateKey(key *dsa.PrivateKey) (Signer, error) { + if err := checkDSAParams(&key.PublicKey.Parameters); err != nil { + return nil, err + } + + return &dsaPrivateKey{key}, nil +} + +type wrappedSigner struct { + signer crypto.Signer + pubKey PublicKey +} + +// NewSignerFromSigner takes any crypto.Signer implementation and +// returns a corresponding Signer interface. This can be used, for +// example, with keys kept in hardware modules. +func NewSignerFromSigner(signer crypto.Signer) (Signer, error) { + pubKey, err := NewPublicKey(signer.Public()) + if err != nil { + return nil, err + } + + return &wrappedSigner{signer, pubKey}, nil +} + +func (s *wrappedSigner) PublicKey() PublicKey { + return s.pubKey +} + +func (s *wrappedSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { + return s.SignWithAlgorithm(rand, data, s.pubKey.Type()) +} + +func (s *wrappedSigner) Algorithms() []string { + return algorithmsForKeyFormat(s.pubKey.Type()) +} + +func (s *wrappedSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) { + if algorithm == "" { + algorithm = s.pubKey.Type() + } + + if !contains(s.Algorithms(), algorithm) { + return nil, fmt.Errorf("ssh: unsupported signature algorithm %q for key format %q", algorithm, s.pubKey.Type()) + } + + hashFunc := hashFuncs[algorithm] + var digest []byte + if hashFunc != 0 { + h := hashFunc.New() + h.Write(data) + digest = h.Sum(nil) + } else { + digest = data + } + + signature, err := s.signer.Sign(rand, digest, hashFunc) + if err != nil { + return nil, err + } + + // crypto.Signer.Sign is expected to return an ASN.1-encoded signature + // for ECDSA and DSA, but that's not the encoding expected by SSH, so + // re-encode. + switch s.pubKey.(type) { + case *ecdsaPublicKey, *dsaPublicKey: + type asn1Signature struct { + R, S *big.Int + } + asn1Sig := new(asn1Signature) + _, err := asn1.Unmarshal(signature, asn1Sig) + if err != nil { + return nil, err + } + + switch s.pubKey.(type) { + case *ecdsaPublicKey: + signature = Marshal(asn1Sig) + + case *dsaPublicKey: + signature = make([]byte, 40) + r := asn1Sig.R.Bytes() + s := asn1Sig.S.Bytes() + copy(signature[20-len(r):20], r) + copy(signature[40-len(s):40], s) + } + } + + return &Signature{ + Format: algorithm, + Blob: signature, + }, nil +} + +// NewPublicKey takes an *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey, +// or ed25519.PublicKey returns a corresponding PublicKey instance. +// ECDSA keys must use P-256, P-384 or P-521. +func NewPublicKey(key interface{}) (PublicKey, error) { + switch key := key.(type) { + case *rsa.PublicKey: + return (*rsaPublicKey)(key), nil + case *ecdsa.PublicKey: + if !supportedEllipticCurve(key.Curve) { + return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported") + } + return (*ecdsaPublicKey)(key), nil + case *dsa.PublicKey: + return (*dsaPublicKey)(key), nil + case ed25519.PublicKey: + if l := len(key); l != ed25519.PublicKeySize { + return nil, fmt.Errorf("ssh: invalid size %d for Ed25519 public key", l) + } + return ed25519PublicKey(key), nil + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", key) + } +} + +// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports +// the same keys as ParseRawPrivateKey. If the private key is encrypted, it +// will return a PassphraseMissingError. +func ParsePrivateKey(pemBytes []byte) (Signer, error) { + key, err := ParseRawPrivateKey(pemBytes) + if err != nil { + return nil, err + } + + return NewSignerFromKey(key) +} + +// encryptedBlock tells whether a private key is +// encrypted by examining its Proc-Type header +// for a mention of ENCRYPTED +// according to RFC 1421 Section 4.6.1.1. +func encryptedBlock(block *pem.Block) bool { + return strings.Contains(block.Headers["Proc-Type"], "ENCRYPTED") +} + +// A PassphraseMissingError indicates that parsing this private key requires a +// passphrase. Use ParsePrivateKeyWithPassphrase. +type PassphraseMissingError struct { + // PublicKey will be set if the private key format includes an unencrypted + // public key along with the encrypted private key. + PublicKey PublicKey +} + +func (*PassphraseMissingError) Error() string { + return "ssh: this private key is passphrase protected" +} + +// ParseRawPrivateKey returns a private key from a PEM encoded private key. It supports +// RSA, DSA, ECDSA, and Ed25519 private keys in PKCS#1, PKCS#8, OpenSSL, and OpenSSH +// formats. If the private key is encrypted, it will return a PassphraseMissingError. +func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, errors.New("ssh: no key found") + } + + if encryptedBlock(block) { + return nil, &PassphraseMissingError{} + } + + switch block.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(block.Bytes) + // RFC5208 - https://tools.ietf.org/html/rfc5208 + case "PRIVATE KEY": + return x509.ParsePKCS8PrivateKey(block.Bytes) + case "EC PRIVATE KEY": + return x509.ParseECPrivateKey(block.Bytes) + case "DSA PRIVATE KEY": + return ParseDSAPrivateKey(block.Bytes) + case "OPENSSH PRIVATE KEY": + return parseOpenSSHPrivateKey(block.Bytes, unencryptedOpenSSHKey) + default: + return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type) + } +} + +// ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as +// specified by the OpenSSL DSA man page. +func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { + var k struct { + Version int + P *big.Int + Q *big.Int + G *big.Int + Pub *big.Int + Priv *big.Int + } + rest, err := asn1.Unmarshal(der, &k) + if err != nil { + return nil, errors.New("ssh: failed to parse DSA key: " + err.Error()) + } + if len(rest) > 0 { + return nil, errors.New("ssh: garbage after DSA key") + } + + return &dsa.PrivateKey{ + PublicKey: dsa.PublicKey{ + Parameters: dsa.Parameters{ + P: k.P, + Q: k.Q, + G: k.G, + }, + Y: k.Pub, + }, + X: k.Priv, + }, nil +} + +func unencryptedOpenSSHKey(cipherName, kdfName, kdfOpts string, privKeyBlock []byte) ([]byte, error) { + if kdfName != "none" || cipherName != "none" { + return nil, &PassphraseMissingError{} + } + if kdfOpts != "" { + return nil, errors.New("ssh: invalid openssh private key") + } + return privKeyBlock, nil +} + +func unencryptedOpenSSHMarshaler(privKeyBlock []byte) ([]byte, string, string, string, error) { + key := generateOpenSSHPadding(privKeyBlock, 8) + return key, "none", "none", "", nil +} + +const privateKeyAuthMagic = "openssh-key-v1\x00" + +type openSSHDecryptFunc func(CipherName, KdfName, KdfOpts string, PrivKeyBlock []byte) ([]byte, error) +type openSSHEncryptFunc func(PrivKeyBlock []byte) (ProtectedKeyBlock []byte, cipherName, kdfName, kdfOptions string, err error) + +type openSSHEncryptedPrivateKey struct { + CipherName string + KdfName string + KdfOpts string + NumKeys uint32 + PubKey []byte + PrivKeyBlock []byte +} + +type openSSHPrivateKey struct { + Check1 uint32 + Check2 uint32 + Keytype string + Rest []byte `ssh:"rest"` +} + +type openSSHRSAPrivateKey struct { + N *big.Int + E *big.Int + D *big.Int + Iqmp *big.Int + P *big.Int + Q *big.Int + Comment string + Pad []byte `ssh:"rest"` +} + +type openSSHEd25519PrivateKey struct { + Pub []byte + Priv []byte + Comment string + Pad []byte `ssh:"rest"` +} + +type openSSHECDSAPrivateKey struct { + Curve string + Pub []byte + D *big.Int + Comment string + Pad []byte `ssh:"rest"` +} + +// parseOpenSSHPrivateKey parses an OpenSSH private key, using the decrypt +// function to unwrap the encrypted portion. unencryptedOpenSSHKey can be used +// as the decrypt function to parse an unencrypted private key. See +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key. +func parseOpenSSHPrivateKey(key []byte, decrypt openSSHDecryptFunc) (crypto.PrivateKey, error) { + if len(key) < len(privateKeyAuthMagic) || string(key[:len(privateKeyAuthMagic)]) != privateKeyAuthMagic { + return nil, errors.New("ssh: invalid openssh private key format") + } + remaining := key[len(privateKeyAuthMagic):] + + var w openSSHEncryptedPrivateKey + if err := Unmarshal(remaining, &w); err != nil { + return nil, err + } + if w.NumKeys != 1 { + // We only support single key files, and so does OpenSSH. + // https://github.com/openssh/openssh-portable/blob/4103a3ec7/sshkey.c#L4171 + return nil, errors.New("ssh: multi-key files are not supported") + } + + privKeyBlock, err := decrypt(w.CipherName, w.KdfName, w.KdfOpts, w.PrivKeyBlock) + if err != nil { + if err, ok := err.(*PassphraseMissingError); ok { + pub, errPub := ParsePublicKey(w.PubKey) + if errPub != nil { + return nil, fmt.Errorf("ssh: failed to parse embedded public key: %v", errPub) + } + err.PublicKey = pub + } + return nil, err + } + + var pk1 openSSHPrivateKey + if err := Unmarshal(privKeyBlock, &pk1); err != nil || pk1.Check1 != pk1.Check2 { + if w.CipherName != "none" { + return nil, x509.IncorrectPasswordError + } + return nil, errors.New("ssh: malformed OpenSSH key") + } + + switch pk1.Keytype { + case KeyAlgoRSA: + var key openSSHRSAPrivateKey + if err := Unmarshal(pk1.Rest, &key); err != nil { + return nil, err + } + + if err := checkOpenSSHKeyPadding(key.Pad); err != nil { + return nil, err + } + + pk := &rsa.PrivateKey{ + PublicKey: rsa.PublicKey{ + N: key.N, + E: int(key.E.Int64()), + }, + D: key.D, + Primes: []*big.Int{key.P, key.Q}, + } + + if err := pk.Validate(); err != nil { + return nil, err + } + + pk.Precompute() + + return pk, nil + case KeyAlgoED25519: + var key openSSHEd25519PrivateKey + if err := Unmarshal(pk1.Rest, &key); err != nil { + return nil, err + } + + if len(key.Priv) != ed25519.PrivateKeySize { + return nil, errors.New("ssh: private key unexpected length") + } + + if err := checkOpenSSHKeyPadding(key.Pad); err != nil { + return nil, err + } + + pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize)) + copy(pk, key.Priv) + return &pk, nil + case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: + var key openSSHECDSAPrivateKey + if err := Unmarshal(pk1.Rest, &key); err != nil { + return nil, err + } + + if err := checkOpenSSHKeyPadding(key.Pad); err != nil { + return nil, err + } + + var curve elliptic.Curve + switch key.Curve { + case "nistp256": + curve = elliptic.P256() + case "nistp384": + curve = elliptic.P384() + case "nistp521": + curve = elliptic.P521() + default: + return nil, errors.New("ssh: unhandled elliptic curve: " + key.Curve) + } + + X, Y := elliptic.Unmarshal(curve, key.Pub) + if X == nil || Y == nil { + return nil, errors.New("ssh: failed to unmarshal public key") + } + + if key.D.Cmp(curve.Params().N) >= 0 { + return nil, errors.New("ssh: scalar is out of range") + } + + x, y := curve.ScalarBaseMult(key.D.Bytes()) + if x.Cmp(X) != 0 || y.Cmp(Y) != 0 { + return nil, errors.New("ssh: public key does not match private key") + } + + return &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{ + Curve: curve, + X: X, + Y: Y, + }, + D: key.D, + }, nil + default: + return nil, errors.New("ssh: unhandled key type") + } +} + +func marshalOpenSSHPrivateKey(key crypto.PrivateKey, comment string, encrypt openSSHEncryptFunc) (*pem.Block, error) { + var w openSSHEncryptedPrivateKey + var pk1 openSSHPrivateKey + + // Random check bytes. + var check uint32 + if err := binary.Read(rand.Reader, binary.BigEndian, &check); err != nil { + return nil, err + } + + pk1.Check1 = check + pk1.Check2 = check + w.NumKeys = 1 + + // Use a []byte directly on ed25519 keys. + if k, ok := key.(*ed25519.PrivateKey); ok { + key = *k + } + + switch k := key.(type) { + case *rsa.PrivateKey: + E := new(big.Int).SetInt64(int64(k.PublicKey.E)) + // Marshal public key: + // E and N are in reversed order in the public and private key. + pubKey := struct { + KeyType string + E *big.Int + N *big.Int + }{ + KeyAlgoRSA, + E, k.PublicKey.N, + } + w.PubKey = Marshal(pubKey) + + // Marshal private key. + key := openSSHRSAPrivateKey{ + N: k.PublicKey.N, + E: E, + D: k.D, + Iqmp: k.Precomputed.Qinv, + P: k.Primes[0], + Q: k.Primes[1], + Comment: comment, + } + pk1.Keytype = KeyAlgoRSA + pk1.Rest = Marshal(key) + case ed25519.PrivateKey: + pub := make([]byte, ed25519.PublicKeySize) + priv := make([]byte, ed25519.PrivateKeySize) + copy(pub, k[32:]) + copy(priv, k) + + // Marshal public key. + pubKey := struct { + KeyType string + Pub []byte + }{ + KeyAlgoED25519, pub, + } + w.PubKey = Marshal(pubKey) + + // Marshal private key. + key := openSSHEd25519PrivateKey{ + Pub: pub, + Priv: priv, + Comment: comment, + } + pk1.Keytype = KeyAlgoED25519 + pk1.Rest = Marshal(key) + case *ecdsa.PrivateKey: + var curve, keyType string + switch name := k.Curve.Params().Name; name { + case "P-256": + curve = "nistp256" + keyType = KeyAlgoECDSA256 + case "P-384": + curve = "nistp384" + keyType = KeyAlgoECDSA384 + case "P-521": + curve = "nistp521" + keyType = KeyAlgoECDSA521 + default: + return nil, errors.New("ssh: unhandled elliptic curve " + name) + } + + pub := elliptic.Marshal(k.Curve, k.PublicKey.X, k.PublicKey.Y) + + // Marshal public key. + pubKey := struct { + KeyType string + Curve string + Pub []byte + }{ + keyType, curve, pub, + } + w.PubKey = Marshal(pubKey) + + // Marshal private key. + key := openSSHECDSAPrivateKey{ + Curve: curve, + Pub: pub, + D: k.D, + Comment: comment, + } + pk1.Keytype = keyType + pk1.Rest = Marshal(key) + default: + return nil, fmt.Errorf("ssh: unsupported key type %T", k) + } + + var err error + // Add padding and encrypt the key if necessary. + w.PrivKeyBlock, w.CipherName, w.KdfName, w.KdfOpts, err = encrypt(Marshal(pk1)) + if err != nil { + return nil, err + } + + b := Marshal(w) + block := &pem.Block{ + Type: "OPENSSH PRIVATE KEY", + Bytes: append([]byte(privateKeyAuthMagic), b...), + } + return block, nil +} + +func checkOpenSSHKeyPadding(pad []byte) error { + for i, b := range pad { + if int(b) != i+1 { + return errors.New("ssh: padding not as expected") + } + } + return nil +} + +func generateOpenSSHPadding(block []byte, blockSize int) []byte { + for i, l := 0, len(block); (l+i)%blockSize != 0; i++ { + block = append(block, byte(i+1)) + } + return block +} + +// FingerprintLegacyMD5 returns the user presentation of the key's +// fingerprint as described by RFC 4716 section 4. +func FingerprintLegacyMD5(pubKey PublicKey) string { + md5sum := md5.Sum(pubKey.Marshal()) + hexarray := make([]string, len(md5sum)) + for i, c := range md5sum { + hexarray[i] = hex.EncodeToString([]byte{c}) + } + return strings.Join(hexarray, ":") +} + +// FingerprintSHA256 returns the user presentation of the key's +// fingerprint as unpadded base64 encoded sha256 hash. +// This format was introduced from OpenSSH 6.8. +// https://www.openssh.com/txt/release-6.8 +// https://tools.ietf.org/html/rfc4648#section-3.2 (unpadded base64 encoding) +func FingerprintSHA256(pubKey PublicKey) string { + sha256sum := sha256.Sum256(pubKey.Marshal()) + hash := base64.RawStdEncoding.EncodeToString(sha256sum[:]) + return "SHA256:" + hash +} diff --git a/tempfork/sshtest/ssh/keys_test.go b/tempfork/sshtest/ssh/keys_test.go new file mode 100644 index 0000000000000..bf1f0be1b2de3 --- /dev/null +++ b/tempfork/sshtest/ssh/keys_test.go @@ -0,0 +1,724 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/dsa" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "encoding/base64" + "encoding/hex" + "encoding/pem" + "fmt" + "io" + "reflect" + "strings" + "testing" + + "golang.org/x/crypto/ssh/testdata" +) + +func rawKey(pub PublicKey) interface{} { + switch k := pub.(type) { + case *rsaPublicKey: + return (*rsa.PublicKey)(k) + case *dsaPublicKey: + return (*dsa.PublicKey)(k) + case *ecdsaPublicKey: + return (*ecdsa.PublicKey)(k) + case ed25519PublicKey: + return (ed25519.PublicKey)(k) + case *Certificate: + return k + } + panic("unknown key type") +} + +func TestKeyMarshalParse(t *testing.T) { + for _, priv := range testSigners { + pub := priv.PublicKey() + roundtrip, err := ParsePublicKey(pub.Marshal()) + if err != nil { + t.Errorf("ParsePublicKey(%T): %v", pub, err) + } + + k1 := rawKey(pub) + k2 := rawKey(roundtrip) + + if !reflect.DeepEqual(k1, k2) { + t.Errorf("got %#v in roundtrip, want %#v", k2, k1) + } + } +} + +func TestUnsupportedCurves(t *testing.T) { + raw, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + if err != nil { + t.Fatalf("GenerateKey: %v", err) + } + + if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P-256") { + t.Fatalf("NewPrivateKey should not succeed with P-224, got: %v", err) + } + + if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P-256") { + t.Fatalf("NewPublicKey should not succeed with P-224, got: %v", err) + } +} + +func TestNewPublicKey(t *testing.T) { + for _, k := range testSigners { + raw := rawKey(k.PublicKey()) + // Skip certificates, as NewPublicKey does not support them. + if _, ok := raw.(*Certificate); ok { + continue + } + pub, err := NewPublicKey(raw) + if err != nil { + t.Errorf("NewPublicKey(%#v): %v", raw, err) + } + if !reflect.DeepEqual(k.PublicKey(), pub) { + t.Errorf("NewPublicKey(%#v) = %#v, want %#v", raw, pub, k.PublicKey()) + } + } +} + +func TestKeySignVerify(t *testing.T) { + for _, priv := range testSigners { + pub := priv.PublicKey() + + data := []byte("sign me") + sig, err := priv.Sign(rand.Reader, data) + if err != nil { + t.Fatalf("Sign(%T): %v", priv, err) + } + + if err := pub.Verify(data, sig); err != nil { + t.Errorf("publicKey.Verify(%T): %v", priv, err) + } + sig.Blob[5]++ + if err := pub.Verify(data, sig); err == nil { + t.Errorf("publicKey.Verify on broken sig did not fail") + } + } +} + +func TestKeySignWithAlgorithmVerify(t *testing.T) { + for k, priv := range testSigners { + if algorithmSigner, ok := priv.(MultiAlgorithmSigner); !ok { + t.Errorf("Signers %q constructed by ssh package should always implement the MultiAlgorithmSigner interface: %T", k, priv) + } else { + pub := priv.PublicKey() + data := []byte("sign me") + + signWithAlgTestCase := func(algorithm string, expectedAlg string) { + sig, err := algorithmSigner.SignWithAlgorithm(rand.Reader, data, algorithm) + if err != nil { + t.Fatalf("Sign(%T): %v", priv, err) + } + if sig.Format != expectedAlg { + t.Errorf("signature format did not match requested signature algorithm: %s != %s", sig.Format, expectedAlg) + } + + if err := pub.Verify(data, sig); err != nil { + t.Errorf("publicKey.Verify(%T): %v", priv, err) + } + sig.Blob[5]++ + if err := pub.Verify(data, sig); err == nil { + t.Errorf("publicKey.Verify on broken sig did not fail") + } + } + + // Using the empty string as the algorithm name should result in the same signature format as the algorithm-free Sign method. + defaultSig, err := priv.Sign(rand.Reader, data) + if err != nil { + t.Fatalf("Sign(%T): %v", priv, err) + } + signWithAlgTestCase("", defaultSig.Format) + + // RSA keys are the only ones which currently support more than one signing algorithm + if pub.Type() == KeyAlgoRSA { + for _, algorithm := range []string{KeyAlgoRSA, KeyAlgoRSASHA256, KeyAlgoRSASHA512} { + signWithAlgTestCase(algorithm, algorithm) + } + } + } + } +} + +func TestKeySignWithShortSignature(t *testing.T) { + signer := testSigners["rsa"].(AlgorithmSigner) + pub := signer.PublicKey() + // Note: data obtained by empirically trying until a result + // starting with 0 appeared + tests := []struct { + algorithm string + data []byte + }{ + { + algorithm: KeyAlgoRSA, + data: []byte("sign me92"), + }, + { + algorithm: KeyAlgoRSASHA256, + data: []byte("sign me294"), + }, + { + algorithm: KeyAlgoRSASHA512, + data: []byte("sign me60"), + }, + } + + for _, tt := range tests { + sig, err := signer.SignWithAlgorithm(rand.Reader, tt.data, tt.algorithm) + if err != nil { + t.Fatalf("Sign(%T): %v", signer, err) + } + if sig.Blob[0] != 0 { + t.Errorf("%s: Expected signature with a leading 0", tt.algorithm) + } + sig.Blob = sig.Blob[1:] + if err := pub.Verify(tt.data, sig); err != nil { + t.Errorf("publicKey.Verify(%s): %v", tt.algorithm, err) + } + } +} + +func TestParseRSAPrivateKey(t *testing.T) { + key := testPrivateKeys["rsa"] + + rsa, ok := key.(*rsa.PrivateKey) + if !ok { + t.Fatalf("got %T, want *rsa.PrivateKey", rsa) + } + + if err := rsa.Validate(); err != nil { + t.Errorf("Validate: %v", err) + } +} + +func TestParseECPrivateKey(t *testing.T) { + key := testPrivateKeys["ecdsa"] + + ecKey, ok := key.(*ecdsa.PrivateKey) + if !ok { + t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey) + } + + if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) { + t.Fatalf("public key does not validate.") + } +} + +func TestParseDSA(t *testing.T) { + // We actually exercise the ParsePrivateKey codepath here, as opposed to + // using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go + // uses. + s, err := ParsePrivateKey(testdata.PEMBytes["dsa"]) + if err != nil { + t.Fatalf("ParsePrivateKey returned error: %s", err) + } + + data := []byte("sign me") + sig, err := s.Sign(rand.Reader, data) + if err != nil { + t.Fatalf("dsa.Sign: %v", err) + } + + if err := s.PublicKey().Verify(data, sig); err != nil { + t.Errorf("Verify failed: %v", err) + } +} + +// Tests for authorized_keys parsing. + +// getTestKey returns a public key, and its base64 encoding. +func getTestKey() (PublicKey, string) { + k := testPublicKeys["rsa"] + + b := &bytes.Buffer{} + e := base64.NewEncoder(base64.StdEncoding, b) + e.Write(k.Marshal()) + e.Close() + + return k, b.String() +} + +func TestMarshalParsePublicKey(t *testing.T) { + pub, pubSerialized := getTestKey() + line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized) + + authKeys := MarshalAuthorizedKey(pub) + actualFields := strings.Fields(string(authKeys)) + if len(actualFields) == 0 { + t.Fatalf("failed authKeys: %v", authKeys) + } + + // drop the comment + expectedFields := strings.Fields(line)[0:2] + + if !reflect.DeepEqual(actualFields, expectedFields) { + t.Errorf("got %v, expected %v", actualFields, expectedFields) + } + + actPub, _, _, _, err := ParseAuthorizedKey([]byte(line)) + if err != nil { + t.Fatalf("cannot parse %v: %v", line, err) + } + if !reflect.DeepEqual(actPub, pub) { + t.Errorf("got %v, expected %v", actPub, pub) + } +} + +func TestMarshalPrivateKey(t *testing.T) { + tests := []struct { + name string + }{ + {"rsa-openssh-format"}, + {"ed25519"}, + {"p256-openssh-format"}, + {"p384-openssh-format"}, + {"p521-openssh-format"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expected, ok := testPrivateKeys[tt.name] + if !ok { + t.Fatalf("cannot find key %s", tt.name) + } + + block, err := MarshalPrivateKey(expected, "test@golang.org") + if err != nil { + t.Fatalf("cannot marshal %s: %v", tt.name, err) + } + + key, err := ParseRawPrivateKey(pem.EncodeToMemory(block)) + if err != nil { + t.Fatalf("cannot parse %s: %v", tt.name, err) + } + + if !reflect.DeepEqual(expected, key) { + t.Errorf("unexpected marshaled key %s", tt.name) + } + }) + } +} + +type testAuthResult struct { + pubKey PublicKey + options []string + comments string + rest string + ok bool +} + +func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []testAuthResult) { + rest := authKeys + var values []testAuthResult + for len(rest) > 0 { + var r testAuthResult + var err error + r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest) + r.ok = (err == nil) + t.Log(err) + r.rest = string(rest) + values = append(values, r) + } + + if !reflect.DeepEqual(values, expected) { + t.Errorf("got %#v, expected %#v", values, expected) + } +} + +func TestAuthorizedKeyBasic(t *testing.T) { + pub, pubSerialized := getTestKey() + line := "ssh-rsa " + pubSerialized + " user@host" + testAuthorizedKeys(t, []byte(line), + []testAuthResult{ + {pub, nil, "user@host", "", true}, + }) +} + +func TestAuth(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithOptions := []string{ + `# comments to ignore before any keys...`, + ``, + `env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`, + `# comments to ignore, along with a blank line`, + ``, + `env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`, + ``, + `# more comments, plus a invalid entry`, + `ssh-rsa data-that-will-not-parse user@host3`, + } + for _, eol := range []string{"\n", "\r\n"} { + authOptions := strings.Join(authWithOptions, eol) + rest2 := strings.Join(authWithOptions[3:], eol) + rest3 := strings.Join(authWithOptions[6:], eol) + testAuthorizedKeys(t, []byte(authOptions), []testAuthResult{ + {pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true}, + {pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true}, + {nil, nil, "", "", false}, + }) + } +} + +func TestAuthWithQuotedSpaceInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []testAuthResult{ + {pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) +} + +func TestAuthWithQuotedCommaInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []testAuthResult{ + {pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) +} + +func TestAuthWithQuotedQuoteInEnv(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`) + authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`) + testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []testAuthResult{ + {pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true}, + }) + + testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []testAuthResult{ + {pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true}, + }) +} + +func TestAuthWithInvalidSpace(t *testing.T) { + _, pubSerialized := getTestKey() + authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host +#more to follow but still no valid keys`) + testAuthorizedKeys(t, []byte(authWithInvalidSpace), []testAuthResult{ + {nil, nil, "", "", false}, + }) +} + +func TestAuthWithMissingQuote(t *testing.T) { + pub, pubSerialized := getTestKey() + authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host +env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`) + + testAuthorizedKeys(t, []byte(authWithMissingQuote), []testAuthResult{ + {pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true}, + }) +} + +func TestInvalidEntry(t *testing.T) { + authInvalid := []byte(`ssh-rsa`) + _, _, _, _, err := ParseAuthorizedKey(authInvalid) + if err == nil { + t.Errorf("got valid entry for %q", authInvalid) + } +} + +var knownHostsParseTests = []struct { + input string + err string + + marker string + comment string + hosts []string + rest string +}{ + { + "", + "EOF", + + "", "", nil, "", + }, + { + "# Just a comment", + "EOF", + + "", "", nil, "", + }, + { + " \t ", + "EOF", + + "", "", nil, "", + }, + { + "localhost ssh-rsa {RSAPUB}", + "", + + "", "", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}", + "", + + "", "", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment", + "", + + "", "comment comment", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment\n", + "", + + "", "comment comment", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment\r\n", + "", + + "", "comment comment", []string{"localhost"}, "", + }, + { + "localhost\tssh-rsa {RSAPUB}\tcomment comment\r\nnext line", + "", + + "", "comment comment", []string{"localhost"}, "next line", + }, + { + "localhost,[host2:123]\tssh-rsa {RSAPUB}\tcomment comment", + "", + + "", "comment comment", []string{"localhost", "[host2:123]"}, "", + }, + { + "@marker \tlocalhost,[host2:123]\tssh-rsa {RSAPUB}", + "", + + "marker", "", []string{"localhost", "[host2:123]"}, "", + }, + { + "@marker \tlocalhost,[host2:123]\tssh-rsa aabbccdd", + "short read", + + "", "", nil, "", + }, +} + +func TestKnownHostsParsing(t *testing.T) { + rsaPub, rsaPubSerialized := getTestKey() + + for i, test := range knownHostsParseTests { + var expectedKey PublicKey + const rsaKeyToken = "{RSAPUB}" + + input := test.input + if strings.Contains(input, rsaKeyToken) { + expectedKey = rsaPub + input = strings.Replace(test.input, rsaKeyToken, rsaPubSerialized, -1) + } + + marker, hosts, pubKey, comment, rest, err := ParseKnownHosts([]byte(input)) + if err != nil { + if len(test.err) == 0 { + t.Errorf("#%d: unexpectedly failed with %q", i, err) + } else if !strings.Contains(err.Error(), test.err) { + t.Errorf("#%d: expected error containing %q, but got %q", i, test.err, err) + } + continue + } else if len(test.err) != 0 { + t.Errorf("#%d: succeeded but expected error including %q", i, test.err) + continue + } + + if !reflect.DeepEqual(expectedKey, pubKey) { + t.Errorf("#%d: expected key %#v, but got %#v", i, expectedKey, pubKey) + } + + if marker != test.marker { + t.Errorf("#%d: expected marker %q, but got %q", i, test.marker, marker) + } + + if comment != test.comment { + t.Errorf("#%d: expected comment %q, but got %q", i, test.comment, comment) + } + + if !reflect.DeepEqual(test.hosts, hosts) { + t.Errorf("#%d: expected hosts %#v, but got %#v", i, test.hosts, hosts) + } + + if rest := string(rest); rest != test.rest { + t.Errorf("#%d: expected remaining input to be %q, but got %q", i, test.rest, rest) + } + } +} + +func TestFingerprintLegacyMD5(t *testing.T) { + pub, _ := getTestKey() + fingerprint := FingerprintLegacyMD5(pub) + want := "b7:ef:d3:d5:89:29:52:96:9f:df:47:41:4d:15:37:f4" // ssh-keygen -lf -E md5 rsa + if fingerprint != want { + t.Errorf("got fingerprint %q want %q", fingerprint, want) + } +} + +func TestFingerprintSHA256(t *testing.T) { + pub, _ := getTestKey() + fingerprint := FingerprintSHA256(pub) + want := "SHA256:fi5+D7UmDZDE9Q2sAVvvlpcQSIakN4DERdINgXd2AnE" // ssh-keygen -lf rsa + if fingerprint != want { + t.Errorf("got fingerprint %q want %q", fingerprint, want) + } +} + +func TestInvalidKeys(t *testing.T) { + keyTypes := []string{ + "RSA PRIVATE KEY", + "PRIVATE KEY", + "EC PRIVATE KEY", + "DSA PRIVATE KEY", + "OPENSSH PRIVATE KEY", + } + + for _, keyType := range keyTypes { + for _, dataLen := range []int{0, 1, 2, 5, 10, 20} { + data := make([]byte, dataLen) + if _, err := io.ReadFull(rand.Reader, data); err != nil { + t.Fatal(err) + } + + var buf bytes.Buffer + pem.Encode(&buf, &pem.Block{ + Type: keyType, + Bytes: data, + }) + + // This test is just to ensure that the function + // doesn't panic so the return value is ignored. + ParseRawPrivateKey(buf.Bytes()) + } + } +} + +func TestSKKeys(t *testing.T) { + for _, d := range testdata.SKData { + pk, _, _, _, err := ParseAuthorizedKey(d.PubKey) + if err != nil { + t.Fatalf("parseAuthorizedKey returned error: %v", err) + } + + sigBuf := make([]byte, hex.DecodedLen(len(d.HexSignature))) + if _, err := hex.Decode(sigBuf, d.HexSignature); err != nil { + t.Fatalf("hex.Decode() failed: %v", err) + } + + dataBuf := make([]byte, hex.DecodedLen(len(d.HexData))) + if _, err := hex.Decode(dataBuf, d.HexData); err != nil { + t.Fatalf("hex.Decode() failed: %v", err) + } + + sig, _, ok := parseSignature(sigBuf) + if !ok { + t.Fatalf("parseSignature(%v) failed", sigBuf) + } + + // Test that good data and signature pass verification + if err := pk.Verify(dataBuf, sig); err != nil { + t.Errorf("%s: PublicKey.Verify(%v, %v) failed: %v", d.Name, dataBuf, sig, err) + } + + // Invalid data being passed in + invalidData := []byte("INVALID DATA") + if err := pk.Verify(invalidData, sig); err == nil { + t.Errorf("%s with invalid data: PublicKey.Verify(%v, %v) passed unexpectedly", d.Name, invalidData, sig) + } + + // Change byte in blob to corrup signature + sig.Blob[5] = byte('A') + // Corrupted data being passed in + if err := pk.Verify(dataBuf, sig); err == nil { + t.Errorf("%s with corrupted signature: PublicKey.Verify(%v, %v) passed unexpectedly", d.Name, dataBuf, sig) + } + } +} + +func TestNewSignerWithAlgos(t *testing.T) { + algorithSigner, ok := testSigners["rsa"].(AlgorithmSigner) + if !ok { + t.Fatal("rsa test signer does not implement the AlgorithmSigner interface") + } + _, err := NewSignerWithAlgorithms(algorithSigner, nil) + if err == nil { + t.Error("signer with algos created with no algorithms") + } + + _, err = NewSignerWithAlgorithms(algorithSigner, []string{KeyAlgoED25519}) + if err == nil { + t.Error("signer with algos created with invalid algorithms") + } + + _, err = NewSignerWithAlgorithms(algorithSigner, []string{CertAlgoRSASHA256v01}) + if err == nil { + t.Error("signer with algos created with certificate algorithms") + } + + mas, err := NewSignerWithAlgorithms(algorithSigner, []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}) + if err != nil { + t.Errorf("unable to create signer with valid algorithms: %v", err) + } + + _, err = NewSignerWithAlgorithms(mas, []string{KeyAlgoRSA}) + if err == nil { + t.Error("signer with algos created with restricted algorithms") + } +} + +func TestCryptoPublicKey(t *testing.T) { + for _, priv := range testSigners { + p1 := priv.PublicKey() + key, ok := p1.(CryptoPublicKey) + if !ok { + continue + } + p2, err := NewPublicKey(key.CryptoPublicKey()) + if err != nil { + t.Fatalf("NewPublicKey(CryptoPublicKey) failed for %s, got: %v", p1.Type(), err) + } + if !reflect.DeepEqual(p1, p2) { + t.Errorf("got %#v in NewPublicKey, want %#v", p2, p1) + } + } + for _, d := range testdata.SKData { + p1, _, _, _, err := ParseAuthorizedKey(d.PubKey) + if err != nil { + t.Fatalf("parseAuthorizedKey returned error: %v", err) + } + k1, ok := p1.(CryptoPublicKey) + if !ok { + t.Fatalf("%T does not implement CryptoPublicKey", p1) + } + + var p2 PublicKey + switch pub := k1.CryptoPublicKey().(type) { + case *ecdsa.PublicKey: + p2 = &skECDSAPublicKey{ + application: "ssh:", + PublicKey: *pub, + } + case ed25519.PublicKey: + p2 = &skEd25519PublicKey{ + application: "ssh:", + PublicKey: pub, + } + default: + t.Fatalf("unexpected type %T from CryptoPublicKey()", pub) + } + if !reflect.DeepEqual(p1, p2) { + t.Errorf("got %#v, want %#v", p2, p1) + } + } +} diff --git a/tempfork/sshtest/ssh/mac.go b/tempfork/sshtest/ssh/mac.go new file mode 100644 index 0000000000000..06a1b27507ee3 --- /dev/null +++ b/tempfork/sshtest/ssh/mac.go @@ -0,0 +1,68 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Message authentication support + +import ( + "crypto/hmac" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "hash" +) + +type macMode struct { + keySize int + etm bool + new func(key []byte) hash.Hash +} + +// truncatingMAC wraps around a hash.Hash and truncates the output digest to +// a given size. +type truncatingMAC struct { + length int + hmac hash.Hash +} + +func (t truncatingMAC) Write(data []byte) (int, error) { + return t.hmac.Write(data) +} + +func (t truncatingMAC) Sum(in []byte) []byte { + out := t.hmac.Sum(in) + return out[:len(in)+t.length] +} + +func (t truncatingMAC) Reset() { + t.hmac.Reset() +} + +func (t truncatingMAC) Size() int { + return t.length +} + +func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() } + +var macModes = map[string]*macMode{ + "hmac-sha2-512-etm@openssh.com": {64, true, func(key []byte) hash.Hash { + return hmac.New(sha512.New, key) + }}, + "hmac-sha2-256-etm@openssh.com": {32, true, func(key []byte) hash.Hash { + return hmac.New(sha256.New, key) + }}, + "hmac-sha2-512": {64, false, func(key []byte) hash.Hash { + return hmac.New(sha512.New, key) + }}, + "hmac-sha2-256": {32, false, func(key []byte) hash.Hash { + return hmac.New(sha256.New, key) + }}, + "hmac-sha1": {20, false, func(key []byte) hash.Hash { + return hmac.New(sha1.New, key) + }}, + "hmac-sha1-96": {20, false, func(key []byte) hash.Hash { + return truncatingMAC{12, hmac.New(sha1.New, key)} + }}, +} diff --git a/tempfork/sshtest/ssh/mempipe_test.go b/tempfork/sshtest/ssh/mempipe_test.go new file mode 100644 index 0000000000000..f27339c51a1b4 --- /dev/null +++ b/tempfork/sshtest/ssh/mempipe_test.go @@ -0,0 +1,124 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "io" + "sync" + "testing" +) + +// An in-memory packetConn. It is safe to call Close and writePacket +// from different goroutines. +type memTransport struct { + eof bool + pending [][]byte + write *memTransport + writeCount uint64 + sync.Mutex + *sync.Cond +} + +func (t *memTransport) readPacket() ([]byte, error) { + t.Lock() + defer t.Unlock() + for { + if len(t.pending) > 0 { + r := t.pending[0] + t.pending = t.pending[1:] + return r, nil + } + if t.eof { + return nil, io.EOF + } + t.Cond.Wait() + } +} + +func (t *memTransport) closeSelf() error { + t.Lock() + defer t.Unlock() + if t.eof { + return io.EOF + } + t.eof = true + t.Cond.Broadcast() + return nil +} + +func (t *memTransport) Close() error { + err := t.write.closeSelf() + t.closeSelf() + return err +} + +func (t *memTransport) writePacket(p []byte) error { + t.write.Lock() + defer t.write.Unlock() + if t.write.eof { + return io.EOF + } + c := make([]byte, len(p)) + copy(c, p) + t.write.pending = append(t.write.pending, c) + t.write.Cond.Signal() + t.writeCount++ + return nil +} + +func (t *memTransport) getWriteCount() uint64 { + t.write.Lock() + defer t.write.Unlock() + return t.writeCount +} + +func memPipe() (a, b packetConn) { + t1 := memTransport{} + t2 := memTransport{} + t1.write = &t2 + t2.write = &t1 + t1.Cond = sync.NewCond(&t1.Mutex) + t2.Cond = sync.NewCond(&t2.Mutex) + return &t1, &t2 +} + +func TestMemPipe(t *testing.T) { + a, b := memPipe() + if err := a.writePacket([]byte{42}); err != nil { + t.Fatalf("writePacket: %v", err) + } + if wc := a.(*memTransport).getWriteCount(); wc != 1 { + t.Fatalf("got %v, want 1", wc) + } + if err := a.Close(); err != nil { + t.Fatal("Close: ", err) + } + p, err := b.readPacket() + if err != nil { + t.Fatal("readPacket: ", err) + } + if len(p) != 1 || p[0] != 42 { + t.Fatalf("got %v, want {42}", p) + } + p, err = b.readPacket() + if err != io.EOF { + t.Fatalf("got %v, %v, want EOF", p, err) + } + if wc := b.(*memTransport).getWriteCount(); wc != 0 { + t.Fatalf("got %v, want 0", wc) + } +} + +func TestDoubleClose(t *testing.T) { + a, _ := memPipe() + err := a.Close() + if err != nil { + t.Errorf("Close: %v", err) + } + err = a.Close() + if err != io.EOF { + t.Errorf("expect EOF on double close.") + } +} diff --git a/tempfork/sshtest/ssh/messages.go b/tempfork/sshtest/ssh/messages.go new file mode 100644 index 0000000000000..b55f860564fe3 --- /dev/null +++ b/tempfork/sshtest/ssh/messages.go @@ -0,0 +1,891 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + "reflect" + "strconv" + "strings" +) + +// These are SSH message type numbers. They are scattered around several +// documents but many were taken from [SSH-PARAMETERS]. +const ( + msgIgnore = 2 + msgUnimplemented = 3 + msgDebug = 4 + msgNewKeys = 21 +) + +// SSH messages: +// +// These structures mirror the wire format of the corresponding SSH messages. +// They are marshaled using reflection with the marshal and unmarshal functions +// in this file. The only wrinkle is that a final member of type []byte with a +// ssh tag of "rest" receives the remainder of a packet when unmarshaling. + +// See RFC 4253, section 11.1. +const msgDisconnect = 1 + +// disconnectMsg is the message that signals a disconnect. It is also +// the error type returned from mux.Wait() +type disconnectMsg struct { + Reason uint32 `sshtype:"1"` + Message string + Language string +} + +func (d *disconnectMsg) Error() string { + return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message) +} + +// See RFC 4253, section 7.1. +const msgKexInit = 20 + +type kexInitMsg struct { + Cookie [16]byte `sshtype:"20"` + KexAlgos []string + ServerHostKeyAlgos []string + CiphersClientServer []string + CiphersServerClient []string + MACsClientServer []string + MACsServerClient []string + CompressionClientServer []string + CompressionServerClient []string + LanguagesClientServer []string + LanguagesServerClient []string + FirstKexFollows bool + Reserved uint32 +} + +// See RFC 4253, section 8. + +// Diffie-Hellman +const msgKexDHInit = 30 + +type kexDHInitMsg struct { + X *big.Int `sshtype:"30"` +} + +const msgKexECDHInit = 30 + +type kexECDHInitMsg struct { + ClientPubKey []byte `sshtype:"30"` +} + +const msgKexECDHReply = 31 + +type kexECDHReplyMsg struct { + HostKey []byte `sshtype:"31"` + EphemeralPubKey []byte + Signature []byte +} + +const msgKexDHReply = 31 + +type kexDHReplyMsg struct { + HostKey []byte `sshtype:"31"` + Y *big.Int + Signature []byte +} + +// See RFC 4419, section 5. +const msgKexDHGexGroup = 31 + +type kexDHGexGroupMsg struct { + P *big.Int `sshtype:"31"` + G *big.Int +} + +const msgKexDHGexInit = 32 + +type kexDHGexInitMsg struct { + X *big.Int `sshtype:"32"` +} + +const msgKexDHGexReply = 33 + +type kexDHGexReplyMsg struct { + HostKey []byte `sshtype:"33"` + Y *big.Int + Signature []byte +} + +const msgKexDHGexRequest = 34 + +type kexDHGexRequestMsg struct { + MinBits uint32 `sshtype:"34"` + PreferedBits uint32 + MaxBits uint32 +} + +// See RFC 4253, section 10. +const msgServiceRequest = 5 + +type serviceRequestMsg struct { + Service string `sshtype:"5"` +} + +// See RFC 4253, section 10. +const msgServiceAccept = 6 + +type serviceAcceptMsg struct { + Service string `sshtype:"6"` +} + +// See RFC 8308, section 2.3 +const msgExtInfo = 7 + +type extInfoMsg struct { + NumExtensions uint32 `sshtype:"7"` + Payload []byte `ssh:"rest"` +} + +// See RFC 4252, section 5. +const msgUserAuthRequest = 50 + +type userAuthRequestMsg struct { + User string `sshtype:"50"` + Service string + Method string + Payload []byte `ssh:"rest"` +} + +// Used for debug printouts of packets. +type userAuthSuccessMsg struct { +} + +// See RFC 4252, section 5.1 +const msgUserAuthFailure = 51 + +type userAuthFailureMsg struct { + Methods []string `sshtype:"51"` + PartialSuccess bool +} + +// See RFC 4252, section 5.1 +const msgUserAuthSuccess = 52 + +// See RFC 4252, section 5.4 +const msgUserAuthBanner = 53 + +type userAuthBannerMsg struct { + Message string `sshtype:"53"` + // unused, but required to allow message parsing + Language string +} + +// See RFC 4256, section 3.2 +const msgUserAuthInfoRequest = 60 +const msgUserAuthInfoResponse = 61 + +type userAuthInfoRequestMsg struct { + Name string `sshtype:"60"` + Instruction string + Language string + NumPrompts uint32 + Prompts []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpen = 90 + +type channelOpenMsg struct { + ChanType string `sshtype:"90"` + PeersID uint32 + PeersWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +const msgChannelExtendedData = 95 +const msgChannelData = 94 + +// Used for debug print outs of packets. +type channelDataMsg struct { + PeersID uint32 `sshtype:"94"` + Length uint32 + Rest []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpenConfirm = 91 + +type channelOpenConfirmMsg struct { + PeersID uint32 `sshtype:"91"` + MyID uint32 + MyWindow uint32 + MaxPacketSize uint32 + TypeSpecificData []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.1. +const msgChannelOpenFailure = 92 + +type channelOpenFailureMsg struct { + PeersID uint32 `sshtype:"92"` + Reason RejectionReason + Message string + Language string +} + +const msgChannelRequest = 98 + +type channelRequestMsg struct { + PeersID uint32 `sshtype:"98"` + Request string + WantReply bool + RequestSpecificData []byte `ssh:"rest"` +} + +// See RFC 4254, section 5.4. +const msgChannelSuccess = 99 + +type channelRequestSuccessMsg struct { + PeersID uint32 `sshtype:"99"` +} + +// See RFC 4254, section 5.4. +const msgChannelFailure = 100 + +type channelRequestFailureMsg struct { + PeersID uint32 `sshtype:"100"` +} + +// See RFC 4254, section 5.3 +const msgChannelClose = 97 + +type channelCloseMsg struct { + PeersID uint32 `sshtype:"97"` +} + +// See RFC 4254, section 5.3 +const msgChannelEOF = 96 + +type channelEOFMsg struct { + PeersID uint32 `sshtype:"96"` +} + +// See RFC 4254, section 4 +const msgGlobalRequest = 80 + +type globalRequestMsg struct { + Type string `sshtype:"80"` + WantReply bool + Data []byte `ssh:"rest"` +} + +// See RFC 4254, section 4 +const msgRequestSuccess = 81 + +type globalRequestSuccessMsg struct { + Data []byte `ssh:"rest" sshtype:"81"` +} + +// See RFC 4254, section 4 +const msgRequestFailure = 82 + +type globalRequestFailureMsg struct { + Data []byte `ssh:"rest" sshtype:"82"` +} + +// See RFC 4254, section 5.2 +const msgChannelWindowAdjust = 93 + +type windowAdjustMsg struct { + PeersID uint32 `sshtype:"93"` + AdditionalBytes uint32 +} + +// See RFC 4252, section 7 +const msgUserAuthPubKeyOk = 60 + +type userAuthPubKeyOkMsg struct { + Algo string `sshtype:"60"` + PubKey []byte +} + +// See RFC 4462, section 3 +const msgUserAuthGSSAPIResponse = 60 + +type userAuthGSSAPIResponse struct { + SupportMech []byte `sshtype:"60"` +} + +const msgUserAuthGSSAPIToken = 61 + +type userAuthGSSAPIToken struct { + Token []byte `sshtype:"61"` +} + +const msgUserAuthGSSAPIMIC = 66 + +type userAuthGSSAPIMIC struct { + MIC []byte `sshtype:"66"` +} + +// See RFC 4462, section 3.9 +const msgUserAuthGSSAPIErrTok = 64 + +type userAuthGSSAPIErrTok struct { + ErrorToken []byte `sshtype:"64"` +} + +// See RFC 4462, section 3.8 +const msgUserAuthGSSAPIError = 65 + +type userAuthGSSAPIError struct { + MajorStatus uint32 `sshtype:"65"` + MinorStatus uint32 + Message string + LanguageTag string +} + +// Transport layer OpenSSH extension. See [PROTOCOL], section 1.9 +const msgPing = 192 + +type pingMsg struct { + Data string `sshtype:"192"` +} + +// Transport layer OpenSSH extension. See [PROTOCOL], section 1.9 +const msgPong = 193 + +type pongMsg struct { + Data string `sshtype:"193"` +} + +// typeTags returns the possible type bytes for the given reflect.Type, which +// should be a struct. The possible values are separated by a '|' character. +func typeTags(structType reflect.Type) (tags []byte) { + tagStr := structType.Field(0).Tag.Get("sshtype") + + for _, tag := range strings.Split(tagStr, "|") { + i, err := strconv.Atoi(tag) + if err == nil { + tags = append(tags, byte(i)) + } + } + + return tags +} + +func fieldError(t reflect.Type, field int, problem string) error { + if problem != "" { + problem = ": " + problem + } + return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem) +} + +var errShortRead = errors.New("ssh: short read") + +// Unmarshal parses data in SSH wire format into a structure. The out +// argument should be a pointer to struct. If the first member of the +// struct has the "sshtype" tag set to a '|'-separated set of numbers +// in decimal, the packet must start with one of those numbers. In +// case of error, Unmarshal returns a ParseError or +// UnexpectedMessageError. +func Unmarshal(data []byte, out interface{}) error { + v := reflect.ValueOf(out).Elem() + structType := v.Type() + expectedTypes := typeTags(structType) + + var expectedType byte + if len(expectedTypes) > 0 { + expectedType = expectedTypes[0] + } + + if len(data) == 0 { + return parseError(expectedType) + } + + if len(expectedTypes) > 0 { + goodType := false + for _, e := range expectedTypes { + if e > 0 && data[0] == e { + goodType = true + break + } + } + if !goodType { + return fmt.Errorf("ssh: unexpected message type %d (expected one of %v)", data[0], expectedTypes) + } + data = data[1:] + } + + var ok bool + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + t := field.Type() + switch t.Kind() { + case reflect.Bool: + if len(data) < 1 { + return errShortRead + } + field.SetBool(data[0] != 0) + data = data[1:] + case reflect.Array: + if t.Elem().Kind() != reflect.Uint8 { + return fieldError(structType, i, "array of unsupported type") + } + if len(data) < t.Len() { + return errShortRead + } + for j, n := 0, t.Len(); j < n; j++ { + field.Index(j).Set(reflect.ValueOf(data[j])) + } + data = data[t.Len():] + case reflect.Uint64: + var u64 uint64 + if u64, data, ok = parseUint64(data); !ok { + return errShortRead + } + field.SetUint(u64) + case reflect.Uint32: + var u32 uint32 + if u32, data, ok = parseUint32(data); !ok { + return errShortRead + } + field.SetUint(uint64(u32)) + case reflect.Uint8: + if len(data) < 1 { + return errShortRead + } + field.SetUint(uint64(data[0])) + data = data[1:] + case reflect.String: + var s []byte + if s, data, ok = parseString(data); !ok { + return fieldError(structType, i, "") + } + field.SetString(string(s)) + case reflect.Slice: + switch t.Elem().Kind() { + case reflect.Uint8: + if structType.Field(i).Tag.Get("ssh") == "rest" { + field.Set(reflect.ValueOf(data)) + data = nil + } else { + var s []byte + if s, data, ok = parseString(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(s)) + } + case reflect.String: + var nl []string + if nl, data, ok = parseNameList(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(nl)) + default: + return fieldError(structType, i, "slice of unsupported type") + } + case reflect.Ptr: + if t == bigIntType { + var n *big.Int + if n, data, ok = parseInt(data); !ok { + return errShortRead + } + field.Set(reflect.ValueOf(n)) + } else { + return fieldError(structType, i, "pointer to unsupported type") + } + default: + return fieldError(structType, i, fmt.Sprintf("unsupported type: %v", t)) + } + } + + if len(data) != 0 { + return parseError(expectedType) + } + + return nil +} + +// Marshal serializes the message in msg to SSH wire format. The msg +// argument should be a struct or pointer to struct. If the first +// member has the "sshtype" tag set to a number in decimal, that +// number is prepended to the result. If the last of member has the +// "ssh" tag set to "rest", its contents are appended to the output. +func Marshal(msg interface{}) []byte { + out := make([]byte, 0, 64) + return marshalStruct(out, msg) +} + +func marshalStruct(out []byte, msg interface{}) []byte { + v := reflect.Indirect(reflect.ValueOf(msg)) + msgTypes := typeTags(v.Type()) + if len(msgTypes) > 0 { + out = append(out, msgTypes[0]) + } + + for i, n := 0, v.NumField(); i < n; i++ { + field := v.Field(i) + switch t := field.Type(); t.Kind() { + case reflect.Bool: + var v uint8 + if field.Bool() { + v = 1 + } + out = append(out, v) + case reflect.Array: + if t.Elem().Kind() != reflect.Uint8 { + panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface())) + } + for j, l := 0, t.Len(); j < l; j++ { + out = append(out, uint8(field.Index(j).Uint())) + } + case reflect.Uint32: + out = appendU32(out, uint32(field.Uint())) + case reflect.Uint64: + out = appendU64(out, uint64(field.Uint())) + case reflect.Uint8: + out = append(out, uint8(field.Uint())) + case reflect.String: + s := field.String() + out = appendInt(out, len(s)) + out = append(out, s...) + case reflect.Slice: + switch t.Elem().Kind() { + case reflect.Uint8: + if v.Type().Field(i).Tag.Get("ssh") != "rest" { + out = appendInt(out, field.Len()) + } + out = append(out, field.Bytes()...) + case reflect.String: + offset := len(out) + out = appendU32(out, 0) + if n := field.Len(); n > 0 { + for j := 0; j < n; j++ { + f := field.Index(j) + if j != 0 { + out = append(out, ',') + } + out = append(out, f.String()...) + } + // overwrite length value + binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4)) + } + default: + panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface())) + } + case reflect.Ptr: + if t == bigIntType { + var n *big.Int + nValue := reflect.ValueOf(&n) + nValue.Elem().Set(field) + needed := intLength(n) + oldLength := len(out) + + if cap(out)-len(out) < needed { + newOut := make([]byte, len(out), 2*(len(out)+needed)) + copy(newOut, out) + out = newOut + } + out = out[:oldLength+needed] + marshalInt(out[oldLength:], n) + } else { + panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface())) + } + } + } + + return out +} + +var bigOne = big.NewInt(1) + +func parseString(in []byte) (out, rest []byte, ok bool) { + if len(in) < 4 { + return + } + length := binary.BigEndian.Uint32(in) + in = in[4:] + if uint32(len(in)) < length { + return + } + out = in[:length] + rest = in[length:] + ok = true + return +} + +var ( + comma = []byte{','} + emptyNameList = []string{} +) + +func parseNameList(in []byte) (out []string, rest []byte, ok bool) { + contents, rest, ok := parseString(in) + if !ok { + return + } + if len(contents) == 0 { + out = emptyNameList + return + } + parts := bytes.Split(contents, comma) + out = make([]string, len(parts)) + for i, part := range parts { + out[i] = string(part) + } + return +} + +func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) { + contents, rest, ok := parseString(in) + if !ok { + return + } + out = new(big.Int) + + if len(contents) > 0 && contents[0]&0x80 == 0x80 { + // This is a negative number + notBytes := make([]byte, len(contents)) + for i := range notBytes { + notBytes[i] = ^contents[i] + } + out.SetBytes(notBytes) + out.Add(out, bigOne) + out.Neg(out) + } else { + // Positive number + out.SetBytes(contents) + } + ok = true + return +} + +func parseUint32(in []byte) (uint32, []byte, bool) { + if len(in) < 4 { + return 0, nil, false + } + return binary.BigEndian.Uint32(in), in[4:], true +} + +func parseUint64(in []byte) (uint64, []byte, bool) { + if len(in) < 8 { + return 0, nil, false + } + return binary.BigEndian.Uint64(in), in[8:], true +} + +func intLength(n *big.Int) int { + length := 4 /* length bytes */ + if n.Sign() < 0 { + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bitLen := nMinus1.BitLen() + if bitLen%8 == 0 { + // The number will need 0xff padding + length++ + } + length += (bitLen + 7) / 8 + } else if n.Sign() == 0 { + // A zero is the zero length string + } else { + bitLen := n.BitLen() + if bitLen%8 == 0 { + // The number will need 0x00 padding + length++ + } + length += (bitLen + 7) / 8 + } + + return length +} + +func marshalUint32(to []byte, n uint32) []byte { + binary.BigEndian.PutUint32(to, n) + return to[4:] +} + +func marshalUint64(to []byte, n uint64) []byte { + binary.BigEndian.PutUint64(to, n) + return to[8:] +} + +func marshalInt(to []byte, n *big.Int) []byte { + lengthBytes := to + to = to[4:] + length := 0 + + if n.Sign() < 0 { + // A negative number has to be converted to two's-complement + // form. So we'll subtract 1 and invert. If the + // most-significant-bit isn't set then we'll need to pad the + // beginning with 0xff in order to keep the number negative. + nMinus1 := new(big.Int).Neg(n) + nMinus1.Sub(nMinus1, bigOne) + bytes := nMinus1.Bytes() + for i := range bytes { + bytes[i] ^= 0xff + } + if len(bytes) == 0 || bytes[0]&0x80 == 0 { + to[0] = 0xff + to = to[1:] + length++ + } + nBytes := copy(to, bytes) + to = to[nBytes:] + length += nBytes + } else if n.Sign() == 0 { + // A zero is the zero length string + } else { + bytes := n.Bytes() + if len(bytes) > 0 && bytes[0]&0x80 != 0 { + // We'll have to pad this with a 0x00 in order to + // stop it looking like a negative number. + to[0] = 0 + to = to[1:] + length++ + } + nBytes := copy(to, bytes) + to = to[nBytes:] + length += nBytes + } + + lengthBytes[0] = byte(length >> 24) + lengthBytes[1] = byte(length >> 16) + lengthBytes[2] = byte(length >> 8) + lengthBytes[3] = byte(length) + return to +} + +func writeInt(w io.Writer, n *big.Int) { + length := intLength(n) + buf := make([]byte, length) + marshalInt(buf, n) + w.Write(buf) +} + +func writeString(w io.Writer, s []byte) { + var lengthBytes [4]byte + lengthBytes[0] = byte(len(s) >> 24) + lengthBytes[1] = byte(len(s) >> 16) + lengthBytes[2] = byte(len(s) >> 8) + lengthBytes[3] = byte(len(s)) + w.Write(lengthBytes[:]) + w.Write(s) +} + +func stringLength(n int) int { + return 4 + n +} + +func marshalString(to []byte, s []byte) []byte { + to[0] = byte(len(s) >> 24) + to[1] = byte(len(s) >> 16) + to[2] = byte(len(s) >> 8) + to[3] = byte(len(s)) + to = to[4:] + copy(to, s) + return to[len(s):] +} + +var bigIntType = reflect.TypeOf((*big.Int)(nil)) + +// Decode a packet into its corresponding message. +func decode(packet []byte) (interface{}, error) { + var msg interface{} + switch packet[0] { + case msgDisconnect: + msg = new(disconnectMsg) + case msgServiceRequest: + msg = new(serviceRequestMsg) + case msgServiceAccept: + msg = new(serviceAcceptMsg) + case msgExtInfo: + msg = new(extInfoMsg) + case msgKexInit: + msg = new(kexInitMsg) + case msgKexDHInit: + msg = new(kexDHInitMsg) + case msgKexDHReply: + msg = new(kexDHReplyMsg) + case msgUserAuthRequest: + msg = new(userAuthRequestMsg) + case msgUserAuthSuccess: + return new(userAuthSuccessMsg), nil + case msgUserAuthFailure: + msg = new(userAuthFailureMsg) + case msgUserAuthPubKeyOk: + msg = new(userAuthPubKeyOkMsg) + case msgGlobalRequest: + msg = new(globalRequestMsg) + case msgRequestSuccess: + msg = new(globalRequestSuccessMsg) + case msgRequestFailure: + msg = new(globalRequestFailureMsg) + case msgChannelOpen: + msg = new(channelOpenMsg) + case msgChannelData: + msg = new(channelDataMsg) + case msgChannelOpenConfirm: + msg = new(channelOpenConfirmMsg) + case msgChannelOpenFailure: + msg = new(channelOpenFailureMsg) + case msgChannelWindowAdjust: + msg = new(windowAdjustMsg) + case msgChannelEOF: + msg = new(channelEOFMsg) + case msgChannelClose: + msg = new(channelCloseMsg) + case msgChannelRequest: + msg = new(channelRequestMsg) + case msgChannelSuccess: + msg = new(channelRequestSuccessMsg) + case msgChannelFailure: + msg = new(channelRequestFailureMsg) + case msgUserAuthGSSAPIToken: + msg = new(userAuthGSSAPIToken) + case msgUserAuthGSSAPIMIC: + msg = new(userAuthGSSAPIMIC) + case msgUserAuthGSSAPIErrTok: + msg = new(userAuthGSSAPIErrTok) + case msgUserAuthGSSAPIError: + msg = new(userAuthGSSAPIError) + default: + return nil, unexpectedMessageError(0, packet[0]) + } + if err := Unmarshal(packet, msg); err != nil { + return nil, err + } + return msg, nil +} + +var packetTypeNames = map[byte]string{ + msgDisconnect: "disconnectMsg", + msgServiceRequest: "serviceRequestMsg", + msgServiceAccept: "serviceAcceptMsg", + msgExtInfo: "extInfoMsg", + msgKexInit: "kexInitMsg", + msgKexDHInit: "kexDHInitMsg", + msgKexDHReply: "kexDHReplyMsg", + msgUserAuthRequest: "userAuthRequestMsg", + msgUserAuthSuccess: "userAuthSuccessMsg", + msgUserAuthFailure: "userAuthFailureMsg", + msgUserAuthPubKeyOk: "userAuthPubKeyOkMsg", + msgGlobalRequest: "globalRequestMsg", + msgRequestSuccess: "globalRequestSuccessMsg", + msgRequestFailure: "globalRequestFailureMsg", + msgChannelOpen: "channelOpenMsg", + msgChannelData: "channelDataMsg", + msgChannelOpenConfirm: "channelOpenConfirmMsg", + msgChannelOpenFailure: "channelOpenFailureMsg", + msgChannelWindowAdjust: "windowAdjustMsg", + msgChannelEOF: "channelEOFMsg", + msgChannelClose: "channelCloseMsg", + msgChannelRequest: "channelRequestMsg", + msgChannelSuccess: "channelRequestSuccessMsg", + msgChannelFailure: "channelRequestFailureMsg", +} diff --git a/tempfork/sshtest/ssh/messages_test.go b/tempfork/sshtest/ssh/messages_test.go new file mode 100644 index 0000000000000..e79076412ab49 --- /dev/null +++ b/tempfork/sshtest/ssh/messages_test.go @@ -0,0 +1,288 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "math/big" + "math/rand" + "reflect" + "testing" + "testing/quick" +) + +var intLengthTests = []struct { + val, length int +}{ + {0, 4 + 0}, + {1, 4 + 1}, + {127, 4 + 1}, + {128, 4 + 2}, + {-1, 4 + 1}, +} + +func TestIntLength(t *testing.T) { + for _, test := range intLengthTests { + v := new(big.Int).SetInt64(int64(test.val)) + length := intLength(v) + if length != test.length { + t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length) + } + } +} + +type msgAllTypes struct { + Bool bool `sshtype:"21"` + Array [16]byte + Uint64 uint64 + Uint32 uint32 + Uint8 uint8 + String string + Strings []string + Bytes []byte + Int *big.Int + Rest []byte `ssh:"rest"` +} + +func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value { + m := &msgAllTypes{} + m.Bool = rand.Intn(2) == 1 + randomBytes(m.Array[:], rand) + m.Uint64 = uint64(rand.Int63n(1<<63 - 1)) + m.Uint32 = uint32(rand.Intn((1 << 31) - 1)) + m.Uint8 = uint8(rand.Intn(1 << 8)) + m.String = string(m.Array[:]) + m.Strings = randomNameList(rand) + m.Bytes = m.Array[:] + m.Int = randomInt(rand) + m.Rest = m.Array[:] + return reflect.ValueOf(m) +} + +func TestMarshalUnmarshal(t *testing.T) { + rand := rand.New(rand.NewSource(0)) + iface := &msgAllTypes{} + ty := reflect.ValueOf(iface).Type() + + n := 100 + if testing.Short() { + n = 5 + } + for j := 0; j < n; j++ { + v, ok := quick.Value(ty, rand) + if !ok { + t.Errorf("failed to create value") + break + } + + m1 := v.Elem().Interface() + m2 := iface + + marshaled := Marshal(m1) + if err := Unmarshal(marshaled, m2); err != nil { + t.Errorf("Unmarshal %#v: %s", m1, err) + break + } + + if !reflect.DeepEqual(v.Interface(), m2) { + t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled) + break + } + } +} + +func TestUnmarshalEmptyPacket(t *testing.T) { + var b []byte + var m channelRequestSuccessMsg + if err := Unmarshal(b, &m); err == nil { + t.Fatalf("unmarshal of empty slice succeeded") + } +} + +func TestUnmarshalUnexpectedPacket(t *testing.T) { + type S struct { + I uint32 `sshtype:"43"` + S string + B bool + } + + s := S{11, "hello", true} + packet := Marshal(s) + packet[0] = 42 + roundtrip := S{} + err := Unmarshal(packet, &roundtrip) + if err == nil { + t.Fatal("expected error, not nil") + } +} + +func TestMarshalPtr(t *testing.T) { + s := struct { + S string + }{"hello"} + + m1 := Marshal(s) + m2 := Marshal(&s) + if !bytes.Equal(m1, m2) { + t.Errorf("got %q, want %q for marshaled pointer", m2, m1) + } +} + +func TestBareMarshalUnmarshal(t *testing.T) { + type S struct { + I uint32 + S string + B bool + } + + s := S{42, "hello", true} + packet := Marshal(s) + roundtrip := S{} + Unmarshal(packet, &roundtrip) + + if !reflect.DeepEqual(s, roundtrip) { + t.Errorf("got %#v, want %#v", roundtrip, s) + } +} + +func TestBareMarshal(t *testing.T) { + type S2 struct { + I uint32 + } + s := S2{42} + packet := Marshal(s) + i, rest, ok := parseUint32(packet) + if len(rest) > 0 || !ok { + t.Errorf("parseInt(%q): parse error", packet) + } + if i != s.I { + t.Errorf("got %d, want %d", i, s.I) + } +} + +func TestUnmarshalShortKexInitPacket(t *testing.T) { + // This used to panic. + // Issue 11348 + packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff} + kim := &kexInitMsg{} + if err := Unmarshal(packet, kim); err == nil { + t.Error("truncated packet unmarshaled without error") + } +} + +func TestMarshalMultiTag(t *testing.T) { + var res struct { + A uint32 `sshtype:"1|2"` + } + + good1 := struct { + A uint32 `sshtype:"1"` + }{ + 1, + } + good2 := struct { + A uint32 `sshtype:"2"` + }{ + 1, + } + + if e := Unmarshal(Marshal(good1), &res); e != nil { + t.Errorf("error unmarshaling multipart tag: %v", e) + } + + if e := Unmarshal(Marshal(good2), &res); e != nil { + t.Errorf("error unmarshaling multipart tag: %v", e) + } + + bad1 := struct { + A uint32 `sshtype:"3"` + }{ + 1, + } + if e := Unmarshal(Marshal(bad1), &res); e == nil { + t.Errorf("bad struct unmarshaled without error") + } +} + +func randomBytes(out []byte, rand *rand.Rand) { + for i := 0; i < len(out); i++ { + out[i] = byte(rand.Int31()) + } +} + +func randomNameList(rand *rand.Rand) []string { + ret := make([]string, rand.Int31()&15) + for i := range ret { + s := make([]byte, 1+(rand.Int31()&15)) + for j := range s { + s[j] = 'a' + uint8(rand.Int31()&15) + } + ret[i] = string(s) + } + return ret +} + +func randomInt(rand *rand.Rand) *big.Int { + return new(big.Int).SetInt64(int64(int32(rand.Uint32()))) +} + +func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { + ki := &kexInitMsg{} + randomBytes(ki.Cookie[:], rand) + ki.KexAlgos = randomNameList(rand) + ki.ServerHostKeyAlgos = randomNameList(rand) + ki.CiphersClientServer = randomNameList(rand) + ki.CiphersServerClient = randomNameList(rand) + ki.MACsClientServer = randomNameList(rand) + ki.MACsServerClient = randomNameList(rand) + ki.CompressionClientServer = randomNameList(rand) + ki.CompressionServerClient = randomNameList(rand) + ki.LanguagesClientServer = randomNameList(rand) + ki.LanguagesServerClient = randomNameList(rand) + if rand.Int31()&1 == 1 { + ki.FirstKexFollows = true + } + return reflect.ValueOf(ki) +} + +func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { + dhi := &kexDHInitMsg{} + dhi.X = randomInt(rand) + return reflect.ValueOf(dhi) +} + +var ( + _kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() + _kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() + + _kexInit = Marshal(_kexInitMsg) + _kexDHInit = Marshal(_kexDHInitMsg) +) + +func BenchmarkMarshalKexInitMsg(b *testing.B) { + for i := 0; i < b.N; i++ { + Marshal(_kexInitMsg) + } +} + +func BenchmarkUnmarshalKexInitMsg(b *testing.B) { + m := new(kexInitMsg) + for i := 0; i < b.N; i++ { + Unmarshal(_kexInit, m) + } +} + +func BenchmarkMarshalKexDHInitMsg(b *testing.B) { + for i := 0; i < b.N; i++ { + Marshal(_kexDHInitMsg) + } +} + +func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) { + m := new(kexDHInitMsg) + for i := 0; i < b.N; i++ { + Unmarshal(_kexDHInit, m) + } +} diff --git a/tempfork/sshtest/ssh/mux.go b/tempfork/sshtest/ssh/mux.go new file mode 100644 index 0000000000000..d2d24c635d32a --- /dev/null +++ b/tempfork/sshtest/ssh/mux.go @@ -0,0 +1,357 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/binary" + "fmt" + "io" + "log" + "sync" + "sync/atomic" +) + +// debugMux, if set, causes messages in the connection protocol to be +// logged. +const debugMux = false + +// chanList is a thread safe channel list. +type chanList struct { + // protects concurrent access to chans + sync.Mutex + + // chans are indexed by the local id of the channel, which the + // other side should send in the PeersId field. + chans []*channel + + // This is a debugging aid: it offsets all IDs by this + // amount. This helps distinguish otherwise identical + // server/client muxes + offset uint32 +} + +// Assigns a channel ID to the given channel. +func (c *chanList) add(ch *channel) uint32 { + c.Lock() + defer c.Unlock() + for i := range c.chans { + if c.chans[i] == nil { + c.chans[i] = ch + return uint32(i) + c.offset + } + } + c.chans = append(c.chans, ch) + return uint32(len(c.chans)-1) + c.offset +} + +// getChan returns the channel for the given ID. +func (c *chanList) getChan(id uint32) *channel { + id -= c.offset + + c.Lock() + defer c.Unlock() + if id < uint32(len(c.chans)) { + return c.chans[id] + } + return nil +} + +func (c *chanList) remove(id uint32) { + id -= c.offset + c.Lock() + if id < uint32(len(c.chans)) { + c.chans[id] = nil + } + c.Unlock() +} + +// dropAll forgets all channels it knows, returning them in a slice. +func (c *chanList) dropAll() []*channel { + c.Lock() + defer c.Unlock() + var r []*channel + + for _, ch := range c.chans { + if ch == nil { + continue + } + r = append(r, ch) + } + c.chans = nil + return r +} + +// mux represents the state for the SSH connection protocol, which +// multiplexes many channels onto a single packet transport. +type mux struct { + conn packetConn + chanList chanList + + incomingChannels chan NewChannel + + globalSentMu sync.Mutex + globalResponses chan interface{} + incomingRequests chan *Request + + errCond *sync.Cond + err error +} + +// When debugging, each new chanList instantiation has a different +// offset. +var globalOff uint32 + +func (m *mux) Wait() error { + m.errCond.L.Lock() + defer m.errCond.L.Unlock() + for m.err == nil { + m.errCond.Wait() + } + return m.err +} + +// newMux returns a mux that runs over the given connection. +func newMux(p packetConn) *mux { + m := &mux{ + conn: p, + incomingChannels: make(chan NewChannel, chanSize), + globalResponses: make(chan interface{}, 1), + incomingRequests: make(chan *Request, chanSize), + errCond: newCond(), + } + if debugMux { + m.chanList.offset = atomic.AddUint32(&globalOff, 1) + } + + go m.loop() + return m +} + +func (m *mux) sendMessage(msg interface{}) error { + p := Marshal(msg) + if debugMux { + log.Printf("send global(%d): %#v", m.chanList.offset, msg) + } + return m.conn.writePacket(p) +} + +func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { + if wantReply { + m.globalSentMu.Lock() + defer m.globalSentMu.Unlock() + } + + if err := m.sendMessage(globalRequestMsg{ + Type: name, + WantReply: wantReply, + Data: payload, + }); err != nil { + return false, nil, err + } + + if !wantReply { + return false, nil, nil + } + + msg, ok := <-m.globalResponses + if !ok { + return false, nil, io.EOF + } + switch msg := msg.(type) { + case *globalRequestFailureMsg: + return false, msg.Data, nil + case *globalRequestSuccessMsg: + return true, msg.Data, nil + default: + return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) + } +} + +// ackRequest must be called after processing a global request that +// has WantReply set. +func (m *mux) ackRequest(ok bool, data []byte) error { + if ok { + return m.sendMessage(globalRequestSuccessMsg{Data: data}) + } + return m.sendMessage(globalRequestFailureMsg{Data: data}) +} + +func (m *mux) Close() error { + return m.conn.Close() +} + +// loop runs the connection machine. It will process packets until an +// error is encountered. To synchronize on loop exit, use mux.Wait. +func (m *mux) loop() { + var err error + for err == nil { + err = m.onePacket() + } + + for _, ch := range m.chanList.dropAll() { + ch.close() + } + + close(m.incomingChannels) + close(m.incomingRequests) + close(m.globalResponses) + + m.conn.Close() + + m.errCond.L.Lock() + m.err = err + m.errCond.Broadcast() + m.errCond.L.Unlock() + + if debugMux { + log.Println("loop exit", err) + } +} + +// onePacket reads and processes one packet. +func (m *mux) onePacket() error { + packet, err := m.conn.readPacket() + if err != nil { + return err + } + + if debugMux { + if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { + log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) + } else { + p, _ := decode(packet) + log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) + } + } + + switch packet[0] { + case msgChannelOpen: + return m.handleChannelOpen(packet) + case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: + return m.handleGlobalPacket(packet) + case msgPing: + var msg pingMsg + if err := Unmarshal(packet, &msg); err != nil { + return fmt.Errorf("failed to unmarshal ping@openssh.com message: %w", err) + } + return m.sendMessage(pongMsg(msg)) + } + + // assume a channel packet. + if len(packet) < 5 { + return parseError(packet[0]) + } + id := binary.BigEndian.Uint32(packet[1:]) + ch := m.chanList.getChan(id) + if ch == nil { + return m.handleUnknownChannelPacket(id, packet) + } + + return ch.handlePacket(packet) +} + +func (m *mux) handleGlobalPacket(packet []byte) error { + msg, err := decode(packet) + if err != nil { + return err + } + + switch msg := msg.(type) { + case *globalRequestMsg: + m.incomingRequests <- &Request{ + Type: msg.Type, + WantReply: msg.WantReply, + Payload: msg.Data, + mux: m, + } + case *globalRequestSuccessMsg, *globalRequestFailureMsg: + m.globalResponses <- msg + default: + panic(fmt.Sprintf("not a global message %#v", msg)) + } + + return nil +} + +// handleChannelOpen schedules a channel to be Accept()ed. +func (m *mux) handleChannelOpen(packet []byte) error { + var msg channelOpenMsg + if err := Unmarshal(packet, &msg); err != nil { + return err + } + + if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { + failMsg := channelOpenFailureMsg{ + PeersID: msg.PeersID, + Reason: ConnectionFailed, + Message: "invalid request", + Language: "en_US.UTF-8", + } + return m.sendMessage(failMsg) + } + + c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) + c.remoteId = msg.PeersID + c.maxRemotePayload = msg.MaxPacketSize + c.remoteWin.add(msg.PeersWindow) + m.incomingChannels <- c + return nil +} + +func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { + ch, err := m.openChannel(chanType, extra) + if err != nil { + return nil, nil, err + } + + return ch, ch.incomingRequests, nil +} + +func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { + ch := m.newChannel(chanType, channelOutbound, extra) + + ch.maxIncomingPayload = channelMaxPacket + + open := channelOpenMsg{ + ChanType: chanType, + PeersWindow: ch.myWindow, + MaxPacketSize: ch.maxIncomingPayload, + TypeSpecificData: extra, + PeersID: ch.localId, + } + if err := m.sendMessage(open); err != nil { + return nil, err + } + + switch msg := (<-ch.msg).(type) { + case *channelOpenConfirmMsg: + return ch, nil + case *channelOpenFailureMsg: + return nil, &OpenChannelError{msg.Reason, msg.Message} + default: + return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) + } +} + +func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error { + msg, err := decode(packet) + if err != nil { + return err + } + + switch msg := msg.(type) { + // RFC 4254 section 5.4 says unrecognized channel requests should + // receive a failure response. + case *channelRequestMsg: + if msg.WantReply { + return m.sendMessage(channelRequestFailureMsg{ + PeersID: msg.PeersID, + }) + } + return nil + default: + return fmt.Errorf("ssh: invalid channel %d", id) + } +} diff --git a/tempfork/sshtest/ssh/mux_test.go b/tempfork/sshtest/ssh/mux_test.go new file mode 100644 index 0000000000000..21f0ac3e325a4 --- /dev/null +++ b/tempfork/sshtest/ssh/mux_test.go @@ -0,0 +1,839 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "errors" + "fmt" + "io" + "sync" + "testing" +) + +func muxPair() (*mux, *mux) { + a, b := memPipe() + + s := newMux(a) + c := newMux(b) + + return s, c +} + +// Returns both ends of a channel, and the mux for the 2nd +// channel. +func channelPair(t *testing.T) (*channel, *channel, *mux) { + c, s := muxPair() + + res := make(chan *channel, 1) + go func() { + newCh, ok := <-s.incomingChannels + if !ok { + t.Error("no incoming channel") + close(res) + return + } + if newCh.ChannelType() != "chan" { + t.Errorf("got type %q want chan", newCh.ChannelType()) + newCh.Reject(Prohibited, fmt.Sprintf("got type %q want chan", newCh.ChannelType())) + close(res) + return + } + ch, _, err := newCh.Accept() + if err != nil { + t.Errorf("accept: %v", err) + close(res) + return + } + res <- ch.(*channel) + }() + + ch, err := c.openChannel("chan", nil) + if err != nil { + t.Fatalf("OpenChannel: %v", err) + } + w := <-res + if w == nil { + t.Fatal("unable to get write channel") + } + + return w, ch, c +} + +// Test that stderr and stdout can be addressed from different +// goroutines. This is intended for use with the race detector. +func TestMuxChannelExtendedThreadSafety(t *testing.T) { + writer, reader, mux := channelPair(t) + defer writer.Close() + defer reader.Close() + defer mux.Close() + + var wr, rd sync.WaitGroup + magic := "hello world" + + wr.Add(2) + go func() { + io.WriteString(writer, magic) + wr.Done() + }() + go func() { + io.WriteString(writer.Stderr(), magic) + wr.Done() + }() + + rd.Add(2) + go func() { + c, err := io.ReadAll(reader) + if string(c) != magic { + t.Errorf("stdout read got %q, want %q (error %s)", c, magic, err) + } + rd.Done() + }() + go func() { + c, err := io.ReadAll(reader.Stderr()) + if string(c) != magic { + t.Errorf("stderr read got %q, want %q (error %s)", c, magic, err) + } + rd.Done() + }() + + wr.Wait() + writer.CloseWrite() + rd.Wait() +} + +func TestMuxReadWrite(t *testing.T) { + s, c, mux := channelPair(t) + defer s.Close() + defer c.Close() + defer mux.Close() + + magic := "hello world" + magicExt := "hello stderr" + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + _, err := s.Write([]byte(magic)) + if err != nil { + t.Errorf("Write: %v", err) + return + } + _, err = s.Extended(1).Write([]byte(magicExt)) + if err != nil { + t.Errorf("Write: %v", err) + return + } + }() + + var buf [1024]byte + n, err := c.Read(buf[:]) + if err != nil { + t.Fatalf("server Read: %v", err) + } + got := string(buf[:n]) + if got != magic { + t.Fatalf("server: got %q want %q", got, magic) + } + + n, err = c.Extended(1).Read(buf[:]) + if err != nil { + t.Fatalf("server Read: %v", err) + } + + got = string(buf[:n]) + if got != magicExt { + t.Fatalf("server: got %q want %q", got, magic) + } +} + +func TestMuxChannelOverflow(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + writer.Write(make([]byte, 1)) + }() + writer.remoteWin.waitWriterBlocked() + + // Send 1 byte. + packet := make([]byte, 1+4+4+1) + packet[0] = msgChannelData + marshalUint32(packet[1:], writer.remoteId) + marshalUint32(packet[5:], uint32(1)) + packet[9] = 42 + + if err := writer.mux.conn.writePacket(packet); err != nil { + t.Errorf("could not send packet") + } + if _, err := reader.SendRequest("hello", true, nil); err == nil { + t.Errorf("SendRequest succeeded.") + } +} + +func TestMuxChannelReadUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != nil { + t.Errorf("Write: %v", err) + } + writer.Close() + }() + + writer.remoteWin.waitWriterBlocked() + + buf := make([]byte, 32768) + for { + _, err := reader.Read(buf) + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("Read: %v", err) + } + } +} + +func TestMuxChannelCloseWriteUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != io.EOF { + t.Errorf("got %v, want EOF for unblock write", err) + } + }() + + writer.remoteWin.waitWriterBlocked() + reader.Close() +} + +func TestMuxConnectionCloseWriteUnblock(t *testing.T) { + reader, writer, mux := channelPair(t) + defer reader.Close() + defer writer.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { + t.Errorf("could not fill window: %v", err) + } + if _, err := writer.Write(make([]byte, 1)); err != io.EOF { + t.Errorf("got %v, want EOF for unblock write", err) + } + }() + + writer.remoteWin.waitWriterBlocked() + mux.Close() +} + +func TestMuxReject(t *testing.T) { + client, server := muxPair() + defer server.Close() + defer client.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + + ch, ok := <-server.incomingChannels + if !ok { + t.Error("cannot accept channel") + return + } + if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { + t.Errorf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) + ch.Reject(RejectionReason(UnknownChannelType), UnknownChannelType.String()) + return + } + ch.Reject(RejectionReason(42), "message") + }() + + ch, err := client.openChannel("ch", []byte("extra")) + if ch != nil { + t.Fatal("openChannel not rejected") + } + + ocf, ok := err.(*OpenChannelError) + if !ok { + t.Errorf("got %#v want *OpenChannelError", err) + } else if ocf.Reason != 42 || ocf.Message != "message" { + t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message") + } + + want := "ssh: rejected: unknown reason 42 (message)" + if err.Error() != want { + t.Errorf("got %q, want %q", err.Error(), want) + } +} + +func TestMuxChannelRequest(t *testing.T) { + client, server, mux := channelPair(t) + defer server.Close() + defer client.Close() + defer mux.Close() + + var received int + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + for r := range server.incomingRequests { + received++ + r.Reply(r.Type == "yes", nil) + } + wg.Done() + }() + _, err := client.SendRequest("yes", false, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + ok, err := client.SendRequest("yes", true, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + + if !ok { + t.Errorf("SendRequest(yes): %v", ok) + + } + + ok, err = client.SendRequest("no", true, nil) + if err != nil { + t.Fatalf("SendRequest: %v", err) + } + if ok { + t.Errorf("SendRequest(no): %v", ok) + } + + client.Close() + wg.Wait() + + if received != 3 { + t.Errorf("got %d requests, want %d", received, 3) + } +} + +func TestMuxUnknownChannelRequests(t *testing.T) { + clientPipe, serverPipe := memPipe() + client := newMux(clientPipe) + defer serverPipe.Close() + defer client.Close() + + kDone := make(chan error, 1) + go func() { + // Ignore unknown channel messages that don't want a reply. + err := serverPipe.writePacket(Marshal(channelRequestMsg{ + PeersID: 1, + Request: "keepalive@openssh.com", + WantReply: false, + RequestSpecificData: []byte{}, + })) + if err != nil { + kDone <- fmt.Errorf("send: %w", err) + return + } + + // Send a keepalive, which should get a channel failure message + // in response. + err = serverPipe.writePacket(Marshal(channelRequestMsg{ + PeersID: 2, + Request: "keepalive@openssh.com", + WantReply: true, + RequestSpecificData: []byte{}, + })) + if err != nil { + kDone <- fmt.Errorf("send: %w", err) + return + } + + packet, err := serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + decoded, err := decode(packet) + if err != nil { + kDone <- fmt.Errorf("decode failed: %w", err) + return + } + + switch msg := decoded.(type) { + case *channelRequestFailureMsg: + if msg.PeersID != 2 { + kDone <- fmt.Errorf("received response to wrong message: %v", msg) + return + + } + default: + kDone <- fmt.Errorf("unexpected channel message: %v", msg) + return + } + + kDone <- nil + + // Receive and respond to the keepalive to confirm the mux is + // still processing requests. + packet, err = serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + if packet[0] != msgGlobalRequest { + kDone <- errors.New("expected global request") + return + } + + err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{ + Data: []byte{}, + })) + if err != nil { + kDone <- fmt.Errorf("failed to send failure msg: %w", err) + return + } + + close(kDone) + }() + + // Wait for the server to send the keepalive message and receive back a + // response. + if err := <-kDone; err != nil { + t.Fatal(err) + } + + // Confirm client hasn't closed. + if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil { + t.Fatalf("failed to send keepalive: %v", err) + } + + // Wait for the server to shut down. + if err := <-kDone; err != nil { + t.Fatal(err) + } +} + +func TestMuxClosedChannel(t *testing.T) { + clientPipe, serverPipe := memPipe() + client := newMux(clientPipe) + defer serverPipe.Close() + defer client.Close() + + kDone := make(chan error, 1) + go func() { + // Open the channel. + packet, err := serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + if packet[0] != msgChannelOpen { + kDone <- errors.New("expected chan open") + return + } + + var openMsg channelOpenMsg + if err := Unmarshal(packet, &openMsg); err != nil { + kDone <- fmt.Errorf("unmarshal: %w", err) + return + } + + // Send back the opened channel confirmation. + err = serverPipe.writePacket(Marshal(channelOpenConfirmMsg{ + PeersID: openMsg.PeersID, + MyID: 0, + MyWindow: 0, + MaxPacketSize: channelMaxPacket, + })) + if err != nil { + kDone <- fmt.Errorf("send: %w", err) + return + } + + // Close the channel. + err = serverPipe.writePacket(Marshal(channelCloseMsg{ + PeersID: openMsg.PeersID, + })) + if err != nil { + kDone <- fmt.Errorf("send: %w", err) + return + } + + // Send a keepalive message on the channel we just closed. + err = serverPipe.writePacket(Marshal(channelRequestMsg{ + PeersID: openMsg.PeersID, + Request: "keepalive@openssh.com", + WantReply: true, + RequestSpecificData: []byte{}, + })) + if err != nil { + kDone <- fmt.Errorf("send: %w", err) + return + } + + // Receive the channel closed response. + packet, err = serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + if packet[0] != msgChannelClose { + kDone <- errors.New("expected channel close") + return + } + + // Receive the keepalive response failure. + packet, err = serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + if packet[0] != msgChannelFailure { + kDone <- errors.New("expected channel failure") + return + } + kDone <- nil + + // Receive and respond to the keepalive to confirm the mux is + // still processing requests. + packet, err = serverPipe.readPacket() + if err != nil { + kDone <- fmt.Errorf("read packet: %w", err) + return + } + if packet[0] != msgGlobalRequest { + kDone <- errors.New("expected global request") + return + } + + err = serverPipe.writePacket(Marshal(globalRequestFailureMsg{ + Data: []byte{}, + })) + if err != nil { + kDone <- fmt.Errorf("failed to send failure msg: %w", err) + return + } + + close(kDone) + }() + + // Open a channel. + ch, err := client.openChannel("chan", nil) + if err != nil { + t.Fatalf("OpenChannel: %v", err) + } + defer ch.Close() + + // Wait for the server to close the channel and send the keepalive. + <-kDone + + // Make sure the channel closed. + if _, ok := <-ch.incomingRequests; ok { + t.Fatalf("channel not closed") + } + + // Confirm client hasn't closed + if _, _, err := client.SendRequest("keepalive@golang.org", true, nil); err != nil { + t.Fatalf("failed to send keepalive: %v", err) + } + + // Wait for the server to shut down. + <-kDone +} + +func TestMuxGlobalRequest(t *testing.T) { + var sawPeek bool + var wg sync.WaitGroup + defer func() { + wg.Wait() + if !sawPeek { + t.Errorf("never saw 'peek' request") + } + }() + + clientMux, serverMux := muxPair() + defer serverMux.Close() + defer clientMux.Close() + + wg.Add(1) + go func() { + defer wg.Done() + for r := range serverMux.incomingRequests { + sawPeek = sawPeek || r.Type == "peek" + if r.WantReply { + err := r.Reply(r.Type == "yes", + append([]byte(r.Type), r.Payload...)) + if err != nil { + t.Errorf("AckRequest: %v", err) + } + } + } + }() + + _, _, err := clientMux.SendRequest("peek", false, nil) + if err != nil { + t.Errorf("SendRequest: %v", err) + } + + ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) + if !ok || string(data) != "yesa" || err != nil { + t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", + ok, data, err) + } + if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil { + t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", + ok, data, err) + } + + if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil { + t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", + ok, data, err) + } +} + +func TestMuxGlobalRequestUnblock(t *testing.T) { + clientMux, serverMux := muxPair() + defer serverMux.Close() + defer clientMux.Close() + + result := make(chan error, 1) + go func() { + _, _, err := clientMux.SendRequest("hello", true, nil) + result <- err + }() + + <-serverMux.incomingRequests + serverMux.conn.Close() + err := <-result + + if err != io.EOF { + t.Errorf("want EOF, got %v", io.EOF) + } +} + +func TestMuxChannelRequestUnblock(t *testing.T) { + a, b, connB := channelPair(t) + defer a.Close() + defer b.Close() + defer connB.Close() + + result := make(chan error, 1) + go func() { + _, err := a.SendRequest("hello", true, nil) + result <- err + }() + + <-b.incomingRequests + connB.conn.Close() + err := <-result + + if err != io.EOF { + t.Errorf("want EOF, got %v", err) + } +} + +func TestMuxCloseChannel(t *testing.T) { + r, w, mux := channelPair(t) + defer mux.Close() + defer r.Close() + defer w.Close() + + result := make(chan error, 1) + go func() { + var b [1024]byte + _, err := r.Read(b[:]) + result <- err + }() + if err := w.Close(); err != nil { + t.Errorf("w.Close: %v", err) + } + + if _, err := w.Write([]byte("hello")); err != io.EOF { + t.Errorf("got err %v, want io.EOF after Close", err) + } + + if err := <-result; err != io.EOF { + t.Errorf("got %v (%T), want io.EOF", err, err) + } +} + +func TestMuxCloseWriteChannel(t *testing.T) { + r, w, mux := channelPair(t) + defer mux.Close() + + result := make(chan error, 1) + go func() { + var b [1024]byte + _, err := r.Read(b[:]) + result <- err + }() + if err := w.CloseWrite(); err != nil { + t.Errorf("w.CloseWrite: %v", err) + } + + if _, err := w.Write([]byte("hello")); err != io.EOF { + t.Errorf("got err %v, want io.EOF after CloseWrite", err) + } + + if err := <-result; err != io.EOF { + t.Errorf("got %v (%T), want io.EOF", err, err) + } +} + +func TestMuxInvalidRecord(t *testing.T) { + a, b := muxPair() + defer a.Close() + defer b.Close() + + packet := make([]byte, 1+4+4+1) + packet[0] = msgChannelData + marshalUint32(packet[1:], 29348723 /* invalid channel id */) + marshalUint32(packet[5:], 1) + packet[9] = 42 + + a.conn.writePacket(packet) + go a.SendRequest("hello", false, nil) + // 'a' wrote an invalid packet, so 'b' has exited. + req, ok := <-b.incomingRequests + if ok { + t.Errorf("got request %#v after receiving invalid packet", req) + } +} + +func TestZeroWindowAdjust(t *testing.T) { + a, b, mux := channelPair(t) + defer a.Close() + defer b.Close() + defer mux.Close() + + go func() { + io.WriteString(a, "hello") + // bogus adjust. + a.sendMessage(windowAdjustMsg{}) + io.WriteString(a, "world") + a.Close() + }() + + want := "helloworld" + c, _ := io.ReadAll(b) + if string(c) != want { + t.Errorf("got %q want %q", c, want) + } +} + +func TestMuxMaxPacketSize(t *testing.T) { + a, b, mux := channelPair(t) + defer a.Close() + defer b.Close() + defer mux.Close() + + large := make([]byte, a.maxRemotePayload+1) + packet := make([]byte, 1+4+4+1+len(large)) + packet[0] = msgChannelData + marshalUint32(packet[1:], a.remoteId) + marshalUint32(packet[5:], uint32(len(large))) + packet[9] = 42 + + if err := a.mux.conn.writePacket(packet); err != nil { + t.Errorf("could not send packet") + } + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + a.SendRequest("hello", false, nil) + wg.Done() + }() + + _, ok := <-b.incomingRequests + if ok { + t.Errorf("connection still alive after receiving large packet.") + } +} + +func TestMuxChannelWindowDeferredUpdates(t *testing.T) { + s, c, mux := channelPair(t) + cTransport := mux.conn.(*memTransport) + defer s.Close() + defer c.Close() + defer mux.Close() + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + + data := make([]byte, 1024) + + wg.Add(1) + go func() { + defer wg.Done() + _, err := s.Write(data) + if err != nil { + t.Errorf("Write: %v", err) + return + } + }() + cWritesInit := cTransport.getWriteCount() + buf := make([]byte, 1) + for i := 0; i < len(data); i++ { + n, err := c.Read(buf) + if n != len(buf) || err != nil { + t.Fatalf("Read: %v, %v", n, err) + } + } + cWrites := cTransport.getWriteCount() - cWritesInit + // reading 1 KiB should not cause any window updates to be sent, but allow + // for some unexpected writes + if cWrites > 30 { + t.Fatalf("reading 1 KiB from channel caused %v writes", cWrites) + } +} + +// Don't ship code with debug=true. +func TestDebug(t *testing.T) { + if debugMux { + t.Error("mux debug switched on") + } + if debugHandshake { + t.Error("handshake debug switched on") + } + if debugTransport { + t.Error("transport debug switched on") + } +} diff --git a/tempfork/sshtest/ssh/server.go b/tempfork/sshtest/ssh/server.go new file mode 100644 index 0000000000000..1839ddc6a4bdc --- /dev/null +++ b/tempfork/sshtest/ssh/server.go @@ -0,0 +1,933 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "strings" +) + +// The Permissions type holds fine-grained permissions that are +// specific to a user or a specific authentication method for a user. +// The Permissions value for a successful authentication attempt is +// available in ServerConn, so it can be used to pass information from +// the user-authentication phase to the application layer. +type Permissions struct { + // CriticalOptions indicate restrictions to the default + // permissions, and are typically used in conjunction with + // user certificates. The standard for SSH certificates + // defines "force-command" (only allow the given command to + // execute) and "source-address" (only allow connections from + // the given address). The SSH package currently only enforces + // the "source-address" critical option. It is up to server + // implementations to enforce other critical options, such as + // "force-command", by checking them after the SSH handshake + // is successful. In general, SSH servers should reject + // connections that specify critical options that are unknown + // or not supported. + CriticalOptions map[string]string + + // Extensions are extra functionality that the server may + // offer on authenticated connections. Lack of support for an + // extension does not preclude authenticating a user. Common + // extensions are "permit-agent-forwarding", + // "permit-X11-forwarding". The Go SSH library currently does + // not act on any extension, and it is up to server + // implementations to honor them. Extensions can be used to + // pass data from the authentication callbacks to the server + // application layer. + Extensions map[string]string +} + +type GSSAPIWithMICConfig struct { + // AllowLogin, must be set, is called when gssapi-with-mic + // authentication is selected (RFC 4462 section 3). The srcName is from the + // results of the GSS-API authentication. The format is username@DOMAIN. + // GSSAPI just guarantees to the server who the user is, but not if they can log in, and with what permissions. + // This callback is called after the user identity is established with GSSAPI to decide if the user can login with + // which permissions. If the user is allowed to login, it should return a nil error. + AllowLogin func(conn ConnMetadata, srcName string) (*Permissions, error) + + // Server must be set. It's the implementation + // of the GSSAPIServer interface. See GSSAPIServer interface for details. + Server GSSAPIServer +} + +// SendAuthBanner implements [ServerPreAuthConn]. +func (s *connection) SendAuthBanner(msg string) error { + return s.transport.writePacket(Marshal(&userAuthBannerMsg{ + Message: msg, + })) +} + +func (*connection) unexportedMethodForFutureProofing() {} + +// ServerPreAuthConn is the interface available on an incoming server +// connection before authentication has completed. +type ServerPreAuthConn interface { + unexportedMethodForFutureProofing() // permits growing ServerPreAuthConn safely later, ala testing.TB + + ConnMetadata + + // SendAuthBanner sends a banner message to the client. + // It returns an error once the authentication phase has ended. + SendAuthBanner(string) error +} + +// ServerConfig holds server specific configuration data. +type ServerConfig struct { + // Config contains configuration shared between client and server. + Config + + // PublicKeyAuthAlgorithms specifies the supported client public key + // authentication algorithms. Note that this should not include certificate + // types since those use the underlying algorithm. This list is sent to the + // client if it supports the server-sig-algs extension. Order is irrelevant. + // If unspecified then a default set of algorithms is used. + PublicKeyAuthAlgorithms []string + + hostKeys []Signer + + // NoClientAuth is true if clients are allowed to connect without + // authenticating. + // To determine NoClientAuth at runtime, set NoClientAuth to true + // and the optional NoClientAuthCallback to a non-nil value. + NoClientAuth bool + + // NoClientAuthCallback, if non-nil, is called when a user + // attempts to authenticate with auth method "none". + // NoClientAuth must also be set to true for this be used, or + // this func is unused. + NoClientAuthCallback func(ConnMetadata) (*Permissions, error) + + // MaxAuthTries specifies the maximum number of authentication attempts + // permitted per connection. If set to a negative number, the number of + // attempts are unlimited. If set to zero, the number of attempts are limited + // to 6. + MaxAuthTries int + + // PasswordCallback, if non-nil, is called when a user + // attempts to authenticate using a password. + PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) + + // PublicKeyCallback, if non-nil, is called when a client + // offers a public key for authentication. It must return a nil error + // if the given public key can be used to authenticate the + // given user. For example, see CertChecker.Authenticate. A + // call to this function does not guarantee that the key + // offered is in fact used to authenticate. To record any data + // depending on the public key, store it inside a + // Permissions.Extensions entry. + PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // KeyboardInteractiveCallback, if non-nil, is called when + // keyboard-interactive authentication is selected (RFC + // 4256). The client object's Challenge function should be + // used to query the user. The callback may offer multiple + // Challenge rounds. To avoid information leaks, the client + // should be presented a challenge even if the user is + // unknown. + KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) + + // AuthLogCallback, if non-nil, is called to log all authentication + // attempts. + AuthLogCallback func(conn ConnMetadata, method string, err error) + + // PreAuthConnCallback, if non-nil, is called upon receiving a new connection + // before any authentication has started. The provided ServerPreAuthConn + // can be used at any time before authentication is complete, including + // after this callback has returned. + PreAuthConnCallback func(ServerPreAuthConn) + + // ServerVersion is the version identification string to announce in + // the public handshake. + // If empty, a reasonable default is used. + // Note that RFC 4253 section 4.2 requires that this string start with + // "SSH-2.0-". + ServerVersion string + + // BannerCallback, if present, is called and the return string is sent to + // the client after key exchange completed but before authentication. + BannerCallback func(conn ConnMetadata) string + + // GSSAPIWithMICConfig includes gssapi server and callback, which if both non-nil, is used + // when gssapi-with-mic authentication is selected (RFC 4462 section 3). + GSSAPIWithMICConfig *GSSAPIWithMICConfig +} + +// AddHostKey adds a private key as a host key. If an existing host +// key exists with the same public key format, it is replaced. Each server +// config must have at least one host key. +func (s *ServerConfig) AddHostKey(key Signer) { + for i, k := range s.hostKeys { + if k.PublicKey().Type() == key.PublicKey().Type() { + s.hostKeys[i] = key + return + } + } + + s.hostKeys = append(s.hostKeys, key) +} + +// cachedPubKey contains the results of querying whether a public key is +// acceptable for a user. This is a FIFO cache. +type cachedPubKey struct { + user string + pubKeyData []byte + result error + perms *Permissions +} + +// maxCachedPubKeys is the number of cache entries we store. +// +// Due to consistent misuse of the PublicKeyCallback API, we have reduced this +// to 1, such that the only key in the cache is the most recently seen one. This +// forces the behavior that the last call to PublicKeyCallback will always be +// with the key that is used for authentication. +const maxCachedPubKeys = 1 + +// pubKeyCache caches tests for public keys. Since SSH clients +// will query whether a public key is acceptable before attempting to +// authenticate with it, we end up with duplicate queries for public +// key validity. The cache only applies to a single ServerConn. +type pubKeyCache struct { + keys []cachedPubKey +} + +// get returns the result for a given user/algo/key tuple. +func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) { + for _, k := range c.keys { + if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) { + return k, true + } + } + return cachedPubKey{}, false +} + +// add adds the given tuple to the cache. +func (c *pubKeyCache) add(candidate cachedPubKey) { + if len(c.keys) >= maxCachedPubKeys { + c.keys = c.keys[1:] + } + c.keys = append(c.keys, candidate) +} + +// ServerConn is an authenticated SSH connection, as seen from the +// server +type ServerConn struct { + Conn + + // If the succeeding authentication callback returned a + // non-nil Permissions pointer, it is stored here. + Permissions *Permissions +} + +// NewServerConn starts a new SSH server with c as the underlying +// transport. It starts with a handshake and, if the handshake is +// unsuccessful, it closes the connection and returns an error. The +// Request and NewChannel channels must be serviced, or the connection +// will hang. +// +// The returned error may be of type *ServerAuthError for +// authentication errors. +func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) { + fullConf := *config + fullConf.SetDefaults() + if fullConf.MaxAuthTries == 0 { + fullConf.MaxAuthTries = 6 + } + if len(fullConf.PublicKeyAuthAlgorithms) == 0 { + fullConf.PublicKeyAuthAlgorithms = supportedPubKeyAuthAlgos + } else { + for _, algo := range fullConf.PublicKeyAuthAlgorithms { + if !contains(supportedPubKeyAuthAlgos, algo) { + c.Close() + return nil, nil, nil, fmt.Errorf("ssh: unsupported public key authentication algorithm %s", algo) + } + } + } + // Check if the config contains any unsupported key exchanges + for _, kex := range fullConf.KeyExchanges { + if _, ok := serverForbiddenKexAlgos[kex]; ok { + c.Close() + return nil, nil, nil, fmt.Errorf("ssh: unsupported key exchange %s for server", kex) + } + } + + s := &connection{ + sshConn: sshConn{conn: c}, + } + perms, err := s.serverHandshake(&fullConf) + if err != nil { + c.Close() + return nil, nil, nil, err + } + return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil +} + +// signAndMarshal signs the data with the appropriate algorithm, +// and serializes the result in SSH wire format. algo is the negotiate +// algorithm and may be a certificate type. +func signAndMarshal(k AlgorithmSigner, rand io.Reader, data []byte, algo string) ([]byte, error) { + sig, err := k.SignWithAlgorithm(rand, data, underlyingAlgo(algo)) + if err != nil { + return nil, err + } + + return Marshal(sig), nil +} + +// handshake performs key exchange and user authentication. +func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) { + if len(config.hostKeys) == 0 { + return nil, errors.New("ssh: server has no host keys") + } + + if !config.NoClientAuth && config.PasswordCallback == nil && config.PublicKeyCallback == nil && + config.KeyboardInteractiveCallback == nil && (config.GSSAPIWithMICConfig == nil || + config.GSSAPIWithMICConfig.AllowLogin == nil || config.GSSAPIWithMICConfig.Server == nil) { + return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") + } + + if config.ServerVersion != "" { + s.serverVersion = []byte(config.ServerVersion) + } else { + s.serverVersion = []byte(packageVersion) + } + var err error + s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion) + if err != nil { + return nil, err + } + + tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */) + s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config) + + if err := s.transport.waitSession(); err != nil { + return nil, err + } + + // We just did the key change, so the session ID is established. + s.sessionID = s.transport.getSessionID() + + var packet []byte + if packet, err = s.transport.readPacket(); err != nil { + return nil, err + } + + var serviceRequest serviceRequestMsg + if err = Unmarshal(packet, &serviceRequest); err != nil { + return nil, err + } + if serviceRequest.Service != serviceUserAuth { + return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating") + } + serviceAccept := serviceAcceptMsg{ + Service: serviceUserAuth, + } + if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil { + return nil, err + } + + perms, err := s.serverAuthenticate(config) + if err != nil { + return nil, err + } + s.mux = newMux(s.transport) + return perms, err +} + +func checkSourceAddress(addr net.Addr, sourceAddrs string) error { + if addr == nil { + return errors.New("ssh: no address known for client, but source-address match required") + } + + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr) + } + + for _, sourceAddr := range strings.Split(sourceAddrs, ",") { + if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil { + if allowedIP.Equal(tcpAddr.IP) { + return nil + } + } else { + _, ipNet, err := net.ParseCIDR(sourceAddr) + if err != nil { + return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err) + } + + if ipNet.Contains(tcpAddr.IP) { + return nil + } + } + } + + return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr) +} + +func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, token []byte, s *connection, + sessionID []byte, userAuthReq userAuthRequestMsg) (authErr error, perms *Permissions, err error) { + gssAPIServer := gssapiConfig.Server + defer gssAPIServer.DeleteSecContext() + var srcName string + for { + var ( + outToken []byte + needContinue bool + ) + outToken, srcName, needContinue, err = gssAPIServer.AcceptSecContext(token) + if err != nil { + return err, nil, nil + } + if len(outToken) != 0 { + if err := s.transport.writePacket(Marshal(&userAuthGSSAPIToken{ + Token: outToken, + })); err != nil { + return nil, nil, err + } + } + if !needContinue { + break + } + packet, err := s.transport.readPacket() + if err != nil { + return nil, nil, err + } + userAuthGSSAPITokenReq := &userAuthGSSAPIToken{} + if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil { + return nil, nil, err + } + token = userAuthGSSAPITokenReq.Token + } + packet, err := s.transport.readPacket() + if err != nil { + return nil, nil, err + } + userAuthGSSAPIMICReq := &userAuthGSSAPIMIC{} + if err := Unmarshal(packet, userAuthGSSAPIMICReq); err != nil { + return nil, nil, err + } + mic := buildMIC(string(sessionID), userAuthReq.User, userAuthReq.Service, userAuthReq.Method) + if err := gssAPIServer.VerifyMIC(mic, userAuthGSSAPIMICReq.MIC); err != nil { + return err, nil, nil + } + perms, authErr = gssapiConfig.AllowLogin(s, srcName) + return authErr, perms, nil +} + +// isAlgoCompatible checks if the signature format is compatible with the +// selected algorithm taking into account edge cases that occur with old +// clients. +func isAlgoCompatible(algo, sigFormat string) bool { + // Compatibility for old clients. + // + // For certificate authentication with OpenSSH 7.2-7.7 signature format can + // be rsa-sha2-256 or rsa-sha2-512 for the algorithm + // ssh-rsa-cert-v01@openssh.com. + // + // With gpg-agent < 2.2.6 the algorithm can be rsa-sha2-256 or rsa-sha2-512 + // for signature format ssh-rsa. + if isRSA(algo) && isRSA(sigFormat) { + return true + } + // Standard case: the underlying algorithm must match the signature format. + return underlyingAlgo(algo) == sigFormat +} + +// ServerAuthError represents server authentication errors and is +// sometimes returned by NewServerConn. It appends any authentication +// errors that may occur, and is returned if all of the authentication +// methods provided by the user failed to authenticate. +type ServerAuthError struct { + // Errors contains authentication errors returned by the authentication + // callback methods. The first entry is typically ErrNoAuth. + Errors []error +} + +func (l ServerAuthError) Error() string { + var errs []string + for _, err := range l.Errors { + errs = append(errs, err.Error()) + } + return "[" + strings.Join(errs, ", ") + "]" +} + +// ServerAuthCallbacks defines server-side authentication callbacks. +type ServerAuthCallbacks struct { + // PasswordCallback behaves like [ServerConfig.PasswordCallback]. + PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) + + // PublicKeyCallback behaves like [ServerConfig.PublicKeyCallback]. + PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) + + // KeyboardInteractiveCallback behaves like [ServerConfig.KeyboardInteractiveCallback]. + KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) + + // GSSAPIWithMICConfig behaves like [ServerConfig.GSSAPIWithMICConfig]. + GSSAPIWithMICConfig *GSSAPIWithMICConfig +} + +// PartialSuccessError can be returned by any of the [ServerConfig] +// authentication callbacks to indicate to the client that authentication has +// partially succeeded, but further steps are required. +type PartialSuccessError struct { + // Next defines the authentication callbacks to apply to further steps. The + // available methods communicated to the client are based on the non-nil + // ServerAuthCallbacks fields. + Next ServerAuthCallbacks +} + +func (p *PartialSuccessError) Error() string { + return "ssh: authenticated with partial success" +} + +// ErrNoAuth is the error value returned if no +// authentication method has been passed yet. This happens as a normal +// part of the authentication loop, since the client first tries +// 'none' authentication to discover available methods. +// It is returned in ServerAuthError.Errors from NewServerConn. +var ErrNoAuth = errors.New("ssh: no auth passed yet") + +// BannerError is an error that can be returned by authentication handlers in +// ServerConfig to send a banner message to the client. +type BannerError struct { + Err error + Message string +} + +func (b *BannerError) Unwrap() error { + return b.Err +} + +func (b *BannerError) Error() string { + if b.Err == nil { + return b.Message + } + return b.Err.Error() +} + +func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { + if config.PreAuthConnCallback != nil { + config.PreAuthConnCallback(s) + } + + sessionID := s.transport.getSessionID() + var cache pubKeyCache + var perms *Permissions + + authFailures := 0 + noneAuthCount := 0 + var authErrs []error + var calledBannerCallback bool + partialSuccessReturned := false + // Set the initial authentication callbacks from the config. They can be + // changed if a PartialSuccessError is returned. + authConfig := ServerAuthCallbacks{ + PasswordCallback: config.PasswordCallback, + PublicKeyCallback: config.PublicKeyCallback, + KeyboardInteractiveCallback: config.KeyboardInteractiveCallback, + GSSAPIWithMICConfig: config.GSSAPIWithMICConfig, + } + +userAuthLoop: + for { + if authFailures >= config.MaxAuthTries && config.MaxAuthTries > 0 { + discMsg := &disconnectMsg{ + Reason: 2, + Message: "too many authentication failures", + } + + if err := s.transport.writePacket(Marshal(discMsg)); err != nil { + return nil, err + } + authErrs = append(authErrs, discMsg) + return nil, &ServerAuthError{Errors: authErrs} + } + + var userAuthReq userAuthRequestMsg + if packet, err := s.transport.readPacket(); err != nil { + if err == io.EOF { + return nil, &ServerAuthError{Errors: authErrs} + } + return nil, err + } else if err = Unmarshal(packet, &userAuthReq); err != nil { + return nil, err + } + + if userAuthReq.Service != serviceSSH { + return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) + } + + if s.user != userAuthReq.User && partialSuccessReturned { + return nil, fmt.Errorf("ssh: client changed the user after a partial success authentication, previous user %q, current user %q", + s.user, userAuthReq.User) + } + + s.user = userAuthReq.User + + if !calledBannerCallback && config.BannerCallback != nil { + calledBannerCallback = true + if msg := config.BannerCallback(s); msg != "" { + if err := s.SendAuthBanner(msg); err != nil { + return nil, err + } + } + } + + perms = nil + authErr := ErrNoAuth + + switch userAuthReq.Method { + case "none": + noneAuthCount++ + // We don't allow none authentication after a partial success + // response. + if config.NoClientAuth && !partialSuccessReturned { + if config.NoClientAuthCallback != nil { + perms, authErr = config.NoClientAuthCallback(s) + } else { + authErr = nil + } + } + case "password": + if authConfig.PasswordCallback == nil { + authErr = errors.New("ssh: password auth not configured") + break + } + payload := userAuthReq.Payload + if len(payload) < 1 || payload[0] != 0 { + return nil, parseError(msgUserAuthRequest) + } + payload = payload[1:] + password, payload, ok := parseString(payload) + if !ok || len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + + perms, authErr = authConfig.PasswordCallback(s, password) + case "keyboard-interactive": + if authConfig.KeyboardInteractiveCallback == nil { + authErr = errors.New("ssh: keyboard-interactive auth not configured") + break + } + + prompter := &sshClientKeyboardInteractive{s} + perms, authErr = authConfig.KeyboardInteractiveCallback(s, prompter.Challenge) + case "publickey": + if authConfig.PublicKeyCallback == nil { + authErr = errors.New("ssh: publickey auth not configured") + break + } + payload := userAuthReq.Payload + if len(payload) < 1 { + return nil, parseError(msgUserAuthRequest) + } + isQuery := payload[0] == 0 + payload = payload[1:] + algoBytes, payload, ok := parseString(payload) + if !ok { + return nil, parseError(msgUserAuthRequest) + } + algo := string(algoBytes) + if !contains(config.PublicKeyAuthAlgorithms, underlyingAlgo(algo)) { + authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo) + break + } + + pubKeyData, payload, ok := parseString(payload) + if !ok { + return nil, parseError(msgUserAuthRequest) + } + + pubKey, err := ParsePublicKey(pubKeyData) + if err != nil { + return nil, err + } + + candidate, ok := cache.get(s.user, pubKeyData) + if !ok { + candidate.user = s.user + candidate.pubKeyData = pubKeyData + candidate.perms, candidate.result = authConfig.PublicKeyCallback(s, pubKey) + _, isPartialSuccessError := candidate.result.(*PartialSuccessError) + + if (candidate.result == nil || isPartialSuccessError) && + candidate.perms != nil && + candidate.perms.CriticalOptions != nil && + candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" { + if err := checkSourceAddress( + s.RemoteAddr(), + candidate.perms.CriticalOptions[sourceAddressCriticalOption]); err != nil { + candidate.result = err + } + } + cache.add(candidate) + } + + if isQuery { + // The client can query if the given public key + // would be okay. + + if len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + _, isPartialSuccessError := candidate.result.(*PartialSuccessError) + if candidate.result == nil || isPartialSuccessError { + okMsg := userAuthPubKeyOkMsg{ + Algo: algo, + PubKey: pubKeyData, + } + if err = s.transport.writePacket(Marshal(&okMsg)); err != nil { + return nil, err + } + continue userAuthLoop + } + authErr = candidate.result + } else { + sig, payload, ok := parseSignature(payload) + if !ok || len(payload) > 0 { + return nil, parseError(msgUserAuthRequest) + } + // Ensure the declared public key algo is compatible with the + // decoded one. This check will ensure we don't accept e.g. + // ssh-rsa-cert-v01@openssh.com algorithm with ssh-rsa public + // key type. The algorithm and public key type must be + // consistent: both must be certificate algorithms, or neither. + if !contains(algorithmsForKeyFormat(pubKey.Type()), algo) { + authErr = fmt.Errorf("ssh: public key type %q not compatible with selected algorithm %q", + pubKey.Type(), algo) + break + } + // Ensure the public key algo and signature algo + // are supported. Compare the private key + // algorithm name that corresponds to algo with + // sig.Format. This is usually the same, but + // for certs, the names differ. + if !contains(config.PublicKeyAuthAlgorithms, sig.Format) { + authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format) + break + } + if !isAlgoCompatible(algo, sig.Format) { + authErr = fmt.Errorf("ssh: signature %q not compatible with selected algorithm %q", sig.Format, algo) + break + } + + signedData := buildDataSignedForAuth(sessionID, userAuthReq, algo, pubKeyData) + + if err := pubKey.Verify(signedData, sig); err != nil { + return nil, err + } + + authErr = candidate.result + perms = candidate.perms + } + case "gssapi-with-mic": + if authConfig.GSSAPIWithMICConfig == nil { + authErr = errors.New("ssh: gssapi-with-mic auth not configured") + break + } + gssapiConfig := authConfig.GSSAPIWithMICConfig + userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload) + if err != nil { + return nil, parseError(msgUserAuthRequest) + } + // OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication. + if userAuthRequestGSSAPI.N == 0 { + authErr = fmt.Errorf("ssh: Mechanism negotiation is not supported") + break + } + var i uint32 + present := false + for i = 0; i < userAuthRequestGSSAPI.N; i++ { + if userAuthRequestGSSAPI.OIDS[i].Equal(krb5Mesh) { + present = true + break + } + } + if !present { + authErr = fmt.Errorf("ssh: GSSAPI authentication must use the Kerberos V5 mechanism") + break + } + // Initial server response, see RFC 4462 section 3.3. + if err := s.transport.writePacket(Marshal(&userAuthGSSAPIResponse{ + SupportMech: krb5OID, + })); err != nil { + return nil, err + } + // Exchange token, see RFC 4462 section 3.4. + packet, err := s.transport.readPacket() + if err != nil { + return nil, err + } + userAuthGSSAPITokenReq := &userAuthGSSAPIToken{} + if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil { + return nil, err + } + authErr, perms, err = gssExchangeToken(gssapiConfig, userAuthGSSAPITokenReq.Token, s, sessionID, + userAuthReq) + if err != nil { + return nil, err + } + default: + authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method) + } + + authErrs = append(authErrs, authErr) + + if config.AuthLogCallback != nil { + config.AuthLogCallback(s, userAuthReq.Method, authErr) + } + + var bannerErr *BannerError + if errors.As(authErr, &bannerErr) { + if bannerErr.Message != "" { + if err := s.SendAuthBanner(bannerErr.Message); err != nil { + return nil, err + } + } + } + + if authErr == nil { + break userAuthLoop + } + + var failureMsg userAuthFailureMsg + + if partialSuccess, ok := authErr.(*PartialSuccessError); ok { + // After a partial success error we don't allow changing the user + // name and execute the NoClientAuthCallback. + partialSuccessReturned = true + + // In case a partial success is returned, the server may send + // a new set of authentication methods. + authConfig = partialSuccess.Next + + // Reset pubkey cache, as the new PublicKeyCallback might + // accept a different set of public keys. + cache = pubKeyCache{} + + // Send back a partial success message to the user. + failureMsg.PartialSuccess = true + } else { + // Allow initial attempt of 'none' without penalty. + if authFailures > 0 || userAuthReq.Method != "none" || noneAuthCount != 1 { + authFailures++ + } + if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries { + // If we have hit the max attempts, don't bother sending the + // final SSH_MSG_USERAUTH_FAILURE message, since there are + // no more authentication methods which can be attempted, + // and this message may cause the client to re-attempt + // authentication while we send the disconnect message. + // Continue, and trigger the disconnect at the start of + // the loop. + // + // The SSH specification is somewhat confusing about this, + // RFC 4252 Section 5.1 requires each authentication failure + // be responded to with a respective SSH_MSG_USERAUTH_FAILURE + // message, but Section 4 says the server should disconnect + // after some number of attempts, but it isn't explicit which + // message should take precedence (i.e. should there be a failure + // message than a disconnect message, or if we are going to + // disconnect, should we only send that message.) + // + // Either way, OpenSSH disconnects immediately after the last + // failed authentication attempt, and given they are typically + // considered the golden implementation it seems reasonable + // to match that behavior. + continue + } + } + + if authConfig.PasswordCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "password") + } + if authConfig.PublicKeyCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "publickey") + } + if authConfig.KeyboardInteractiveCallback != nil { + failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive") + } + if authConfig.GSSAPIWithMICConfig != nil && authConfig.GSSAPIWithMICConfig.Server != nil && + authConfig.GSSAPIWithMICConfig.AllowLogin != nil { + failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic") + } + + if len(failureMsg.Methods) == 0 { + return nil, errors.New("ssh: no authentication methods available") + } + + if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil { + return nil, err + } + } + + if err := s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil { + return nil, err + } + return perms, nil +} + +// sshClientKeyboardInteractive implements a ClientKeyboardInteractive by +// asking the client on the other side of a ServerConn. +type sshClientKeyboardInteractive struct { + *connection +} + +func (c *sshClientKeyboardInteractive) Challenge(name, instruction string, questions []string, echos []bool) (answers []string, err error) { + if len(questions) != len(echos) { + return nil, errors.New("ssh: echos and questions must have equal length") + } + + var prompts []byte + for i := range questions { + prompts = appendString(prompts, questions[i]) + prompts = appendBool(prompts, echos[i]) + } + + if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{ + Name: name, + Instruction: instruction, + NumPrompts: uint32(len(questions)), + Prompts: prompts, + })); err != nil { + return nil, err + } + + packet, err := c.transport.readPacket() + if err != nil { + return nil, err + } + if packet[0] != msgUserAuthInfoResponse { + return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0]) + } + packet = packet[1:] + + n, packet, ok := parseUint32(packet) + if !ok || int(n) != len(questions) { + return nil, parseError(msgUserAuthInfoResponse) + } + + for i := uint32(0); i < n; i++ { + ans, rest, ok := parseString(packet) + if !ok { + return nil, parseError(msgUserAuthInfoResponse) + } + + answers = append(answers, string(ans)) + packet = rest + } + if len(packet) != 0 { + return nil, errors.New("ssh: junk at end of message") + } + + return answers, nil +} diff --git a/tempfork/sshtest/ssh/server_multi_auth_test.go b/tempfork/sshtest/ssh/server_multi_auth_test.go new file mode 100644 index 0000000000000..3b3980243763f --- /dev/null +++ b/tempfork/sshtest/ssh/server_multi_auth_test.go @@ -0,0 +1,412 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "strings" + "testing" +) + +func doClientServerAuth(t *testing.T, serverConfig *ServerConfig, clientConfig *ClientConfig) ([]error, error) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + var serverAuthErrors []error + + serverConfig.AddHostKey(testSigners["rsa"]) + serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) { + serverAuthErrors = append(serverAuthErrors, err) + } + go newServer(c1, serverConfig) + c, _, _, err := NewClientConn(c2, "", clientConfig) + if err == nil { + c.Close() + } + return serverAuthErrors, err +} + +func TestMultiStepAuth(t *testing.T) { + // This user can login with password, public key or public key + password. + username := "testuser" + // This user can login with public key + password only. + usernameSecondFactor := "testuser_second_factor" + errPwdAuthFailed := errors.New("password auth failed") + errWrongSequence := errors.New("wrong sequence") + + serverConfig := &ServerConfig{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if conn.User() == usernameSecondFactor { + return nil, errWrongSequence + } + if conn.User() == username && string(password) == clientPassword { + return nil, nil + } + return nil, errPwdAuthFailed + }, + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + if conn.User() == usernameSecondFactor { + return nil, &PartialSuccessError{ + Next: ServerAuthCallbacks{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if string(password) == clientPassword { + return nil, nil + } + return nil, errPwdAuthFailed + }, + }, + } + } + return nil, nil + } + return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) + }, + } + + clientConfig := &ClientConfig{ + User: usernameSecondFactor, + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + + // The error sequence is: + // - no auth passed yet + // - partial success + // - nil + if len(serverAuthErrors) != 3 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { + t.Fatalf("expected partial success error, got: %v", serverAuthErrors[1]) + } + // Now test a wrong sequence. + clientConfig.Auth = []AuthMethod{ + Password(clientPassword), + PublicKeys(testSigners["rsa"]), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("client login with wrong sequence must fail") + } + // The error sequence is: + // - no auth passed yet + // - wrong sequence + // - partial success + if len(serverAuthErrors) != 3 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if serverAuthErrors[1] != errWrongSequence { + t.Fatal("server not returned wrong sequence") + } + if _, ok := serverAuthErrors[2].(*PartialSuccessError); !ok { + t.Fatalf("expected partial success error, got: %v", serverAuthErrors[2]) + } + // Now test using a correct sequence but a wrong password before the right + // one. + n := 0 + passwords := []string{"WRONG", "WRONG", clientPassword} + clientConfig.Auth = []AuthMethod{ + PublicKeys(testSigners["rsa"]), + RetryableAuthMethod(PasswordCallback(func() (string, error) { + p := passwords[n] + n++ + return p, nil + }), 3), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + // The error sequence is: + // - no auth passed yet + // - partial success + // - wrong password + // - wrong password + // - nil + if len(serverAuthErrors) != 5 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + if serverAuthErrors[2] != errPwdAuthFailed { + t.Fatal("server not returned password authentication failed") + } + if serverAuthErrors[3] != errPwdAuthFailed { + t.Fatal("server not returned password authentication failed") + } + // Only password authentication should fail. + clientConfig.Auth = []AuthMethod{ + Password(clientPassword), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("client login with password only must fail") + } + // The error sequence is: + // - no auth passed yet + // - wrong sequence + if len(serverAuthErrors) != 2 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if serverAuthErrors[1] != errWrongSequence { + t.Fatal("server not returned wrong sequence") + } + + // Only public key authentication should fail. + clientConfig.Auth = []AuthMethod{ + PublicKeys(testSigners["rsa"]), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("client login with public key only must fail") + } + // The error sequence is: + // - no auth passed yet + // - partial success + if len(serverAuthErrors) != 2 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + + // Public key and wrong password. + clientConfig.Auth = []AuthMethod{ + PublicKeys(testSigners["rsa"]), + Password("WRONG"), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("client login with wrong password after public key must fail") + } + // The error sequence is: + // - no auth passed yet + // - partial success + // - password auth failed + if len(serverAuthErrors) != 3 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + if serverAuthErrors[2] != errPwdAuthFailed { + t.Fatal("server not returned password authentication failed") + } + + // Public key, public key again and then correct password. Public key + // authentication is attempted only once because the partial success error + // returns only "password" as the allowed authentication method. + clientConfig.Auth = []AuthMethod{ + PublicKeys(testSigners["rsa"]), + PublicKeys(testSigners["rsa"]), + Password(clientPassword), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + // The error sequence is: + // - no auth passed yet + // - partial success + // - nil + if len(serverAuthErrors) != 3 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[1].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + + // The unrestricted username can do anything + clientConfig = &ClientConfig{ + User: username, + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + _, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("unrestricted client login error: %s", err) + } + + clientConfig = &ClientConfig{ + User: username, + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + _, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("unrestricted client login error: %s", err) + } + + clientConfig = &ClientConfig{ + User: username, + Auth: []AuthMethod{ + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + _, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("unrestricted client login error: %s", err) + } +} + +func TestDynamicAuthCallbacks(t *testing.T) { + user1 := "user1" + user2 := "user2" + errInvalidCredentials := errors.New("invalid credentials") + + serverConfig := &ServerConfig{ + NoClientAuth: true, + NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) { + switch conn.User() { + case user1: + return nil, &PartialSuccessError{ + Next: ServerAuthCallbacks{ + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if conn.User() == user1 && string(password) == clientPassword { + return nil, nil + } + return nil, errInvalidCredentials + }, + }, + } + case user2: + return nil, &PartialSuccessError{ + Next: ServerAuthCallbacks{ + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { + if conn.User() == user2 { + return nil, nil + } + } + return nil, errInvalidCredentials + }, + }, + } + default: + return nil, errInvalidCredentials + } + }, + } + + clientConfig := &ClientConfig{ + User: user1, + Auth: []AuthMethod{ + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + // The error sequence is: + // - partial success + // - nil + if len(serverAuthErrors) != 2 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + + clientConfig = &ClientConfig{ + User: user2, + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + // The error sequence is: + // - partial success + // - nil + if len(serverAuthErrors) != 2 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + + // user1 cannot login with public key + clientConfig = &ClientConfig{ + User: user1, + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("user1 login with public key must fail") + } + if !strings.Contains(err.Error(), "no supported methods remain") { + t.Errorf("got %v, expected 'no supported methods remain'", err) + } + if len(serverAuthErrors) != 1 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } + // user2 cannot login with password + clientConfig = &ClientConfig{ + User: user2, + Auth: []AuthMethod{ + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err = doClientServerAuth(t, serverConfig, clientConfig) + if err == nil { + t.Fatal("user2 login with password must fail") + } + if !strings.Contains(err.Error(), "no supported methods remain") { + t.Errorf("got %v, expected 'no supported methods remain'", err) + } + if len(serverAuthErrors) != 1 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if _, ok := serverAuthErrors[0].(*PartialSuccessError); !ok { + t.Fatal("server not returned partial success") + } +} diff --git a/tempfork/sshtest/ssh/server_test.go b/tempfork/sshtest/ssh/server_test.go new file mode 100644 index 0000000000000..c2b24f47ce878 --- /dev/null +++ b/tempfork/sshtest/ssh/server_test.go @@ -0,0 +1,478 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "reflect" + "strings" + "sync/atomic" + "testing" + "time" +) + +func TestClientAuthRestrictedPublicKeyAlgos(t *testing.T) { + for _, tt := range []struct { + name string + key Signer + wantError bool + }{ + {"rsa", testSigners["rsa"], false}, + {"dsa", testSigners["dsa"], true}, + {"ed25519", testSigners["ed25519"], true}, + } { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + serverConf := &ServerConfig{ + PublicKeyAuthAlgorithms: []string{KeyAlgoRSASHA256, KeyAlgoRSASHA512}, + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + return nil, nil + }, + } + serverConf.AddHostKey(testSigners["ecdsap256"]) + + done := make(chan struct{}) + go func() { + defer close(done) + NewServerConn(c1, serverConf) + }() + + clientConf := ClientConfig{ + User: "user", + Auth: []AuthMethod{ + PublicKeys(tt.key), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + _, _, _, err = NewClientConn(c2, "", &clientConf) + if err != nil { + if !tt.wantError { + t.Errorf("%s: got unexpected error %q", tt.name, err.Error()) + } + } else if tt.wantError { + t.Errorf("%s: succeeded, but want error", tt.name) + } + <-done + } +} + +func TestMaxAuthTriesNoneMethod(t *testing.T) { + username := "testuser" + serverConfig := &ServerConfig{ + MaxAuthTries: 2, + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if conn.User() == username && string(password) == clientPassword { + return nil, nil + } + return nil, errors.New("invalid credentials") + }, + } + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + var serverAuthErrors []error + + serverConfig.AddHostKey(testSigners["rsa"]) + serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) { + serverAuthErrors = append(serverAuthErrors, err) + } + go newServer(c1, serverConfig) + + clientConfig := ClientConfig{ + User: username, + HostKeyCallback: InsecureIgnoreHostKey(), + } + clientConfig.SetDefaults() + // Our client will send 'none' auth only once, so we need to send the + // requests manually. + c := &connection{ + sshConn: sshConn{ + conn: c2, + user: username, + clientVersion: []byte(packageVersion), + }, + } + c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) + if err != nil { + t.Fatalf("unable to exchange version: %v", err) + } + c.transport = newClientTransport( + newTransport(c.sshConn.conn, clientConfig.Rand, true /* is client */), + c.clientVersion, c.serverVersion, &clientConfig, "", c.sshConn.RemoteAddr()) + if err := c.transport.waitSession(); err != nil { + t.Fatalf("unable to wait session: %v", err) + } + c.sessionID = c.transport.getSessionID() + if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { + t.Fatalf("unable to send ssh-userauth message: %v", err) + } + packet, err := c.transport.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) > 0 && packet[0] == msgExtInfo { + packet, err = c.transport.readPacket() + if err != nil { + t.Fatal(err) + } + } + var serviceAccept serviceAcceptMsg + if err := Unmarshal(packet, &serviceAccept); err != nil { + t.Fatal(err) + } + for i := 0; i <= serverConfig.MaxAuthTries; i++ { + auth := new(noneAuth) + _, _, err := auth.auth(c.sessionID, clientConfig.User, c.transport, clientConfig.Rand, nil) + if i < serverConfig.MaxAuthTries { + if err != nil { + t.Fatal(err) + } + continue + } + if err == nil { + t.Fatal("client: got no error") + } else if !strings.Contains(err.Error(), "too many authentication failures") { + t.Fatalf("client: got unexpected error: %v", err) + } + } + if len(serverAuthErrors) != 3 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + for _, err := range serverAuthErrors { + if !errors.Is(err, ErrNoAuth) { + t.Errorf("go error: %v; want: %v", err, ErrNoAuth) + } + } +} + +func TestMaxAuthTriesFirstNoneAuthErrorIgnored(t *testing.T) { + username := "testuser" + serverConfig := &ServerConfig{ + MaxAuthTries: 1, + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + if conn.User() == username && string(password) == clientPassword { + return nil, nil + } + return nil, errors.New("invalid credentials") + }, + } + clientConfig := &ClientConfig{ + User: username, + Auth: []AuthMethod{ + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + serverAuthErrors, err := doClientServerAuth(t, serverConfig, clientConfig) + if err != nil { + t.Fatalf("client login error: %s", err) + } + if len(serverAuthErrors) != 2 { + t.Fatalf("unexpected number of server auth errors: %v, errors: %+v", len(serverAuthErrors), serverAuthErrors) + } + if !errors.Is(serverAuthErrors[0], ErrNoAuth) { + t.Errorf("go error: %v; want: %v", serverAuthErrors[0], ErrNoAuth) + } + if serverAuthErrors[1] != nil { + t.Errorf("unexpected error: %v", serverAuthErrors[1]) + } +} + +func TestNewServerConnValidationErrors(t *testing.T) { + serverConf := &ServerConfig{ + PublicKeyAuthAlgorithms: []string{CertAlgoRSAv01}, + } + c := &markerConn{} + _, _, _, err := NewServerConn(c, serverConf) + if err == nil { + t.Fatal("NewServerConn with invalid public key auth algorithms succeeded") + } + if !c.isClosed() { + t.Fatal("NewServerConn with invalid public key auth algorithms left connection open") + } + if c.isUsed() { + t.Fatal("NewServerConn with invalid public key auth algorithms used connection") + } + + serverConf = &ServerConfig{ + Config: Config{ + KeyExchanges: []string{kexAlgoDHGEXSHA256}, + }, + } + c = &markerConn{} + _, _, _, err = NewServerConn(c, serverConf) + if err == nil { + t.Fatal("NewServerConn with unsupported key exchange succeeded") + } + if !c.isClosed() { + t.Fatal("NewServerConn with unsupported key exchange left connection open") + } + if c.isUsed() { + t.Fatal("NewServerConn with unsupported key exchange used connection") + } +} + +func TestBannerError(t *testing.T) { + serverConfig := &ServerConfig{ + BannerCallback: func(ConnMetadata) string { + return "banner from BannerCallback" + }, + NoClientAuth: true, + NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) { + err := &BannerError{ + Err: errors.New("error from NoClientAuthCallback"), + Message: "banner from NoClientAuthCallback", + } + return nil, fmt.Errorf("wrapped: %w", err) + }, + PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) { + return &Permissions{}, nil + }, + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + return nil, &BannerError{ + Err: errors.New("error from PublicKeyCallback"), + Message: "banner from PublicKeyCallback", + } + }, + KeyboardInteractiveCallback: func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) { + return nil, &BannerError{ + Err: nil, // make sure that a nil inner error is allowed + Message: "banner from KeyboardInteractiveCallback", + } + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + var banners []string + clientConfig := &ClientConfig{ + User: "test", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"]), + KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) { + return []string{"letmein"}, nil + }), + Password(clientPassword), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + BannerCallback: func(msg string) error { + banners = append(banners, msg) + return nil + }, + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + go newServer(c1, serverConfig) + c, _, _, err := NewClientConn(c2, "", clientConfig) + if err != nil { + t.Fatalf("client connection failed: %v", err) + } + defer c.Close() + + wantBanners := []string{ + "banner from BannerCallback", + "banner from NoClientAuthCallback", + "banner from PublicKeyCallback", + "banner from KeyboardInteractiveCallback", + } + if !reflect.DeepEqual(banners, wantBanners) { + t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners) + } +} + +func TestPublicKeyCallbackLastSeen(t *testing.T) { + var lastSeenKey PublicKey + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + serverConf := &ServerConfig{ + PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { + lastSeenKey = key + fmt.Printf("seen %#v\n", key) + if _, ok := key.(*dsaPublicKey); !ok { + return nil, errors.New("nope") + } + return nil, nil + }, + } + serverConf.AddHostKey(testSigners["ecdsap256"]) + + done := make(chan struct{}) + go func() { + defer close(done) + NewServerConn(c1, serverConf) + }() + + clientConf := ClientConfig{ + User: "user", + Auth: []AuthMethod{ + PublicKeys(testSigners["rsa"], testSigners["dsa"], testSigners["ed25519"]), + }, + HostKeyCallback: InsecureIgnoreHostKey(), + } + + _, _, _, err = NewClientConn(c2, "", &clientConf) + if err != nil { + t.Fatal(err) + } + <-done + + expectedPublicKey := testSigners["dsa"].PublicKey().Marshal() + lastSeenMarshalled := lastSeenKey.Marshal() + if !bytes.Equal(lastSeenMarshalled, expectedPublicKey) { + t.Errorf("unexpected key: got %#v, want %#v", lastSeenKey, testSigners["dsa"].PublicKey()) + } +} + +func TestPreAuthConnAndBanners(t *testing.T) { + testDone := make(chan struct{}) + defer close(testDone) + + authConnc := make(chan ServerPreAuthConn, 1) + serverConfig := &ServerConfig{ + PreAuthConnCallback: func(c ServerPreAuthConn) { + t.Logf("got ServerPreAuthConn: %v", c) + authConnc <- c // for use later in the test + for _, s := range []string{"hello1", "hello2"} { + if err := c.SendAuthBanner(s); err != nil { + t.Errorf("failed to send banner %q: %v", s, err) + } + } + // Now start a goroutine to spam SendAuthBanner in hopes + // of hitting a race. + go func() { + for { + select { + case <-testDone: + return + default: + if err := c.SendAuthBanner("attempted-race"); err != nil && err != errSendBannerPhase { + t.Errorf("unexpected error from SendAuthBanner: %v", err) + } + time.Sleep(5 * time.Millisecond) + } + } + }() + }, + NoClientAuth: true, + NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) { + t.Logf("got NoClientAuthCallback") + return &Permissions{}, nil + }, + } + serverConfig.AddHostKey(testSigners["rsa"]) + + var banners []string + clientConfig := &ClientConfig{ + User: "test", + HostKeyCallback: InsecureIgnoreHostKey(), + BannerCallback: func(msg string) error { + if msg != "attempted-race" { + banners = append(banners, msg) + } + return nil + }, + } + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + go newServer(c1, serverConfig) + c, _, _, err := NewClientConn(c2, "", clientConfig) + if err != nil { + t.Fatalf("client connection failed: %v", err) + } + defer c.Close() + + wantBanners := []string{ + "hello1", + "hello2", + } + if !reflect.DeepEqual(banners, wantBanners) { + t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners) + } + + // Now that we're authenticated, verify that use of SendBanner + // is an error. + var bc ServerPreAuthConn + select { + case bc = <-authConnc: + default: + t.Fatal("expected ServerPreAuthConn") + } + if err := bc.SendAuthBanner("wrong-phase"); err == nil { + t.Error("unexpected success of SendAuthBanner after authentication") + } else if err != errSendBannerPhase { + t.Errorf("unexpected error: %v; want %v", err, errSendBannerPhase) + } +} + +type markerConn struct { + closed uint32 + used uint32 +} + +func (c *markerConn) isClosed() bool { + return atomic.LoadUint32(&c.closed) != 0 +} + +func (c *markerConn) isUsed() bool { + return atomic.LoadUint32(&c.used) != 0 +} + +func (c *markerConn) Close() error { + atomic.StoreUint32(&c.closed, 1) + return nil +} + +func (c *markerConn) Read(b []byte) (n int, err error) { + atomic.StoreUint32(&c.used, 1) + if atomic.LoadUint32(&c.closed) != 0 { + return 0, net.ErrClosed + } else { + return 0, io.EOF + } +} + +func (c *markerConn) Write(b []byte) (n int, err error) { + atomic.StoreUint32(&c.used, 1) + if atomic.LoadUint32(&c.closed) != 0 { + return 0, net.ErrClosed + } else { + return 0, io.ErrClosedPipe + } +} + +func (*markerConn) LocalAddr() net.Addr { return nil } +func (*markerConn) RemoteAddr() net.Addr { return nil } + +func (*markerConn) SetDeadline(t time.Time) error { return nil } +func (*markerConn) SetReadDeadline(t time.Time) error { return nil } +func (*markerConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/tempfork/sshtest/ssh/session.go b/tempfork/sshtest/ssh/session.go new file mode 100644 index 0000000000000..acef62259fdee --- /dev/null +++ b/tempfork/sshtest/ssh/session.go @@ -0,0 +1,647 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Session implements an interactive session described in +// "RFC 4254, section 6". + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" + "sync" +) + +type Signal string + +// POSIX signals as listed in RFC 4254 Section 6.10. +const ( + SIGABRT Signal = "ABRT" + SIGALRM Signal = "ALRM" + SIGFPE Signal = "FPE" + SIGHUP Signal = "HUP" + SIGILL Signal = "ILL" + SIGINT Signal = "INT" + SIGKILL Signal = "KILL" + SIGPIPE Signal = "PIPE" + SIGQUIT Signal = "QUIT" + SIGSEGV Signal = "SEGV" + SIGTERM Signal = "TERM" + SIGUSR1 Signal = "USR1" + SIGUSR2 Signal = "USR2" +) + +var signals = map[Signal]int{ + SIGABRT: 6, + SIGALRM: 14, + SIGFPE: 8, + SIGHUP: 1, + SIGILL: 4, + SIGINT: 2, + SIGKILL: 9, + SIGPIPE: 13, + SIGQUIT: 3, + SIGSEGV: 11, + SIGTERM: 15, +} + +type TerminalModes map[uint8]uint32 + +// POSIX terminal mode flags as listed in RFC 4254 Section 8. +const ( + tty_OP_END = 0 + VINTR = 1 + VQUIT = 2 + VERASE = 3 + VKILL = 4 + VEOF = 5 + VEOL = 6 + VEOL2 = 7 + VSTART = 8 + VSTOP = 9 + VSUSP = 10 + VDSUSP = 11 + VREPRINT = 12 + VWERASE = 13 + VLNEXT = 14 + VFLUSH = 15 + VSWTCH = 16 + VSTATUS = 17 + VDISCARD = 18 + IGNPAR = 30 + PARMRK = 31 + INPCK = 32 + ISTRIP = 33 + INLCR = 34 + IGNCR = 35 + ICRNL = 36 + IUCLC = 37 + IXON = 38 + IXANY = 39 + IXOFF = 40 + IMAXBEL = 41 + IUTF8 = 42 // RFC 8160 + ISIG = 50 + ICANON = 51 + XCASE = 52 + ECHO = 53 + ECHOE = 54 + ECHOK = 55 + ECHONL = 56 + NOFLSH = 57 + TOSTOP = 58 + IEXTEN = 59 + ECHOCTL = 60 + ECHOKE = 61 + PENDIN = 62 + OPOST = 70 + OLCUC = 71 + ONLCR = 72 + OCRNL = 73 + ONOCR = 74 + ONLRET = 75 + CS7 = 90 + CS8 = 91 + PARENB = 92 + PARODD = 93 + TTY_OP_ISPEED = 128 + TTY_OP_OSPEED = 129 +) + +// A Session represents a connection to a remote command or shell. +type Session struct { + // Stdin specifies the remote process's standard input. + // If Stdin is nil, the remote process reads from an empty + // bytes.Buffer. + Stdin io.Reader + + // Stdout and Stderr specify the remote process's standard + // output and error. + // + // If either is nil, Run connects the corresponding file + // descriptor to an instance of io.Discard. There is a + // fixed amount of buffering that is shared for the two streams. + // If either blocks it may eventually cause the remote + // command to block. + Stdout io.Writer + Stderr io.Writer + + ch Channel // the channel backing this session + started bool // true once Start, Run or Shell is invoked. + copyFuncs []func() error + errors chan error // one send per copyFunc + + // true if pipe method is active + stdinpipe, stdoutpipe, stderrpipe bool + + // stdinPipeWriter is non-nil if StdinPipe has not been called + // and Stdin was specified by the user; it is the write end of + // a pipe connecting Session.Stdin to the stdin channel. + stdinPipeWriter io.WriteCloser + + exitStatus chan error +} + +// SendRequest sends an out-of-band channel request on the SSH channel +// underlying the session. +func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { + return s.ch.SendRequest(name, wantReply, payload) +} + +func (s *Session) Close() error { + return s.ch.Close() +} + +// RFC 4254 Section 6.4. +type setenvRequest struct { + Name string + Value string +} + +// Setenv sets an environment variable that will be applied to any +// command executed by Shell or Run. +func (s *Session) Setenv(name, value string) error { + msg := setenvRequest{ + Name: name, + Value: value, + } + ok, err := s.ch.SendRequest("env", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: setenv failed") + } + return err +} + +// RFC 4254 Section 6.2. +type ptyRequestMsg struct { + Term string + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 + Modelist string +} + +// RequestPty requests the association of a pty with the session on the remote host. +func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error { + var tm []byte + for k, v := range termmodes { + kv := struct { + Key byte + Val uint32 + }{k, v} + + tm = append(tm, Marshal(&kv)...) + } + tm = append(tm, tty_OP_END) + req := ptyRequestMsg{ + Term: term, + Columns: uint32(w), + Rows: uint32(h), + Width: uint32(w * 8), + Height: uint32(h * 8), + Modelist: string(tm), + } + ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req)) + if err == nil && !ok { + err = errors.New("ssh: pty-req failed") + } + return err +} + +// RFC 4254 Section 6.5. +type subsystemRequestMsg struct { + Subsystem string +} + +// RequestSubsystem requests the association of a subsystem with the session on the remote host. +// A subsystem is a predefined command that runs in the background when the ssh session is initiated +func (s *Session) RequestSubsystem(subsystem string) error { + msg := subsystemRequestMsg{ + Subsystem: subsystem, + } + ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg)) + if err == nil && !ok { + err = errors.New("ssh: subsystem request failed") + } + return err +} + +// RFC 4254 Section 6.7. +type ptyWindowChangeMsg struct { + Columns uint32 + Rows uint32 + Width uint32 + Height uint32 +} + +// WindowChange informs the remote host about a terminal window dimension change to h rows and w columns. +func (s *Session) WindowChange(h, w int) error { + req := ptyWindowChangeMsg{ + Columns: uint32(w), + Rows: uint32(h), + Width: uint32(w * 8), + Height: uint32(h * 8), + } + _, err := s.ch.SendRequest("window-change", false, Marshal(&req)) + return err +} + +// RFC 4254 Section 6.9. +type signalMsg struct { + Signal string +} + +// Signal sends the given signal to the remote process. +// sig is one of the SIG* constants. +func (s *Session) Signal(sig Signal) error { + msg := signalMsg{ + Signal: string(sig), + } + + _, err := s.ch.SendRequest("signal", false, Marshal(&msg)) + return err +} + +// RFC 4254 Section 6.5. +type execMsg struct { + Command string +} + +// Start runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Run, Start or Shell. +func (s *Session) Start(cmd string) error { + if s.started { + return errors.New("ssh: session already started") + } + req := execMsg{ + Command: cmd, + } + + ok, err := s.ch.SendRequest("exec", true, Marshal(&req)) + if err == nil && !ok { + err = fmt.Errorf("ssh: command %v failed", cmd) + } + if err != nil { + return err + } + return s.start() +} + +// Run runs cmd on the remote host. Typically, the remote +// server passes cmd to the shell for interpretation. +// A Session only accepts one call to Run, Start, Shell, Output, +// or CombinedOutput. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the remote server does not send an exit status, an error of type +// *ExitMissingError is returned. If the command completes +// unsuccessfully or is interrupted by a signal, the error is of type +// *ExitError. Other error types may be returned for I/O problems. +func (s *Session) Run(cmd string) error { + err := s.Start(cmd) + if err != nil { + return err + } + return s.Wait() +} + +// Output runs cmd on the remote host and returns its standard output. +func (s *Session) Output(cmd string) ([]byte, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + var b bytes.Buffer + s.Stdout = &b + err := s.Run(cmd) + return b.Bytes(), err +} + +type singleWriter struct { + b bytes.Buffer + mu sync.Mutex +} + +func (w *singleWriter) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + return w.b.Write(p) +} + +// CombinedOutput runs cmd on the remote host and returns its combined +// standard output and standard error. +func (s *Session) CombinedOutput(cmd string) ([]byte, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + if s.Stderr != nil { + return nil, errors.New("ssh: Stderr already set") + } + var b singleWriter + s.Stdout = &b + s.Stderr = &b + err := s.Run(cmd) + return b.b.Bytes(), err +} + +// Shell starts a login shell on the remote host. A Session only +// accepts one call to Run, Start, Shell, Output, or CombinedOutput. +func (s *Session) Shell() error { + if s.started { + return errors.New("ssh: session already started") + } + + ok, err := s.ch.SendRequest("shell", true, nil) + if err == nil && !ok { + return errors.New("ssh: could not start shell") + } + if err != nil { + return err + } + return s.start() +} + +func (s *Session) start() error { + s.started = true + + type F func(*Session) + for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} { + setupFd(s) + } + + s.errors = make(chan error, len(s.copyFuncs)) + for _, fn := range s.copyFuncs { + go func(fn func() error) { + s.errors <- fn() + }(fn) + } + return nil +} + +// Wait waits for the remote command to exit. +// +// The returned error is nil if the command runs, has no problems +// copying stdin, stdout, and stderr, and exits with a zero exit +// status. +// +// If the remote server does not send an exit status, an error of type +// *ExitMissingError is returned. If the command completes +// unsuccessfully or is interrupted by a signal, the error is of type +// *ExitError. Other error types may be returned for I/O problems. +func (s *Session) Wait() error { + if !s.started { + return errors.New("ssh: session not started") + } + waitErr := <-s.exitStatus + + if s.stdinPipeWriter != nil { + s.stdinPipeWriter.Close() + } + var copyError error + for range s.copyFuncs { + if err := <-s.errors; err != nil && copyError == nil { + copyError = err + } + } + if waitErr != nil { + return waitErr + } + return copyError +} + +func (s *Session) wait(reqs <-chan *Request) error { + wm := Waitmsg{status: -1} + // Wait for msg channel to be closed before returning. + for msg := range reqs { + switch msg.Type { + case "exit-status": + wm.status = int(binary.BigEndian.Uint32(msg.Payload)) + case "exit-signal": + var sigval struct { + Signal string + CoreDumped bool + Error string + Lang string + } + if err := Unmarshal(msg.Payload, &sigval); err != nil { + return err + } + + // Must sanitize strings? + wm.signal = sigval.Signal + wm.msg = sigval.Error + wm.lang = sigval.Lang + default: + // This handles keepalives and matches + // OpenSSH's behaviour. + if msg.WantReply { + msg.Reply(false, nil) + } + } + } + if wm.status == 0 { + return nil + } + if wm.status == -1 { + // exit-status was never sent from server + if wm.signal == "" { + // signal was not sent either. RFC 4254 + // section 6.10 recommends against this + // behavior, but it is allowed, so we let + // clients handle it. + return &ExitMissingError{} + } + wm.status = 128 + if _, ok := signals[Signal(wm.signal)]; ok { + wm.status += signals[Signal(wm.signal)] + } + } + + return &ExitError{wm} +} + +// ExitMissingError is returned if a session is torn down cleanly, but +// the server sends no confirmation of the exit status. +type ExitMissingError struct{} + +func (e *ExitMissingError) Error() string { + return "wait: remote command exited without exit status or exit signal" +} + +func (s *Session) stdin() { + if s.stdinpipe { + return + } + var stdin io.Reader + if s.Stdin == nil { + stdin = new(bytes.Buffer) + } else { + r, w := io.Pipe() + go func() { + _, err := io.Copy(w, s.Stdin) + w.CloseWithError(err) + }() + stdin, s.stdinPipeWriter = r, w + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.ch, stdin) + if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF { + err = err1 + } + return err + }) +} + +func (s *Session) stdout() { + if s.stdoutpipe { + return + } + if s.Stdout == nil { + s.Stdout = io.Discard + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.Stdout, s.ch) + return err + }) +} + +func (s *Session) stderr() { + if s.stderrpipe { + return + } + if s.Stderr == nil { + s.Stderr = io.Discard + } + s.copyFuncs = append(s.copyFuncs, func() error { + _, err := io.Copy(s.Stderr, s.ch.Stderr()) + return err + }) +} + +// sessionStdin reroutes Close to CloseWrite. +type sessionStdin struct { + io.Writer + ch Channel +} + +func (s *sessionStdin) Close() error { + return s.ch.CloseWrite() +} + +// StdinPipe returns a pipe that will be connected to the +// remote command's standard input when the command starts. +func (s *Session) StdinPipe() (io.WriteCloser, error) { + if s.Stdin != nil { + return nil, errors.New("ssh: Stdin already set") + } + if s.started { + return nil, errors.New("ssh: StdinPipe after process started") + } + s.stdinpipe = true + return &sessionStdin{s.ch, s.ch}, nil +} + +// StdoutPipe returns a pipe that will be connected to the +// remote command's standard output when the command starts. +// There is a fixed amount of buffering that is shared between +// stdout and stderr streams. If the StdoutPipe reader is +// not serviced fast enough it may eventually cause the +// remote command to block. +func (s *Session) StdoutPipe() (io.Reader, error) { + if s.Stdout != nil { + return nil, errors.New("ssh: Stdout already set") + } + if s.started { + return nil, errors.New("ssh: StdoutPipe after process started") + } + s.stdoutpipe = true + return s.ch, nil +} + +// StderrPipe returns a pipe that will be connected to the +// remote command's standard error when the command starts. +// There is a fixed amount of buffering that is shared between +// stdout and stderr streams. If the StderrPipe reader is +// not serviced fast enough it may eventually cause the +// remote command to block. +func (s *Session) StderrPipe() (io.Reader, error) { + if s.Stderr != nil { + return nil, errors.New("ssh: Stderr already set") + } + if s.started { + return nil, errors.New("ssh: StderrPipe after process started") + } + s.stderrpipe = true + return s.ch.Stderr(), nil +} + +// newSession returns a new interactive session on the remote host. +func newSession(ch Channel, reqs <-chan *Request) (*Session, error) { + s := &Session{ + ch: ch, + } + s.exitStatus = make(chan error, 1) + go func() { + s.exitStatus <- s.wait(reqs) + }() + + return s, nil +} + +// An ExitError reports unsuccessful completion of a remote command. +type ExitError struct { + Waitmsg +} + +func (e *ExitError) Error() string { + return e.Waitmsg.String() +} + +// Waitmsg stores the information about an exited remote command +// as reported by Wait. +type Waitmsg struct { + status int + signal string + msg string + lang string +} + +// ExitStatus returns the exit status of the remote command. +func (w Waitmsg) ExitStatus() int { + return w.status +} + +// Signal returns the exit signal of the remote command if +// it was terminated violently. +func (w Waitmsg) Signal() string { + return w.signal +} + +// Msg returns the exit message given by the remote command +func (w Waitmsg) Msg() string { + return w.msg +} + +// Lang returns the language tag. See RFC 3066 +func (w Waitmsg) Lang() string { + return w.lang +} + +func (w Waitmsg) String() string { + str := fmt.Sprintf("Process exited with status %v", w.status) + if w.signal != "" { + str += fmt.Sprintf(" from signal %v", w.signal) + } + if w.msg != "" { + str += fmt.Sprintf(". Reason was: %v", w.msg) + } + return str +} diff --git a/tempfork/sshtest/ssh/session_test.go b/tempfork/sshtest/ssh/session_test.go new file mode 100644 index 0000000000000..807a913e5ace7 --- /dev/null +++ b/tempfork/sshtest/ssh/session_test.go @@ -0,0 +1,892 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +// Session tests. + +import ( + "bytes" + crypto_rand "crypto/rand" + "errors" + "io" + "math/rand" + "net" + "sync" + "testing" + + "golang.org/x/crypto/ssh/terminal" +) + +type serverType func(Channel, <-chan *Request, *testing.T) + +// dial constructs a new test server and returns a *ClientConn. +func dial(handler serverType, t *testing.T) *Client { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer func() { + c1.Close() + wg.Done() + }() + conf := ServerConfig{ + NoClientAuth: true, + } + conf.AddHostKey(testSigners["rsa"]) + + conn, chans, reqs, err := NewServerConn(c1, &conf) + if err != nil { + t.Errorf("Unable to handshake: %v", err) + return + } + wg.Add(1) + go func() { + DiscardRequests(reqs) + wg.Done() + }() + + for newCh := range chans { + if newCh.ChannelType() != "session" { + newCh.Reject(UnknownChannelType, "unknown channel type") + continue + } + + ch, inReqs, err := newCh.Accept() + if err != nil { + t.Errorf("Accept: %v", err) + continue + } + wg.Add(1) + go func() { + handler(ch, inReqs, t) + wg.Done() + }() + } + if err := conn.Wait(); err != io.EOF { + t.Logf("server exit reason: %v", err) + } + }() + + config := &ClientConfig{ + User: "testuser", + HostKeyCallback: InsecureIgnoreHostKey(), + } + + conn, chans, reqs, err := NewClientConn(c2, "", config) + if err != nil { + t.Fatalf("unable to dial remote side: %v", err) + } + + return NewClient(conn, chans, reqs) +} + +// Test a simple string is returned to session.Stdout. +func TestSessionShell(t *testing.T) { + conn := dial(shellHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + stdout := new(bytes.Buffer) + session.Stdout = stdout + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %s", err) + } + if err := session.Wait(); err != nil { + t.Fatalf("Remote command did not exit cleanly: %v", err) + } + actual := stdout.String() + if actual != "golang" { + t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) + } +} + +// TODO(dfc) add support for Std{in,err}Pipe when the Server supports it. + +// Test a simple string is returned via StdoutPipe. +func TestSessionStdoutPipe(t *testing.T) { + conn := dial(shellHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("Unable to request StdoutPipe(): %v", err) + } + var buf bytes.Buffer + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + done := make(chan bool, 1) + go func() { + if _, err := io.Copy(&buf, stdout); err != nil { + t.Errorf("Copy of stdout failed: %v", err) + } + done <- true + }() + if err := session.Wait(); err != nil { + t.Fatalf("Remote command did not exit cleanly: %v", err) + } + <-done + actual := buf.String() + if actual != "golang" { + t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) + } +} + +// Test that a simple string is returned via the Output helper, +// and that stderr is discarded. +func TestSessionOutput(t *testing.T) { + conn := dial(fixedOutputHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + buf, err := session.Output("") // cmd is ignored by fixedOutputHandler + if err != nil { + t.Error("Remote command did not exit cleanly:", err) + } + w := "this-is-stdout." + g := string(buf) + if g != w { + t.Error("Remote command did not return expected string:") + t.Logf("want %q", w) + t.Logf("got %q", g) + } +} + +// Test that both stdout and stderr are returned +// via the CombinedOutput helper. +func TestSessionCombinedOutput(t *testing.T) { + conn := dial(fixedOutputHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler + if err != nil { + t.Error("Remote command did not exit cleanly:", err) + } + const stdout = "this-is-stdout." + const stderr = "this-is-stderr." + g := string(buf) + if g != stdout+stderr && g != stderr+stdout { + t.Error("Remote command did not return expected string:") + t.Logf("want %q, or %q", stdout+stderr, stderr+stdout) + t.Logf("got %q", g) + } +} + +// Test non-0 exit status is returned correctly. +func TestExitStatusNonZero(t *testing.T) { + conn := dial(exitStatusNonZeroHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.ExitStatus() != 15 { + t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus()) + } +} + +// Test 0 exit status is returned correctly. +func TestExitStatusZero(t *testing.T) { + conn := dial(exitStatusZeroHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err != nil { + t.Fatalf("expected nil but got %v", err) + } +} + +// Test exit signal and status are both returned correctly. +func TestExitSignalAndStatus(t *testing.T) { + conn := dial(exitSignalAndStatusHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "TERM" || e.ExitStatus() != 15 { + t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +// Test exit signal and status are both returned correctly. +func TestKnownExitSignalOnly(t *testing.T) { + conn := dial(exitSignalHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "TERM" || e.ExitStatus() != 143 { + t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +// Test exit signal and status are both returned correctly. +func TestUnknownExitSignal(t *testing.T) { + conn := dial(exitSignalUnknownHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + e, ok := err.(*ExitError) + if !ok { + t.Fatalf("expected *ExitError but got %T", err) + } + if e.Signal() != "SYS" || e.ExitStatus() != 128 { + t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus()) + } +} + +func TestExitWithoutStatusOrSignal(t *testing.T) { + conn := dial(exitWithoutSignalOrStatus, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatalf("Unable to request new session: %v", err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err == nil { + t.Fatalf("expected command to fail but it didn't") + } + if _, ok := err.(*ExitMissingError); !ok { + t.Fatalf("got %T want *ExitMissingError", err) + } +} + +// windowTestBytes is the number of bytes that we'll send to the SSH server. +const windowTestBytes = 16000 * 200 + +// TestServerWindow writes random data to the server. The server is expected to echo +// the same data back, which is compared against the original. +func TestServerWindow(t *testing.T) { + origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) + io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes) + origBytes := origBuf.Bytes() + + conn := dial(echoHandler, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + + serverStdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("StdinPipe failed: %v", err) + } + + result := make(chan []byte) + go func() { + defer close(result) + echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) + serverStdout, err := session.StdoutPipe() + if err != nil { + t.Errorf("StdoutPipe failed: %v", err) + return + } + n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes) + if err != nil && err != io.EOF { + t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err) + } + result <- echoedBuf.Bytes() + }() + + written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes) + if err != nil { + t.Errorf("failed to copy origBuf to serverStdin: %v", err) + } else if written != windowTestBytes { + t.Errorf("Wrote only %d of %d bytes to server", written, windowTestBytes) + } + + echoedBytes := <-result + + if !bytes.Equal(origBytes, echoedBytes) { + t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes)) + } +} + +// Verify the client can handle a keepalive packet from the server. +func TestClientHandlesKeepalives(t *testing.T) { + conn := dial(channelKeepaliveSender, t) + defer conn.Close() + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + if err := session.Shell(); err != nil { + t.Fatalf("Unable to execute command: %v", err) + } + err = session.Wait() + if err != nil { + t.Fatalf("expected nil but got: %v", err) + } +} + +type exitStatusMsg struct { + Status uint32 +} + +type exitSignalMsg struct { + Signal string + CoreDumped bool + Errmsg string + Lang string +} + +func handleTerminalRequests(in <-chan *Request) { + for req := range in { + ok := false + switch req.Type { + case "shell": + ok = true + if len(req.Payload) > 0 { + // We don't accept any commands, only the default shell. + ok = false + } + case "env": + ok = true + } + req.Reply(ok, nil) + } +} + +func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal { + term := terminal.NewTerminal(ch, prompt) + go handleTerminalRequests(in) + return term +} + +func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + // this string is returned to stdout + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(0, ch, t) +} + +func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(15, ch, t) +} + +func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendStatus(15, ch, t) + sendSignal("TERM", ch, t) +} + +func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendSignal("TERM", ch, t) +} + +func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + sendSignal("SYS", ch, t) +} + +func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) +} + +func shellHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + // this string is returned to stdout + shell := newServerShell(ch, in, "golang") + readLine(shell, t) + sendStatus(0, ch, t) +} + +// Ignores the command, writes fixed strings to stderr and stdout. +// Strings are "this-is-stdout." and "this-is-stderr.". +func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + _, err := ch.Read(nil) + + req, ok := <-in + if !ok { + t.Fatalf("error: expected channel request, got: %#v", err) + return + } + + // ignore request, always send some text + req.Reply(true, nil) + + _, err = io.WriteString(ch, "this-is-stdout.") + if err != nil { + t.Fatalf("error writing on server: %v", err) + } + _, err = io.WriteString(ch.Stderr(), "this-is-stderr.") + if err != nil { + t.Fatalf("error writing on server: %v", err) + } + sendStatus(0, ch, t) +} + +func readLine(shell *terminal.Terminal, t *testing.T) { + if _, err := shell.ReadLine(); err != nil && err != io.EOF { + t.Errorf("unable to read line: %v", err) + } +} + +func sendStatus(status uint32, ch Channel, t *testing.T) { + msg := exitStatusMsg{ + Status: status, + } + if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil { + t.Errorf("unable to send status: %v", err) + } +} + +func sendSignal(signal string, ch Channel, t *testing.T) { + sig := exitSignalMsg{ + Signal: signal, + CoreDumped: false, + Errmsg: "Process terminated", + Lang: "en-GB-oed", + } + if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil { + t.Errorf("unable to send signal: %v", err) + } +} + +func discardHandler(ch Channel, t *testing.T) { + defer ch.Close() + io.Copy(io.Discard, ch) +} + +func echoHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil { + t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err) + } +} + +// copyNRandomly copies n bytes from src to dst. It uses a variable, and random, +// buffer size to exercise more code paths. +func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) { + var ( + buf = make([]byte, 32*1024) + written int + remaining = n + ) + for remaining > 0 { + l := rand.Intn(1 << 15) + if remaining < l { + l = remaining + } + nr, er := src.Read(buf[:l]) + nw, ew := dst.Write(buf[:nr]) + remaining -= nw + written += nw + if ew != nil { + return written, ew + } + if nr != nw { + return written, io.ErrShortWrite + } + if er != nil && er != io.EOF { + return written, er + } + } + return written, nil +} + +func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + shell := newServerShell(ch, in, "> ") + readLine(shell, t) + if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil { + t.Errorf("unable to send channel keepalive request: %v", err) + } + sendStatus(0, ch, t) +} + +func TestClientWriteEOF(t *testing.T) { + conn := dial(simpleEchoHandler, t) + defer conn.Close() + + session, err := conn.NewSession() + if err != nil { + t.Fatal(err) + } + defer session.Close() + stdin, err := session.StdinPipe() + if err != nil { + t.Fatalf("StdinPipe failed: %v", err) + } + stdout, err := session.StdoutPipe() + if err != nil { + t.Fatalf("StdoutPipe failed: %v", err) + } + + data := []byte(`0000`) + _, err = stdin.Write(data) + if err != nil { + t.Fatalf("Write failed: %v", err) + } + stdin.Close() + + res, err := io.ReadAll(stdout) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + + if !bytes.Equal(data, res) { + t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res) + } +} + +func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) { + defer ch.Close() + data, err := io.ReadAll(ch) + if err != nil { + t.Errorf("handler read error: %v", err) + } + _, err = ch.Write(data) + if err != nil { + t.Errorf("handler write error: %v", err) + } +} + +func TestSessionID(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serverID := make(chan []byte, 1) + clientID := make(chan []byte, 1) + + serverConf := &ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["ecdsa"]) + clientConf := &ClientConfig{ + HostKeyCallback: InsecureIgnoreHostKey(), + User: "user", + } + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + + srvErrCh := make(chan error, 1) + wg.Add(1) + go func() { + defer wg.Done() + conn, chans, reqs, err := NewServerConn(c1, serverConf) + srvErrCh <- err + if err != nil { + return + } + serverID <- conn.SessionID() + wg.Add(1) + go func() { + DiscardRequests(reqs) + wg.Done() + }() + for ch := range chans { + ch.Reject(Prohibited, "") + } + }() + + cliErrCh := make(chan error, 1) + wg.Add(1) + go func() { + defer wg.Done() + conn, chans, reqs, err := NewClientConn(c2, "", clientConf) + cliErrCh <- err + if err != nil { + return + } + clientID <- conn.SessionID() + wg.Add(1) + go func() { + DiscardRequests(reqs) + wg.Done() + }() + for ch := range chans { + ch.Reject(Prohibited, "") + } + }() + + if err := <-srvErrCh; err != nil { + t.Fatalf("server handshake: %v", err) + } + + if err := <-cliErrCh; err != nil { + t.Fatalf("client handshake: %v", err) + } + + s := <-serverID + c := <-clientID + if bytes.Compare(s, c) != 0 { + t.Errorf("server session ID (%x) != client session ID (%x)", s, c) + } else if len(s) == 0 { + t.Errorf("client and server SessionID were empty.") + } +} + +type noReadConn struct { + readSeen bool + net.Conn +} + +func (c *noReadConn) Close() error { + return nil +} + +func (c *noReadConn) Read(b []byte) (int, error) { + c.readSeen = true + return 0, errors.New("noReadConn error") +} + +func TestInvalidServerConfiguration(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + serveConn := noReadConn{Conn: c1} + serverConf := &ServerConfig{} + + NewServerConn(&serveConn, serverConf) + if serveConn.readSeen { + t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing host key") + } + + serverConf.AddHostKey(testSigners["ecdsa"]) + + NewServerConn(&serveConn, serverConf) + if serveConn.readSeen { + t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing authentication method") + } +} + +func TestHostKeyAlgorithms(t *testing.T) { + serverConf := &ServerConfig{ + NoClientAuth: true, + } + serverConf.AddHostKey(testSigners["rsa"]) + serverConf.AddHostKey(testSigners["ecdsa"]) + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + connect := func(clientConf *ClientConfig, want string) { + var alg string + clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error { + alg = key.Type() + return nil + } + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + wg.Add(1) + go func() { + NewServerConn(c1, serverConf) + wg.Done() + }() + _, _, _, err = NewClientConn(c2, "", clientConf) + if err != nil { + t.Fatalf("NewClientConn: %v", err) + } + if alg != want { + t.Errorf("selected key algorithm %s, want %s", alg, want) + } + } + + // By default, we get the preferred algorithm, which is ECDSA 256. + + clientConf := &ClientConfig{ + HostKeyCallback: InsecureIgnoreHostKey(), + } + connect(clientConf, KeyAlgoECDSA256) + + // Client asks for RSA explicitly. + clientConf.HostKeyAlgorithms = []string{KeyAlgoRSA} + connect(clientConf, KeyAlgoRSA) + + // Client asks for RSA-SHA2-512 explicitly. + clientConf.HostKeyAlgorithms = []string{KeyAlgoRSASHA512} + // We get back an "ssh-rsa" key but the verification happened + // with an RSA-SHA2-512 signature. + connect(clientConf, KeyAlgoRSA) + + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + wg.Add(1) + go func() { + NewServerConn(c1, serverConf) + wg.Done() + }() + clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"} + _, _, _, err = NewClientConn(c2, "", clientConf) + if err == nil { + t.Fatal("succeeded connecting with unknown hostkey algorithm") + } +} + +func TestServerClientAuthCallback(t *testing.T) { + c1, c2, err := netPipe() + if err != nil { + t.Fatalf("netPipe: %v", err) + } + defer c1.Close() + defer c2.Close() + + userCh := make(chan string, 1) + + serverConf := &ServerConfig{ + NoClientAuth: true, + NoClientAuthCallback: func(conn ConnMetadata) (*Permissions, error) { + userCh <- conn.User() + return nil, nil + }, + } + const someUsername = "some-username" + + serverConf.AddHostKey(testSigners["ecdsa"]) + clientConf := &ClientConfig{ + HostKeyCallback: InsecureIgnoreHostKey(), + User: someUsername, + } + + var wg sync.WaitGroup + t.Cleanup(wg.Wait) + wg.Add(1) + go func() { + defer wg.Done() + _, chans, reqs, err := NewServerConn(c1, serverConf) + if err != nil { + t.Errorf("server handshake: %v", err) + userCh <- "error" + return + } + wg.Add(1) + go func() { + DiscardRequests(reqs) + wg.Done() + }() + for ch := range chans { + ch.Reject(Prohibited, "") + } + }() + + conn, _, _, err := NewClientConn(c2, "", clientConf) + if err != nil { + t.Fatalf("client handshake: %v", err) + return + } + conn.Close() + + got := <-userCh + if got != someUsername { + t.Errorf("username = %q; want %q", got, someUsername) + } +} diff --git a/tempfork/sshtest/ssh/ssh_gss.go b/tempfork/sshtest/ssh/ssh_gss.go new file mode 100644 index 0000000000000..24bd7c8e83048 --- /dev/null +++ b/tempfork/sshtest/ssh/ssh_gss.go @@ -0,0 +1,139 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "encoding/asn1" + "errors" +) + +var krb5OID []byte + +func init() { + krb5OID, _ = asn1.Marshal(krb5Mesh) +} + +// GSSAPIClient provides the API to plug-in GSSAPI authentication for client logins. +type GSSAPIClient interface { + // InitSecContext initiates the establishment of a security context for GSS-API between the + // ssh client and ssh server. Initially the token parameter should be specified as nil. + // The routine may return a outputToken which should be transferred to + // the ssh server, where the ssh server will present it to + // AcceptSecContext. If no token need be sent, InitSecContext will indicate this by setting + // needContinue to false. To complete the context + // establishment, one or more reply tokens may be required from the ssh + // server;if so, InitSecContext will return a needContinue which is true. + // In this case, InitSecContext should be called again when the + // reply token is received from the ssh server, passing the reply + // token to InitSecContext via the token parameters. + // See RFC 2743 section 2.2.1 and RFC 4462 section 3.4. + InitSecContext(target string, token []byte, isGSSDelegCreds bool) (outputToken []byte, needContinue bool, err error) + // GetMIC generates a cryptographic MIC for the SSH2 message, and places + // the MIC in a token for transfer to the ssh server. + // The contents of the MIC field are obtained by calling GSS_GetMIC() + // over the following, using the GSS-API context that was just + // established: + // string session identifier + // byte SSH_MSG_USERAUTH_REQUEST + // string user name + // string service + // string "gssapi-with-mic" + // See RFC 2743 section 2.3.1 and RFC 4462 3.5. + GetMIC(micFiled []byte) ([]byte, error) + // Whenever possible, it should be possible for + // DeleteSecContext() calls to be successfully processed even + // if other calls cannot succeed, thereby enabling context-related + // resources to be released. + // In addition to deleting established security contexts, + // gss_delete_sec_context must also be able to delete "half-built" + // security contexts resulting from an incomplete sequence of + // InitSecContext()/AcceptSecContext() calls. + // See RFC 2743 section 2.2.3. + DeleteSecContext() error +} + +// GSSAPIServer provides the API to plug in GSSAPI authentication for server logins. +type GSSAPIServer interface { + // AcceptSecContext allows a remotely initiated security context between the application + // and a remote peer to be established by the ssh client. The routine may return a + // outputToken which should be transferred to the ssh client, + // where the ssh client will present it to InitSecContext. + // If no token need be sent, AcceptSecContext will indicate this + // by setting the needContinue to false. To + // complete the context establishment, one or more reply tokens may be + // required from the ssh client. if so, AcceptSecContext + // will return a needContinue which is true, in which case it + // should be called again when the reply token is received from the ssh + // client, passing the token to AcceptSecContext via the + // token parameters. + // The srcName return value is the authenticated username. + // See RFC 2743 section 2.2.2 and RFC 4462 section 3.4. + AcceptSecContext(token []byte) (outputToken []byte, srcName string, needContinue bool, err error) + // VerifyMIC verifies that a cryptographic MIC, contained in the token parameter, + // fits the supplied message is received from the ssh client. + // See RFC 2743 section 2.3.2. + VerifyMIC(micField []byte, micToken []byte) error + // Whenever possible, it should be possible for + // DeleteSecContext() calls to be successfully processed even + // if other calls cannot succeed, thereby enabling context-related + // resources to be released. + // In addition to deleting established security contexts, + // gss_delete_sec_context must also be able to delete "half-built" + // security contexts resulting from an incomplete sequence of + // InitSecContext()/AcceptSecContext() calls. + // See RFC 2743 section 2.2.3. + DeleteSecContext() error +} + +var ( + // OpenSSH supports Kerberos V5 mechanism only for GSS-API authentication, + // so we also support the krb5 mechanism only. + // See RFC 1964 section 1. + krb5Mesh = asn1.ObjectIdentifier{1, 2, 840, 113554, 1, 2, 2} +) + +// The GSS-API authentication method is initiated when the client sends an SSH_MSG_USERAUTH_REQUEST +// See RFC 4462 section 3.2. +type userAuthRequestGSSAPI struct { + N uint32 + OIDS []asn1.ObjectIdentifier +} + +func parseGSSAPIPayload(payload []byte) (*userAuthRequestGSSAPI, error) { + n, rest, ok := parseUint32(payload) + if !ok { + return nil, errors.New("parse uint32 failed") + } + s := &userAuthRequestGSSAPI{ + N: n, + OIDS: make([]asn1.ObjectIdentifier, n), + } + for i := 0; i < int(n); i++ { + var ( + desiredMech []byte + err error + ) + desiredMech, rest, ok = parseString(rest) + if !ok { + return nil, errors.New("parse string failed") + } + if rest, err = asn1.Unmarshal(desiredMech, &s.OIDS[i]); err != nil { + return nil, err + } + + } + return s, nil +} + +// See RFC 4462 section 3.6. +func buildMIC(sessionID string, username string, service string, authMethod string) []byte { + out := make([]byte, 0, 0) + out = appendString(out, sessionID) + out = append(out, msgUserAuthRequest) + out = appendString(out, username) + out = appendString(out, service) + out = appendString(out, authMethod) + return out +} diff --git a/tempfork/sshtest/ssh/ssh_gss_test.go b/tempfork/sshtest/ssh/ssh_gss_test.go new file mode 100644 index 0000000000000..39a111288af09 --- /dev/null +++ b/tempfork/sshtest/ssh/ssh_gss_test.go @@ -0,0 +1,109 @@ +package ssh + +import ( + "fmt" + "testing" +) + +func TestParseGSSAPIPayload(t *testing.T) { + payload := []byte{0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0b, 0x06, 0x09, + 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x12, 0x01, 0x02, 0x02} + res, err := parseGSSAPIPayload(payload) + if err != nil { + t.Fatal(err) + } + if ok := res.OIDS[0].Equal(krb5Mesh); !ok { + t.Fatalf("got %v, want %v", res, krb5Mesh) + } +} + +func TestBuildMIC(t *testing.T) { + sessionID := []byte{134, 180, 134, 194, 62, 145, 171, 82, 119, 149, 254, 196, 125, 173, 177, 145, 187, 85, 53, + 183, 44, 150, 219, 129, 166, 195, 19, 33, 209, 246, 175, 121} + username := "testuser" + service := "ssh-connection" + authMethod := "gssapi-with-mic" + expected := []byte{0, 0, 0, 32, 134, 180, 134, 194, 62, 145, 171, 82, 119, 149, 254, 196, 125, 173, 177, 145, 187, 85, 53, 183, 44, 150, 219, 129, 166, 195, 19, 33, 209, 246, 175, 121, 50, 0, 0, 0, 8, 116, 101, 115, 116, 117, 115, 101, 114, 0, 0, 0, 14, 115, 115, 104, 45, 99, 111, 110, 110, 101, 99, 116, 105, 111, 110, 0, 0, 0, 15, 103, 115, 115, 97, 112, 105, 45, 119, 105, 116, 104, 45, 109, 105, 99} + result := buildMIC(string(sessionID), username, service, authMethod) + if string(result) != string(expected) { + t.Fatalf("buildMic: got %v, want %v", result, expected) + } +} + +type exchange struct { + outToken string + expectedToken string +} + +type FakeClient struct { + exchanges []*exchange + round int + mic []byte + maxRound int +} + +func (f *FakeClient) InitSecContext(target string, token []byte, isGSSDelegCreds bool) (outputToken []byte, needContinue bool, err error) { + if token == nil { + if f.exchanges[f.round].expectedToken != "" { + err = fmt.Errorf("got empty token, want %q", f.exchanges[f.round].expectedToken) + } else { + outputToken = []byte(f.exchanges[f.round].outToken) + } + } else { + if string(token) != string(f.exchanges[f.round].expectedToken) { + err = fmt.Errorf("got %q, want token %q", token, f.exchanges[f.round].expectedToken) + } else { + outputToken = []byte(f.exchanges[f.round].outToken) + } + } + f.round++ + needContinue = f.round < f.maxRound + return +} + +func (f *FakeClient) GetMIC(micField []byte) ([]byte, error) { + return f.mic, nil +} + +func (f *FakeClient) DeleteSecContext() error { + return nil +} + +type FakeServer struct { + exchanges []*exchange + round int + expectedMIC []byte + srcName string + maxRound int +} + +func (f *FakeServer) AcceptSecContext(token []byte) (outputToken []byte, srcName string, needContinue bool, err error) { + if token == nil { + if f.exchanges[f.round].expectedToken != "" { + err = fmt.Errorf("got empty token, want %q", f.exchanges[f.round].expectedToken) + } else { + outputToken = []byte(f.exchanges[f.round].outToken) + } + } else { + if string(token) != string(f.exchanges[f.round].expectedToken) { + err = fmt.Errorf("got %q, want token %q", token, f.exchanges[f.round].expectedToken) + } else { + outputToken = []byte(f.exchanges[f.round].outToken) + } + } + f.round++ + needContinue = f.round < f.maxRound + srcName = f.srcName + return +} + +func (f *FakeServer) VerifyMIC(micField []byte, micToken []byte) error { + if string(micToken) != string(f.expectedMIC) { + return fmt.Errorf("got MICToken %q, want %q", micToken, f.expectedMIC) + } + return nil +} + +func (f *FakeServer) DeleteSecContext() error { + return nil +} diff --git a/tempfork/sshtest/ssh/streamlocal.go b/tempfork/sshtest/ssh/streamlocal.go new file mode 100644 index 0000000000000..b171b330bc380 --- /dev/null +++ b/tempfork/sshtest/ssh/streamlocal.go @@ -0,0 +1,116 @@ +package ssh + +import ( + "errors" + "io" + "net" +) + +// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message +// with "direct-streamlocal@openssh.com" string. +// +// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding +// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235 +type streamLocalChannelOpenDirectMsg struct { + socketPath string + reserved0 string + reserved1 uint32 +} + +// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message +// with "forwarded-streamlocal@openssh.com" string. +type forwardedStreamLocalPayload struct { + SocketPath string + Reserved0 string +} + +// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message +// with "streamlocal-forward@openssh.com"/"cancel-streamlocal-forward@openssh.com" string. +type streamLocalChannelForwardMsg struct { + socketPath string +} + +// ListenUnix is similar to ListenTCP but uses a Unix domain socket. +func (c *Client) ListenUnix(socketPath string) (net.Listener, error) { + c.handleForwardsOnce.Do(c.handleForwards) + m := streamLocalChannelForwardMsg{ + socketPath, + } + // send message + ok, _, err := c.SendRequest("streamlocal-forward@openssh.com", true, Marshal(&m)) + if err != nil { + return nil, err + } + if !ok { + return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer") + } + ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"}) + + return &unixListener{socketPath, c, ch}, nil +} + +func (c *Client) dialStreamLocal(socketPath string) (Channel, error) { + msg := streamLocalChannelOpenDirectMsg{ + socketPath: socketPath, + } + ch, in, err := c.OpenChannel("direct-streamlocal@openssh.com", Marshal(&msg)) + if err != nil { + return nil, err + } + go DiscardRequests(in) + return ch, err +} + +type unixListener struct { + socketPath string + + conn *Client + in <-chan forward +} + +// Accept waits for and returns the next connection to the listener. +func (l *unixListener) Accept() (net.Conn, error) { + s, ok := <-l.in + if !ok { + return nil, io.EOF + } + ch, incoming, err := s.newCh.Accept() + if err != nil { + return nil, err + } + go DiscardRequests(incoming) + + return &chanConn{ + Channel: ch, + laddr: &net.UnixAddr{ + Name: l.socketPath, + Net: "unix", + }, + raddr: &net.UnixAddr{ + Name: "@", + Net: "unix", + }, + }, nil +} + +// Close closes the listener. +func (l *unixListener) Close() error { + // this also closes the listener. + l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"}) + m := streamLocalChannelForwardMsg{ + l.socketPath, + } + ok, _, err := l.conn.SendRequest("cancel-streamlocal-forward@openssh.com", true, Marshal(&m)) + if err == nil && !ok { + err = errors.New("ssh: cancel-streamlocal-forward@openssh.com failed") + } + return err +} + +// Addr returns the listener's network address. +func (l *unixListener) Addr() net.Addr { + return &net.UnixAddr{ + Name: l.socketPath, + Net: "unix", + } +} diff --git a/tempfork/sshtest/ssh/tcpip.go b/tempfork/sshtest/ssh/tcpip.go new file mode 100644 index 0000000000000..ef5059a11d79e --- /dev/null +++ b/tempfork/sshtest/ssh/tcpip.go @@ -0,0 +1,509 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "context" + "errors" + "fmt" + "io" + "math/rand" + "net" + "strconv" + "strings" + "sync" + "time" +) + +// Listen requests the remote peer open a listening socket on +// addr. Incoming connections will be available by calling Accept on +// the returned net.Listener. The listener must be serviced, or the +// SSH connection may hang. +// N must be "tcp", "tcp4", "tcp6", or "unix". +func (c *Client) Listen(n, addr string) (net.Listener, error) { + switch n { + case "tcp", "tcp4", "tcp6": + laddr, err := net.ResolveTCPAddr(n, addr) + if err != nil { + return nil, err + } + return c.ListenTCP(laddr) + case "unix": + return c.ListenUnix(addr) + default: + return nil, fmt.Errorf("ssh: unsupported protocol: %s", n) + } +} + +// Automatic port allocation is broken with OpenSSH before 6.0. See +// also https://bugzilla.mindrot.org/show_bug.cgi?id=2017. In +// particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0, +// rather than the actual port number. This means you can never open +// two different listeners with auto allocated ports. We work around +// this by trying explicit ports until we succeed. + +const openSSHPrefix = "OpenSSH_" + +var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano())) + +// isBrokenOpenSSHVersion returns true if the given version string +// specifies a version of OpenSSH that is known to have a bug in port +// forwarding. +func isBrokenOpenSSHVersion(versionStr string) bool { + i := strings.Index(versionStr, openSSHPrefix) + if i < 0 { + return false + } + i += len(openSSHPrefix) + j := i + for ; j < len(versionStr); j++ { + if versionStr[j] < '0' || versionStr[j] > '9' { + break + } + } + version, _ := strconv.Atoi(versionStr[i:j]) + return version < 6 +} + +// autoPortListenWorkaround simulates automatic port allocation by +// trying random ports repeatedly. +func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) { + var sshListener net.Listener + var err error + const tries = 10 + for i := 0; i < tries; i++ { + addr := *laddr + addr.Port = 1024 + portRandomizer.Intn(60000) + sshListener, err = c.ListenTCP(&addr) + if err == nil { + laddr.Port = addr.Port + return sshListener, err + } + } + return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err) +} + +// RFC 4254 7.1 +type channelForwardMsg struct { + addr string + rport uint32 +} + +// handleForwards starts goroutines handling forwarded connections. +// It's called on first use by (*Client).ListenTCP to not launch +// goroutines until needed. +func (c *Client) handleForwards() { + go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip")) + go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com")) +} + +// ListenTCP requests the remote peer open a listening socket +// on laddr. Incoming connections will be available by calling +// Accept on the returned net.Listener. +func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { + c.handleForwardsOnce.Do(c.handleForwards) + if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) { + return c.autoPortListenWorkaround(laddr) + } + + m := channelForwardMsg{ + laddr.IP.String(), + uint32(laddr.Port), + } + // send message + ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m)) + if err != nil { + return nil, err + } + if !ok { + return nil, errors.New("ssh: tcpip-forward request denied by peer") + } + + // If the original port was 0, then the remote side will + // supply a real port number in the response. + if laddr.Port == 0 { + var p struct { + Port uint32 + } + if err := Unmarshal(resp, &p); err != nil { + return nil, err + } + laddr.Port = int(p.Port) + } + + // Register this forward, using the port number we obtained. + ch := c.forwards.add(laddr) + + return &tcpListener{laddr, c, ch}, nil +} + +// forwardList stores a mapping between remote +// forward requests and the tcpListeners. +type forwardList struct { + sync.Mutex + entries []forwardEntry +} + +// forwardEntry represents an established mapping of a laddr on a +// remote ssh server to a channel connected to a tcpListener. +type forwardEntry struct { + laddr net.Addr + c chan forward +} + +// forward represents an incoming forwarded tcpip connection. The +// arguments to add/remove/lookup should be address as specified in +// the original forward-request. +type forward struct { + newCh NewChannel // the ssh client channel underlying this forward + raddr net.Addr // the raddr of the incoming connection +} + +func (l *forwardList) add(addr net.Addr) chan forward { + l.Lock() + defer l.Unlock() + f := forwardEntry{ + laddr: addr, + c: make(chan forward, 1), + } + l.entries = append(l.entries, f) + return f.c +} + +// See RFC 4254, section 7.2 +type forwardedTCPPayload struct { + Addr string + Port uint32 + OriginAddr string + OriginPort uint32 +} + +// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr. +func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) { + if port == 0 || port > 65535 { + return nil, fmt.Errorf("ssh: port number out of range: %d", port) + } + ip := net.ParseIP(string(addr)) + if ip == nil { + return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr) + } + return &net.TCPAddr{IP: ip, Port: int(port)}, nil +} + +func (l *forwardList) handleChannels(in <-chan NewChannel) { + for ch := range in { + var ( + laddr net.Addr + raddr net.Addr + err error + ) + switch channelType := ch.ChannelType(); channelType { + case "forwarded-tcpip": + var payload forwardedTCPPayload + if err = Unmarshal(ch.ExtraData(), &payload); err != nil { + ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) + continue + } + + // RFC 4254 section 7.2 specifies that incoming + // addresses should list the address, in string + // format. It is implied that this should be an IP + // address, as it would be impossible to connect to it + // otherwise. + laddr, err = parseTCPAddr(payload.Addr, payload.Port) + if err != nil { + ch.Reject(ConnectionFailed, err.Error()) + continue + } + raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort) + if err != nil { + ch.Reject(ConnectionFailed, err.Error()) + continue + } + + case "forwarded-streamlocal@openssh.com": + var payload forwardedStreamLocalPayload + if err = Unmarshal(ch.ExtraData(), &payload); err != nil { + ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error()) + continue + } + laddr = &net.UnixAddr{ + Name: payload.SocketPath, + Net: "unix", + } + raddr = &net.UnixAddr{ + Name: "@", + Net: "unix", + } + default: + panic(fmt.Errorf("ssh: unknown channel type %s", channelType)) + } + if ok := l.forward(laddr, raddr, ch); !ok { + // Section 7.2, implementations MUST reject spurious incoming + // connections. + ch.Reject(Prohibited, "no forward for address") + continue + } + + } +} + +// remove removes the forward entry, and the channel feeding its +// listener. +func (l *forwardList) remove(addr net.Addr) { + l.Lock() + defer l.Unlock() + for i, f := range l.entries { + if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() { + l.entries = append(l.entries[:i], l.entries[i+1:]...) + close(f.c) + return + } + } +} + +// closeAll closes and clears all forwards. +func (l *forwardList) closeAll() { + l.Lock() + defer l.Unlock() + for _, f := range l.entries { + close(f.c) + } + l.entries = nil +} + +func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool { + l.Lock() + defer l.Unlock() + for _, f := range l.entries { + if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() { + f.c <- forward{newCh: ch, raddr: raddr} + return true + } + } + return false +} + +type tcpListener struct { + laddr *net.TCPAddr + + conn *Client + in <-chan forward +} + +// Accept waits for and returns the next connection to the listener. +func (l *tcpListener) Accept() (net.Conn, error) { + s, ok := <-l.in + if !ok { + return nil, io.EOF + } + ch, incoming, err := s.newCh.Accept() + if err != nil { + return nil, err + } + go DiscardRequests(incoming) + + return &chanConn{ + Channel: ch, + laddr: l.laddr, + raddr: s.raddr, + }, nil +} + +// Close closes the listener. +func (l *tcpListener) Close() error { + m := channelForwardMsg{ + l.laddr.IP.String(), + uint32(l.laddr.Port), + } + + // this also closes the listener. + l.conn.forwards.remove(l.laddr) + ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m)) + if err == nil && !ok { + err = errors.New("ssh: cancel-tcpip-forward failed") + } + return err +} + +// Addr returns the listener's network address. +func (l *tcpListener) Addr() net.Addr { + return l.laddr +} + +// DialContext initiates a connection to the addr from the remote host. +// +// The provided Context must be non-nil. If the context expires before the +// connection is complete, an error is returned. Once successfully connected, +// any expiration of the context will not affect the connection. +// +// See func Dial for additional information. +func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + type connErr struct { + conn net.Conn + err error + } + ch := make(chan connErr) + go func() { + conn, err := c.Dial(n, addr) + select { + case ch <- connErr{conn, err}: + case <-ctx.Done(): + if conn != nil { + conn.Close() + } + } + }() + select { + case res := <-ch: + return res.conn, res.err + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// Dial initiates a connection to the addr from the remote host. +// The resulting connection has a zero LocalAddr() and RemoteAddr(). +func (c *Client) Dial(n, addr string) (net.Conn, error) { + var ch Channel + switch n { + case "tcp", "tcp4", "tcp6": + // Parse the address into host and numeric port. + host, portString, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + port, err := strconv.ParseUint(portString, 10, 16) + if err != nil { + return nil, err + } + ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port)) + if err != nil { + return nil, err + } + // Use a zero address for local and remote address. + zeroAddr := &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } + return &chanConn{ + Channel: ch, + laddr: zeroAddr, + raddr: zeroAddr, + }, nil + case "unix": + var err error + ch, err = c.dialStreamLocal(addr) + if err != nil { + return nil, err + } + return &chanConn{ + Channel: ch, + laddr: &net.UnixAddr{ + Name: "@", + Net: "unix", + }, + raddr: &net.UnixAddr{ + Name: addr, + Net: "unix", + }, + }, nil + default: + return nil, fmt.Errorf("ssh: unsupported protocol: %s", n) + } +} + +// DialTCP connects to the remote address raddr on the network net, +// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used +// as the local address for the connection. +func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) { + if laddr == nil { + laddr = &net.TCPAddr{ + IP: net.IPv4zero, + Port: 0, + } + } + ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port) + if err != nil { + return nil, err + } + return &chanConn{ + Channel: ch, + laddr: laddr, + raddr: raddr, + }, nil +} + +// RFC 4254 7.2 +type channelOpenDirectMsg struct { + raddr string + rport uint32 + laddr string + lport uint32 +} + +func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) { + msg := channelOpenDirectMsg{ + raddr: raddr, + rport: uint32(rport), + laddr: laddr, + lport: uint32(lport), + } + ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg)) + if err != nil { + return nil, err + } + go DiscardRequests(in) + return ch, err +} + +type tcpChan struct { + Channel // the backing channel +} + +// chanConn fulfills the net.Conn interface without +// the tcpChan having to hold laddr or raddr directly. +type chanConn struct { + Channel + laddr, raddr net.Addr +} + +// LocalAddr returns the local network address. +func (t *chanConn) LocalAddr() net.Addr { + return t.laddr +} + +// RemoteAddr returns the remote network address. +func (t *chanConn) RemoteAddr() net.Addr { + return t.raddr +} + +// SetDeadline sets the read and write deadlines associated +// with the connection. +func (t *chanConn) SetDeadline(deadline time.Time) error { + if err := t.SetReadDeadline(deadline); err != nil { + return err + } + return t.SetWriteDeadline(deadline) +} + +// SetReadDeadline sets the read deadline. +// A zero value for t means Read will not time out. +// After the deadline, the error from Read will implement net.Error +// with Timeout() == true. +func (t *chanConn) SetReadDeadline(deadline time.Time) error { + // for compatibility with previous version, + // the error message contains "tcpChan" + return errors.New("ssh: tcpChan: deadline not supported") +} + +// SetWriteDeadline exists to satisfy the net.Conn interface +// but is not implemented by this type. It always returns an error. +func (t *chanConn) SetWriteDeadline(deadline time.Time) error { + return errors.New("ssh: tcpChan: deadline not supported") +} diff --git a/tempfork/sshtest/ssh/tcpip_test.go b/tempfork/sshtest/ssh/tcpip_test.go new file mode 100644 index 0000000000000..4d8511472782d --- /dev/null +++ b/tempfork/sshtest/ssh/tcpip_test.go @@ -0,0 +1,53 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "context" + "net" + "testing" + "time" +) + +func TestAutoPortListenBroken(t *testing.T) { + broken := "SSH-2.0-OpenSSH_5.9hh11" + works := "SSH-2.0-OpenSSH_6.1" + if !isBrokenOpenSSHVersion(broken) { + t.Errorf("version %q not marked as broken", broken) + } + if isBrokenOpenSSHVersion(works) { + t.Errorf("version %q marked as broken", works) + } +} + +func TestClientImplementsDialContext(t *testing.T) { + type ContextDialer interface { + DialContext(context.Context, string, string) (net.Conn, error) + } + // Belt and suspenders assertion, since package net does not + // declare a ContextDialer type. + var _ ContextDialer = &net.Dialer{} + var _ ContextDialer = &Client{} +} + +func TestClientDialContextWithCancel(t *testing.T) { + c := &Client{} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err := c.DialContext(ctx, "tcp", "localhost:1000") + if err != context.Canceled { + t.Errorf("DialContext: got nil error, expected %v", context.Canceled) + } +} + +func TestClientDialContextWithDeadline(t *testing.T) { + c := &Client{} + ctx, cancel := context.WithDeadline(context.Background(), time.Now()) + defer cancel() + _, err := c.DialContext(ctx, "tcp", "localhost:1000") + if err != context.DeadlineExceeded { + t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded) + } +} diff --git a/tempfork/sshtest/ssh/testdata_test.go b/tempfork/sshtest/ssh/testdata_test.go new file mode 100644 index 0000000000000..2da8c79dc64a4 --- /dev/null +++ b/tempfork/sshtest/ssh/testdata_test.go @@ -0,0 +1,63 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// IMPLEMENTATION NOTE: To avoid a package loop, this file is in three places: +// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three +// instances. + +package ssh + +import ( + "crypto/rand" + "fmt" + + "golang.org/x/crypto/ssh/testdata" +) + +var ( + testPrivateKeys map[string]interface{} + testSigners map[string]Signer + testPublicKeys map[string]PublicKey +) + +func init() { + var err error + + n := len(testdata.PEMBytes) + testPrivateKeys = make(map[string]interface{}, n) + testSigners = make(map[string]Signer, n) + testPublicKeys = make(map[string]PublicKey, n) + for t, k := range testdata.PEMBytes { + testPrivateKeys[t], err = ParseRawPrivateKey(k) + if err != nil { + panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) + } + testSigners[t], err = NewSignerFromKey(testPrivateKeys[t]) + if err != nil { + panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) + } + testPublicKeys[t] = testSigners[t].PublicKey() + } + + // Create a cert and sign it for use in tests. + testCert := &Certificate{ + Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage + ValidAfter: 0, // unix epoch + ValidBefore: CertTimeInfinity, // The end of currently representable time. + Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil + Key: testPublicKeys["ecdsa"], + SignatureKey: testPublicKeys["rsa"], + Permissions: Permissions{ + CriticalOptions: map[string]string{}, + Extensions: map[string]string{}, + }, + } + testCert.SignCert(rand.Reader, testSigners["rsa"]) + testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] + testSigners["cert"], err = NewCertSigner(testCert, testSigners["ecdsa"]) + if err != nil { + panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) + } +} diff --git a/tempfork/sshtest/ssh/transport.go b/tempfork/sshtest/ssh/transport.go new file mode 100644 index 0000000000000..0424d2d37c0bb --- /dev/null +++ b/tempfork/sshtest/ssh/transport.go @@ -0,0 +1,380 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bufio" + "bytes" + "errors" + "io" + "log" +) + +// debugTransport if set, will print packet types as they go over the +// wire. No message decoding is done, to minimize the impact on timing. +const debugTransport = false + +const ( + gcm128CipherID = "aes128-gcm@openssh.com" + gcm256CipherID = "aes256-gcm@openssh.com" + aes128cbcID = "aes128-cbc" + tripledescbcID = "3des-cbc" +) + +// packetConn represents a transport that implements packet based +// operations. +type packetConn interface { + // Encrypt and send a packet of data to the remote peer. + writePacket(packet []byte) error + + // Read a packet from the connection. The read is blocking, + // i.e. if error is nil, then the returned byte slice is + // always non-empty. + readPacket() ([]byte, error) + + // Close closes the write-side of the connection. + Close() error +} + +// transport is the keyingTransport that implements the SSH packet +// protocol. +type transport struct { + reader connectionState + writer connectionState + + bufReader *bufio.Reader + bufWriter *bufio.Writer + rand io.Reader + isClient bool + io.Closer + + strictMode bool + initialKEXDone bool +} + +// packetCipher represents a combination of SSH encryption/MAC +// protocol. A single instance should be used for one direction only. +type packetCipher interface { + // writeCipherPacket encrypts the packet and writes it to w. The + // contents of the packet are generally scrambled. + writeCipherPacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error + + // readCipherPacket reads and decrypts a packet of data. The + // returned packet may be overwritten by future calls of + // readPacket. + readCipherPacket(seqnum uint32, r io.Reader) ([]byte, error) +} + +// connectionState represents one side (read or write) of the +// connection. This is necessary because each direction has its own +// keys, and can even have its own algorithms +type connectionState struct { + packetCipher + seqNum uint32 + dir direction + pendingKeyChange chan packetCipher +} + +func (t *transport) setStrictMode() error { + if t.reader.seqNum != 1 { + return errors.New("ssh: sequence number != 1 when strict KEX mode requested") + } + t.strictMode = true + return nil +} + +func (t *transport) setInitialKEXDone() { + t.initialKEXDone = true +} + +// prepareKeyChange sets up key material for a keychange. The key changes in +// both directions are triggered by reading and writing a msgNewKey packet +// respectively. +func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error { + ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult) + if err != nil { + return err + } + t.reader.pendingKeyChange <- ciph + + ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult) + if err != nil { + return err + } + t.writer.pendingKeyChange <- ciph + + return nil +} + +func (t *transport) printPacket(p []byte, write bool) { + if len(p) == 0 { + return + } + who := "server" + if t.isClient { + who = "client" + } + what := "read" + if write { + what = "write" + } + + log.Println(what, who, p[0]) +} + +// Read and decrypt next packet. +func (t *transport) readPacket() (p []byte, err error) { + for { + p, err = t.reader.readPacket(t.bufReader, t.strictMode) + if err != nil { + break + } + // in strict mode we pass through DEBUG and IGNORE packets only during the initial KEX + if len(p) == 0 || (t.strictMode && !t.initialKEXDone) || (p[0] != msgIgnore && p[0] != msgDebug) { + break + } + } + if debugTransport { + t.printPacket(p, false) + } + + return p, err +} + +func (s *connectionState) readPacket(r *bufio.Reader, strictMode bool) ([]byte, error) { + packet, err := s.packetCipher.readCipherPacket(s.seqNum, r) + s.seqNum++ + if err == nil && len(packet) == 0 { + err = errors.New("ssh: zero length packet") + } + + if len(packet) > 0 { + switch packet[0] { + case msgNewKeys: + select { + case cipher := <-s.pendingKeyChange: + s.packetCipher = cipher + if strictMode { + s.seqNum = 0 + } + default: + return nil, errors.New("ssh: got bogus newkeys message") + } + + case msgDisconnect: + // Transform a disconnect message into an + // error. Since this is lowest level at which + // we interpret message types, doing it here + // ensures that we don't have to handle it + // elsewhere. + var msg disconnectMsg + if err := Unmarshal(packet, &msg); err != nil { + return nil, err + } + return nil, &msg + } + } + + // The packet may point to an internal buffer, so copy the + // packet out here. + fresh := make([]byte, len(packet)) + copy(fresh, packet) + + return fresh, err +} + +func (t *transport) writePacket(packet []byte) error { + if debugTransport { + t.printPacket(packet, true) + } + return t.writer.writePacket(t.bufWriter, t.rand, packet, t.strictMode) +} + +func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte, strictMode bool) error { + changeKeys := len(packet) > 0 && packet[0] == msgNewKeys + + err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet) + if err != nil { + return err + } + if err = w.Flush(); err != nil { + return err + } + s.seqNum++ + if changeKeys { + select { + case cipher := <-s.pendingKeyChange: + s.packetCipher = cipher + if strictMode { + s.seqNum = 0 + } + default: + panic("ssh: no key material for msgNewKeys") + } + } + return err +} + +func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport { + t := &transport{ + bufReader: bufio.NewReader(rwc), + bufWriter: bufio.NewWriter(rwc), + rand: rand, + reader: connectionState{ + packetCipher: &streamPacketCipher{cipher: noneCipher{}}, + pendingKeyChange: make(chan packetCipher, 1), + }, + writer: connectionState{ + packetCipher: &streamPacketCipher{cipher: noneCipher{}}, + pendingKeyChange: make(chan packetCipher, 1), + }, + Closer: rwc, + } + t.isClient = isClient + + if isClient { + t.reader.dir = serverKeys + t.writer.dir = clientKeys + } else { + t.reader.dir = clientKeys + t.writer.dir = serverKeys + } + + return t +} + +type direction struct { + ivTag []byte + keyTag []byte + macKeyTag []byte +} + +var ( + serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}} + clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}} +) + +// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as +// described in RFC 4253, section 6.4. direction should either be serverKeys +// (to setup server->client keys) or clientKeys (for client->server keys). +func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) { + cipherMode := cipherModes[algs.Cipher] + + iv := make([]byte, cipherMode.ivSize) + key := make([]byte, cipherMode.keySize) + + generateKeyMaterial(iv, d.ivTag, kex) + generateKeyMaterial(key, d.keyTag, kex) + + var macKey []byte + if !aeadCiphers[algs.Cipher] { + macMode := macModes[algs.MAC] + macKey = make([]byte, macMode.keySize) + generateKeyMaterial(macKey, d.macKeyTag, kex) + } + + return cipherModes[algs.Cipher].create(key, iv, macKey, algs) +} + +// generateKeyMaterial fills out with key material generated from tag, K, H +// and sessionId, as specified in RFC 4253, section 7.2. +func generateKeyMaterial(out, tag []byte, r *kexResult) { + var digestsSoFar []byte + + h := r.Hash.New() + for len(out) > 0 { + h.Reset() + h.Write(r.K) + h.Write(r.H) + + if len(digestsSoFar) == 0 { + h.Write(tag) + h.Write(r.SessionID) + } else { + h.Write(digestsSoFar) + } + + digest := h.Sum(nil) + n := copy(out, digest) + out = out[n:] + if len(out) > 0 { + digestsSoFar = append(digestsSoFar, digest...) + } + } +} + +const packageVersion = "SSH-2.0-Go" + +// Sends and receives a version line. The versionLine string should +// be US ASCII, start with "SSH-2.0-", and should not include a +// newline. exchangeVersions returns the other side's version line. +func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) { + // Contrary to the RFC, we do not ignore lines that don't + // start with "SSH-2.0-" to make the library usable with + // nonconforming servers. + for _, c := range versionLine { + // The spec disallows non US-ASCII chars, and + // specifically forbids null chars. + if c < 32 { + return nil, errors.New("ssh: junk character in version line") + } + } + if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil { + return + } + + them, err = readVersion(rw) + return them, err +} + +// maxVersionStringBytes is the maximum number of bytes that we'll +// accept as a version string. RFC 4253 section 4.2 limits this at 255 +// chars +const maxVersionStringBytes = 255 + +// Read version string as specified by RFC 4253, section 4.2. +func readVersion(r io.Reader) ([]byte, error) { + versionString := make([]byte, 0, 64) + var ok bool + var buf [1]byte + + for length := 0; length < maxVersionStringBytes; length++ { + _, err := io.ReadFull(r, buf[:]) + if err != nil { + return nil, err + } + // The RFC says that the version should be terminated with \r\n + // but several SSH servers actually only send a \n. + if buf[0] == '\n' { + if !bytes.HasPrefix(versionString, []byte("SSH-")) { + // RFC 4253 says we need to ignore all version string lines + // except the one containing the SSH version (provided that + // all the lines do not exceed 255 bytes in total). + versionString = versionString[:0] + continue + } + ok = true + break + } + + // non ASCII chars are disallowed, but we are lenient, + // since Go doesn't use null-terminated strings. + + // The RFC allows a comment after a space, however, + // all of it (version and comments) goes into the + // session hash. + versionString = append(versionString, buf[0]) + } + + if !ok { + return nil, errors.New("ssh: overflow reading version string") + } + + // There might be a '\r' on the end which we should remove. + if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' { + versionString = versionString[:len(versionString)-1] + } + return versionString, nil +} diff --git a/tempfork/sshtest/ssh/transport_test.go b/tempfork/sshtest/ssh/transport_test.go new file mode 100644 index 0000000000000..8445e1e561c15 --- /dev/null +++ b/tempfork/sshtest/ssh/transport_test.go @@ -0,0 +1,113 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssh + +import ( + "bytes" + "crypto/rand" + "encoding/binary" + "strings" + "testing" +) + +func TestReadVersion(t *testing.T) { + longVersion := strings.Repeat("SSH-2.0-bla", 50)[:253] + multiLineVersion := strings.Repeat("ignored\r\n", 20) + "SSH-2.0-bla\r\n" + cases := map[string]string{ + "SSH-2.0-bla\r\n": "SSH-2.0-bla", + "SSH-2.0-bla\n": "SSH-2.0-bla", + multiLineVersion: "SSH-2.0-bla", + longVersion + "\r\n": longVersion, + } + + for in, want := range cases { + result, err := readVersion(bytes.NewBufferString(in)) + if err != nil { + t.Errorf("readVersion(%q): %s", in, err) + } + got := string(result) + if got != want { + t.Errorf("got %q, want %q", got, want) + } + } +} + +func TestReadVersionError(t *testing.T) { + longVersion := strings.Repeat("SSH-2.0-bla", 50)[:253] + multiLineVersion := strings.Repeat("ignored\r\n", 50) + "SSH-2.0-bla\r\n" + cases := []string{ + longVersion + "too-long\r\n", + multiLineVersion, + } + for _, in := range cases { + if _, err := readVersion(bytes.NewBufferString(in)); err == nil { + t.Errorf("readVersion(%q) should have failed", in) + } + } +} + +func TestExchangeVersionsBasic(t *testing.T) { + v := "SSH-2.0-bla" + buf := bytes.NewBufferString(v + "\r\n") + them, err := exchangeVersions(buf, []byte("xyz")) + if err != nil { + t.Errorf("exchangeVersions: %v", err) + } + + if want := "SSH-2.0-bla"; string(them) != want { + t.Errorf("got %q want %q for our version", them, want) + } +} + +func TestExchangeVersions(t *testing.T) { + cases := []string{ + "not\x000allowed", + "not allowed\x01\r\n", + } + for _, c := range cases { + buf := bytes.NewBufferString("SSH-2.0-bla\r\n") + if _, err := exchangeVersions(buf, []byte(c)); err == nil { + t.Errorf("exchangeVersions(%q): should have failed", c) + } + } +} + +type closerBuffer struct { + bytes.Buffer +} + +func (b *closerBuffer) Close() error { + return nil +} + +func TestTransportMaxPacketWrite(t *testing.T) { + buf := &closerBuffer{} + tr := newTransport(buf, rand.Reader, true) + huge := make([]byte, maxPacket+1) + err := tr.writePacket(huge) + if err == nil { + t.Errorf("transport accepted write for a huge packet.") + } +} + +func TestTransportMaxPacketReader(t *testing.T) { + var header [5]byte + huge := make([]byte, maxPacket+128) + binary.BigEndian.PutUint32(header[0:], uint32(len(huge))) + // padding. + header[4] = 0 + + buf := &closerBuffer{} + buf.Write(header[:]) + buf.Write(huge) + + tr := newTransport(buf, rand.Reader, true) + _, err := tr.readPacket() + if err == nil { + t.Errorf("transport succeeded reading huge packet.") + } else if !strings.Contains(err.Error(), "large") { + t.Errorf("got %q, should mention %q", err.Error(), "large") + } +} diff --git a/tool/gocross/autoflags.go b/tool/gocross/autoflags.go index c66cab55a6770..b28d3bc5dd26e 100644 --- a/tool/gocross/autoflags.go +++ b/tool/gocross/autoflags.go @@ -35,7 +35,7 @@ func autoflagsForTest(argv []string, env *Environment, goroot, nativeGOOS, nativ cc = "cc" targetOS = cmp.Or(env.Get("GOOS", ""), nativeGOOS) targetArch = cmp.Or(env.Get("GOARCH", ""), nativeGOARCH) - buildFlags = []string{"-trimpath"} + buildFlags = []string{} cgoCflags = []string{"-O3", "-std=gnu11", "-g"} cgoLdflags []string ldflags []string @@ -47,6 +47,10 @@ func autoflagsForTest(argv []string, env *Environment, goroot, nativeGOOS, nativ subcommand = argv[1] } + if subcommand != "test" { + buildFlags = append(buildFlags, "-trimpath") + } + switch subcommand { case "build", "env", "install", "run", "test", "list": default: @@ -146,7 +150,11 @@ func autoflagsForTest(argv []string, env *Environment, goroot, nativeGOOS, nativ case env.IsSet("MACOSX_DEPLOYMENT_TARGET"): xcodeFlags = append(xcodeFlags, "-mmacosx-version-min="+env.Get("MACOSX_DEPLOYMENT_TARGET", "")) case env.IsSet("TVOS_DEPLOYMENT_TARGET"): - xcodeFlags = append(xcodeFlags, "-mtvos-version-min="+env.Get("TVOS_DEPLOYMENT_TARGET", "")) + if env.Get("TARGET_DEVICE_PLATFORM_NAME", "") == "appletvsimulator" { + xcodeFlags = append(xcodeFlags, "-mtvos-simulator-version-min="+env.Get("TVOS_DEPLOYMENT_TARGET", "")) + } else { + xcodeFlags = append(xcodeFlags, "-mtvos-version-min="+env.Get("TVOS_DEPLOYMENT_TARGET", "")) + } default: return nil, nil, fmt.Errorf("invoked by Xcode but couldn't figure out deployment target. Did Xcode change its envvars again?") } diff --git a/tool/gocross/autoflags_test.go b/tool/gocross/autoflags_test.go index 8f24dd8a32797..a0f3edfd2bb68 100644 --- a/tool/gocross/autoflags_test.go +++ b/tool/gocross/autoflags_test.go @@ -163,7 +163,6 @@ GOTOOLCHAIN=local (was ) TS_LINK_FAIL_REFLECT=0 (was )`, wantArgv: []string{ "gocross", "test", - "-trimpath", "-tags=tailscale_go,osusergo,netgo", "-ldflags", "-X tailscale.com/version.longStamp=1.2.3-long -X tailscale.com/version.shortStamp=1.2.3 -X tailscale.com/version.gitCommitStamp=abcd -X tailscale.com/version.extraGitCommitStamp=defg '-extldflags=-static'", "-race", diff --git a/tool/gocross/gocross.go b/tool/gocross/gocross.go index 8011c10956c05..d14ea03885868 100644 --- a/tool/gocross/gocross.go +++ b/tool/gocross/gocross.go @@ -15,9 +15,9 @@ import ( "fmt" "os" "path/filepath" + "runtime/debug" "tailscale.com/atomicfile" - "tailscale.com/version" ) func main() { @@ -28,8 +28,19 @@ func main() { // any time. switch os.Args[1] { case "gocross-version": - fmt.Println(version.GetMeta().GitCommit) - os.Exit(0) + bi, ok := debug.ReadBuildInfo() + if !ok { + fmt.Fprintln(os.Stderr, "failed getting build info") + os.Exit(1) + } + for _, s := range bi.Settings { + if s.Key == "vcs.revision" { + fmt.Println(s.Value) + os.Exit(0) + } + } + fmt.Fprintln(os.Stderr, "did not find vcs.revision in build info") + os.Exit(1) case "is-gocross": // This subcommand exits with an error code when called on a // regular go binary, so it can be used to detect when `go` is @@ -85,9 +96,9 @@ func main() { path := filepath.Join(toolchain, "bin") + string(os.PathListSeparator) + os.Getenv("PATH") env.Set("PATH", path) - debug("Input: %s\n", formatArgv(os.Args)) - debug("Command: %s\n", formatArgv(newArgv)) - debug("Set the following flags/envvars:\n%s\n", env.Diff()) + debugf("Input: %s\n", formatArgv(os.Args)) + debugf("Command: %s\n", formatArgv(newArgv)) + debugf("Set the following flags/envvars:\n%s\n", env.Diff()) args = newArgv if err := env.Apply(); err != nil { @@ -103,7 +114,7 @@ func main() { //go:embed gocross-wrapper.sh var wrapperScript []byte -func debug(format string, args ...any) { +func debugf(format string, args ...any) { debug := os.Getenv("GOCROSS_DEBUG") var ( out *os.File diff --git a/tool/gocross/gocross_test.go b/tool/gocross/gocross_test.go new file mode 100644 index 0000000000000..82afd268c6d8f --- /dev/null +++ b/tool/gocross/gocross_test.go @@ -0,0 +1,19 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "testing" + + "tailscale.com/tstest/deptest" +) + +func TestDeps(t *testing.T) { + deptest.DepChecker{ + BadDeps: map[string]string{ + "tailscale.com/tailcfg": "circular dependency via go generate", + "tailscale.com/version": "circular dependency via go generate", + }, + }.Check(t) +} diff --git a/tool/node.rev b/tool/node.rev index 17719ce25a0c2..7d41c735d7127 100644 --- a/tool/node.rev +++ b/tool/node.rev @@ -1 +1 @@ -18.20.4 +22.14.0 diff --git a/tsconsensus/authorization.go b/tsconsensus/authorization.go new file mode 100644 index 0000000000000..1e0b70c0759d3 --- /dev/null +++ b/tsconsensus/authorization.go @@ -0,0 +1,134 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsconsensus + +import ( + "context" + "errors" + "net/netip" + "sync" + "time" + + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tsnet" + "tailscale.com/types/views" + "tailscale.com/util/set" +) + +type statusGetter interface { + getStatus(context.Context) (*ipnstate.Status, error) +} + +type tailscaleStatusGetter struct { + ts *tsnet.Server + + mu sync.Mutex // protects the following + lastStatus *ipnstate.Status + lastStatusTime time.Time +} + +func (sg *tailscaleStatusGetter) fetchStatus(ctx context.Context) (*ipnstate.Status, error) { + lc, err := sg.ts.LocalClient() + if err != nil { + return nil, err + } + return lc.Status(ctx) +} + +func (sg *tailscaleStatusGetter) getStatus(ctx context.Context) (*ipnstate.Status, error) { + sg.mu.Lock() + defer sg.mu.Unlock() + if sg.lastStatus != nil && time.Since(sg.lastStatusTime) < 1*time.Second { + return sg.lastStatus, nil + } + status, err := sg.fetchStatus(ctx) + if err != nil { + return nil, err + } + sg.lastStatus = status + sg.lastStatusTime = time.Now() + return status, nil +} + +type authorization struct { + sg statusGetter + tag string + + mu sync.Mutex + peers *peers // protected by mu +} + +func newAuthorization(ts *tsnet.Server, tag string) *authorization { + return &authorization{ + sg: &tailscaleStatusGetter{ + ts: ts, + }, + tag: tag, + } +} + +func (a *authorization) Refresh(ctx context.Context) error { + tStatus, err := a.sg.getStatus(ctx) + if err != nil { + return err + } + if tStatus == nil { + return errors.New("no status") + } + if tStatus.BackendState != ipn.Running.String() { + return errors.New("ts Server is not running") + } + a.mu.Lock() + defer a.mu.Unlock() + a.peers = newPeers(tStatus, a.tag) + return nil +} + +func (a *authorization) AllowsHost(addr netip.Addr) bool { + if a.peers == nil { + return false + } + a.mu.Lock() + defer a.mu.Unlock() + return a.peers.addrs.Contains(addr) +} + +func (a *authorization) SelfAllowed() bool { + if a.peers == nil { + return false + } + a.mu.Lock() + defer a.mu.Unlock() + return a.peers.status.Self.Tags != nil && views.SliceContains(*a.peers.status.Self.Tags, a.tag) +} + +func (a *authorization) AllowedPeers() views.Slice[*ipnstate.PeerStatus] { + if a.peers == nil { + return views.Slice[*ipnstate.PeerStatus]{} + } + a.mu.Lock() + defer a.mu.Unlock() + return views.SliceOf(a.peers.statuses) +} + +type peers struct { + status *ipnstate.Status + addrs set.Set[netip.Addr] + statuses []*ipnstate.PeerStatus +} + +func newPeers(status *ipnstate.Status, tag string) *peers { + ps := &peers{ + status: status, + addrs: set.Set[netip.Addr]{}, + } + for _, p := range status.Peer { + if p.Tags != nil && views.SliceContains(*p.Tags, tag) { + ps.statuses = append(ps.statuses, p) + ps.addrs.AddSlice(p.TailscaleIPs) + } + } + return ps +} diff --git a/tsconsensus/authorization_test.go b/tsconsensus/authorization_test.go new file mode 100644 index 0000000000000..e0023f4ff24d2 --- /dev/null +++ b/tsconsensus/authorization_test.go @@ -0,0 +1,230 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsconsensus + +import ( + "context" + "fmt" + "net/netip" + "testing" + + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/views" +) + +type testStatusGetter struct { + status *ipnstate.Status +} + +func (sg testStatusGetter) getStatus(ctx context.Context) (*ipnstate.Status, error) { + return sg.status, nil +} + +const testTag string = "tag:clusterTag" + +func makeAuthTestPeer(i int, tags views.Slice[string]) *ipnstate.PeerStatus { + return &ipnstate.PeerStatus{ + ID: tailcfg.StableNodeID(fmt.Sprintf("%d", i)), + Tags: &tags, + TailscaleIPs: []netip.Addr{ + netip.AddrFrom4([4]byte{100, 0, 0, byte(i)}), + netip.MustParseAddr(fmt.Sprintf("fd7a:115c:a1e0:0::%d", i)), + }, + } +} + +func makeAuthTestPeers(tags [][]string) []*ipnstate.PeerStatus { + peers := make([]*ipnstate.PeerStatus, len(tags)) + for i, ts := range tags { + peers[i] = makeAuthTestPeer(i, views.SliceOf(ts)) + } + return peers +} + +func authForStatus(s *ipnstate.Status) *authorization { + return &authorization{ + sg: testStatusGetter{ + status: s, + }, + tag: testTag, + } +} + +func authForPeers(self *ipnstate.PeerStatus, peers []*ipnstate.PeerStatus) *authorization { + s := &ipnstate.Status{ + BackendState: ipn.Running.String(), + Self: self, + Peer: map[key.NodePublic]*ipnstate.PeerStatus{}, + } + for _, p := range peers { + s.Peer[key.NewNode().Public()] = p + } + return authForStatus(s) +} + +func TestAuthRefreshErrorsNotRunning(t *testing.T) { + tests := []struct { + in *ipnstate.Status + expected string + }{ + { + in: nil, + expected: "no status", + }, + { + in: &ipnstate.Status{ + BackendState: "NeedsMachineAuth", + }, + expected: "ts Server is not running", + }, + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + ctx := t.Context() + a := authForStatus(tt.in) + err := a.Refresh(ctx) + if err == nil { + t.Fatalf("expected err to be non-nil") + } + if err.Error() != tt.expected { + t.Fatalf("expected: %s, got: %s", tt.expected, err.Error()) + } + }) + } +} + +func TestAuthUnrefreshed(t *testing.T) { + a := authForStatus(nil) + if a.AllowsHost(netip.MustParseAddr("100.0.0.1")) { + t.Fatalf("never refreshed authorization, allowsHost: expected false, got true") + } + gotAllowedPeers := a.AllowedPeers() + if gotAllowedPeers.Len() != 0 { + t.Fatalf("never refreshed authorization, allowedPeers: expected [], got %v", gotAllowedPeers) + } + if a.SelfAllowed() != false { + t.Fatalf("never refreshed authorization, selfAllowed: expected false got true") + } +} + +func TestAuthAllowsHost(t *testing.T) { + peerTags := [][]string{ + {"woo"}, + nil, + {"woo", testTag}, + {testTag}, + } + peers := makeAuthTestPeers(peerTags) + + tests := []struct { + name string + peerStatus *ipnstate.PeerStatus + expected bool + }{ + { + name: "tagged with different tag", + peerStatus: peers[0], + expected: false, + }, + { + name: "not tagged", + peerStatus: peers[1], + expected: false, + }, + { + name: "tags includes testTag", + peerStatus: peers[2], + expected: true, + }, + { + name: "only tag is testTag", + peerStatus: peers[3], + expected: true, + }, + } + + a := authForPeers(nil, peers) + err := a.Refresh(t.Context()) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // test we get the expected result for any of the peers TailscaleIPs + for _, addr := range tt.peerStatus.TailscaleIPs { + got := a.AllowsHost(addr) + if got != tt.expected { + t.Fatalf("allowed for peer with tags: %v, expected: %t, got %t", tt.peerStatus.Tags, tt.expected, got) + } + } + }) + } +} + +func TestAuthAllowedPeers(t *testing.T) { + ctx := t.Context() + peerTags := [][]string{ + {"woo"}, + nil, + {"woo", testTag}, + {testTag}, + } + peers := makeAuthTestPeers(peerTags) + a := authForPeers(nil, peers) + err := a.Refresh(ctx) + if err != nil { + t.Fatal(err) + } + ps := a.AllowedPeers() + if ps.Len() != 2 { + t.Fatalf("expected: 2, got: %d", ps.Len()) + } + for _, i := range []int{2, 3} { + if !ps.ContainsFunc(func(p *ipnstate.PeerStatus) bool { + return p.ID == peers[i].ID + }) { + t.Fatalf("expected peers[%d] to be in AllowedPeers because it is tagged with testTag", i) + } + } +} + +func TestAuthSelfAllowed(t *testing.T) { + tests := []struct { + name string + in []string + expected bool + }{ + { + name: "self has different tag", + in: []string{"woo"}, + expected: false, + }, + { + name: "selfs tags include testTag", + in: []string{"woo", testTag}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := t.Context() + self := makeAuthTestPeer(0, views.SliceOf(tt.in)) + a := authForPeers(self, nil) + err := a.Refresh(ctx) + if err != nil { + t.Fatal(err) + } + got := a.SelfAllowed() + if got != tt.expected { + t.Fatalf("expected: %t, got: %t", tt.expected, got) + } + }) + } +} diff --git a/tsconsensus/http.go b/tsconsensus/http.go new file mode 100644 index 0000000000000..d2a44015f8f68 --- /dev/null +++ b/tsconsensus/http.go @@ -0,0 +1,182 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsconsensus + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "time" + + "tailscale.com/util/httpm" +) + +type joinRequest struct { + RemoteHost string + RemoteID string +} + +type commandClient struct { + port uint16 + httpClient *http.Client +} + +func (rac *commandClient) url(host string, path string) string { + return fmt.Sprintf("http://%s:%d%s", host, rac.port, path) +} + +const maxBodyBytes = 1024 * 1024 + +func readAllMaxBytes(r io.Reader) ([]byte, error) { + return io.ReadAll(io.LimitReader(r, maxBodyBytes+1)) +} + +func (rac *commandClient) join(host string, jr joinRequest) error { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + rBs, err := json.Marshal(jr) + if err != nil { + return err + } + url := rac.url(host, "/join") + req, err := http.NewRequestWithContext(ctx, httpm.POST, url, bytes.NewReader(rBs)) + if err != nil { + return err + } + resp, err := rac.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + respBs, err := readAllMaxBytes(resp.Body) + if err != nil { + return err + } + return fmt.Errorf("remote responded %d: %s", resp.StatusCode, string(respBs)) + } + return nil +} + +func (rac *commandClient) executeCommand(host string, bs []byte) (CommandResult, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + url := rac.url(host, "/executeCommand") + req, err := http.NewRequestWithContext(ctx, httpm.POST, url, bytes.NewReader(bs)) + if err != nil { + return CommandResult{}, err + } + resp, err := rac.httpClient.Do(req) + if err != nil { + return CommandResult{}, err + } + defer resp.Body.Close() + respBs, err := readAllMaxBytes(resp.Body) + if err != nil { + return CommandResult{}, err + } + if resp.StatusCode != 200 { + return CommandResult{}, fmt.Errorf("remote responded %d: %s", resp.StatusCode, string(respBs)) + } + var cr CommandResult + if err = json.Unmarshal(respBs, &cr); err != nil { + return CommandResult{}, err + } + return cr, nil +} + +type authedHandler struct { + auth *authorization + handler http.Handler +} + +func (h authedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + err := h.auth.Refresh(r.Context()) + if err != nil { + log.Printf("error authedHandler ServeHTTP refresh auth: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + a, err := addrFromServerAddress(r.RemoteAddr) + if err != nil { + log.Printf("error authedHandler ServeHTTP refresh auth: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + allowed := h.auth.AllowsHost(a) + if !allowed { + http.Error(w, "peer not allowed", http.StatusForbidden) + return + } + h.handler.ServeHTTP(w, r) +} + +func (c *Consensus) handleJoinHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + decoder := json.NewDecoder(http.MaxBytesReader(w, r.Body, maxBodyBytes+1)) + var jr joinRequest + err := decoder.Decode(&jr) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + _, err = decoder.Token() + if !errors.Is(err, io.EOF) { + http.Error(w, "Request body must only contain a single JSON object", http.StatusBadRequest) + return + } + if jr.RemoteHost == "" { + http.Error(w, "Required: remoteAddr", http.StatusBadRequest) + return + } + if jr.RemoteID == "" { + http.Error(w, "Required: remoteID", http.StatusBadRequest) + return + } + err = c.handleJoin(jr) + if err != nil { + log.Printf("join handler error: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } +} + +func (c *Consensus) handleExecuteCommandHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + decoder := json.NewDecoder(r.Body) + var cmd Command + err := decoder.Decode(&cmd) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + result, err := c.executeCommandLocally(cmd) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := json.NewEncoder(w).Encode(result); err != nil { + log.Printf("error encoding execute command result: %v", err) + return + } +} + +func (c *Consensus) makeCommandMux() *http.ServeMux { + mux := http.NewServeMux() + mux.HandleFunc("POST /join", c.handleJoinHTTP) + mux.HandleFunc("POST /executeCommand", c.handleExecuteCommandHTTP) + return mux +} + +func (c *Consensus) makeCommandHandler(auth *authorization) http.Handler { + return authedHandler{ + handler: c.makeCommandMux(), + auth: auth, + } +} diff --git a/tsconsensus/monitor.go b/tsconsensus/monitor.go new file mode 100644 index 0000000000000..61a5a74a07c42 --- /dev/null +++ b/tsconsensus/monitor.go @@ -0,0 +1,160 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsconsensus + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "slices" + + "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tsnet" + "tailscale.com/util/dnsname" +) + +type status struct { + Status *ipnstate.Status + RaftState string +} + +type monitor struct { + ts *tsnet.Server + con *Consensus + sg statusGetter +} + +func (m *monitor) getStatus(ctx context.Context) (status, error) { + tStatus, err := m.sg.getStatus(ctx) + if err != nil { + return status{}, err + } + return status{Status: tStatus, RaftState: m.con.raft.State().String()}, nil +} + +func serveMonitor(c *Consensus, ts *tsnet.Server, listenAddr string) (*http.Server, error) { + ln, err := ts.Listen("tcp", listenAddr) + if err != nil { + return nil, err + } + m := &monitor{con: c, ts: ts, sg: &tailscaleStatusGetter{ + ts: ts, + }} + mux := http.NewServeMux() + mux.HandleFunc("GET /full", m.handleFullStatus) + mux.HandleFunc("GET /{$}", m.handleSummaryStatus) + mux.HandleFunc("GET /netmap", m.handleNetmap) + mux.HandleFunc("POST /dial", m.handleDial) + srv := &http.Server{Handler: mux} + go func() { + err := srv.Serve(ln) + log.Printf("MonitorHTTP stopped serving with error: %v", err) + }() + return srv, nil +} + +func (m *monitor) handleFullStatus(w http.ResponseWriter, r *http.Request) { + s, err := m.getStatus(r.Context()) + if err != nil { + log.Printf("monitor: error getStatus: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + if err := json.NewEncoder(w).Encode(s); err != nil { + log.Printf("monitor: error encoding full status: %v", err) + return + } +} + +func (m *monitor) handleSummaryStatus(w http.ResponseWriter, r *http.Request) { + s, err := m.getStatus(r.Context()) + if err != nil { + log.Printf("monitor: error getStatus: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + lines := []string{} + for _, p := range s.Status.Peer { + if p.Online { + name := dnsname.FirstLabel(p.DNSName) + lines = append(lines, fmt.Sprintf("%s\t\t%d\t%d\t%t", name, p.RxBytes, p.TxBytes, p.Active)) + } + } + _, err = w.Write([]byte(fmt.Sprintf("RaftState: %s\n", s.RaftState))) + if err != nil { + log.Printf("monitor: error writing status: %v", err) + return + } + + slices.Sort(lines) + for _, l := range lines { + _, err = w.Write([]byte(fmt.Sprintf("%s\n", l))) + if err != nil { + log.Printf("monitor: error writing status: %v", err) + return + } + } +} + +func (m *monitor) handleNetmap(w http.ResponseWriter, r *http.Request) { + var mask ipn.NotifyWatchOpt = ipn.NotifyInitialNetMap + mask |= ipn.NotifyNoPrivateKeys + lc, err := m.ts.LocalClient() + if err != nil { + log.Printf("monitor: error LocalClient: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + watcher, err := lc.WatchIPNBus(r.Context(), mask) + if err != nil { + log.Printf("monitor: error WatchIPNBus: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + defer watcher.Close() + + n, err := watcher.Next() + if err != nil { + log.Printf("monitor: error watcher.Next: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + encoder := json.NewEncoder(w) + encoder.SetIndent("", "\t") + if err := encoder.Encode(n); err != nil { + log.Printf("monitor: error encoding netmap: %v", err) + return + } +} + +func (m *monitor) handleDial(w http.ResponseWriter, r *http.Request) { + var dialParams struct { + Addr string + } + defer r.Body.Close() + bs, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxBodyBytes)) + if err != nil { + log.Printf("monitor: error reading body: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + err = json.Unmarshal(bs, &dialParams) + if err != nil { + log.Printf("monitor: error unmarshalling json: %v", err) + http.Error(w, "", http.StatusBadRequest) + return + } + c, err := m.ts.Dial(r.Context(), "tcp", dialParams.Addr) + if err != nil { + log.Printf("monitor: error dialing: %v", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + c.Close() + w.Write([]byte("ok\n")) +} diff --git a/tsconsensus/tsconsensus.go b/tsconsensus/tsconsensus.go new file mode 100644 index 0000000000000..74094782f4383 --- /dev/null +++ b/tsconsensus/tsconsensus.go @@ -0,0 +1,447 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package tsconsensus implements a consensus algorithm for a group of tsnet.Servers +// +// The Raft consensus algorithm relies on you implementing a state machine that will give the same +// result to a given command as long as the same logs have been applied in the same order. +// +// tsconsensus uses the hashicorp/raft library to implement leader elections and log application. +// +// tsconsensus provides: +// - cluster peer discovery based on tailscale tags +// - executing a command on the leader +// - communication between cluster peers over tailscale using tsnet +// +// Users implement a state machine that satisfies the raft.FSM interface, with the business logic they desire. +// When changes to state are needed any node may +// - create a Command instance with serialized Args. +// - call ExecuteCommand with the Command instance +// this will propagate the command to the leader, +// and then from the reader to every node via raft. +// - the state machine then can implement raft.Apply, and dispatch commands via the Command.Name +// returning a CommandResult with an Err or a serialized Result. +package tsconsensus + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "net" + "net/http" + "net/netip" + "time" + + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/raft" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tsnet" + "tailscale.com/types/views" +) + +func raftAddr(host netip.Addr, cfg Config) string { + return netip.AddrPortFrom(host, cfg.RaftPort).String() +} + +func addrFromServerAddress(sa string) (netip.Addr, error) { + addrPort, err := netip.ParseAddrPort(sa) + if err != nil { + return netip.Addr{}, err + } + return addrPort.Addr(), nil +} + +// A selfRaftNode is the info we need to talk to hashicorp/raft about our node. +// We specify the ID and Addr on Consensus Start, and then use it later for raft +// operations such as BootstrapCluster and AddVoter. +type selfRaftNode struct { + id string + hostAddr netip.Addr +} + +// A Config holds configurable values such as ports and timeouts. +// Use DefaultConfig to get a useful Config. +type Config struct { + CommandPort uint16 + RaftPort uint16 + MonitorPort uint16 + Raft *raft.Config + MaxConnPool int + ConnTimeout time.Duration + ServeDebugMonitor bool +} + +// DefaultConfig returns a Config populated with default values ready for use. +func DefaultConfig() Config { + raftConfig := raft.DefaultConfig() + // these values are 2x the raft DefaultConfig + raftConfig.HeartbeatTimeout = 2000 * time.Millisecond + raftConfig.ElectionTimeout = 2000 * time.Millisecond + raftConfig.LeaderLeaseTimeout = 1000 * time.Millisecond + + return Config{ + CommandPort: 6271, + RaftPort: 6270, + MonitorPort: 8081, + Raft: raftConfig, + MaxConnPool: 5, + ConnTimeout: 5 * time.Second, + } +} + +// StreamLayer implements an interface asked for by raft.NetworkTransport. +// It does the raft interprocess communication via tailscale. +type StreamLayer struct { + net.Listener + s *tsnet.Server + auth *authorization + shutdownCtx context.Context +} + +// Dial implements the raft.StreamLayer interface with the tsnet.Server's Dial. +func (sl StreamLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) { + ctx, cancel := context.WithTimeout(sl.shutdownCtx, timeout) + defer cancel() + authorized, err := sl.addrAuthorized(ctx, string(address)) + if err != nil { + return nil, err + } + if !authorized { + return nil, errors.New("dial: peer is not allowed") + } + return sl.s.Dial(ctx, "tcp", string(address)) +} + +func (sl StreamLayer) addrAuthorized(ctx context.Context, address string) (bool, error) { + addr, err := addrFromServerAddress(address) + if err != nil { + // bad RemoteAddr is not authorized + return false, nil + } + err = sl.auth.Refresh(ctx) + if err != nil { + // might be authorized, we couldn't tell + return false, err + } + return sl.auth.AllowsHost(addr), nil +} + +func (sl StreamLayer) Accept() (net.Conn, error) { + ctx, cancel := context.WithCancel(sl.shutdownCtx) + defer cancel() + for { + conn, err := sl.Listener.Accept() + if err != nil || conn == nil { + return conn, err + } + addr := conn.RemoteAddr() + if addr == nil { + conn.Close() + return nil, errors.New("conn has no remote addr") + } + authorized, err := sl.addrAuthorized(ctx, addr.String()) + if err != nil { + conn.Close() + return nil, err + } + if !authorized { + log.Printf("StreamLayer accept: unauthorized: %s", addr) + conn.Close() + continue + } + return conn, err + } +} + +// Start returns a pointer to a running Consensus instance. +// Calling it with a *tsnet.Server will cause that server to join or start a consensus cluster +// with other nodes on the tailnet tagged with the clusterTag. The *tsnet.Server will run the state +// machine defined by the raft.FSM also provided, and keep it in sync with the other cluster members' +// state machines using Raft. +func Start(ctx context.Context, ts *tsnet.Server, fsm raft.FSM, clusterTag string, cfg Config) (*Consensus, error) { + if clusterTag == "" { + return nil, errors.New("cluster tag must be provided") + } + + cc := commandClient{ + port: cfg.CommandPort, + httpClient: ts.HTTPClient(), + } + v4, _ := ts.TailscaleIPs() + // TODO(fran) support tailnets that have ipv4 disabled + self := selfRaftNode{ + id: v4.String(), + hostAddr: v4, + } + shutdownCtx, shutdownCtxCancel := context.WithCancel(ctx) + c := Consensus{ + commandClient: &cc, + self: self, + config: cfg, + shutdownCtxCancel: shutdownCtxCancel, + } + + auth := newAuthorization(ts, clusterTag) + err := auth.Refresh(shutdownCtx) + if err != nil { + return nil, fmt.Errorf("auth refresh: %w", err) + } + if !auth.SelfAllowed() { + return nil, errors.New("this node is not tagged with the cluster tag") + } + + srv, err := c.serveCommandHTTP(ts, auth) + if err != nil { + return nil, err + } + c.cmdHttpServer = srv + + // after startRaft it's possible some other raft node that has us in their configuration will get + // in contact, so by the time we do anything else we may already be a functioning member + // of a consensus + r, err := startRaft(shutdownCtx, ts, &fsm, c.self, auth, cfg) + if err != nil { + return nil, err + } + c.raft = r + + c.bootstrap(auth.AllowedPeers()) + + if cfg.ServeDebugMonitor { + srv, err = serveMonitor(&c, ts, netip.AddrPortFrom(c.self.hostAddr, cfg.MonitorPort).String()) + if err != nil { + return nil, err + } + c.monitorHttpServer = srv + } + + return &c, nil +} + +func startRaft(shutdownCtx context.Context, ts *tsnet.Server, fsm *raft.FSM, self selfRaftNode, auth *authorization, cfg Config) (*raft.Raft, error) { + cfg.Raft.LocalID = raft.ServerID(self.id) + + // no persistence (for now?) + logStore := raft.NewInmemStore() + stableStore := raft.NewInmemStore() + snapshots := raft.NewInmemSnapshotStore() + + // opens the listener on the raft port, raft will close it when it thinks it's appropriate + ln, err := ts.Listen("tcp", raftAddr(self.hostAddr, cfg)) + if err != nil { + return nil, err + } + + logger := hclog.New(&hclog.LoggerOptions{ + Name: "raft-net", + Output: cfg.Raft.LogOutput, + Level: hclog.LevelFromString(cfg.Raft.LogLevel), + }) + + transport := raft.NewNetworkTransportWithLogger(StreamLayer{ + s: ts, + Listener: ln, + auth: auth, + shutdownCtx: shutdownCtx, + }, + cfg.MaxConnPool, + cfg.ConnTimeout, + logger) + + return raft.NewRaft(cfg.Raft, *fsm, logStore, stableStore, snapshots, transport) +} + +// A Consensus is the consensus algorithm for a tsnet.Server +// It wraps a raft.Raft instance and performs the peer discovery +// and command execution on the leader. +type Consensus struct { + raft *raft.Raft + commandClient *commandClient + self selfRaftNode + config Config + cmdHttpServer *http.Server + monitorHttpServer *http.Server + shutdownCtxCancel context.CancelFunc +} + +// bootstrap tries to join a raft cluster, or start one. +// +// We need to do the very first raft cluster configuration, but after that raft manages it. +// bootstrap is called at start up, and we are not currently aware of what the cluster config might be, +// our node may already be in it. Try to join the raft cluster of all the other nodes we know about, and +// if unsuccessful, assume we are the first and start our own. +// +// It's possible for bootstrap to return an error, or start a errant breakaway cluster. +// +// We have a list of expected cluster members already from control (the members of the tailnet with the tag) +// so we could do the initial configuration with all servers specified. +// Choose to start with just this machine in the raft configuration instead, as: +// - We want to handle machines joining after start anyway. +// - Not all tagged nodes tailscale believes are active are necessarily actually responsive right now, +// so let each node opt in when able. +func (c *Consensus) bootstrap(targets views.Slice[*ipnstate.PeerStatus]) error { + log.Printf("Trying to find cluster: num targets to try: %d", targets.Len()) + for _, p := range targets.All() { + if !p.Online { + log.Printf("Trying to find cluster: tailscale reports not online: %s", p.TailscaleIPs[0]) + continue + } + log.Printf("Trying to find cluster: trying %s", p.TailscaleIPs[0]) + err := c.commandClient.join(p.TailscaleIPs[0].String(), joinRequest{ + RemoteHost: c.self.hostAddr.String(), + RemoteID: c.self.id, + }) + if err != nil { + log.Printf("Trying to find cluster: could not join %s: %v", p.TailscaleIPs[0], err) + continue + } + log.Printf("Trying to find cluster: joined %s", p.TailscaleIPs[0]) + return nil + } + + log.Printf("Trying to find cluster: unsuccessful, starting as leader: %s", c.self.hostAddr.String()) + f := c.raft.BootstrapCluster( + raft.Configuration{ + Servers: []raft.Server{ + { + ID: raft.ServerID(c.self.id), + Address: raft.ServerAddress(c.raftAddr(c.self.hostAddr)), + }, + }, + }) + return f.Error() +} + +// ExecuteCommand propagates a Command to be executed on the leader. Which +// uses raft to Apply it to the followers. +func (c *Consensus) ExecuteCommand(cmd Command) (CommandResult, error) { + b, err := json.Marshal(cmd) + if err != nil { + return CommandResult{}, err + } + result, err := c.executeCommandLocally(cmd) + var leErr lookElsewhereError + for errors.As(err, &leErr) { + result, err = c.commandClient.executeCommand(leErr.where, b) + } + return result, err +} + +// Stop attempts to gracefully shutdown various components. +func (c *Consensus) Stop(ctx context.Context) error { + fut := c.raft.Shutdown() + err := fut.Error() + if err != nil { + log.Printf("Stop: Error in Raft Shutdown: %v", err) + } + c.shutdownCtxCancel() + err = c.cmdHttpServer.Shutdown(ctx) + if err != nil { + log.Printf("Stop: Error in command HTTP Shutdown: %v", err) + } + if c.monitorHttpServer != nil { + err = c.monitorHttpServer.Shutdown(ctx) + if err != nil { + log.Printf("Stop: Error in monitor HTTP Shutdown: %v", err) + } + } + return nil +} + +// A Command is a representation of a state machine action. +type Command struct { + // The Name can be used to dispatch the command when received. + Name string + // The Args are serialized for transport. + Args json.RawMessage +} + +// A CommandResult is a representation of the result of a state +// machine action. +type CommandResult struct { + // Err is any error that occurred on the node that tried to execute the command, + // including any error from the underlying operation and deserialization problems etc. + Err error + // Result is serialized for transport. + Result json.RawMessage +} + +type lookElsewhereError struct { + where string +} + +func (e lookElsewhereError) Error() string { + return fmt.Sprintf("not the leader, try: %s", e.where) +} + +var errLeaderUnknown = errors.New("leader unknown") + +func (c *Consensus) serveCommandHTTP(ts *tsnet.Server, auth *authorization) (*http.Server, error) { + ln, err := ts.Listen("tcp", c.commandAddr(c.self.hostAddr)) + if err != nil { + return nil, err + } + srv := &http.Server{Handler: c.makeCommandHandler(auth)} + go func() { + err := srv.Serve(ln) + log.Printf("CmdHttp stopped serving with err: %v", err) + }() + return srv, nil +} + +func (c *Consensus) getLeader() (string, error) { + raftLeaderAddr, _ := c.raft.LeaderWithID() + leaderAddr := (string)(raftLeaderAddr) + if leaderAddr == "" { + // Raft doesn't know who the leader is. + return "", errLeaderUnknown + } + // Raft gives us the address with the raft port, we don't always want that. + host, _, err := net.SplitHostPort(leaderAddr) + return host, err +} + +func (c *Consensus) executeCommandLocally(cmd Command) (CommandResult, error) { + b, err := json.Marshal(cmd) + if err != nil { + return CommandResult{}, err + } + f := c.raft.Apply(b, 0) + err = f.Error() + result := f.Response() + if errors.Is(err, raft.ErrNotLeader) { + leader, err := c.getLeader() + if err != nil { + // we know we're not leader but we were unable to give the address of the leader + return CommandResult{}, err + } + return CommandResult{}, lookElsewhereError{where: leader} + } + if result == nil { + result = CommandResult{} + } + return result.(CommandResult), err +} + +func (c *Consensus) handleJoin(jr joinRequest) error { + addr, err := netip.ParseAddr(jr.RemoteHost) + if err != nil { + return err + } + remoteAddr := c.raftAddr(addr) + f := c.raft.AddVoter(raft.ServerID(jr.RemoteID), raft.ServerAddress(remoteAddr), 0, 0) + if f.Error() != nil { + return f.Error() + } + return nil +} + +func (c *Consensus) raftAddr(host netip.Addr) string { + return raftAddr(host, c.config) +} + +func (c *Consensus) commandAddr(host netip.Addr) string { + return netip.AddrPortFrom(host, c.config.CommandPort).String() +} diff --git a/tsconsensus/tsconsensus_test.go b/tsconsensus/tsconsensus_test.go new file mode 100644 index 0000000000000..d1b92f8a489f7 --- /dev/null +++ b/tsconsensus/tsconsensus_test.go @@ -0,0 +1,741 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsconsensus + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "net/netip" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/hashicorp/go-hclog" + "github.com/hashicorp/raft" + "tailscale.com/client/tailscale" + "tailscale.com/cmd/testwrapper/flakytest" + "tailscale.com/ipn/store/mem" + "tailscale.com/net/netns" + "tailscale.com/tailcfg" + "tailscale.com/tsnet" + "tailscale.com/tstest/integration" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/tstest/nettest" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/types/views" + "tailscale.com/util/racebuild" +) + +type fsm struct { + mu sync.Mutex + applyEvents []string +} + +func commandWith(t *testing.T, s string) []byte { + jsonArgs, err := json.Marshal(s) + if err != nil { + t.Fatal(err) + } + bs, err := json.Marshal(Command{ + Args: jsonArgs, + }) + if err != nil { + t.Fatal(err) + } + return bs +} + +func fromCommand(bs []byte) (string, error) { + var cmd Command + err := json.Unmarshal(bs, &cmd) + if err != nil { + return "", err + } + var args string + err = json.Unmarshal(cmd.Args, &args) + if err != nil { + return "", err + } + return args, nil +} + +func (f *fsm) Apply(l *raft.Log) any { + f.mu.Lock() + defer f.mu.Unlock() + s, err := fromCommand(l.Data) + if err != nil { + return CommandResult{ + Err: err, + } + } + f.applyEvents = append(f.applyEvents, s) + result, err := json.Marshal(len(f.applyEvents)) + if err != nil { + panic("should be able to Marshal that?") + } + return CommandResult{ + Result: result, + } +} + +func (f *fsm) numEvents() int { + f.mu.Lock() + defer f.mu.Unlock() + return len(f.applyEvents) +} + +func (f *fsm) eventsMatch(es []string) bool { + f.mu.Lock() + defer f.mu.Unlock() + return cmp.Equal(es, f.applyEvents) +} + +func (f *fsm) Snapshot() (raft.FSMSnapshot, error) { + return nil, nil +} + +func (f *fsm) Restore(rc io.ReadCloser) error { + return nil +} + +func testConfig(t *testing.T) { + // -race AND Parallel makes things start to take too long. + if !racebuild.On { + t.Parallel() + } + nettest.SkipIfNoNetwork(t) +} + +func startControl(t testing.TB) (control *testcontrol.Server, controlURL string) { + t.Helper() + // tailscale/corp#4520: don't use netns for tests. + netns.SetEnabled(false) + t.Cleanup(func() { + netns.SetEnabled(true) + }) + + derpLogf := logger.Discard + derpMap := integration.RunDERPAndSTUN(t, derpLogf, "127.0.0.1") + control = &testcontrol.Server{ + DERPMap: derpMap, + DNSConfig: &tailcfg.DNSConfig{ + Proxied: true, + }, + MagicDNSDomain: "tail-scale.ts.net", + } + control.HTTPTestServer = httptest.NewUnstartedServer(control) + control.HTTPTestServer.Start() + t.Cleanup(control.HTTPTestServer.Close) + controlURL = control.HTTPTestServer.URL + t.Logf("testcontrol listening on %s", controlURL) + return control, controlURL +} + +func startNode(t testing.TB, ctx context.Context, controlURL, hostname string) (*tsnet.Server, key.NodePublic, netip.Addr) { + t.Helper() + + tmp := filepath.Join(t.TempDir(), hostname) + os.MkdirAll(tmp, 0755) + s := &tsnet.Server{ + Dir: tmp, + ControlURL: controlURL, + Hostname: hostname, + Store: new(mem.Store), + Ephemeral: true, + } + t.Cleanup(func() { s.Close() }) + + status, err := s.Up(ctx) + if err != nil { + t.Fatal(err) + } + return s, status.Self.PublicKey, status.TailscaleIPs[0] +} + +func waitForNodesToBeTaggedInStatus(t testing.TB, ctx context.Context, ts *tsnet.Server, nodeKeys []key.NodePublic, tag string) { + t.Helper() + waitFor(t, "nodes tagged in status", func() bool { + lc, err := ts.LocalClient() + if err != nil { + t.Fatal(err) + } + status, err := lc.Status(ctx) + if err != nil { + t.Fatalf("error getting status: %v", err) + } + for _, k := range nodeKeys { + var tags *views.Slice[string] + if k == status.Self.PublicKey { + tags = status.Self.Tags + } else { + tags = status.Peer[k].Tags + } + if tag == "" { + if tags != nil && tags.Len() != 0 { + return false + } + } else { + if tags == nil { + return false + } + if tags.Len() != 1 || tags.At(0) != tag { + return false + } + } + } + return true + }, 2*time.Second) +} + +func tagNodes(t testing.TB, control *testcontrol.Server, nodeKeys []key.NodePublic, tag string) { + t.Helper() + for _, key := range nodeKeys { + n := control.Node(key) + if tag == "" { + if len(n.Tags) != 1 { + t.Fatalf("expected tags to have one tag") + } + n.Tags = nil + } else { + if len(n.Tags) != 0 { + // if we want this to work with multiple tags we'll have to change the logic + // for checking if a tag got removed yet. + t.Fatalf("expected tags to be empty") + } + n.Tags = append(n.Tags, tag) + } + b := true + n.Online = &b + control.UpdateNode(n) + } +} + +func addIDedLogger(id string, c Config) Config { + // logs that identify themselves + c.Raft.Logger = hclog.New(&hclog.LoggerOptions{ + Name: fmt.Sprintf("raft: %s", id), + Output: c.Raft.LogOutput, + Level: hclog.LevelFromString(c.Raft.LogLevel), + }) + return c +} + +func warnLogConfig() Config { + c := DefaultConfig() + // fewer logs from raft + c.Raft.LogLevel = "WARN" + // timeouts long enough that we can form a cluster under -race + c.Raft.LeaderLeaseTimeout = 2 * time.Second + c.Raft.HeartbeatTimeout = 4 * time.Second + c.Raft.ElectionTimeout = 4 * time.Second + return c +} + +func TestStart(t *testing.T) { + testConfig(t) + control, controlURL := startControl(t) + ctx := context.Background() + one, k, _ := startNode(t, ctx, controlURL, "one") + + clusterTag := "tag:whatever" + // nodes must be tagged with the cluster tag, to find each other + tagNodes(t, control, []key.NodePublic{k}, clusterTag) + waitForNodesToBeTaggedInStatus(t, ctx, one, []key.NodePublic{k}, clusterTag) + + sm := &fsm{} + r, err := Start(ctx, one, sm, clusterTag, warnLogConfig()) + if err != nil { + t.Fatal(err) + } + defer r.Stop(ctx) +} + +func waitFor(t testing.TB, msg string, condition func() bool, waitBetweenTries time.Duration) { + t.Helper() + try := 0 + for true { + try++ + done := condition() + if done { + t.Logf("waitFor success: %s: after %d tries", msg, try) + return + } + time.Sleep(waitBetweenTries) + } +} + +type participant struct { + c *Consensus + sm *fsm + ts *tsnet.Server + key key.NodePublic +} + +// starts and tags the *tsnet.Server nodes with the control, waits for the nodes to make successful +// LocalClient Status calls that show the first node as Online. +func startNodesAndWaitForPeerStatus(t testing.TB, ctx context.Context, clusterTag string, nNodes int) ([]*participant, *testcontrol.Server, string) { + t.Helper() + ps := make([]*participant, nNodes) + keysToTag := make([]key.NodePublic, nNodes) + localClients := make([]*tailscale.LocalClient, nNodes) + control, controlURL := startControl(t) + for i := 0; i < nNodes; i++ { + ts, key, _ := startNode(t, ctx, controlURL, fmt.Sprintf("node %d", i)) + ps[i] = &participant{ts: ts, key: key} + keysToTag[i] = key + lc, err := ts.LocalClient() + if err != nil { + t.Fatalf("%d: error getting local client: %v", i, err) + } + localClients[i] = lc + } + tagNodes(t, control, keysToTag, clusterTag) + waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, keysToTag, clusterTag) + fxCameOnline := func() bool { + // all the _other_ nodes see the first as online + for i := 1; i < nNodes; i++ { + status, err := localClients[i].Status(ctx) + if err != nil { + t.Fatalf("%d: error getting status: %v", i, err) + } + if !status.Peer[ps[0].key].Online { + return false + } + } + return true + } + waitFor(t, "other nodes see node 1 online in ts status", fxCameOnline, 2*time.Second) + return ps, control, controlURL +} + +// populates participants with their consensus fields, waits for all nodes to show all nodes +// as part of the same consensus cluster. Starts the first participant first and waits for it to +// become leader before adding other nodes. +func createConsensusCluster(t testing.TB, ctx context.Context, clusterTag string, participants []*participant, cfg Config) { + t.Helper() + participants[0].sm = &fsm{} + myCfg := addIDedLogger("0", cfg) + first, err := Start(ctx, participants[0].ts, participants[0].sm, clusterTag, myCfg) + if err != nil { + t.Fatal(err) + } + fxFirstIsLeader := func() bool { + return first.raft.State() == raft.Leader + } + waitFor(t, "node 0 is leader", fxFirstIsLeader, 2*time.Second) + participants[0].c = first + + for i := 1; i < len(participants); i++ { + participants[i].sm = &fsm{} + myCfg := addIDedLogger(fmt.Sprintf("%d", i), cfg) + c, err := Start(ctx, participants[i].ts, participants[i].sm, clusterTag, myCfg) + if err != nil { + t.Fatal(err) + } + participants[i].c = c + } + + fxRaftConfigContainsAll := func() bool { + for i := 0; i < len(participants); i++ { + fut := participants[i].c.raft.GetConfiguration() + err = fut.Error() + if err != nil { + t.Fatalf("%d: Getting Configuration errored: %v", i, err) + } + if len(fut.Configuration().Servers) != len(participants) { + return false + } + } + return true + } + waitFor(t, "all raft machines have all servers in their config", fxRaftConfigContainsAll, time.Second*2) +} + +func TestApply(t *testing.T) { + testConfig(t) + ctx := context.Background() + clusterTag := "tag:whatever" + ps, _, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 2) + cfg := warnLogConfig() + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + for _, p := range ps { + defer p.c.Stop(ctx) + } + + fut := ps[0].c.raft.Apply(commandWith(t, "woo"), 2*time.Second) + err := fut.Error() + if err != nil { + t.Fatalf("Raft Apply Error: %v", err) + } + + want := []string{"woo"} + fxBothMachinesHaveTheApply := func() bool { + return ps[0].sm.eventsMatch(want) && ps[1].sm.eventsMatch(want) + } + waitFor(t, "the apply event made it into both state machines", fxBothMachinesHaveTheApply, time.Second*1) +} + +// calls ExecuteCommand on each participant and checks that all participants get all commands +func assertCommandsWorkOnAnyNode(t testing.TB, participants []*participant) { + t.Helper() + want := []string{} + for i, p := range participants { + si := fmt.Sprintf("%d", i) + want = append(want, si) + bs, err := json.Marshal(si) + if err != nil { + t.Fatal(err) + } + res, err := p.c.ExecuteCommand(Command{Args: bs}) + if err != nil { + t.Fatalf("%d: Error ExecuteCommand: %v", i, err) + } + if res.Err != nil { + t.Fatalf("%d: Result Error ExecuteCommand: %v", i, res.Err) + } + var retVal int + err = json.Unmarshal(res.Result, &retVal) + if err != nil { + t.Fatal(err) + } + // the test implementation of the fsm returns the count of events that have been received + if retVal != i+1 { + t.Fatalf("Result, want %d, got %d", i+1, retVal) + } + + fxEventsInAll := func() bool { + for _, pOther := range participants { + if !pOther.sm.eventsMatch(want) { + return false + } + } + return true + } + waitFor(t, "event makes it to all", fxEventsInAll, time.Second*1) + } +} + +func TestConfig(t *testing.T) { + testConfig(t) + ctx := context.Background() + clusterTag := "tag:whatever" + ps, _, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := warnLogConfig() + // test all is well with non default ports + cfg.CommandPort = 12347 + cfg.RaftPort = 11882 + mp := uint16(8798) + cfg.MonitorPort = mp + cfg.ServeDebugMonitor = true + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + for _, p := range ps { + defer p.c.Stop(ctx) + } + assertCommandsWorkOnAnyNode(t, ps) + + url := fmt.Sprintf("http://%s:%d/", ps[0].c.self.hostAddr.String(), mp) + httpClientOnTailnet := ps[1].ts.HTTPClient() + rsp, err := httpClientOnTailnet.Get(url) + if err != nil { + t.Fatal(err) + } + if rsp.StatusCode != 200 { + t.Fatalf("monitor status want %d, got %d", 200, rsp.StatusCode) + } + defer rsp.Body.Close() + reader := bufio.NewReader(rsp.Body) + line1, err := reader.ReadString('\n') + if err != nil { + t.Fatal(err) + } + // Not a great assertion because it relies on the format of the response. + if !strings.HasPrefix(line1, "RaftState:") { + t.Fatalf("getting monitor status, first line, want something that starts with 'RaftState:', got '%s'", line1) + } +} + +func TestFollowerFailover(t *testing.T) { + testConfig(t) + ctx := context.Background() + clusterTag := "tag:whatever" + ps, _, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := warnLogConfig() + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + for _, p := range ps { + defer p.c.Stop(ctx) + } + + smThree := ps[2].sm + + fut := ps[0].c.raft.Apply(commandWith(t, "a"), 2*time.Second) + futTwo := ps[0].c.raft.Apply(commandWith(t, "b"), 2*time.Second) + err := fut.Error() + if err != nil { + t.Fatalf("Apply Raft error %v", err) + } + err = futTwo.Error() + if err != nil { + t.Fatalf("Apply Raft error %v", err) + } + + wantFirstTwoEvents := []string{"a", "b"} + fxAllMachinesHaveTheApplies := func() bool { + return ps[0].sm.eventsMatch(wantFirstTwoEvents) && + ps[1].sm.eventsMatch(wantFirstTwoEvents) && + smThree.eventsMatch(wantFirstTwoEvents) + } + waitFor(t, "the apply events made it into all state machines", fxAllMachinesHaveTheApplies, time.Second*1) + + //a follower goes loses contact with the cluster + ps[2].c.Stop(ctx) + + // applies still make it to one and two + futThree := ps[0].c.raft.Apply(commandWith(t, "c"), 2*time.Second) + futFour := ps[0].c.raft.Apply(commandWith(t, "d"), 2*time.Second) + err = futThree.Error() + if err != nil { + t.Fatalf("Apply Raft error %v", err) + } + err = futFour.Error() + if err != nil { + t.Fatalf("Apply Raft error %v", err) + } + wantFourEvents := []string{"a", "b", "c", "d"} + fxAliveMachinesHaveTheApplies := func() bool { + return ps[0].sm.eventsMatch(wantFourEvents) && + ps[1].sm.eventsMatch(wantFourEvents) && + smThree.eventsMatch(wantFirstTwoEvents) + } + waitFor(t, "the apply events made it into eligible state machines", fxAliveMachinesHaveTheApplies, time.Second*1) + + // follower comes back + smThreeAgain := &fsm{} + cfg = addIDedLogger("2 after restarting", warnLogConfig()) + rThreeAgain, err := Start(ctx, ps[2].ts, smThreeAgain, clusterTag, cfg) + if err != nil { + t.Fatal(err) + } + defer rThreeAgain.Stop(ctx) + fxThreeGetsCaughtUp := func() bool { + return smThreeAgain.eventsMatch(wantFourEvents) + } + waitFor(t, "the apply events made it into the third node when it appeared with an empty state machine", fxThreeGetsCaughtUp, time.Second*2) + if !smThree.eventsMatch(wantFirstTwoEvents) { + t.Fatalf("Expected smThree to remain on 2 events: got %d", smThree.numEvents()) + } +} + +func TestRejoin(t *testing.T) { + testConfig(t) + ctx := context.Background() + clusterTag := "tag:whatever" + ps, control, controlURL := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := warnLogConfig() + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + for _, p := range ps { + defer p.c.Stop(ctx) + } + + // 1st node gets a redundant second join request from the second node + ps[0].c.handleJoin(joinRequest{ + RemoteHost: ps[1].c.self.hostAddr.String(), + RemoteID: ps[1].c.self.id, + }) + + tsJoiner, keyJoiner, _ := startNode(t, ctx, controlURL, "node joiner") + tagNodes(t, control, []key.NodePublic{keyJoiner}, clusterTag) + waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, []key.NodePublic{keyJoiner}, clusterTag) + smJoiner := &fsm{} + cJoiner, err := Start(ctx, tsJoiner, smJoiner, clusterTag, cfg) + if err != nil { + t.Fatal(err) + } + ps = append(ps, &participant{ + sm: smJoiner, + c: cJoiner, + ts: tsJoiner, + key: keyJoiner, + }) + + assertCommandsWorkOnAnyNode(t, ps) +} + +func TestOnlyTaggedPeersCanDialRaftPort(t *testing.T) { + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/15627") + testConfig(t) + ctx := context.Background() + clusterTag := "tag:whatever" + ps, control, controlURL := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := warnLogConfig() + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + for _, p := range ps { + defer p.c.Stop(ctx) + } + assertCommandsWorkOnAnyNode(t, ps) + + untaggedNode, _, _ := startNode(t, ctx, controlURL, "untagged node") + + taggedNode, taggedKey, _ := startNode(t, ctx, controlURL, "untagged node") + tagNodes(t, control, []key.NodePublic{taggedKey}, clusterTag) + waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, []key.NodePublic{taggedKey}, clusterTag) + + // surface area: command http, peer tcp + //untagged + ipv4, _ := ps[0].ts.TailscaleIPs() + sAddr := fmt.Sprintf("%s:%d", ipv4, cfg.RaftPort) + + getErrorFromTryingToSend := func(s *tsnet.Server) error { + ctx := context.Background() + conn, err := s.Dial(ctx, "tcp", sAddr) + if err != nil { + t.Fatalf("unexpected Dial err: %v", err) + } + fmt.Fprintf(conn, "hellllllloooooo") + status, err := bufio.NewReader(conn).ReadString('\n') + if status != "" { + t.Fatalf("node sending non-raft message should get empty response, got: '%s' for: %s", status, s.Hostname) + } + if err == nil { + t.Fatalf("node sending non-raft message should get an error but got nil err for: %s", s.Hostname) + } + return err + } + + isNetErr := func(err error) bool { + var netErr net.Error + return errors.As(err, &netErr) + } + + err := getErrorFromTryingToSend(untaggedNode) + if !isNetErr(err) { + t.Fatalf("untagged node trying to send should get a net.Error, got: %v", err) + } + // we still get an error trying to send but it's EOF the target node was happy to talk + // to us but couldn't understand what we said. + err = getErrorFromTryingToSend(taggedNode) + if isNetErr(err) { + t.Fatalf("tagged node trying to send should not get a net.Error, got: %v", err) + } +} + +func TestOnlyTaggedPeersCanBeDialed(t *testing.T) { + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/15627") + testConfig(t) + ctx := context.Background() + clusterTag := "tag:whatever" + ps, control, _ := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + + // make a StreamLayer for ps[0] + ts := ps[0].ts + auth := newAuthorization(ts, clusterTag) + + port := 19841 + lns := make([]net.Listener, 3) + for i, p := range ps { + ln, err := p.ts.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + t.Fatal(err) + } + lns[i] = ln + } + + sl := StreamLayer{ + s: ts, + Listener: lns[0], + auth: auth, + shutdownCtx: ctx, + } + + ip1, _ := ps[1].ts.TailscaleIPs() + a1 := raft.ServerAddress(fmt.Sprintf("%s:%d", ip1, port)) + + ip2, _ := ps[2].ts.TailscaleIPs() + a2 := raft.ServerAddress(fmt.Sprintf("%s:%d", ip2, port)) + + // both can be dialed... + conn, err := sl.Dial(a1, 2*time.Second) + if err != nil { + t.Fatal(err) + } + conn.Close() + + conn, err = sl.Dial(a2, 2*time.Second) + if err != nil { + t.Fatal(err) + } + conn.Close() + + // untag ps[2] + tagNodes(t, control, []key.NodePublic{ps[2].key}, "") + waitForNodesToBeTaggedInStatus(t, ctx, ps[0].ts, []key.NodePublic{ps[2].key}, "") + + // now only ps[1] can be dialed + conn, err = sl.Dial(a1, 2*time.Second) + if err != nil { + t.Fatal(err) + } + conn.Close() + + _, err = sl.Dial(a2, 2*time.Second) + if err.Error() != "dial: peer is not allowed" { + t.Fatalf("expected dial: peer is not allowed, got: %v", err) + } + +} + +func TestOnlyTaggedPeersCanJoin(t *testing.T) { + testConfig(t) + ctx := context.Background() + clusterTag := "tag:whatever" + ps, _, controlURL := startNodesAndWaitForPeerStatus(t, ctx, clusterTag, 3) + cfg := warnLogConfig() + createConsensusCluster(t, ctx, clusterTag, ps, cfg) + for _, p := range ps { + defer p.c.Stop(ctx) + } + + tsJoiner, _, _ := startNode(t, ctx, controlURL, "joiner node") + + ipv4, _ := tsJoiner.TailscaleIPs() + url := fmt.Sprintf("http://%s/join", ps[0].c.commandAddr(ps[0].c.self.hostAddr)) + payload, err := json.Marshal(joinRequest{ + RemoteHost: ipv4.String(), + RemoteID: "node joiner", + }) + if err != nil { + t.Fatal(err) + } + body := bytes.NewBuffer(payload) + req, err := http.NewRequest("POST", url, body) + if err != nil { + t.Fatal(err) + } + resp, err := tsJoiner.HTTPClient().Do(req) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusForbidden { + t.Fatalf("join req when not tagged, expected status: %d, got: %d", http.StatusForbidden, resp.StatusCode) + } + rBody, _ := io.ReadAll(resp.Body) + sBody := strings.TrimSpace(string(rBody)) + expected := "peer not allowed" + if sBody != expected { + t.Fatalf("join req when not tagged, expected body: %s, got: %s", expected, sBody) + } +} diff --git a/tsd/tsd.go b/tsd/tsd.go index acd09560c7601..ccd804f816aaa 100644 --- a/tsd/tsd.go +++ b/tsd/tsd.go @@ -32,6 +32,7 @@ import ( "tailscale.com/net/tstun" "tailscale.com/proxymap" "tailscale.com/types/netmap" + "tailscale.com/util/eventbus" "tailscale.com/util/usermetric" "tailscale.com/wgengine" "tailscale.com/wgengine/magicsock" @@ -39,7 +40,12 @@ import ( ) // System contains all the subsystems of a Tailscale node (tailscaled, etc.) +// +// A valid System value must always have a non-nil Bus populated. Callers must +// ensure this before using the value further. Call [NewSystem] to obtain a +// value ready to use. type System struct { + Bus SubSystem[*eventbus.Bus] Dialer SubSystem[*tsdial.Dialer] DNSManager SubSystem[*dns.Manager] // can get its *resolver.Resolver from DNSManager.Resolver Engine SubSystem[wgengine.Engine] @@ -70,6 +76,14 @@ type System struct { userMetricsRegistry usermetric.Registry } +// NewSystem constructs a new otherwise-empty [System] with a +// freshly-constructed event bus populated. +func NewSystem() *System { + sys := new(System) + sys.Set(eventbus.New()) + return sys +} + // NetstackImpl is the interface that *netstack.Impl implements. // It's an interface for circular dependency reasons: netstack.Impl // references LocalBackend, and LocalBackend has a tsd.System. @@ -82,6 +96,8 @@ type NetstackImpl interface { // has already been set. func (s *System) Set(v any) { switch v := v.(type) { + case *eventbus.Bus: + s.Bus.Set(v) case *netmon.Monitor: s.NetMon.Set(v) case *dns.Manager: diff --git a/tsnet/depaware.txt b/tsnet/depaware.txt new file mode 100644 index 0000000000000..662752554d02e --- /dev/null +++ b/tsnet/depaware.txt @@ -0,0 +1,651 @@ +tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) + + filippo.io/edwards25519 from github.com/hdevalence/ed25519consensus + filippo.io/edwards25519/field from filippo.io/edwards25519 + W đŸ’Ŗ github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/internal/common+ + W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate + W đŸ’Ŗ github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy + L github.com/aws/aws-sdk-go-v2/aws from github.com/aws/aws-sdk-go-v2/aws/defaults+ + L github.com/aws/aws-sdk-go-v2/aws/arn from tailscale.com/ipn/store/awsstore + L github.com/aws/aws-sdk-go-v2/aws/defaults from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/aws-sdk-go-v2/aws/middleware from github.com/aws/aws-sdk-go-v2/aws/retry+ + L github.com/aws/aws-sdk-go-v2/aws/protocol/query from github.com/aws/aws-sdk-go-v2/service/sts + L github.com/aws/aws-sdk-go-v2/aws/protocol/restjson from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/aws-sdk-go-v2/aws/protocol/xml from github.com/aws/aws-sdk-go-v2/service/sts + L github.com/aws/aws-sdk-go-v2/aws/ratelimit from github.com/aws/aws-sdk-go-v2/aws/retry + L github.com/aws/aws-sdk-go-v2/aws/retry from github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client+ + L github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 from github.com/aws/aws-sdk-go-v2/aws/signer/v4 + L github.com/aws/aws-sdk-go-v2/aws/signer/v4 from github.com/aws/aws-sdk-go-v2/internal/auth/smithy+ + L github.com/aws/aws-sdk-go-v2/aws/transport/http from github.com/aws/aws-sdk-go-v2/config+ + L github.com/aws/aws-sdk-go-v2/config from tailscale.com/ipn/store/awsstore + L github.com/aws/aws-sdk-go-v2/credentials from github.com/aws/aws-sdk-go-v2/config + L github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds from github.com/aws/aws-sdk-go-v2/config + L github.com/aws/aws-sdk-go-v2/credentials/endpointcreds from github.com/aws/aws-sdk-go-v2/config + L github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client from github.com/aws/aws-sdk-go-v2/credentials/endpointcreds + L github.com/aws/aws-sdk-go-v2/credentials/processcreds from github.com/aws/aws-sdk-go-v2/config + L github.com/aws/aws-sdk-go-v2/credentials/ssocreds from github.com/aws/aws-sdk-go-v2/config + L github.com/aws/aws-sdk-go-v2/credentials/stscreds from github.com/aws/aws-sdk-go-v2/config + L github.com/aws/aws-sdk-go-v2/feature/ec2/imds from github.com/aws/aws-sdk-go-v2/config+ + L github.com/aws/aws-sdk-go-v2/feature/ec2/imds/internal/config from github.com/aws/aws-sdk-go-v2/feature/ec2/imds + L github.com/aws/aws-sdk-go-v2/internal/auth from github.com/aws/aws-sdk-go-v2/aws/signer/v4+ + L github.com/aws/aws-sdk-go-v2/internal/auth/smithy from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/aws-sdk-go-v2/internal/configsources from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/aws-sdk-go-v2/internal/context from github.com/aws/aws-sdk-go-v2/aws/retry+ + L github.com/aws/aws-sdk-go-v2/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/aws-sdk-go-v2/internal/endpoints/awsrulesfn from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 from github.com/aws/aws-sdk-go-v2/service/ssm/internal/endpoints+ + L github.com/aws/aws-sdk-go-v2/internal/ini from github.com/aws/aws-sdk-go-v2/config + L github.com/aws/aws-sdk-go-v2/internal/middleware from github.com/aws/aws-sdk-go-v2/service/sso+ + L github.com/aws/aws-sdk-go-v2/internal/rand from github.com/aws/aws-sdk-go-v2/aws+ + L github.com/aws/aws-sdk-go-v2/internal/sdk from github.com/aws/aws-sdk-go-v2/aws+ + L github.com/aws/aws-sdk-go-v2/internal/sdkio from github.com/aws/aws-sdk-go-v2/credentials/processcreds + L github.com/aws/aws-sdk-go-v2/internal/shareddefaults from github.com/aws/aws-sdk-go-v2/config+ + L github.com/aws/aws-sdk-go-v2/internal/strings from github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 + L github.com/aws/aws-sdk-go-v2/internal/sync/singleflight from github.com/aws/aws-sdk-go-v2/aws + L github.com/aws/aws-sdk-go-v2/internal/timeconv from github.com/aws/aws-sdk-go-v2/aws/retry + L github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding from github.com/aws/aws-sdk-go-v2/service/sts + L github.com/aws/aws-sdk-go-v2/service/internal/presigned-url from github.com/aws/aws-sdk-go-v2/service/sts + L github.com/aws/aws-sdk-go-v2/service/ssm from tailscale.com/ipn/store/awsstore + L github.com/aws/aws-sdk-go-v2/service/ssm/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/ssm + L github.com/aws/aws-sdk-go-v2/service/ssm/types from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/aws-sdk-go-v2/service/sso from github.com/aws/aws-sdk-go-v2/config+ + L github.com/aws/aws-sdk-go-v2/service/sso/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sso + L github.com/aws/aws-sdk-go-v2/service/sso/types from github.com/aws/aws-sdk-go-v2/service/sso + L github.com/aws/aws-sdk-go-v2/service/ssooidc from github.com/aws/aws-sdk-go-v2/config+ + L github.com/aws/aws-sdk-go-v2/service/ssooidc/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/ssooidc + L github.com/aws/aws-sdk-go-v2/service/ssooidc/types from github.com/aws/aws-sdk-go-v2/service/ssooidc + L github.com/aws/aws-sdk-go-v2/service/sts from github.com/aws/aws-sdk-go-v2/config+ + L github.com/aws/aws-sdk-go-v2/service/sts/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sts + L github.com/aws/aws-sdk-go-v2/service/sts/types from github.com/aws/aws-sdk-go-v2/credentials/stscreds+ + L github.com/aws/smithy-go from github.com/aws/aws-sdk-go-v2/aws/protocol/restjson+ + L github.com/aws/smithy-go/auth from github.com/aws/aws-sdk-go-v2/internal/auth+ + L github.com/aws/smithy-go/auth/bearer from github.com/aws/aws-sdk-go-v2/aws+ + L github.com/aws/smithy-go/context from github.com/aws/smithy-go/auth/bearer + L github.com/aws/smithy-go/document from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/smithy-go/encoding from github.com/aws/smithy-go/encoding/json+ + L github.com/aws/smithy-go/encoding/httpbinding from github.com/aws/aws-sdk-go-v2/aws/protocol/query+ + L github.com/aws/smithy-go/encoding/json from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/smithy-go/encoding/xml from github.com/aws/aws-sdk-go-v2/service/sts + L github.com/aws/smithy-go/endpoints from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/smithy-go/internal/sync/singleflight from github.com/aws/smithy-go/auth/bearer + L github.com/aws/smithy-go/io from github.com/aws/aws-sdk-go-v2/feature/ec2/imds+ + L github.com/aws/smithy-go/logging from github.com/aws/aws-sdk-go-v2/aws+ + L github.com/aws/smithy-go/metrics from github.com/aws/aws-sdk-go-v2/aws/retry+ + L github.com/aws/smithy-go/middleware from github.com/aws/aws-sdk-go-v2/aws+ + L github.com/aws/smithy-go/private/requestcompression from github.com/aws/aws-sdk-go-v2/config + L github.com/aws/smithy-go/ptr from github.com/aws/aws-sdk-go-v2/aws+ + L github.com/aws/smithy-go/rand from github.com/aws/aws-sdk-go-v2/aws/middleware+ + L github.com/aws/smithy-go/time from github.com/aws/aws-sdk-go-v2/service/ssm+ + L github.com/aws/smithy-go/tracing from github.com/aws/aws-sdk-go-v2/aws/middleware+ + L github.com/aws/smithy-go/transport/http from github.com/aws/aws-sdk-go-v2/aws/middleware+ + L github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http + L github.com/aws/smithy-go/waiter from github.com/aws/aws-sdk-go-v2/service/ssm + LDW github.com/coder/websocket from tailscale.com/util/eventbus + LDW github.com/coder/websocket/internal/errd from github.com/coder/websocket + LDW github.com/coder/websocket/internal/util from github.com/coder/websocket + LDW github.com/coder/websocket/internal/xsync from github.com/coder/websocket + L github.com/coreos/go-iptables/iptables from tailscale.com/util/linuxfw + W đŸ’Ŗ github.com/dblohm7/wingoes from github.com/dblohm7/wingoes/com+ + W đŸ’Ŗ github.com/dblohm7/wingoes/com from tailscale.com/util/osdiag+ + W đŸ’Ŗ github.com/dblohm7/wingoes/com/automation from tailscale.com/util/osdiag/internal/wsc + W github.com/dblohm7/wingoes/internal from github.com/dblohm7/wingoes/com + W đŸ’Ŗ github.com/dblohm7/wingoes/pe from tailscale.com/util/osdiag+ + LW đŸ’Ŗ github.com/digitalocean/go-smbios/smbios from tailscale.com/posture + github.com/fxamacker/cbor/v2 from tailscale.com/tka + github.com/gaissmai/bart from tailscale.com/net/ipset+ + github.com/gaissmai/bart/internal/bitset from github.com/gaissmai/bart+ + github.com/gaissmai/bart/internal/sparse from github.com/gaissmai/bart + github.com/go-json-experiment/json from tailscale.com/types/opt+ + github.com/go-json-experiment/json/internal from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/internal/jsonflags from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/internal/jsonwire from github.com/go-json-experiment/json+ + github.com/go-json-experiment/json/jsontext from github.com/go-json-experiment/json+ + W đŸ’Ŗ github.com/go-ole/go-ole from github.com/go-ole/go-ole/oleutil+ + W đŸ’Ŗ github.com/go-ole/go-ole/oleutil from tailscale.com/wgengine/winnet + L đŸ’Ŗ github.com/godbus/dbus/v5 from tailscale.com/net/dns + github.com/golang/groupcache/lru from tailscale.com/net/dnscache + github.com/google/btree from gvisor.dev/gvisor/pkg/tcpip/header+ + L github.com/google/nftables from tailscale.com/util/linuxfw + L đŸ’Ŗ github.com/google/nftables/alignedbuff from github.com/google/nftables/xt + L đŸ’Ŗ github.com/google/nftables/binaryutil from github.com/google/nftables+ + L github.com/google/nftables/expr from github.com/google/nftables+ + L github.com/google/nftables/internal/parseexprfunc from github.com/google/nftables+ + L github.com/google/nftables/xt from github.com/google/nftables/expr+ + DWI github.com/google/uuid from github.com/prometheus-community/pro-bing+ + LDW github.com/gorilla/csrf from tailscale.com/client/web + LDW github.com/gorilla/securecookie from github.com/gorilla/csrf + github.com/hdevalence/ed25519consensus from tailscale.com/clientupdate/distsign+ + L đŸ’Ŗ github.com/illarion/gonotify/v3 from tailscale.com/net/dns + L github.com/illarion/gonotify/v3/syscallf from github.com/illarion/gonotify/v3 + L github.com/jmespath/go-jmespath from github.com/aws/aws-sdk-go-v2/service/ssm + L đŸ’Ŗ github.com/jsimonetti/rtnetlink from tailscale.com/net/netmon + L github.com/jsimonetti/rtnetlink/internal/unix from github.com/jsimonetti/rtnetlink + github.com/klauspost/compress from github.com/klauspost/compress/zstd + github.com/klauspost/compress/fse from github.com/klauspost/compress/huff0 + github.com/klauspost/compress/huff0 from github.com/klauspost/compress/zstd + github.com/klauspost/compress/internal/cpuinfo from github.com/klauspost/compress/huff0+ + github.com/klauspost/compress/internal/snapref from github.com/klauspost/compress/zstd + github.com/klauspost/compress/zstd from tailscale.com/util/zstdframe + github.com/klauspost/compress/zstd/internal/xxhash from github.com/klauspost/compress/zstd + L github.com/mdlayher/genetlink from tailscale.com/net/tstun + L đŸ’Ŗ github.com/mdlayher/netlink from github.com/google/nftables+ + L đŸ’Ŗ github.com/mdlayher/netlink/nlenc from github.com/jsimonetti/rtnetlink+ + L github.com/mdlayher/netlink/nltest from github.com/google/nftables + L github.com/mdlayher/sdnotify from tailscale.com/util/systemd + LA đŸ’Ŗ github.com/mdlayher/socket from github.com/mdlayher/netlink+ + github.com/miekg/dns from tailscale.com/net/dns/recursive + LDW đŸ’Ŗ github.com/mitchellh/go-ps from tailscale.com/safesocket + DI github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack + L đŸ’Ŗ github.com/safchain/ethtool from tailscale.com/doctor/ethtool+ + W đŸ’Ŗ github.com/tailscale/certstore from tailscale.com/control/controlclient + W đŸ’Ŗ github.com/tailscale/go-winio from tailscale.com/safesocket + W đŸ’Ŗ github.com/tailscale/go-winio/internal/fs from github.com/tailscale/go-winio + W đŸ’Ŗ github.com/tailscale/go-winio/internal/socket from github.com/tailscale/go-winio + W github.com/tailscale/go-winio/internal/stringbuffer from github.com/tailscale/go-winio/internal/fs + W github.com/tailscale/go-winio/pkg/guid from github.com/tailscale/go-winio+ + github.com/tailscale/goupnp from github.com/tailscale/goupnp/dcps/internetgateway2+ + github.com/tailscale/goupnp/dcps/internetgateway2 from tailscale.com/net/portmapper + github.com/tailscale/goupnp/httpu from github.com/tailscale/goupnp+ + github.com/tailscale/goupnp/scpd from github.com/tailscale/goupnp + github.com/tailscale/goupnp/soap from github.com/tailscale/goupnp+ + github.com/tailscale/goupnp/ssdp from github.com/tailscale/goupnp + LDW github.com/tailscale/hujson from tailscale.com/ipn/conffile + L đŸ’Ŗ github.com/tailscale/netlink from tailscale.com/net/routetable+ + L đŸ’Ŗ github.com/tailscale/netlink/nl from github.com/tailscale/netlink + github.com/tailscale/peercred from tailscale.com/ipn/ipnauth + LDW github.com/tailscale/web-client-prebuilt from tailscale.com/client/web + đŸ’Ŗ github.com/tailscale/wireguard-go/conn from github.com/tailscale/wireguard-go/device+ + W đŸ’Ŗ github.com/tailscale/wireguard-go/conn/winrio from github.com/tailscale/wireguard-go/conn + đŸ’Ŗ github.com/tailscale/wireguard-go/device from tailscale.com/net/tstun+ + đŸ’Ŗ github.com/tailscale/wireguard-go/ipc from github.com/tailscale/wireguard-go/device + W đŸ’Ŗ github.com/tailscale/wireguard-go/ipc/namedpipe from github.com/tailscale/wireguard-go/ipc + github.com/tailscale/wireguard-go/ratelimiter from github.com/tailscale/wireguard-go/device + github.com/tailscale/wireguard-go/replay from github.com/tailscale/wireguard-go/device + github.com/tailscale/wireguard-go/rwcancel from github.com/tailscale/wireguard-go/device+ + github.com/tailscale/wireguard-go/tai64n from github.com/tailscale/wireguard-go/device + đŸ’Ŗ github.com/tailscale/wireguard-go/tun from github.com/tailscale/wireguard-go/device+ + L github.com/vishvananda/netns from github.com/tailscale/netlink+ + github.com/x448/float16 from github.com/fxamacker/cbor/v2 + đŸ’Ŗ go4.org/mem from tailscale.com/client/local+ + go4.org/netipx from tailscale.com/ipn/ipnlocal+ + W đŸ’Ŗ golang.zx2c4.com/wintun from github.com/tailscale/wireguard-go/tun + W đŸ’Ŗ golang.zx2c4.com/wireguard/windows/tunnel/winipcfg from tailscale.com/net/dns+ + gvisor.dev/gvisor/pkg/atomicbitops from gvisor.dev/gvisor/pkg/buffer+ + gvisor.dev/gvisor/pkg/bits from gvisor.dev/gvisor/pkg/buffer + đŸ’Ŗ gvisor.dev/gvisor/pkg/buffer from gvisor.dev/gvisor/pkg/tcpip+ + gvisor.dev/gvisor/pkg/context from gvisor.dev/gvisor/pkg/refs + đŸ’Ŗ gvisor.dev/gvisor/pkg/gohacks from gvisor.dev/gvisor/pkg/state/wire+ + gvisor.dev/gvisor/pkg/linewriter from gvisor.dev/gvisor/pkg/log + gvisor.dev/gvisor/pkg/log from gvisor.dev/gvisor/pkg/context+ + gvisor.dev/gvisor/pkg/rand from gvisor.dev/gvisor/pkg/tcpip+ + gvisor.dev/gvisor/pkg/refs from gvisor.dev/gvisor/pkg/buffer+ + đŸ’Ŗ gvisor.dev/gvisor/pkg/sleep from gvisor.dev/gvisor/pkg/tcpip/transport/tcp + đŸ’Ŗ gvisor.dev/gvisor/pkg/state from gvisor.dev/gvisor/pkg/atomicbitops+ + gvisor.dev/gvisor/pkg/state/wire from gvisor.dev/gvisor/pkg/state + đŸ’Ŗ gvisor.dev/gvisor/pkg/sync from gvisor.dev/gvisor/pkg/atomicbitops+ + đŸ’Ŗ gvisor.dev/gvisor/pkg/sync/locking from gvisor.dev/gvisor/pkg/tcpip/stack + gvisor.dev/gvisor/pkg/tcpip from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ + gvisor.dev/gvisor/pkg/tcpip/adapters/gonet from tailscale.com/wgengine/netstack + đŸ’Ŗ gvisor.dev/gvisor/pkg/tcpip/checksum from gvisor.dev/gvisor/pkg/buffer+ + gvisor.dev/gvisor/pkg/tcpip/hash/jenkins from gvisor.dev/gvisor/pkg/tcpip/stack+ + gvisor.dev/gvisor/pkg/tcpip/header from gvisor.dev/gvisor/pkg/tcpip/header/parse+ + gvisor.dev/gvisor/pkg/tcpip/header/parse from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ + gvisor.dev/gvisor/pkg/tcpip/internal/tcp from gvisor.dev/gvisor/pkg/tcpip/transport/tcp + gvisor.dev/gvisor/pkg/tcpip/network/hash from gvisor.dev/gvisor/pkg/tcpip/network/ipv4 + gvisor.dev/gvisor/pkg/tcpip/network/internal/fragmentation from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ + gvisor.dev/gvisor/pkg/tcpip/network/internal/ip from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ + gvisor.dev/gvisor/pkg/tcpip/network/internal/multicast from gvisor.dev/gvisor/pkg/tcpip/network/ipv4+ + gvisor.dev/gvisor/pkg/tcpip/network/ipv4 from tailscale.com/wgengine/netstack + gvisor.dev/gvisor/pkg/tcpip/network/ipv6 from tailscale.com/wgengine/netstack + gvisor.dev/gvisor/pkg/tcpip/ports from gvisor.dev/gvisor/pkg/tcpip/stack+ + gvisor.dev/gvisor/pkg/tcpip/seqnum from gvisor.dev/gvisor/pkg/tcpip/header+ + đŸ’Ŗ gvisor.dev/gvisor/pkg/tcpip/stack from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ + LDWA gvisor.dev/gvisor/pkg/tcpip/stack/gro from tailscale.com/wgengine/netstack/gro + gvisor.dev/gvisor/pkg/tcpip/transport from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ + gvisor.dev/gvisor/pkg/tcpip/transport/icmp from tailscale.com/wgengine/netstack + gvisor.dev/gvisor/pkg/tcpip/transport/internal/network from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ + gvisor.dev/gvisor/pkg/tcpip/transport/internal/noop from gvisor.dev/gvisor/pkg/tcpip/transport/raw + gvisor.dev/gvisor/pkg/tcpip/transport/packet from gvisor.dev/gvisor/pkg/tcpip/transport/raw + gvisor.dev/gvisor/pkg/tcpip/transport/raw from gvisor.dev/gvisor/pkg/tcpip/transport/icmp+ + đŸ’Ŗ gvisor.dev/gvisor/pkg/tcpip/transport/tcp from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ + gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack from gvisor.dev/gvisor/pkg/tcpip/stack + gvisor.dev/gvisor/pkg/tcpip/transport/udp from gvisor.dev/gvisor/pkg/tcpip/adapters/gonet+ + gvisor.dev/gvisor/pkg/waiter from gvisor.dev/gvisor/pkg/context+ + tailscale.com from tailscale.com/version + tailscale.com/appc from tailscale.com/ipn/ipnlocal + đŸ’Ŗ tailscale.com/atomicfile from tailscale.com/ipn+ + tailscale.com/client/local from tailscale.com/client/tailscale+ + tailscale.com/client/tailscale from tailscale.com/derp+ + tailscale.com/client/tailscale/apitype from tailscale.com/client/local+ + LDW tailscale.com/client/web from tailscale.com/ipn/ipnlocal + tailscale.com/clientupdate from tailscale.com/client/web+ + LW tailscale.com/clientupdate/distsign from tailscale.com/clientupdate + tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ + tailscale.com/control/controlclient from tailscale.com/ipn/ipnext+ + tailscale.com/control/controlhttp from tailscale.com/control/controlclient + tailscale.com/control/controlhttp/controlhttpcommon from tailscale.com/control/controlhttp + tailscale.com/control/controlknobs from tailscale.com/control/controlclient+ + tailscale.com/derp from tailscale.com/derp/derphttp+ + tailscale.com/derp/derpconst from tailscale.com/derp+ + tailscale.com/derp/derphttp from tailscale.com/ipn/localapi+ + tailscale.com/disco from tailscale.com/derp+ + tailscale.com/doctor from tailscale.com/ipn/ipnlocal + tailscale.com/doctor/ethtool from tailscale.com/ipn/ipnlocal + đŸ’Ŗ tailscale.com/doctor/permissions from tailscale.com/ipn/ipnlocal + tailscale.com/doctor/routetable from tailscale.com/ipn/ipnlocal + tailscale.com/drive from tailscale.com/client/local+ + tailscale.com/envknob from tailscale.com/client/local+ + tailscale.com/envknob/featureknob from tailscale.com/client/web+ + tailscale.com/feature from tailscale.com/ipn/ipnext+ + tailscale.com/health from tailscale.com/control/controlclient+ + tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal + tailscale.com/hostinfo from tailscale.com/client/web+ + tailscale.com/internal/noiseconn from tailscale.com/control/controlclient + tailscale.com/ipn from tailscale.com/client/local+ + tailscale.com/ipn/conffile from tailscale.com/ipn/ipnlocal+ + đŸ’Ŗ tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnext+ + tailscale.com/ipn/ipnext from tailscale.com/ipn/ipnlocal + tailscale.com/ipn/ipnlocal from tailscale.com/ipn/localapi+ + tailscale.com/ipn/ipnstate from tailscale.com/client/local+ + tailscale.com/ipn/localapi from tailscale.com/tsnet + tailscale.com/ipn/policy from tailscale.com/ipn/ipnlocal + tailscale.com/ipn/store from tailscale.com/ipn/ipnlocal+ + L tailscale.com/ipn/store/awsstore from tailscale.com/ipn/store + L tailscale.com/ipn/store/kubestore from tailscale.com/ipn/store + tailscale.com/ipn/store/mem from tailscale.com/ipn/ipnlocal+ + L tailscale.com/kube/kubeapi from tailscale.com/ipn/store/kubestore+ + L tailscale.com/kube/kubeclient from tailscale.com/ipn/store/kubestore + tailscale.com/kube/kubetypes from tailscale.com/envknob+ + LDW tailscale.com/licenses from tailscale.com/client/web + tailscale.com/log/filelogger from tailscale.com/logpolicy + tailscale.com/log/sockstatlog from tailscale.com/ipn/ipnlocal + tailscale.com/logpolicy from tailscale.com/ipn/ipnlocal+ + tailscale.com/logtail from tailscale.com/control/controlclient+ + tailscale.com/logtail/backoff from tailscale.com/control/controlclient+ + tailscale.com/logtail/filch from tailscale.com/log/sockstatlog+ + tailscale.com/metrics from tailscale.com/derp+ + tailscale.com/net/bakedroots from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/captivedetection from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/connstats from tailscale.com/net/tstun+ + tailscale.com/net/dns from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/dns/publicdns from tailscale.com/net/dns+ + tailscale.com/net/dns/recursive from tailscale.com/net/dnsfallback + tailscale.com/net/dns/resolvconffile from tailscale.com/net/dns+ + tailscale.com/net/dns/resolver from tailscale.com/net/dns+ + tailscale.com/net/dnscache from tailscale.com/control/controlclient+ + tailscale.com/net/dnsfallback from tailscale.com/control/controlclient+ + tailscale.com/net/flowtrack from tailscale.com/net/packet+ + tailscale.com/net/ipset from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/memnet from tailscale.com/tsnet + tailscale.com/net/netaddr from tailscale.com/ipn+ + tailscale.com/net/netcheck from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/neterror from tailscale.com/net/dns/resolver+ + tailscale.com/net/netkernelconf from tailscale.com/ipn/ipnlocal + tailscale.com/net/netknob from tailscale.com/logpolicy+ + đŸ’Ŗ tailscale.com/net/netmon from tailscale.com/control/controlclient+ + đŸ’Ŗ tailscale.com/net/netns from tailscale.com/derp/derphttp+ + W đŸ’Ŗ tailscale.com/net/netstat from tailscale.com/portlist + tailscale.com/net/netutil from tailscale.com/client/local+ + tailscale.com/net/netx from tailscale.com/control/controlclient+ + tailscale.com/net/packet from tailscale.com/ipn/ipnlocal+ + tailscale.com/net/packet/checksum from tailscale.com/net/tstun + tailscale.com/net/ping from tailscale.com/net/netcheck+ + tailscale.com/net/portmapper from tailscale.com/ipn/localapi+ + tailscale.com/net/proxymux from tailscale.com/tsnet + tailscale.com/net/routetable from tailscale.com/doctor/routetable + tailscale.com/net/socks5 from tailscale.com/tsnet + tailscale.com/net/sockstats from tailscale.com/control/controlclient+ + tailscale.com/net/stun from tailscale.com/ipn/localapi+ + L tailscale.com/net/tcpinfo from tailscale.com/derp + tailscale.com/net/tlsdial from tailscale.com/control/controlclient+ + tailscale.com/net/tlsdial/blockblame from tailscale.com/net/tlsdial + tailscale.com/net/tsaddr from tailscale.com/client/web+ + tailscale.com/net/tsdial from tailscale.com/control/controlclient+ + đŸ’Ŗ tailscale.com/net/tshttpproxy from tailscale.com/clientupdate/distsign+ + tailscale.com/net/tstun from tailscale.com/tsd+ + tailscale.com/net/udprelay/endpoint from tailscale.com/wgengine/magicsock + tailscale.com/omit from tailscale.com/ipn/conffile + tailscale.com/paths from tailscale.com/client/local+ + đŸ’Ŗ tailscale.com/portlist from tailscale.com/ipn/ipnlocal + tailscale.com/posture from tailscale.com/ipn/ipnlocal + tailscale.com/proxymap from tailscale.com/tsd+ + đŸ’Ŗ tailscale.com/safesocket from tailscale.com/client/local+ + tailscale.com/syncs from tailscale.com/control/controlhttp+ + tailscale.com/tailcfg from tailscale.com/client/local+ + tailscale.com/tempfork/acme from tailscale.com/ipn/ipnlocal + tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock + tailscale.com/tempfork/httprec from tailscale.com/control/controlclient + tailscale.com/tka from tailscale.com/client/local+ + tailscale.com/tsconst from tailscale.com/ipn/ipnlocal+ + tailscale.com/tsd from tailscale.com/ipn/ipnext+ + tailscale.com/tstime from tailscale.com/control/controlclient+ + tailscale.com/tstime/mono from tailscale.com/net/tstun+ + tailscale.com/tstime/rate from tailscale.com/derp+ + tailscale.com/tsweb from tailscale.com/util/eventbus + tailscale.com/tsweb/varz from tailscale.com/tsweb+ + tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal + tailscale.com/types/bools from tailscale.com/tsnet + tailscale.com/types/dnstype from tailscale.com/client/local+ + tailscale.com/types/empty from tailscale.com/ipn+ + tailscale.com/types/ipproto from tailscale.com/ipn+ + tailscale.com/types/key from tailscale.com/client/local+ + tailscale.com/types/lazy from tailscale.com/clientupdate+ + tailscale.com/types/logger from tailscale.com/appc+ + tailscale.com/types/logid from tailscale.com/ipn/ipnlocal+ + tailscale.com/types/mapx from tailscale.com/ipn/ipnext + tailscale.com/types/netlogtype from tailscale.com/net/connstats+ + tailscale.com/types/netmap from tailscale.com/control/controlclient+ + tailscale.com/types/nettype from tailscale.com/ipn/localapi+ + tailscale.com/types/opt from tailscale.com/client/tailscale+ + tailscale.com/types/persist from tailscale.com/control/controlclient+ + tailscale.com/types/preftype from tailscale.com/ipn+ + tailscale.com/types/ptr from tailscale.com/control/controlclient+ + tailscale.com/types/result from tailscale.com/util/lineiter + tailscale.com/types/structs from tailscale.com/control/controlclient+ + tailscale.com/types/tkatype from tailscale.com/client/local+ + tailscale.com/types/views from tailscale.com/appc+ + tailscale.com/util/cibuild from tailscale.com/health + tailscale.com/util/clientmetric from tailscale.com/appc+ + tailscale.com/util/cloudenv from tailscale.com/hostinfo+ + tailscale.com/util/cmpver from tailscale.com/clientupdate+ + tailscale.com/util/ctxkey from tailscale.com/client/tailscale/apitype+ + đŸ’Ŗ tailscale.com/util/deephash from tailscale.com/ipn/ipnlocal+ + LA đŸ’Ŗ tailscale.com/util/dirwalk from tailscale.com/metrics+ + tailscale.com/util/dnsname from tailscale.com/appc+ + tailscale.com/util/eventbus from tailscale.com/ipn/localapi+ + tailscale.com/util/execqueue from tailscale.com/appc+ + tailscale.com/util/goroutines from tailscale.com/ipn/ipnlocal + tailscale.com/util/groupmember from tailscale.com/client/web+ + đŸ’Ŗ tailscale.com/util/hashx from tailscale.com/util/deephash + tailscale.com/util/httpm from tailscale.com/client/tailscale+ + tailscale.com/util/lineiter from tailscale.com/hostinfo+ + L tailscale.com/util/linuxfw from tailscale.com/net/netns+ + tailscale.com/util/mak from tailscale.com/appc+ + tailscale.com/util/multierr from tailscale.com/control/controlclient+ + tailscale.com/util/must from tailscale.com/clientupdate/distsign+ + tailscale.com/util/nocasemaps from tailscale.com/types/ipproto + đŸ’Ŗ tailscale.com/util/osdiag from tailscale.com/ipn/localapi + W đŸ’Ŗ tailscale.com/util/osdiag/internal/wsc from tailscale.com/util/osdiag + tailscale.com/util/osuser from tailscale.com/ipn/ipnlocal + tailscale.com/util/race from tailscale.com/net/dns/resolver + tailscale.com/util/racebuild from tailscale.com/logpolicy + tailscale.com/util/rands from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/ringbuffer from tailscale.com/wgengine/magicsock + tailscale.com/util/set from tailscale.com/control/controlclient+ + tailscale.com/util/singleflight from tailscale.com/control/controlclient+ + tailscale.com/util/slicesx from tailscale.com/appc+ + tailscale.com/util/syspolicy from tailscale.com/control/controlclient+ + tailscale.com/util/syspolicy/internal from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/internal/loggerx from tailscale.com/util/syspolicy+ + tailscale.com/util/syspolicy/internal/metrics from tailscale.com/util/syspolicy/source + tailscale.com/util/syspolicy/rsop from tailscale.com/ipn/ipnlocal+ + tailscale.com/util/syspolicy/setting from tailscale.com/client/local+ + tailscale.com/util/syspolicy/source from tailscale.com/util/syspolicy+ + tailscale.com/util/sysresources from tailscale.com/wgengine/magicsock + tailscale.com/util/systemd from tailscale.com/control/controlclient+ + tailscale.com/util/testenv from tailscale.com/control/controlclient+ + tailscale.com/util/truncate from tailscale.com/logtail + tailscale.com/util/usermetric from tailscale.com/health+ + tailscale.com/util/vizerror from tailscale.com/tailcfg+ + đŸ’Ŗ tailscale.com/util/winutil from tailscale.com/clientupdate+ + W đŸ’Ŗ tailscale.com/util/winutil/authenticode from tailscale.com/clientupdate+ + W đŸ’Ŗ tailscale.com/util/winutil/gp from tailscale.com/net/dns+ + W tailscale.com/util/winutil/policy from tailscale.com/ipn/ipnlocal + W đŸ’Ŗ tailscale.com/util/winutil/winenv from tailscale.com/hostinfo+ + tailscale.com/util/zstdframe from tailscale.com/control/controlclient+ + tailscale.com/version from tailscale.com/client/web+ + tailscale.com/version/distro from tailscale.com/client/web+ + tailscale.com/wgengine from tailscale.com/ipn/ipnlocal+ + tailscale.com/wgengine/filter from tailscale.com/control/controlclient+ + tailscale.com/wgengine/filter/filtertype from tailscale.com/types/netmap+ + đŸ’Ŗ tailscale.com/wgengine/magicsock from tailscale.com/ipn/ipnlocal+ + tailscale.com/wgengine/netlog from tailscale.com/wgengine + tailscale.com/wgengine/netstack from tailscale.com/tsnet + tailscale.com/wgengine/netstack/gro from tailscale.com/net/tstun+ + tailscale.com/wgengine/router from tailscale.com/ipn/ipnlocal+ + tailscale.com/wgengine/wgcfg from tailscale.com/ipn/ipnlocal+ + tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal + đŸ’Ŗ tailscale.com/wgengine/wgint from tailscale.com/wgengine+ + tailscale.com/wgengine/wglog from tailscale.com/wgengine + W đŸ’Ŗ tailscale.com/wgengine/winnet from tailscale.com/wgengine/router + golang.org/x/crypto/argon2 from tailscale.com/tka + golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ + golang.org/x/crypto/blake2s from github.com/tailscale/wireguard-go/device+ + LD golang.org/x/crypto/blowfish from golang.org/x/crypto/ssh/internal/bcrypt_pbkdf + golang.org/x/crypto/chacha20 from golang.org/x/crypto/chacha20poly1305+ + golang.org/x/crypto/chacha20poly1305 from crypto/internal/hpke+ + golang.org/x/crypto/cryptobyte from crypto/ecdsa+ + golang.org/x/crypto/cryptobyte/asn1 from crypto/ecdsa+ + golang.org/x/crypto/curve25519 from github.com/tailscale/wireguard-go/device+ + golang.org/x/crypto/hkdf from tailscale.com/control/controlbase + golang.org/x/crypto/internal/alias from golang.org/x/crypto/chacha20+ + golang.org/x/crypto/internal/poly1305 from golang.org/x/crypto/chacha20poly1305+ + golang.org/x/crypto/nacl/box from tailscale.com/types/key + golang.org/x/crypto/nacl/secretbox from golang.org/x/crypto/nacl/box + golang.org/x/crypto/poly1305 from github.com/tailscale/wireguard-go/device + golang.org/x/crypto/salsa20/salsa from golang.org/x/crypto/nacl/box+ + LD golang.org/x/crypto/ssh from tailscale.com/ipn/ipnlocal + LD golang.org/x/crypto/ssh/internal/bcrypt_pbkdf from golang.org/x/crypto/ssh + golang.org/x/exp/constraints from github.com/dblohm7/wingoes/pe+ + golang.org/x/exp/maps from tailscale.com/ipn/store/mem+ + golang.org/x/net/bpf from github.com/mdlayher/genetlink+ + golang.org/x/net/dns/dnsmessage from net+ + golang.org/x/net/http/httpguts from golang.org/x/net/http2+ + golang.org/x/net/http/httpproxy from net/http+ + golang.org/x/net/http2 from golang.org/x/net/http2/h2c+ + LDW golang.org/x/net/http2/h2c from tailscale.com/ipn/ipnlocal + golang.org/x/net/http2/hpack from golang.org/x/net/http2+ + golang.org/x/net/icmp from github.com/prometheus-community/pro-bing+ + golang.org/x/net/idna from golang.org/x/net/http/httpguts+ + golang.org/x/net/internal/httpcommon from golang.org/x/net/http2 + golang.org/x/net/internal/iana from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + LDW golang.org/x/net/internal/socks from golang.org/x/net/proxy + golang.org/x/net/ipv4 from github.com/miekg/dns+ + golang.org/x/net/ipv6 from github.com/miekg/dns+ + LDW golang.org/x/net/proxy from tailscale.com/net/netns + DI golang.org/x/net/route from net+ + golang.org/x/sync/errgroup from github.com/mdlayher/socket+ + golang.org/x/sys/cpu from github.com/tailscale/certstore+ + LDAI golang.org/x/sys/unix from github.com/google/nftables+ + W golang.org/x/sys/windows from github.com/dblohm7/wingoes+ + W golang.org/x/sys/windows/registry from github.com/dblohm7/wingoes+ + W golang.org/x/sys/windows/svc from golang.org/x/sys/windows/svc/mgr+ + W golang.org/x/sys/windows/svc/mgr from tailscale.com/util/winutil + golang.org/x/term from tailscale.com/logpolicy + golang.org/x/text/secure/bidirule from golang.org/x/net/idna + golang.org/x/text/transform from golang.org/x/text/secure/bidirule+ + golang.org/x/text/unicode/bidi from golang.org/x/net/idna+ + golang.org/x/text/unicode/norm from golang.org/x/net/idna + golang.org/x/time/rate from gvisor.dev/gvisor/pkg/log+ + archive/tar from tailscale.com/clientupdate + bufio from compress/flate+ + bytes from archive/tar+ + cmp from encoding/json+ + compress/flate from compress/gzip+ + compress/gzip from github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding+ + W compress/zlib from debug/pe + container/heap from gvisor.dev/gvisor/pkg/tcpip/transport/tcp + container/list from crypto/tls+ + context from crypto/tls+ + crypto from crypto/ecdh+ + crypto/aes from crypto/internal/hpke+ + crypto/cipher from crypto/aes+ + crypto/des from crypto/tls+ + crypto/dsa from crypto/x509+ + crypto/ecdh from crypto/ecdsa+ + crypto/ecdsa from crypto/tls+ + crypto/ed25519 from crypto/tls+ + crypto/elliptic from crypto/ecdsa+ + crypto/hmac from crypto/tls+ + crypto/internal/boring from crypto/aes+ + crypto/internal/boring/bbig from crypto/ecdsa+ + crypto/internal/boring/sig from crypto/internal/boring + crypto/internal/entropy from crypto/internal/fips140/drbg + crypto/internal/fips140 from crypto/internal/fips140/aes+ + crypto/internal/fips140/aes from crypto/aes+ + crypto/internal/fips140/aes/gcm from crypto/cipher+ + crypto/internal/fips140/alias from crypto/cipher+ + crypto/internal/fips140/bigmod from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/check from crypto/internal/fips140/aes+ + crypto/internal/fips140/drbg from crypto/internal/fips140/aes/gcm+ + crypto/internal/fips140/ecdh from crypto/ecdh + crypto/internal/fips140/ecdsa from crypto/ecdsa + crypto/internal/fips140/ed25519 from crypto/ed25519 + crypto/internal/fips140/edwards25519 from crypto/internal/fips140/ed25519 + crypto/internal/fips140/edwards25519/field from crypto/ecdh+ + crypto/internal/fips140/hkdf from crypto/internal/fips140/tls13+ + crypto/internal/fips140/hmac from crypto/hmac+ + crypto/internal/fips140/mlkem from crypto/tls + crypto/internal/fips140/nistec from crypto/elliptic+ + crypto/internal/fips140/nistec/fiat from crypto/internal/fips140/nistec + crypto/internal/fips140/rsa from crypto/rsa + crypto/internal/fips140/sha256 from crypto/internal/fips140/check+ + crypto/internal/fips140/sha3 from crypto/internal/fips140/hmac+ + crypto/internal/fips140/sha512 from crypto/internal/fips140/ecdsa+ + crypto/internal/fips140/subtle from crypto/internal/fips140/aes+ + crypto/internal/fips140/tls12 from crypto/tls + crypto/internal/fips140/tls13 from crypto/tls + crypto/internal/fips140deps/byteorder from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/cpu from crypto/internal/fips140/aes+ + crypto/internal/fips140deps/godebug from crypto/internal/fips140+ + crypto/internal/fips140hash from crypto/ecdsa+ + crypto/internal/fips140only from crypto/cipher+ + crypto/internal/hpke from crypto/tls + crypto/internal/impl from crypto/internal/fips140/aes+ + crypto/internal/randutil from crypto/dsa+ + crypto/internal/sysrand from crypto/internal/entropy+ + crypto/md5 from crypto/tls+ + crypto/rand from crypto/ed25519+ + crypto/rc4 from crypto/tls+ + crypto/rsa from crypto/tls+ + crypto/sha1 from crypto/tls+ + crypto/sha256 from crypto/tls+ + crypto/sha3 from crypto/internal/fips140hash + crypto/sha512 from crypto/ecdsa+ + crypto/subtle from crypto/cipher+ + crypto/tls from github.com/aws/aws-sdk-go-v2/aws/transport/http+ + crypto/tls/internal/fips140tls from crypto/tls + crypto/x509 from crypto/tls+ + DI crypto/x509/internal/macos from crypto/x509 + crypto/x509/pkix from crypto/x509+ + DWI database/sql/driver from github.com/google/uuid + W debug/dwarf from debug/pe + W debug/pe from github.com/dblohm7/wingoes/pe + embed from github.com/tailscale/web-client-prebuilt+ + encoding from encoding/gob+ + encoding/asn1 from crypto/x509+ + encoding/base32 from github.com/fxamacker/cbor/v2+ + encoding/base64 from encoding/json+ + encoding/binary from compress/gzip+ + LDW encoding/gob from github.com/gorilla/securecookie + encoding/hex from crypto/x509+ + encoding/json from expvar+ + encoding/pem from crypto/tls+ + encoding/xml from github.com/aws/aws-sdk-go-v2/aws/protocol/xml+ + errors from archive/tar+ + expvar from tailscale.com/derp+ + flag from tailscale.com/util/testenv + fmt from archive/tar+ + hash from compress/zlib+ + W hash/adler32 from compress/zlib + hash/crc32 from compress/gzip+ + hash/maphash from go4.org/mem + html from html/template+ + LDW html/template from github.com/gorilla/csrf+ + internal/abi from crypto/x509/internal/macos+ + internal/asan from internal/runtime/maps+ + internal/bisect from internal/godebug + internal/bytealg from bytes+ + internal/byteorder from crypto/cipher+ + internal/chacha8rand from math/rand/v2+ + internal/coverage/rtcov from runtime + internal/cpu from crypto/internal/fips140deps/cpu+ + internal/filepathlite from os+ + internal/fmtsort from fmt+ + internal/goarch from crypto/internal/fips140deps/cpu+ + internal/godebug from archive/tar+ + internal/godebugs from internal/godebug+ + internal/goexperiment from hash/maphash+ + internal/goos from crypto/x509+ + internal/itoa from internal/poll+ + internal/msan from internal/runtime/maps+ + internal/nettrace from net+ + internal/oserror from io/fs+ + internal/poll from net+ + internal/profile from net/http/pprof + internal/profilerecord from runtime+ + internal/race from internal/poll+ + internal/reflectlite from context+ + internal/runtime/atomic from internal/runtime/exithook+ + internal/runtime/exithook from runtime + internal/runtime/maps from reflect+ + internal/runtime/math from internal/runtime/maps+ + internal/runtime/sys from crypto/subtle+ + LA internal/runtime/syscall from runtime+ + LDW internal/saferio from debug/pe+ + internal/singleflight from net + internal/stringslite from embed+ + internal/sync from sync+ + internal/syscall/execenv from os+ + LDAI internal/syscall/unix from crypto/internal/sysrand+ + W internal/syscall/windows from crypto/internal/sysrand+ + W internal/syscall/windows/registry from mime+ + W internal/syscall/windows/sysdll from internal/syscall/windows+ + internal/testlog from os + internal/unsafeheader from internal/reflectlite+ + io from archive/tar+ + io/fs from archive/tar+ + io/ioutil from github.com/aws/aws-sdk-go-v2/aws/protocol/query+ + iter from bytes+ + log from expvar+ + log/internal from log + maps from archive/tar+ + math from archive/tar+ + math/big from crypto/dsa+ + math/bits from bytes+ + math/rand from github.com/fxamacker/cbor/v2+ + math/rand/v2 from crypto/ecdsa+ + mime from mime/multipart+ + mime/multipart from net/http + mime/quotedprintable from mime/multipart + net from crypto/tls+ + net/http from expvar+ + net/http/httptrace from github.com/aws/smithy-go/transport/http+ + net/http/httputil from github.com/aws/smithy-go/transport/http+ + net/http/internal from net/http+ + net/http/internal/ascii from net/http+ + net/http/pprof from tailscale.com/ipn/localapi+ + net/netip from crypto/x509+ + net/textproto from github.com/aws/aws-sdk-go-v2/aws/signer/v4+ + net/url from crypto/x509+ + os from crypto/internal/sysrand+ + os/exec from github.com/aws/aws-sdk-go-v2/credentials/processcreds+ + os/user from archive/tar+ + path from archive/tar+ + path/filepath from archive/tar+ + reflect from archive/tar+ + regexp from github.com/aws/aws-sdk-go-v2/internal/endpoints+ + regexp/syntax from regexp + runtime from archive/tar+ + runtime/debug from github.com/aws/aws-sdk-go-v2/internal/sync/singleflight+ + runtime/pprof from net/http/pprof+ + runtime/trace from net/http/pprof + slices from archive/tar+ + sort from compress/flate+ + strconv from archive/tar+ + strings from archive/tar+ + sync from archive/tar+ + sync/atomic from context+ + syscall from archive/tar+ + text/tabwriter from runtime/pprof + LDW text/template from html/template + LDW text/template/parse from html/template+ + time from archive/tar+ + unicode from bytes+ + unicode/utf16 from crypto/x509+ + unicode/utf8 from bufio+ + unique from net/netip + unsafe from bytes+ + weak from unique diff --git a/tsnet/packet_filter_test.go b/tsnet/packet_filter_test.go new file mode 100644 index 0000000000000..462234222f936 --- /dev/null +++ b/tsnet/packet_filter_test.go @@ -0,0 +1,248 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tsnet + +import ( + "context" + "fmt" + "net/netip" + "testing" + "time" + + "tailscale.com/ipn" + "tailscale.com/tailcfg" + "tailscale.com/types/ipproto" + "tailscale.com/types/key" + "tailscale.com/types/netmap" + "tailscale.com/util/must" + "tailscale.com/wgengine/filter" +) + +// waitFor blocks until a NetMap is seen on the IPN bus that satisfies the given +// function f. Note: has no timeout, should be called with a ctx that has an +// appropriate timeout set. +func waitFor(t testing.TB, ctx context.Context, s *Server, f func(*netmap.NetworkMap) bool) error { + t.Helper() + watcher, err := s.localClient.WatchIPNBus(ctx, ipn.NotifyInitialNetMap) + if err != nil { + t.Fatalf("error watching IPN bus: %s", err) + } + defer watcher.Close() + + for { + n, err := watcher.Next() + if err != nil { + return fmt.Errorf("getting next ipn.Notify from IPN bus: %w", err) + } + if n.NetMap != nil { + if f(n.NetMap) { + return nil + } + } + } +} + +// TestPacketFilterFromNetmap tests all of the client code for processing +// netmaps and turning them into packet filters together. Only the control-plane +// side is mocked out. +func TestPacketFilterFromNetmap(t *testing.T) { + t.Parallel() + + var key key.NodePublic + must.Do(key.UnmarshalText([]byte("nodekey:5c8f86d5fc70d924e55f02446165a5dae8f822994ad26bcf4b08fd841f9bf261"))) + + type check struct { + src string + dst string + port uint16 + want filter.Response + } + + tests := []struct { + name string + mapResponse *tailcfg.MapResponse + waitTest func(*netmap.NetworkMap) bool + + incrementalMapResponse *tailcfg.MapResponse // optional + incrementalWaitTest func(*netmap.NetworkMap) bool // optional + + checks []check + }{ + { + name: "IP_based_peers", + mapResponse: &tailcfg.MapResponse{ + Node: &tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")}, + }, + Peers: []*tailcfg.Node{{ + ID: 2, + Name: "foo", + Key: key, + Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, + CapMap: nil, + }}, + PacketFilter: []tailcfg.FilterRule{{ + SrcIPs: []string{"2.2.2.2/32"}, + DstPorts: []tailcfg.NetPortRange{{ + IP: "1.1.1.1/32", + Ports: tailcfg.PortRange{ + First: 22, + Last: 22, + }, + }}, + IPProto: []int{int(ipproto.TCP)}, + }}, + }, + waitTest: func(nm *netmap.NetworkMap) bool { + return len(nm.Peers) > 0 + }, + checks: []check{ + {src: "2.2.2.2", dst: "1.1.1.1", port: 22, want: filter.Accept}, + {src: "2.2.2.2", dst: "1.1.1.1", port: 23, want: filter.Drop}, // different port + {src: "3.3.3.3", dst: "1.1.1.1", port: 22, want: filter.Drop}, // different src + {src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst + }, + }, + { + name: "capmap_based_peers", + mapResponse: &tailcfg.MapResponse{ + Node: &tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")}, + }, + Peers: []*tailcfg.Node{{ + ID: 2, + Name: "foo", + Key: key, + Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, + CapMap: tailcfg.NodeCapMap{"X": nil}, + }}, + PacketFilter: []tailcfg.FilterRule{{ + SrcIPs: []string{"cap:X"}, + DstPorts: []tailcfg.NetPortRange{{ + IP: "1.1.1.1/32", + Ports: tailcfg.PortRange{ + First: 22, + Last: 22, + }, + }}, + IPProto: []int{int(ipproto.TCP)}, + }}, + }, + waitTest: func(nm *netmap.NetworkMap) bool { + return len(nm.Peers) > 0 + }, + checks: []check{ + {src: "2.2.2.2", dst: "1.1.1.1", port: 22, want: filter.Accept}, + {src: "2.2.2.2", dst: "1.1.1.1", port: 23, want: filter.Drop}, // different port + {src: "3.3.3.3", dst: "1.1.1.1", port: 22, want: filter.Drop}, // different src + {src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst + }, + }, + { + name: "capmap_based_peers_changed", + mapResponse: &tailcfg.MapResponse{ + Node: &tailcfg.Node{ + Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")}, + CapMap: tailcfg.NodeCapMap{"X-sigil": nil}, + }, + PacketFilter: []tailcfg.FilterRule{{ + SrcIPs: []string{"cap:label-1"}, + DstPorts: []tailcfg.NetPortRange{{ + IP: "1.1.1.1/32", + Ports: tailcfg.PortRange{ + First: 22, + Last: 22, + }, + }}, + IPProto: []int{int(ipproto.TCP)}, + }}, + }, + waitTest: func(nm *netmap.NetworkMap) bool { + return nm.SelfNode.HasCap("X-sigil") + }, + incrementalMapResponse: &tailcfg.MapResponse{ + PeersChanged: []*tailcfg.Node{{ + ID: 2, + Name: "foo", + Key: key, + Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, + CapMap: tailcfg.NodeCapMap{"label-1": nil}, + }}, + }, + incrementalWaitTest: func(nm *netmap.NetworkMap) bool { + return len(nm.Peers) > 0 + }, + checks: []check{ + {src: "2.2.2.2", dst: "1.1.1.1", port: 22, want: filter.Accept}, + {src: "2.2.2.2", dst: "1.1.1.1", port: 23, want: filter.Drop}, // different port + {src: "3.3.3.3", dst: "1.1.1.1", port: 22, want: filter.Drop}, // different src + {src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) + defer cancel() + + controlURL, c := startControl(t) + s, _, pubKey := startServer(t, ctx, controlURL, "node") + + if test.waitTest(s.lb.NetMap()) { + t.Fatal("waitTest already passes before sending initial netmap: this will be flaky") + } + + if !c.AddRawMapResponse(pubKey, test.mapResponse) { + t.Fatalf("could not send map response to %s", pubKey) + } + + if err := waitFor(t, ctx, s, test.waitTest); err != nil { + t.Fatalf("waitFor: %s", err) + } + + pf := s.lb.GetFilterForTest() + + for _, check := range test.checks { + got := pf.Check(netip.MustParseAddr(check.src), netip.MustParseAddr(check.dst), check.port, ipproto.TCP) + + want := check.want + if test.incrementalMapResponse != nil { + want = filter.Drop + } + if got != want { + t.Errorf("check %s -> %s:%d, got: %s, want: %s", check.src, check.dst, check.port, got, want) + } + } + + if test.incrementalMapResponse != nil { + if test.incrementalWaitTest == nil { + t.Fatal("incrementalWaitTest must be set if incrementalMapResponse is set") + } + + if test.incrementalWaitTest(s.lb.NetMap()) { + t.Fatal("incrementalWaitTest already passes before sending incremental netmap: this will be flaky") + } + + if !c.AddRawMapResponse(pubKey, test.incrementalMapResponse) { + t.Fatalf("could not send map response to %s", pubKey) + } + + if err := waitFor(t, ctx, s, test.incrementalWaitTest); err != nil { + t.Fatalf("waitFor: %s", err) + } + + pf := s.lb.GetFilterForTest() + + for _, check := range test.checks { + got := pf.Check(netip.MustParseAddr(check.src), netip.MustParseAddr(check.dst), check.port, ipproto.TCP) + if got != check.want { + t.Errorf("check %s -> %s:%d, got: %s, want: %s", check.src, check.dst, check.port, got, check.want) + } + } + } + + }) + } +} diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 6751e0bb03cbe..4664a66a796d4 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -26,12 +26,14 @@ import ( "sync" "time" + "tailscale.com/client/local" "tailscale.com/client/tailscale" "tailscale.com/control/controlclient" "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/hostinfo" "tailscale.com/ipn" + "tailscale.com/ipn/ipnauth" "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/ipnstate" "tailscale.com/ipn/localapi" @@ -46,6 +48,7 @@ import ( "tailscale.com/net/socks5" "tailscale.com/net/tsdial" "tailscale.com/tsd" + "tailscale.com/types/bools" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/nettype" @@ -78,7 +81,7 @@ type Server struct { // If nil, a new FileStore is initialized at `Dir/tailscaled.state`. // See tailscale.com/ipn/store for supported stores. // - // Logs will automatically be uploaded to log.tailscale.io, + // Logs will automatically be uploaded to log.tailscale.com, // where the configuration file for logging will be saved at // `Dir/tailscaled.log.conf`. Store ipn.StateStore @@ -126,17 +129,18 @@ type Server struct { initOnce sync.Once initErr error lb *ipnlocal.LocalBackend + sys *tsd.System netstack *netstack.Impl netMon *netmon.Monitor rootPath string // the state directory hostname string shutdownCtx context.Context shutdownCancel context.CancelFunc - proxyCred string // SOCKS5 proxy auth for loopbackListener - localAPICred string // basic auth password for loopbackListener - loopbackListener net.Listener // optional loopback for localapi and proxies - localAPIListener net.Listener // in-memory, used by localClient - localClient *tailscale.LocalClient // in-memory + proxyCred string // SOCKS5 proxy auth for loopbackListener + localAPICred string // basic auth password for loopbackListener + loopbackListener net.Listener // optional loopback for localapi and proxies + localAPIListener net.Listener // in-memory, used by localClient + localClient *local.Client // in-memory localAPIServer *http.Server logbuffer *filch.Filch logtail *logtail.Logger @@ -168,9 +172,41 @@ func (s *Server) Dial(ctx context.Context, network, address string) (net.Conn, e if err := s.Start(); err != nil { return nil, err } + if err := s.awaitRunning(ctx); err != nil { + return nil, err + } return s.dialer.UserDial(ctx, network, address) } +// awaitRunning waits until the backend is in state Running. +// If the backend is in state Starting, it blocks until it reaches +// a terminal state (such as Stopped, NeedsMachineAuth) +// or the context expires. +func (s *Server) awaitRunning(ctx context.Context) error { + st := s.lb.State() + for { + if err := ctx.Err(); err != nil { + return err + } + switch st { + case ipn.Running: + return nil + case ipn.NeedsLogin, ipn.Starting: + // Even after LocalBackend.Start, the state machine is still briefly + // in the "NeedsLogin" state. So treat that as also "Starting" and + // wait for us to get out of that state. + s.lb.WatchNotifications(ctx, ipn.NotifyInitialState, nil, func(n *ipn.Notify) (keepGoing bool) { + if n.State != nil { + st = *n.State + } + return st == ipn.NeedsLogin || st == ipn.Starting + }) + default: + return fmt.Errorf("tsnet: backend in state %v", st) + } + } +} + // HTTPClient returns an HTTP client that is configured to connect over Tailscale. // // This is useful if you need to have your tsnet services connect to other devices on @@ -187,7 +223,7 @@ func (s *Server) HTTPClient() *http.Client { // // It will start the server if it has not been started yet. If the server's // already been started successfully, it doesn't return an error. -func (s *Server) LocalClient() (*tailscale.LocalClient, error) { +func (s *Server) LocalClient() (*local.Client, error) { if err := s.Start(); err != nil { return nil, err } @@ -238,7 +274,7 @@ func (s *Server) Loopback() (addr string, proxyCred, localAPICred string, err er // out the CONNECT code from tailscaled/proxy.go that uses // httputil.ReverseProxy and adding auth support. go func() { - lah := localapi.NewHandler(s.lb, s.logf, s.logid) + lah := localapi.NewHandler(ipnauth.Self, s.lb, s.logf, s.logid) lah.PermitWrite = true lah.PermitRead = true lah.RequiredPassword = s.localAPICred @@ -398,8 +434,11 @@ func (s *Server) Close() error { for _, ln := range s.listeners { ln.closeLocked() } - wg.Wait() + + if bus := s.sys.Bus.Get(); bus != nil { + bus.Close() + } s.closed = true return nil } @@ -432,8 +471,7 @@ func (s *Server) TailscaleIPs() (ip4, ip6 netip.Addr) { return } addrs := nm.GetAddresses() - for i := range addrs.Len() { - addr := addrs.At(i) + for _, addr := range addrs.All() { ip := addr.Addr() if ip.Is6() { ip6 = ip @@ -469,6 +507,11 @@ func (s *Server) start() (reterr error) { // directory and hostname when they're not supplied. But we can fall // back to "tsnet" as well. exe = "tsnet" + case "ios": + // When compiled as a framework (via TailscaleKit in libtailscale), + // os.Executable() returns an error, so fall back to "tsnet" there + // too. + exe = "tsnet" default: return err } @@ -517,12 +560,13 @@ func (s *Server) start() (reterr error) { s.Logf(format, a...) } - sys := new(tsd.System) + sys := tsd.NewSystem() + s.sys = sys if err := s.startLogger(&closePool, sys.HealthTracker(), tsLogf); err != nil { return err } - s.netMon, err = netmon.New(tsLogf) + s.netMon, err = netmon.New(sys.Bus.Get(), tsLogf) if err != nil { return err } @@ -530,6 +574,7 @@ func (s *Server) start() (reterr error) { s.dialer = &tsdial.Dialer{Logf: tsLogf} // mutated below (before used) eng, err := wgengine.NewUserspaceEngine(tsLogf, wgengine.Config{ + EventBus: sys.Bus.Get(), ListenPort: s.Port, NetMon: s.netMon, Dialer: s.dialer, @@ -546,7 +591,7 @@ func (s *Server) start() (reterr error) { sys.HealthTracker().SetMetricsRegistry(sys.UserMetricsRegistry()) // TODO(oxtoacart): do we need to support Taildrive on tsnet, and if so, how? - ns, err := netstack.Create(tsLogf, sys.Tun.Get(), eng, sys.MagicSock.Get(), s.dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil) + ns, err := netstack.Create(tsLogf, sys.Tun.Get(), eng, sys.MagicSock.Get(), s.dialer, sys.DNSManager.Get(), sys.ProxyMapper()) if err != nil { return fmt.Errorf("netstack.Create: %w", err) } @@ -565,7 +610,9 @@ func (s *Server) start() (reterr error) { // Note: don't just return ns.DialContextTCP or we'll return // *gonet.TCPConn(nil) instead of a nil interface which trips up // callers. - tcpConn, err := ns.DialContextTCP(ctx, dst) + v4, v6 := s.TailscaleIPs() + src := bools.IfElse(dst.Addr().Is6(), v6, v4) + tcpConn, err := ns.DialContextTCPWithBind(ctx, src, dst) if err != nil { return nil, err } @@ -575,7 +622,9 @@ func (s *Server) start() (reterr error) { // Note: don't just return ns.DialContextUDP or we'll return // *gonet.UDPConn(nil) instead of a nil interface which trips up // callers. - udpConn, err := ns.DialContextUDP(ctx, dst) + v4, v6 := s.TailscaleIPs() + src := bools.IfElse(dst.Addr().Is6(), v6, v4) + udpConn, err := ns.DialContextUDPWithBind(ctx, src, dst) if err != nil { return nil, err } @@ -633,7 +682,7 @@ func (s *Server) start() (reterr error) { go s.printAuthURLLoop() // Run the localapi handler, to allow fetching LetsEncrypt certs. - lah := localapi.NewHandler(lb, tsLogf, s.logid) + lah := localapi.NewHandler(ipnauth.Self, lb, tsLogf, s.logid) lah.PermitWrite = true lah.PermitRead = true @@ -641,7 +690,7 @@ func (s *Server) start() (reterr error) { // nettest.Listen provides a in-memory pipe based implementation for net.Conn. lal := memnet.Listen("local-tailscaled.sock:80") s.localAPIListener = lal - s.localClient = &tailscale.LocalClient{Dial: lal.Dial} + s.localClient = &local.Client{Dial: lal.Dial} s.localAPIServer = &http.Server{Handler: lah} s.lb.ConfigureWebClient(s.localClient) go func() { @@ -894,6 +943,8 @@ func getTSNetDir(logf logger.Logf, confDir, prog string) (string, error) { // APIClient returns a tailscale.Client that can be used to make authenticated // requests to the Tailscale control server. // It requires the user to set tailscale.I_Acknowledge_This_API_Is_Unstable. +// +// Deprecated: use AuthenticatedAPITransport with tailscale.com/client/tailscale/v2 instead. func (s *Server) APIClient() (*tailscale.Client, error) { if !tailscale.I_Acknowledge_This_API_Is_Unstable { return nil, errors.New("use of Client without setting I_Acknowledge_This_API_Is_Unstable") @@ -908,6 +959,41 @@ func (s *Server) APIClient() (*tailscale.Client, error) { return c, nil } +// I_Acknowledge_This_API_Is_Experimental must be set true to use AuthenticatedAPITransport() +// for now. +var I_Acknowledge_This_API_Is_Experimental = false + +// AuthenticatedAPITransport provides an HTTP transport that can be used with +// the control server API without needing additional authentication details. It +// authenticates using the current client's nodekey. +// +// It requires the user to set I_Acknowledge_This_API_Is_Experimental. +// +// For example: +// +// import "net/http" +// import "tailscale.com/client/tailscale/v2" +// import "tailscale.com/tsnet" +// +// var s *tsnet.Server +// ... +// rt, err := s.AuthenticatedAPITransport() +// // handler err ... +// var client tailscale.Client{HTTP: http.Client{ +// Timeout: 1*time.Minute, +// UserAgent: "your-useragent-here", +// Transport: rt, +// }} +func (s *Server) AuthenticatedAPITransport() (http.RoundTripper, error) { + if !I_Acknowledge_This_API_Is_Experimental { + return nil, errors.New("use of AuthenticatedAPITransport without setting I_Acknowledge_This_API_Is_Experimental") + } + if err := s.Start(); err != nil { + return nil, err + } + return s.lb.KeyProvingNoiseRoundTripper(), nil +} + // Listen announces only on the Tailscale network. // It will start the server if it has not been started yet. // @@ -1014,13 +1100,33 @@ type FunnelOption interface { funnelOption() } -type funnelOnly int +type funnelOnly struct{} func (funnelOnly) funnelOption() {} // FunnelOnly configures the listener to only respond to connections from Tailscale Funnel. // The local tailnet will not be able to connect to the listener. -func FunnelOnly() FunnelOption { return funnelOnly(1) } +func FunnelOnly() FunnelOption { return funnelOnly{} } + +type funnelTLSConfig struct{ conf *tls.Config } + +func (f funnelTLSConfig) funnelOption() {} + +// FunnelTLSConfig configures the TLS configuration for [Server.ListenFunnel] +// +// This is rarely needed but can permit requiring client certificates, specific +// ciphers suites, etc. +// +// The provided conf should at least be able to get a certificate, setting +// GetCertificate, Certificates or GetConfigForClient appropriately. +// The most common configuration is to set GetCertificate to +// Server.LocalClient's GetCertificate method. +// +// Unless [FunnelOnly] is also used, the configuration is also used for +// in-tailnet connections that don't arrive over Funnel. +func FunnelTLSConfig(conf *tls.Config) FunnelOption { + return funnelTLSConfig{conf: conf} +} // ListenFunnel announces on the public internet using Tailscale Funnel. // @@ -1053,6 +1159,26 @@ func (s *Server) ListenFunnel(network, addr string, opts ...FunnelOption) (net.L return nil, err } + // Process, validate opts. + lnOn := listenOnBoth + var tlsConfig *tls.Config + for _, opt := range opts { + switch v := opt.(type) { + case funnelTLSConfig: + if v.conf == nil { + return nil, errors.New("invalid nil FunnelTLSConfig") + } + tlsConfig = v.conf + case funnelOnly: + lnOn = listenOnFunnel + default: + return nil, fmt.Errorf("unknown opts FunnelOption type %T", v) + } + } + if tlsConfig == nil { + tlsConfig = &tls.Config{GetCertificate: s.getCert} + } + ctx := context.Background() st, err := s.Up(ctx) if err != nil { @@ -1090,19 +1216,11 @@ func (s *Server) ListenFunnel(network, addr string, opts ...FunnelOption) (net.L } // Start a funnel listener. - lnOn := listenOnBoth - for _, opt := range opts { - if _, ok := opt.(funnelOnly); ok { - lnOn = listenOnFunnel - } - } ln, err := s.listen(network, addr, lnOn) if err != nil { return nil, err } - return tls.NewListener(ln, &tls.Config{ - GetCertificate: s.getCert, - }), nil + return tls.NewListener(ln, tlsConfig), nil } type listenOn string @@ -1179,7 +1297,8 @@ func (s *Server) listen(network, addr string, lnOn listenOn) (net.Listener, erro keys: keys, addr: addr, - conn: make(chan net.Conn), + closedc: make(chan struct{}), + conn: make(chan net.Conn), } s.mu.Lock() for _, key := range keys { @@ -1227,6 +1346,13 @@ func (s *Server) CapturePcap(ctx context.Context, pcapFile string) error { return nil } +// Sys returns a handle to the Tailscale subsystems of this node. +// +// This is not a stable API, nor are the APIs of the returned subsystems. +func (s *Server) Sys() *tsd.System { + return s.sys +} + type listenKey struct { network string host netip.Addr // or zero value for unspecified @@ -1235,19 +1361,21 @@ type listenKey struct { } type listener struct { - s *Server - keys []listenKey - addr string - conn chan net.Conn - closed bool // guarded by s.mu + s *Server + keys []listenKey + addr string + conn chan net.Conn // unbuffered, never closed + closedc chan struct{} // closed on [listener.Close] + closed bool // guarded by s.mu } func (ln *listener) Accept() (net.Conn, error) { - c, ok := <-ln.conn - if !ok { + select { + case c := <-ln.conn: + return c, nil + case <-ln.closedc: return nil, fmt.Errorf("tsnet: %w", net.ErrClosed) } - return c, nil } func (ln *listener) Addr() net.Addr { return addr{ln} } @@ -1269,21 +1397,22 @@ func (ln *listener) closeLocked() error { delete(ln.s.listeners, key) } } - close(ln.conn) + close(ln.closedc) ln.closed = true return nil } func (ln *listener) handle(c net.Conn) { - t := time.NewTimer(time.Second) - defer t.Stop() select { case ln.conn <- c: - case <-t.C: + return + case <-ln.closedc: + case <-ln.s.shutdownCtx.Done(): + case <-time.After(time.Second): // TODO(bradfitz): this isn't ideal. Think about how // we how we want to do pushback. - c.Close() } + c.Close() } // Server returns the tsnet Server associated with the listener. diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 98c1fd4ab3462..d00628453260f 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -36,9 +36,8 @@ import ( dto "github.com/prometheus/client_model/go" "github.com/prometheus/common/expfmt" "golang.org/x/net/proxy" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/cmd/testwrapper/flakytest" - "tailscale.com/health" "tailscale.com/ipn" "tailscale.com/ipn/store/mem" "tailscale.com/net/netns" @@ -121,6 +120,7 @@ func startControl(t *testing.T) (controlURL string, control *testcontrol.Server) Proxied: true, }, MagicDNSDomain: "tail-scale.ts.net", + Logf: t.Logf, } control.HTTPTestServer = httptest.NewUnstartedServer(control) control.HTTPTestServer.Start() @@ -222,7 +222,7 @@ func startServer(t *testing.T, ctx context.Context, controlURL, hostname string) getCertForTesting: testCertRoot.getCert, } if *verboseNodes { - s.Logf = log.Printf + s.Logf = t.Logf } t.Cleanup(func() { s.Close() }) @@ -233,6 +233,46 @@ func startServer(t *testing.T, ctx context.Context, controlURL, hostname string) return s, status.TailscaleIPs[0], status.Self.PublicKey } +func TestDialBlocks(t *testing.T) { + tstest.ResourceCheck(t) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + controlURL, _ := startControl(t) + + // Make one tsnet that blocks until it's up. + s1, _, _ := startServer(t, ctx, controlURL, "s1") + + ln, err := s1.Listen("tcp", ":8080") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + // Then make another tsnet node that will only be woken up + // upon the first dial. + tmp := filepath.Join(t.TempDir(), "s2") + os.MkdirAll(tmp, 0755) + s2 := &Server{ + Dir: tmp, + ControlURL: controlURL, + Hostname: "s2", + Store: new(mem.Store), + Ephemeral: true, + getCertForTesting: testCertRoot.getCert, + } + if *verboseNodes { + s2.Logf = log.Printf + } + t.Cleanup(func() { s2.Close() }) + + c, err := s2.Dial(ctx, "tcp", "s1:8080") + if err != nil { + t.Fatal(err) + } + defer c.Close() +} + func TestConn(t *testing.T) { tstest.ResourceCheck(t) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) @@ -495,6 +535,25 @@ func TestListenerCleanup(t *testing.T) { if err := ln.Close(); !errors.Is(err, net.ErrClosed) { t.Fatalf("second ln.Close error: %v, want net.ErrClosed", err) } + + // Verify that handling a connection from gVisor (from a packet arriving) + // after a listener closed doesn't panic (previously: sending on a closed + // channel) or hang. + c := &closeTrackConn{} + ln.(*listener).handle(c) + if !c.closed { + t.Errorf("c.closed = false, want true") + } +} + +type closeTrackConn struct { + net.Conn + closed bool +} + +func (wc *closeTrackConn) Close() error { + wc.closed = true + return nil } // tests https://github.com/tailscale/tailscale/issues/6973 -- that we can start a tsnet server, @@ -609,6 +668,37 @@ func TestFunnel(t *testing.T) { } } +func TestListenerClose(t *testing.T) { + ctx := context.Background() + controlURL, _ := startControl(t) + + s1, _, _ := startServer(t, ctx, controlURL, "s1") + + ln, err := s1.Listen("tcp", ":8080") + if err != nil { + t.Fatal(err) + } + + errc := make(chan error, 1) + go func() { + c, err := ln.Accept() + if c != nil { + c.Close() + } + errc <- err + }() + + ln.Close() + select { + case err := <-errc: + if !errors.Is(err, net.ErrClosed) { + t.Errorf("unexpected error: %v", err) + } + case <-time.After(10 * time.Second): + t.Fatal("timeout waiting for Accept to return") + } +} + func dialIngressConn(from, to *Server, target string) (net.Conn, error) { toLC := must.Get(to.LocalClient()) toStatus := must.Get(toLC.StatusWithoutPeers(context.Background())) @@ -822,16 +912,6 @@ func TestUDPConn(t *testing.T) { } } -// testWarnable is a Warnable that is used within this package for testing purposes only. -var testWarnable = health.Register(&health.Warnable{ - Code: "test-warnable-tsnet", - Title: "Test warnable", - Severity: health.SeverityLow, - Text: func(args health.Args) string { - return args[health.ArgError] - }, -}) - func parseMetrics(m []byte) (map[string]float64, error) { metrics := make(map[string]float64) @@ -905,9 +985,11 @@ func sendData(logf func(format string, args ...any), ctx context.Context, bytesC for { got := make([]byte, bytesCount) n, err := conn.Read(got) - if n != bytesCount { - logf("read %d bytes, want %d", n, bytesCount) + if err != nil { + allReceived <- fmt.Errorf("failed reading packet, %s", err) + return } + got = got[:n] select { case <-stopReceive: @@ -915,13 +997,17 @@ func sendData(logf func(format string, args ...any), ctx context.Context, bytesC default: } - if err != nil { - allReceived <- fmt.Errorf("failed reading packet, %s", err) - return - } - total += n logf("received %d/%d bytes, %.2f %%", total, bytesCount, (float64(total) / (float64(bytesCount)) * 100)) + + // Validate the received bytes to be the same as the sent bytes. + for _, b := range string(got) { + if b != 'A' { + allReceived <- fmt.Errorf("received unexpected byte: %c", b) + return + } + } + if total == bytesCount { break } @@ -947,15 +1033,135 @@ func sendData(logf func(format string, args ...any), ctx context.Context, bytesC return nil } -func TestUserMetrics(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/13420") - tstest.ResourceCheck(t) +func TestUserMetricsByteCounters(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) defer cancel() - controlURL, c := startControl(t) - s1, s1ip, s1PubKey := startServer(t, ctx, controlURL, "s1") + controlURL, _ := startControl(t) + s1, s1ip, _ := startServer(t, ctx, controlURL, "s1") + defer s1.Close() s2, s2ip, _ := startServer(t, ctx, controlURL, "s2") + defer s2.Close() + + lc1, err := s1.LocalClient() + if err != nil { + t.Fatal(err) + } + + lc2, err := s2.LocalClient() + if err != nil { + t.Fatal(err) + } + + // Force an update to the netmap to ensure that the metrics are up-to-date. + s1.lb.DebugForceNetmapUpdate() + s2.lb.DebugForceNetmapUpdate() + + // Wait for both nodes to have a peer in their netmap. + waitForCondition(t, "waiting for netmaps to contain peer", 90*time.Second, func() bool { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + status1, err := lc1.Status(ctx) + if err != nil { + t.Logf("getting status: %s", err) + return false + } + status2, err := lc2.Status(ctx) + if err != nil { + t.Logf("getting status: %s", err) + return false + } + return len(status1.Peers()) > 0 && len(status2.Peers()) > 0 + }) + + // ping to make sure the connection is up. + res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP) + if err != nil { + t.Fatalf("pinging: %s", err) + } + t.Logf("ping success: %#+v", res) + + mustDirect(t, t.Logf, lc1, lc2) + + // 1 megabytes + bytesToSend := 1 * 1024 * 1024 + + // This asserts generates some traffic, it is factored out + // of TestUDPConn. + start := time.Now() + err = sendData(t.Logf, ctx, bytesToSend, s1, s2, s1ip, s2ip) + if err != nil { + t.Fatalf("Failed to send packets: %v", err) + } + t.Logf("Sent %d bytes from s1 to s2 in %s", bytesToSend, time.Since(start).String()) + + ctxLc, cancelLc := context.WithTimeout(context.Background(), 5*time.Second) + defer cancelLc() + metrics1, err := lc1.UserMetrics(ctxLc) + if err != nil { + t.Fatal(err) + } + + parsedMetrics1, err := parseMetrics(metrics1) + if err != nil { + t.Fatal(err) + } + + // Allow the metrics for the bytes sent to be off by 15%. + bytesSentTolerance := 1.15 + + t.Logf("Metrics1:\n%s\n", metrics1) + + // Verify that the amount of data recorded in bytes is higher or equal to the data sent + inboundBytes1 := parsedMetrics1[`tailscaled_inbound_bytes_total{path="direct_ipv4"}`] + if inboundBytes1 < float64(bytesToSend) { + t.Errorf(`metrics1, tailscaled_inbound_bytes_total{path="direct_ipv4"}: expected higher (or equal) than %d, got: %f`, bytesToSend, inboundBytes1) + } + + // But ensure that it is not too much higher than the data sent. + if inboundBytes1 > float64(bytesToSend)*bytesSentTolerance { + t.Errorf(`metrics1, tailscaled_inbound_bytes_total{path="direct_ipv4"}: expected lower than %f, got: %f`, float64(bytesToSend)*bytesSentTolerance, inboundBytes1) + } + + metrics2, err := lc2.UserMetrics(ctx) + if err != nil { + t.Fatal(err) + } + + parsedMetrics2, err := parseMetrics(metrics2) + if err != nil { + t.Fatal(err) + } + + t.Logf("Metrics2:\n%s\n", metrics2) + + // Verify that the amount of data recorded in bytes is higher or equal than the data sent. + outboundBytes2 := parsedMetrics2[`tailscaled_outbound_bytes_total{path="direct_ipv4"}`] + if outboundBytes2 < float64(bytesToSend) { + t.Errorf(`metrics2, tailscaled_outbound_bytes_total{path="direct_ipv4"}: expected higher (or equal) than %d, got: %f`, bytesToSend, outboundBytes2) + } + + // But ensure that it is not too much higher than the data sent. + if outboundBytes2 > float64(bytesToSend)*bytesSentTolerance { + t.Errorf(`metrics2, tailscaled_outbound_bytes_total{path="direct_ipv4"}: expected lower than %f, got: %f`, float64(bytesToSend)*bytesSentTolerance, outboundBytes2) + } +} + +func TestUserMetricsRouteGauges(t *testing.T) { + // Windows does not seem to support or report back routes when running in + // userspace via tsnet. So, we skip this check on Windows. + // TODO(kradalby): Figure out if this is correct. + if runtime.GOOS == "windows" { + t.Skipf("skipping on windows") + } + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + controlURL, c := startControl(t) + s1, _, s1PubKey := startServer(t, ctx, controlURL, "s1") + defer s1.Close() + s2, _, _ := startServer(t, ctx, controlURL, "s2") + defer s2.Close() s1.lb.EditPrefs(&ipn.MaskedPrefs{ Prefs: ipn.Prefs{ @@ -984,24 +1190,11 @@ func TestUserMetrics(t *testing.T) { t.Fatal(err) } - // ping to make sure the connection is up. - res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP) - if err != nil { - t.Fatalf("pinging: %s", err) - } - t.Logf("ping success: %#+v", res) - - ht := s1.lb.HealthTracker() - ht.SetUnhealthy(testWarnable, health.Args{"Text": "Hello world 1"}) - // Force an update to the netmap to ensure that the metrics are up-to-date. s1.lb.DebugForceNetmapUpdate() s2.lb.DebugForceNetmapUpdate() wantRoutes := float64(2) - if runtime.GOOS == "windows" { - wantRoutes = 0 - } // Wait for the routes to be propagated to node 1 to ensure // that the metrics are up-to-date. @@ -1013,31 +1206,11 @@ func TestUserMetrics(t *testing.T) { t.Logf("getting status: %s", err) return false } - if runtime.GOOS == "windows" { - // Windows does not seem to support or report back routes when running in - // userspace via tsnet. So, we skip this check on Windows. - // TODO(kradalby): Figure out if this is correct. - return true - } // Wait for the primary routes to reach our desired routes, which is wantRoutes + 1, because // the PrimaryRoutes list will contain a exit node route, which the metric does not count. return status1.Self.PrimaryRoutes != nil && status1.Self.PrimaryRoutes.Len() == int(wantRoutes)+1 }) - mustDirect(t, t.Logf, lc1, lc2) - - // 10 megabytes - bytesToSend := 10 * 1024 * 1024 - - // This asserts generates some traffic, it is factored out - // of TestUDPConn. - start := time.Now() - err = sendData(t.Logf, ctx, bytesToSend, s1, s2, s1ip, s2ip) - if err != nil { - t.Fatalf("Failed to send packets: %v", err) - } - t.Logf("Sent %d bytes from s1 to s2 in %s", bytesToSend, time.Since(start).String()) - ctxLc, cancelLc := context.WithTimeout(context.Background(), 5*time.Second) defer cancelLc() metrics1, err := lc1.UserMetrics(ctxLc) @@ -1045,19 +1218,11 @@ func TestUserMetrics(t *testing.T) { t.Fatal(err) } - status1, err := lc1.Status(ctxLc) - if err != nil { - t.Fatal(err) - } - parsedMetrics1, err := parseMetrics(metrics1) if err != nil { t.Fatal(err) } - // Allow the metrics for the bytes sent to be off by 15%. - bytesSentTolerance := 1.15 - t.Logf("Metrics1:\n%s\n", metrics1) // The node is advertising 4 routes: @@ -1075,40 +1240,11 @@ func TestUserMetrics(t *testing.T) { t.Errorf("metrics1, tailscaled_approved_routes: got %v, want %v", got, want) } - // Validate the health counter metric against the status of the node - if got, want := parsedMetrics1[`tailscaled_health_messages{type="warning"}`], float64(len(status1.Health)); got != want { - t.Errorf("metrics1, tailscaled_health_messages: got %v, want %v", got, want) - } - - // The node is the primary subnet router for 2 routes: - // - 192.0.2.0/24 - // - 192.0.5.1/32 - if got, want := parsedMetrics1["tailscaled_primary_routes"], wantRoutes; got != want { - t.Errorf("metrics1, tailscaled_primary_routes: got %v, want %v", got, want) - } - - // Verify that the amount of data recorded in bytes is higher or equal to the - // 10 megabytes sent. - inboundBytes1 := parsedMetrics1[`tailscaled_inbound_bytes_total{path="direct_ipv4"}`] - if inboundBytes1 < float64(bytesToSend) { - t.Errorf(`metrics1, tailscaled_inbound_bytes_total{path="direct_ipv4"}: expected higher (or equal) than %d, got: %f`, bytesToSend, inboundBytes1) - } - - // But ensure that it is not too much higher than the 10 megabytes sent. - if inboundBytes1 > float64(bytesToSend)*bytesSentTolerance { - t.Errorf(`metrics1, tailscaled_inbound_bytes_total{path="direct_ipv4"}: expected lower than %f, got: %f`, float64(bytesToSend)*bytesSentTolerance, inboundBytes1) - } - metrics2, err := lc2.UserMetrics(ctx) if err != nil { t.Fatal(err) } - status2, err := lc2.Status(ctx) - if err != nil { - t.Fatal(err) - } - parsedMetrics2, err := parseMetrics(metrics2) if err != nil { t.Fatal(err) @@ -1125,28 +1261,6 @@ func TestUserMetrics(t *testing.T) { if got, want := parsedMetrics2["tailscaled_approved_routes"], 0.0; got != want { t.Errorf("metrics2, tailscaled_approved_routes: got %v, want %v", got, want) } - - // Validate the health counter metric against the status of the node - if got, want := parsedMetrics2[`tailscaled_health_messages{type="warning"}`], float64(len(status2.Health)); got != want { - t.Errorf("metrics2, tailscaled_health_messages: got %v, want %v", got, want) - } - - // The node is the primary subnet router for 0 routes - if got, want := parsedMetrics2["tailscaled_primary_routes"], 0.0; got != want { - t.Errorf("metrics2, tailscaled_primary_routes: got %v, want %v", got, want) - } - - // Verify that the amount of data recorded in bytes is higher or equal than the - // 10 megabytes sent. - outboundBytes2 := parsedMetrics2[`tailscaled_outbound_bytes_total{path="direct_ipv4"}`] - if outboundBytes2 < float64(bytesToSend) { - t.Errorf(`metrics2, tailscaled_outbound_bytes_total{path="direct_ipv4"}: expected higher (or equal) than %d, got: %f`, bytesToSend, outboundBytes2) - } - - // But ensure that it is not too much higher than the 10 megabytes sent. - if outboundBytes2 > float64(bytesToSend)*bytesSentTolerance { - t.Errorf(`metrics2, tailscaled_outbound_bytes_total{path="direct_ipv4"}: expected lower than %f, got: %f`, float64(bytesToSend)*bytesSentTolerance, outboundBytes2) - } } func waitForCondition(t *testing.T, msg string, waitTime time.Duration, f func() bool) { @@ -1160,7 +1274,7 @@ func waitForCondition(t *testing.T, msg string, waitTime time.Duration, f func() } // mustDirect ensures there is a direct connection between LocalClient 1 and 2 -func mustDirect(t *testing.T, logf logger.Logf, lc1, lc2 *tailscale.LocalClient) { +func mustDirect(t *testing.T, logf logger.Logf, lc1, lc2 *local.Client) { t.Helper() lastLog := time.Now().Add(-time.Minute) // See https://github.com/tailscale/tailscale/issues/654 diff --git a/tstest/deptest/deptest.go b/tstest/deptest/deptest.go index 57db2b79aa3c7..4effd4a7883af 100644 --- a/tstest/deptest/deptest.go +++ b/tstest/deptest/deptest.go @@ -13,14 +13,21 @@ import ( "path/filepath" "regexp" "runtime" + "slices" "strings" + "sync" "testing" + + "tailscale.com/util/set" ) type DepChecker struct { - GOOS string // optional - GOARCH string // optional - BadDeps map[string]string // package => why + GOOS string // optional + GOARCH string // optional + BadDeps map[string]string // package => why + WantDeps set.Set[string] // packages expected + Tags string // comma-separated + ExtraEnv []string // extra environment for "go list" (e.g. CGO_ENABLED=1) } func (c DepChecker) Check(t *testing.T) { @@ -29,7 +36,7 @@ func (c DepChecker) Check(t *testing.T) { t.Skip("skipping dep tests on windows hosts") } t.Helper() - cmd := exec.Command("go", "list", "-json", ".") + cmd := exec.Command("go", "list", "-json", "-tags="+c.Tags, ".") var extraEnv []string if c.GOOS != "" { extraEnv = append(extraEnv, "GOOS="+c.GOOS) @@ -37,6 +44,7 @@ func (c DepChecker) Check(t *testing.T) { if c.GOARCH != "" { extraEnv = append(extraEnv, "GOARCH="+c.GOARCH) } + extraEnv = append(extraEnv, c.ExtraEnv...) cmd.Env = append(os.Environ(), extraEnv...) out, err := cmd.Output() if err != nil { @@ -49,11 +57,40 @@ func (c DepChecker) Check(t *testing.T) { t.Fatal(err) } + tsRoot := sync.OnceValue(func() string { + out, err := exec.Command("go", "list", "-f", "{{.Dir}}", "tailscale.com").Output() + if err != nil { + t.Fatalf("failed to find tailscale.com root: %v", err) + } + return strings.TrimSpace(string(out)) + }) + for _, dep := range res.Deps { if why, ok := c.BadDeps[dep]; ok { t.Errorf("package %q is not allowed as a dependency (env: %q); reason: %s", dep, extraEnv, why) } } + // Make sure the BadDeps packages actually exists. If they got renamed or + // moved around, we should update the test referencing the old name. + // Doing this in the general case requires network access at runtime + // (resolving a package path to its module, possibly doing the ?go-get=1 + // meta tag dance), so we just check the common case of + // "tailscale.com/*" packages for now, with the assumption that all + // "tailscale.com/*" packages are in the same module, which isn't + // necessarily true in the general case. + for dep := range c.BadDeps { + if suf, ok := strings.CutPrefix(dep, "tailscale.com/"); ok { + pkgDir := filepath.Join(tsRoot(), suf) + if _, err := os.Stat(pkgDir); err != nil { + t.Errorf("listed BadDep %q doesn't seem to exist anymore: %v", dep, err) + } + } + } + for dep := range c.WantDeps { + if !slices.Contains(res.Deps, dep) { + t.Errorf("expected package %q to be a dependency (env: %q)", dep, extraEnv) + } + } t.Logf("got %d dependencies", len(res.Deps)) } diff --git a/tstest/integration/integration.go b/tstest/integration/integration.go index 36a92759f7dd4..d64bfbbd9d755 100644 --- a/tstest/integration/integration.go +++ b/tstest/integration/integration.go @@ -9,92 +9,207 @@ package integration import ( "bytes" + "context" "crypto/tls" "encoding/json" + "flag" "fmt" "io" "log" "net" "net/http" "net/http/httptest" + "net/netip" "os" "os/exec" "path" "path/filepath" + "regexp" "runtime" + "strconv" "strings" "sync" "testing" "time" "go4.org/mem" + "tailscale.com/client/local" "tailscale.com/derp" "tailscale.com/derp/derphttp" + "tailscale.com/ipn" + "tailscale.com/ipn/ipnlocal" + "tailscale.com/ipn/ipnstate" + "tailscale.com/ipn/store" "tailscale.com/net/stun/stuntest" + "tailscale.com/safesocket" + "tailscale.com/syncs" "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/tstest/integration/testcontrol" "tailscale.com/types/key" "tailscale.com/types/logger" "tailscale.com/types/logid" "tailscale.com/types/nettype" + "tailscale.com/util/rands" "tailscale.com/util/zstdframe" "tailscale.com/version" ) -// CleanupBinaries cleans up any resources created by calls to BinaryDir, TailscaleBinary, or TailscaledBinary. -// It should be called from TestMain after all tests have completed. -func CleanupBinaries() { - buildOnce.Do(func() {}) - if binDir != "" { - os.RemoveAll(binDir) +var ( + verboseTailscaled = flag.Bool("verbose-tailscaled", false, "verbose tailscaled logging") + verboseTailscale = flag.Bool("verbose-tailscale", false, "verbose tailscale CLI logging") +) + +// MainError is an error that's set if an error conditions happens outside of a +// context where a testing.TB is available. The caller can check it in its TestMain +// as a last ditch place to report errors. +var MainError syncs.AtomicValue[error] + +// Binaries contains the paths to the tailscale and tailscaled binaries. +type Binaries struct { + Dir string + Tailscale BinaryInfo + Tailscaled BinaryInfo +} + +// BinaryInfo describes a tailscale or tailscaled binary. +type BinaryInfo struct { + Path string // abs path to tailscale or tailscaled binary + Size int64 + + // FD and FDmu are set on Unix to efficiently copy the binary to a new + // test's automatically-cleaned-up temp directory. + FD *os.File // for Unix (macOS, Linux, ...) + FDMu sync.Locker + + // Contents is used on Windows instead of FD to copy the binary between + // test directories. (On Windows you can't keep an FD open while an earlier + // test's temp directories are deleted.) + // This burns some memory and costs more in I/O, but oh well. + Contents []byte +} + +func (b BinaryInfo) CopyTo(dir string) (BinaryInfo, error) { + ret := b + ret.Path = filepath.Join(dir, path.Base(b.Path)) + + switch runtime.GOOS { + case "linux": + // TODO(bradfitz): be fancy and use linkat with AT_EMPTY_PATH to avoid + // copying? I couldn't get it to work, though. + // For now, just do the same thing as every other Unix and copy + // the binary. + fallthrough + case "darwin", "freebsd", "openbsd", "netbsd": + f, err := os.OpenFile(ret.Path, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0o755) + if err != nil { + return BinaryInfo{}, err + } + b.FDMu.Lock() + b.FD.Seek(0, 0) + size, err := io.Copy(f, b.FD) + b.FDMu.Unlock() + if err != nil { + f.Close() + return BinaryInfo{}, fmt.Errorf("copying %q: %w", b.Path, err) + } + if size != b.Size { + f.Close() + return BinaryInfo{}, fmt.Errorf("copy %q: size mismatch: %d != %d", b.Path, size, b.Size) + } + if err := f.Close(); err != nil { + return BinaryInfo{}, err + } + return ret, nil + case "windows": + return ret, os.WriteFile(ret.Path, b.Contents, 0o755) + default: + return BinaryInfo{}, fmt.Errorf("unsupported OS %q", runtime.GOOS) } } -// BinaryDir returns a directory containing test tailscale and tailscaled binaries. -// If any test calls BinaryDir, there must be a TestMain function that calls -// CleanupBinaries after all tests are complete. -func BinaryDir(tb testing.TB) string { +// GetBinaries create a temp directory using tb and builds (or copies previously +// built) cmd/tailscale and cmd/tailscaled binaries into that directory. +// +// It fails tb if the build or binary copies fail. +func GetBinaries(tb testing.TB) *Binaries { + dir := tb.TempDir() buildOnce.Do(func() { - binDir, buildErr = buildTestBinaries() + buildErr = buildTestBinaries(dir) }) if buildErr != nil { tb.Fatal(buildErr) } - return binDir -} - -// TailscaleBinary returns the path to the test tailscale binary. -// If any test calls TailscaleBinary, there must be a TestMain function that calls -// CleanupBinaries after all tests are complete. -func TailscaleBinary(tb testing.TB) string { - return filepath.Join(BinaryDir(tb), "tailscale"+exe()) -} - -// TailscaledBinary returns the path to the test tailscaled binary. -// If any test calls TailscaleBinary, there must be a TestMain function that calls -// CleanupBinaries after all tests are complete. -func TailscaledBinary(tb testing.TB) string { - return filepath.Join(BinaryDir(tb), "tailscaled"+exe()) + if binariesCache.Dir == dir { + return binariesCache + } + ts, err := binariesCache.Tailscale.CopyTo(dir) + if err != nil { + tb.Fatalf("copying tailscale binary: %v", err) + } + tsd, err := binariesCache.Tailscaled.CopyTo(dir) + if err != nil { + tb.Fatalf("copying tailscaled binary: %v", err) + } + return &Binaries{ + Dir: dir, + Tailscale: ts, + Tailscaled: tsd, + } } var ( - buildOnce sync.Once - buildErr error - binDir string + buildOnce sync.Once + buildErr error + binariesCache *Binaries ) // buildTestBinaries builds tailscale and tailscaled. -// It returns the dir containing the binaries. -func buildTestBinaries() (string, error) { - bindir, err := os.MkdirTemp("", "") +// On success, it initializes [binariesCache]. +func buildTestBinaries(dir string) error { + getBinaryInfo := func(name string) (BinaryInfo, error) { + bi := BinaryInfo{Path: filepath.Join(dir, name+exe())} + fi, err := os.Stat(bi.Path) + if err != nil { + return BinaryInfo{}, fmt.Errorf("stat %q: %v", bi.Path, err) + } + bi.Size = fi.Size() + + switch runtime.GOOS { + case "windows": + bi.Contents, err = os.ReadFile(bi.Path) + if err != nil { + return BinaryInfo{}, fmt.Errorf("read %q: %v", bi.Path, err) + } + default: + bi.FD, err = os.OpenFile(bi.Path, os.O_RDONLY, 0) + if err != nil { + return BinaryInfo{}, fmt.Errorf("open %q: %v", bi.Path, err) + } + bi.FDMu = new(sync.Mutex) + // Note: bi.FD is copied around between tests but never closed, by + // design. It will be closed when the process exits, and that will + // close the inode that we're copying the bytes from for each test. + } + return bi, nil + } + err := build(dir, "tailscale.com/cmd/tailscaled", "tailscale.com/cmd/tailscale") if err != nil { - return "", err + return err } - err = build(bindir, "tailscale.com/cmd/tailscaled", "tailscale.com/cmd/tailscale") + b := &Binaries{ + Dir: dir, + } + b.Tailscale, err = getBinaryInfo("tailscale") if err != nil { - os.RemoveAll(bindir) - return "", err + return err } - return bindir, nil + b.Tailscaled, err = getBinaryInfo("tailscaled") + if err != nil { + return err + } + binariesCache = b + return nil } func build(outDir string, targets ...string) error { @@ -361,3 +476,632 @@ func (lc *LogCatcher) ServeHTTP(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(200) // must have no content, but not a 204 } + +// TestEnv contains the test environment (set of servers) used by one +// or more nodes. +type TestEnv struct { + t testing.TB + tunMode bool + cli string + daemon string + loopbackPort *int + + LogCatcher *LogCatcher + LogCatcherServer *httptest.Server + + Control *testcontrol.Server + ControlServer *httptest.Server + + TrafficTrap *trafficTrap + TrafficTrapServer *httptest.Server +} + +// ControlURL returns e.ControlServer.URL, panicking if it's the empty string, +// which it should never be in tests. +func (e *TestEnv) ControlURL() string { + s := e.ControlServer.URL + if s == "" { + panic("control server not set") + } + return s +} + +// TestEnvOpt represents an option that can be passed to NewTestEnv. +type TestEnvOpt interface { + ModifyTestEnv(*TestEnv) +} + +// ConfigureControl is a test option that configures the test control server. +type ConfigureControl func(*testcontrol.Server) + +func (f ConfigureControl) ModifyTestEnv(te *TestEnv) { + f(te.Control) +} + +// NewTestEnv starts a bunch of services and returns a new test environment. +// NewTestEnv arranges for the environment's resources to be cleaned up on exit. +func NewTestEnv(t testing.TB, opts ...TestEnvOpt) *TestEnv { + if runtime.GOOS == "windows" { + t.Skip("not tested/working on Windows yet") + } + derpMap := RunDERPAndSTUN(t, logger.Discard, "127.0.0.1") + logc := new(LogCatcher) + control := &testcontrol.Server{ + Logf: logger.WithPrefix(t.Logf, "testcontrol: "), + DERPMap: derpMap, + } + control.HTTPTestServer = httptest.NewUnstartedServer(control) + trafficTrap := new(trafficTrap) + binaries := GetBinaries(t) + e := &TestEnv{ + t: t, + cli: binaries.Tailscale.Path, + daemon: binaries.Tailscaled.Path, + LogCatcher: logc, + LogCatcherServer: httptest.NewServer(logc), + Control: control, + ControlServer: control.HTTPTestServer, + TrafficTrap: trafficTrap, + TrafficTrapServer: httptest.NewServer(trafficTrap), + } + for _, o := range opts { + o.ModifyTestEnv(e) + } + control.HTTPTestServer.Start() + t.Cleanup(func() { + // Shut down e. + if err := e.TrafficTrap.Err(); err != nil { + e.t.Errorf("traffic trap: %v", err) + e.t.Logf("logs: %s", e.LogCatcher.logsString()) + } + e.LogCatcherServer.Close() + e.TrafficTrapServer.Close() + e.ControlServer.Close() + }) + t.Logf("control URL: %v", e.ControlURL()) + return e +} + +// TestNode is a machine with a tailscale & tailscaled. +// Currently, the test is simplistic and user==node==machine. +// That may grow complexity later to test more. +type TestNode struct { + env *TestEnv + tailscaledParser *nodeOutputParser + + dir string // temp dir for sock & state + configFile string // or empty for none + sockFile string + stateFile string + upFlagGOOS string // if non-empty, sets TS_DEBUG_UP_FLAG_GOOS for cmd/tailscale CLI + + mu sync.Mutex + onLogLine []func([]byte) + lc *local.Client +} + +// NewTestNode allocates a temp directory for a new test node. +// The node is not started automatically. +func NewTestNode(t *testing.T, env *TestEnv) *TestNode { + dir := t.TempDir() + sockFile := filepath.Join(dir, "tailscale.sock") + if len(sockFile) >= 104 { + // Maximum length for a unix socket on darwin. Try something else. + sockFile = filepath.Join(os.TempDir(), rands.HexString(8)+".sock") + t.Cleanup(func() { os.Remove(sockFile) }) + } + n := &TestNode{ + env: env, + dir: dir, + sockFile: sockFile, + stateFile: filepath.Join(dir, "tailscaled.state"), // matches what cmd/tailscaled uses + } + + // Look for a data race or panic. + // Once we see the start marker, start logging the rest. + var sawRace bool + var sawPanic bool + n.addLogLineHook(func(line []byte) { + lineB := mem.B(line) + if mem.Contains(lineB, mem.S("DEBUG-ADDR=")) { + t.Log(strings.TrimSpace(string(line))) + } + if mem.Contains(lineB, mem.S("WARNING: DATA RACE")) { + sawRace = true + } + if mem.HasPrefix(lineB, mem.S("panic: ")) { + sawPanic = true + } + if sawRace || sawPanic { + t.Logf("%s", line) + } + }) + + return n +} + +func (n *TestNode) LocalClient() *local.Client { + n.mu.Lock() + defer n.mu.Unlock() + if n.lc == nil { + tr := &http.Transport{} + n.lc = &local.Client{ + Socket: n.sockFile, + UseSocketOnly: true, + } + n.env.t.Cleanup(tr.CloseIdleConnections) + } + return n.lc +} + +func (n *TestNode) diskPrefs() *ipn.Prefs { + t := n.env.t + t.Helper() + if _, err := os.ReadFile(n.stateFile); err != nil { + t.Fatalf("reading prefs: %v", err) + } + fs, err := store.NewFileStore(nil, n.stateFile) + if err != nil { + t.Fatalf("reading prefs, NewFileStore: %v", err) + } + p, err := ipnlocal.ReadStartupPrefsForTest(t.Logf, fs) + if err != nil { + t.Fatalf("reading prefs, ReadDiskPrefsForTest: %v", err) + } + return p.AsStruct() +} + +// AwaitResponding waits for n's tailscaled to be up enough to be +// responding, but doesn't wait for any particular state. +func (n *TestNode) AwaitResponding() { + t := n.env.t + t.Helper() + n.AwaitListening() + + st := n.MustStatus() + t.Logf("Status: %s", st.BackendState) + + if err := tstest.WaitFor(20*time.Second, func() error { + const sub = `Program starting: ` + if !n.env.LogCatcher.logsContains(mem.S(sub)) { + return fmt.Errorf("log catcher didn't see %#q; got %s", sub, n.env.LogCatcher.logsString()) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// addLogLineHook registers a hook f to be called on each tailscaled +// log line output. +func (n *TestNode) addLogLineHook(f func([]byte)) { + n.mu.Lock() + defer n.mu.Unlock() + n.onLogLine = append(n.onLogLine, f) +} + +// socks5AddrChan returns a channel that receives the address (e.g. "localhost:23874") +// of the node's SOCKS5 listener, once started. +func (n *TestNode) socks5AddrChan() <-chan string { + ch := make(chan string, 1) + n.addLogLineHook(func(line []byte) { + const sub = "SOCKS5 listening on " + i := mem.Index(mem.B(line), mem.S(sub)) + if i == -1 { + return + } + addr := strings.TrimSpace(string(line)[i+len(sub):]) + select { + case ch <- addr: + default: + } + }) + return ch +} + +func (n *TestNode) AwaitSocksAddr(ch <-chan string) string { + t := n.env.t + t.Helper() + timer := time.NewTimer(10 * time.Second) + defer timer.Stop() + select { + case v := <-ch: + return v + case <-timer.C: + t.Fatal("timeout waiting for node to log its SOCK5 listening address") + panic("unreachable") + } +} + +// nodeOutputParser parses stderr of tailscaled processes, calling the +// per-line callbacks previously registered via +// testNode.addLogLineHook. +type nodeOutputParser struct { + allBuf bytes.Buffer + pendLineBuf bytes.Buffer + n *TestNode +} + +func (op *nodeOutputParser) Write(p []byte) (n int, err error) { + tn := op.n + tn.mu.Lock() + defer tn.mu.Unlock() + + op.allBuf.Write(p) + n, err = op.pendLineBuf.Write(p) + op.parseLinesLocked() + return +} + +func (op *nodeOutputParser) parseLinesLocked() { + n := op.n + buf := op.pendLineBuf.Bytes() + for len(buf) > 0 { + nl := bytes.IndexByte(buf, '\n') + if nl == -1 { + break + } + line := buf[:nl+1] + buf = buf[nl+1:] + + for _, f := range n.onLogLine { + f(line) + } + } + if len(buf) == 0 { + op.pendLineBuf.Reset() + } else { + io.CopyN(io.Discard, &op.pendLineBuf, int64(op.pendLineBuf.Len()-len(buf))) + } +} + +type Daemon struct { + Process *os.Process +} + +func (d *Daemon) MustCleanShutdown(t testing.TB) { + d.Process.Signal(os.Interrupt) + ps, err := d.Process.Wait() + if err != nil { + t.Fatalf("tailscaled Wait: %v", err) + } + if ps.ExitCode() != 0 { + t.Errorf("tailscaled ExitCode = %d; want 0", ps.ExitCode()) + } +} + +// awaitTailscaledRunnable tries to run `tailscaled --version` until it +// works. This is an unsatisfying workaround for ETXTBSY we were seeing +// on GitHub Actions that aren't understood. It's not clear what's holding +// a writable fd to tailscaled after `go install` completes. +// See https://github.com/tailscale/tailscale/issues/15868. +func (n *TestNode) awaitTailscaledRunnable() error { + t := n.env.t + t.Helper() + if err := tstest.WaitFor(10*time.Second, func() error { + out, err := exec.Command(n.env.daemon, "--version").CombinedOutput() + if err == nil { + return nil + } + t.Logf("error running tailscaled --version: %v, %s", err, out) + return err + }); err != nil { + return fmt.Errorf("gave up trying to run tailscaled: %v", err) + } + return nil +} + +// StartDaemon starts the node's tailscaled, failing if it fails to start. +// StartDaemon ensures that the process will exit when the test completes. +func (n *TestNode) StartDaemon() *Daemon { + return n.StartDaemonAsIPNGOOS(runtime.GOOS) +} + +func (n *TestNode) StartDaemonAsIPNGOOS(ipnGOOS string) *Daemon { + t := n.env.t + + if err := n.awaitTailscaledRunnable(); err != nil { + t.Fatalf("awaitTailscaledRunnable: %v", err) + } + + cmd := exec.Command(n.env.daemon) + cmd.Args = append(cmd.Args, + "--statedir="+n.dir, + "--socket="+n.sockFile, + "--socks5-server=localhost:0", + "--debug=localhost:0", + ) + if *verboseTailscaled { + cmd.Args = append(cmd.Args, "-verbose=2") + } + if !n.env.tunMode { + cmd.Args = append(cmd.Args, + "--tun=userspace-networking", + ) + } + if n.configFile != "" { + cmd.Args = append(cmd.Args, "--config="+n.configFile) + } + cmd.Env = append(os.Environ(), + "TS_DEBUG_PERMIT_HTTP_C2N=1", + "TS_LOG_TARGET="+n.env.LogCatcherServer.URL, + "HTTP_PROXY="+n.env.TrafficTrapServer.URL, + "HTTPS_PROXY="+n.env.TrafficTrapServer.URL, + "TS_DEBUG_FAKE_GOOS="+ipnGOOS, + "TS_LOGS_DIR="+t.TempDir(), + "TS_NETCHECK_GENERATE_204_URL="+n.env.ControlServer.URL+"/generate_204", + "TS_ASSUME_NETWORK_UP_FOR_TEST=1", // don't pause control client in airplane mode (no wifi, etc) + "TS_PANIC_IF_HIT_MAIN_CONTROL=1", + "TS_DISABLE_PORTMAPPER=1", // shouldn't be needed; test is all localhost + "TS_DEBUG_LOG_RATE=all", + ) + if n.env.loopbackPort != nil { + cmd.Env = append(cmd.Env, "TS_DEBUG_NETSTACK_LOOPBACK_PORT="+strconv.Itoa(*n.env.loopbackPort)) + } + if version.IsRace() { + cmd.Env = append(cmd.Env, "GORACE=halt_on_error=1") + } + n.tailscaledParser = &nodeOutputParser{n: n} + cmd.Stderr = n.tailscaledParser + if *verboseTailscaled { + cmd.Stdout = os.Stdout + cmd.Stderr = io.MultiWriter(cmd.Stderr, os.Stderr) + } + if runtime.GOOS != "windows" { + pr, pw, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { pw.Close() }) + cmd.ExtraFiles = append(cmd.ExtraFiles, pr) + cmd.Env = append(cmd.Env, "TS_PARENT_DEATH_FD=3") + } + if err := cmd.Start(); err != nil { + t.Fatalf("starting tailscaled: %v", err) + } + t.Cleanup(func() { cmd.Process.Kill() }) + return &Daemon{ + Process: cmd.Process, + } +} + +func (n *TestNode) MustUp(extraArgs ...string) { + t := n.env.t + t.Helper() + args := []string{ + "up", + "--login-server=" + n.env.ControlURL(), + "--reset", + } + args = append(args, extraArgs...) + cmd := n.Tailscale(args...) + t.Logf("Running %v ...", cmd) + cmd.Stdout = nil // in case --verbose-tailscale was set + cmd.Stderr = nil // in case --verbose-tailscale was set + if b, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("up: %v, %v", string(b), err) + } +} + +func (n *TestNode) MustDown() { + t := n.env.t + t.Logf("Running down ...") + if err := n.Tailscale("down", "--accept-risk=all").Run(); err != nil { + t.Fatalf("down: %v", err) + } +} + +func (n *TestNode) MustLogOut() { + t := n.env.t + t.Logf("Running logout ...") + if err := n.Tailscale("logout").Run(); err != nil { + t.Fatalf("logout: %v", err) + } +} + +func (n *TestNode) Ping(otherNode *TestNode) error { + t := n.env.t + ip := otherNode.AwaitIP4().String() + t.Logf("Running ping %v (from %v)...", ip, n.AwaitIP4()) + return n.Tailscale("ping", ip).Run() +} + +// AwaitListening waits for the tailscaled to be serving local clients +// over its localhost IPC mechanism. (Unix socket, etc) +func (n *TestNode) AwaitListening() { + t := n.env.t + if err := tstest.WaitFor(20*time.Second, func() (err error) { + c, err := safesocket.ConnectContext(context.Background(), n.sockFile) + if err == nil { + c.Close() + } + return err + }); err != nil { + t.Fatal(err) + } +} + +func (n *TestNode) AwaitIPs() []netip.Addr { + t := n.env.t + t.Helper() + var addrs []netip.Addr + if err := tstest.WaitFor(20*time.Second, func() error { + cmd := n.Tailscale("ip") + cmd.Stdout = nil // in case --verbose-tailscale was set + cmd.Stderr = nil // in case --verbose-tailscale was set + out, err := cmd.Output() + if err != nil { + return err + } + ips := string(out) + ipslice := strings.Fields(ips) + addrs = make([]netip.Addr, len(ipslice)) + + for i, ip := range ipslice { + netIP, err := netip.ParseAddr(ip) + if err != nil { + t.Fatal(err) + } + addrs[i] = netIP + } + return nil + }); err != nil { + t.Fatalf("awaiting an IP address: %v", err) + } + if len(addrs) == 0 { + t.Fatalf("returned IP address was blank") + } + return addrs +} + +// AwaitIP4 returns the IPv4 address of n. +func (n *TestNode) AwaitIP4() netip.Addr { + t := n.env.t + t.Helper() + ips := n.AwaitIPs() + return ips[0] +} + +// AwaitIP6 returns the IPv6 address of n. +func (n *TestNode) AwaitIP6() netip.Addr { + t := n.env.t + t.Helper() + ips := n.AwaitIPs() + return ips[1] +} + +// AwaitRunning waits for n to reach the IPN state "Running". +func (n *TestNode) AwaitRunning() { + t := n.env.t + t.Helper() + n.AwaitBackendState("Running") +} + +func (n *TestNode) AwaitBackendState(state string) { + t := n.env.t + t.Helper() + if err := tstest.WaitFor(20*time.Second, func() error { + st, err := n.Status() + if err != nil { + return err + } + if st.BackendState != state { + return fmt.Errorf("in state %q; want %q", st.BackendState, state) + } + return nil + }); err != nil { + t.Fatalf("failure/timeout waiting for transition to Running status: %v", err) + } +} + +// AwaitNeedsLogin waits for n to reach the IPN state "NeedsLogin". +func (n *TestNode) AwaitNeedsLogin() { + t := n.env.t + t.Helper() + if err := tstest.WaitFor(20*time.Second, func() error { + st, err := n.Status() + if err != nil { + return err + } + if st.BackendState != "NeedsLogin" { + return fmt.Errorf("in state %q", st.BackendState) + } + return nil + }); err != nil { + t.Fatalf("failure/timeout waiting for transition to NeedsLogin status: %v", err) + } +} + +func (n *TestNode) TailscaleForOutput(arg ...string) *exec.Cmd { + cmd := n.Tailscale(arg...) + cmd.Stdout = nil + cmd.Stderr = nil + return cmd +} + +// Tailscale returns a command that runs the tailscale CLI with the provided arguments. +// It does not start the process. +func (n *TestNode) Tailscale(arg ...string) *exec.Cmd { + cmd := exec.Command(n.env.cli) + cmd.Args = append(cmd.Args, "--socket="+n.sockFile) + cmd.Args = append(cmd.Args, arg...) + cmd.Dir = n.dir + cmd.Env = append(os.Environ(), + "TS_DEBUG_UP_FLAG_GOOS="+n.upFlagGOOS, + "TS_LOGS_DIR="+n.env.t.TempDir(), + ) + if *verboseTailscale { + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + } + return cmd +} + +func (n *TestNode) Status() (*ipnstate.Status, error) { + cmd := n.Tailscale("status", "--json") + cmd.Stdout = nil // in case --verbose-tailscale was set + cmd.Stderr = nil // in case --verbose-tailscale was set + out, err := cmd.CombinedOutput() + if err != nil { + return nil, fmt.Errorf("running tailscale status: %v, %s", err, out) + } + st := new(ipnstate.Status) + if err := json.Unmarshal(out, st); err != nil { + return nil, fmt.Errorf("decoding tailscale status JSON: %w\njson:\n%s", err, out) + } + return st, nil +} + +func (n *TestNode) MustStatus() *ipnstate.Status { + tb := n.env.t + tb.Helper() + st, err := n.Status() + if err != nil { + tb.Fatal(err) + } + return st +} + +// trafficTrap is an HTTP proxy handler to note whether any +// HTTP traffic tries to leave localhost from tailscaled. We don't +// expect any, so any request triggers a failure. +type trafficTrap struct { + atomicErr syncs.AtomicValue[error] +} + +func (tt *trafficTrap) Err() error { + return tt.atomicErr.Load() +} + +func (tt *trafficTrap) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var got bytes.Buffer + r.Write(&got) + err := fmt.Errorf("unexpected HTTP request via proxy: %s", got.Bytes()) + MainError.Store(err) + if tt.Err() == nil { + // Best effort at remembering the first request. + tt.atomicErr.Store(err) + } + log.Printf("Error: %v", err) + w.WriteHeader(403) +} + +type authURLParserWriter struct { + buf bytes.Buffer + fn func(urlStr string) error +} + +var authURLRx = regexp.MustCompile(`(https?://\S+/auth/\S+)`) + +func (w *authURLParserWriter) Write(p []byte) (n int, err error) { + n, err = w.buf.Write(p) + m := authURLRx.FindSubmatch(w.buf.Bytes()) + if m != nil { + urlStr := string(m[1]) + w.buf.Reset() // so it's not matched again + if err := w.fn(urlStr); err != nil { + return 0, err + } + } + return n, err +} diff --git a/tstest/integration/integration_test.go b/tstest/integration/integration_test.go index 70c5d68c336e0..90cc7e443b5d3 100644 --- a/tstest/integration/integration_test.go +++ b/tstest/integration/integration_test.go @@ -13,7 +13,6 @@ import ( "flag" "fmt" "io" - "log" "net" "net/http" "net/http/httptest" @@ -22,57 +21,38 @@ import ( "os/exec" "path/filepath" "regexp" - "runtime" "strconv" - "strings" - "sync" "sync/atomic" "testing" "time" "github.com/miekg/dns" "go4.org/mem" + "tailscale.com/client/local" "tailscale.com/client/tailscale" "tailscale.com/clientupdate" "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/ipn" - "tailscale.com/ipn/ipnlocal" - "tailscale.com/ipn/ipnstate" - "tailscale.com/ipn/store" "tailscale.com/net/tsaddr" "tailscale.com/net/tstun" - "tailscale.com/safesocket" - "tailscale.com/syncs" "tailscale.com/tailcfg" "tailscale.com/tstest" "tailscale.com/tstest/integration/testcontrol" "tailscale.com/types/key" - "tailscale.com/types/logger" "tailscale.com/types/opt" "tailscale.com/types/ptr" - "tailscale.com/util/dnsname" "tailscale.com/util/must" - "tailscale.com/util/rands" - "tailscale.com/version" ) -var ( - verboseTailscaled = flag.Bool("verbose-tailscaled", false, "verbose tailscaled logging") - verboseTailscale = flag.Bool("verbose-tailscale", false, "verbose tailscale CLI logging") -) - -var mainError syncs.AtomicValue[error] - func TestMain(m *testing.M) { // Have to disable UPnP which hits the network, otherwise it fails due to HTTP proxy. os.Setenv("TS_DISABLE_UPNP", "true") flag.Parse() v := m.Run() - CleanupBinaries() if v != 0 { os.Exit(v) } - if err := mainError.Load(); err != nil { + if err := MainError.Load(); err != nil { fmt.Fprintf(os.Stderr, "FAIL: %v\n", err) os.Exit(1) } @@ -87,9 +67,9 @@ func TestTUNMode(t *testing.T) { t.Skip("skipping when not root") } tstest.Parallel(t) - env := newTestEnv(t) + env := NewTestEnv(t) env.tunMode = true - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitResponding() @@ -104,8 +84,8 @@ func TestTUNMode(t *testing.T) { func TestOneNodeUpNoAuth(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - n1 := newTestNode(t, env) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitResponding() @@ -122,8 +102,8 @@ func TestOneNodeUpNoAuth(t *testing.T) { func TestOneNodeExpiredKey(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - n1 := newTestNode(t, env) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitResponding() @@ -159,8 +139,8 @@ func TestOneNodeExpiredKey(t *testing.T) { func TestControlKnobs(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - n1 := newTestNode(t, env) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() defer d1.MustCleanShutdown(t) @@ -190,8 +170,8 @@ func TestControlKnobs(t *testing.T) { func TestCollectPanic(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - n := newTestNode(t, env) + env := NewTestEnv(t) + n := NewTestNode(t, env) cmd := exec.Command(env.daemon, "--cleanup") cmd.Env = append(os.Environ(), @@ -221,9 +201,9 @@ func TestCollectPanic(t *testing.T) { func TestControlTimeLogLine(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) + env := NewTestEnv(t) env.LogCatcher.StoreRawJSON() - n := newTestNode(t, env) + n := NewTestNode(t, env) n.StartDaemon() n.AwaitResponding() @@ -245,8 +225,8 @@ func TestControlTimeLogLine(t *testing.T) { func TestStateSavedOnStart(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - n1 := newTestNode(t, env) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitResponding() @@ -264,7 +244,7 @@ func TestStateSavedOnStart(t *testing.T) { n1.MustDown() // And change the hostname to something: - if err := n1.Tailscale("up", "--login-server="+n1.env.controlURL(), "--hostname=foo").Run(); err != nil { + if err := n1.Tailscale("up", "--login-server="+n1.env.ControlURL(), "--hostname=foo").Run(); err != nil { t.Fatalf("up: %v", err) } @@ -282,11 +262,11 @@ func TestStateSavedOnStart(t *testing.T) { func TestOneNodeUpAuth(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t, configureControl(func(control *testcontrol.Server) { + env := NewTestEnv(t, ConfigureControl(func(control *testcontrol.Server) { control.RequireAuth = true })) - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitListening() @@ -294,18 +274,23 @@ func TestOneNodeUpAuth(t *testing.T) { st := n1.MustStatus() t.Logf("Status: %s", st.BackendState) - t.Logf("Running up --login-server=%s ...", env.controlURL()) + t.Logf("Running up --login-server=%s ...", env.ControlURL()) - cmd := n1.Tailscale("up", "--login-server="+env.controlURL()) - var authCountAtomic int32 + cmd := n1.Tailscale("up", "--login-server="+env.ControlURL()) + var authCountAtomic atomic.Int32 cmd.Stdout = &authURLParserWriter{fn: func(urlStr string) error { + t.Logf("saw auth URL %q", urlStr) if env.Control.CompleteAuth(urlStr) { - atomic.AddInt32(&authCountAtomic, 1) + if authCountAtomic.Add(1) > 1 { + err := errors.New("completed multple auth URLs") + t.Error(err) + return err + } t.Logf("completed auth path %s", urlStr) return nil } err := fmt.Errorf("Failed to complete auth path to %q", urlStr) - t.Log(err) + t.Error(err) return err }} cmd.Stderr = cmd.Stdout @@ -316,7 +301,7 @@ func TestOneNodeUpAuth(t *testing.T) { n1.AwaitRunning() - if n := atomic.LoadInt32(&authCountAtomic); n != 1 { + if n := authCountAtomic.Load(); n != 1 { t.Errorf("Auth URLs completed = %d; want 1", n) } @@ -328,11 +313,11 @@ func TestConfigFileAuthKey(t *testing.T) { tstest.Shard(t) t.Parallel() const authKey = "opensesame" - env := newTestEnv(t, configureControl(func(control *testcontrol.Server) { + env := NewTestEnv(t, ConfigureControl(func(control *testcontrol.Server) { control.RequireAuthKey = authKey })) - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) n1.configFile = filepath.Join(n1.dir, "config.json") authKeyFile := filepath.Join(n1.dir, "my-auth-key") must.Do(os.WriteFile(authKeyFile, fmt.Appendf(nil, "%s\n", authKey), 0666)) @@ -353,14 +338,14 @@ func TestConfigFileAuthKey(t *testing.T) { func TestTwoNodes(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) + env := NewTestEnv(t) // Create two nodes: - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) n1SocksAddrCh := n1.socks5AddrChan() d1 := n1.StartDaemon() - n2 := newTestNode(t, env) + n2 := NewTestNode(t, env) n2SocksAddrCh := n2.socks5AddrChan() d2 := n2.StartDaemon() @@ -379,7 +364,7 @@ func TestTwoNodes(t *testing.T) { defer n2.mu.Unlock() rxNoDates := regexp.MustCompile(`(?m)^\d{4}.\d{2}.\d{2}.\d{2}:\d{2}:\d{2}`) - cleanLog := func(n *testNode) []byte { + cleanLog := func(n *TestNode) []byte { b := n.tailscaledParser.allBuf.Bytes() b = rxNoDates.ReplaceAll(b, nil) return b @@ -439,10 +424,10 @@ func TestTwoNodes(t *testing.T) { func TestIncrementalMapUpdatePeersRemoved(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) + env := NewTestEnv(t) // Create one node: - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitListening() n1.MustUp() @@ -454,7 +439,7 @@ func TestIncrementalMapUpdatePeersRemoved(t *testing.T) { } tnode1 := all[0] - n2 := newTestNode(t, env) + n2 := NewTestNode(t, env) d2 := n2.StartDaemon() n2.AwaitListening() n2.MustUp() @@ -524,8 +509,8 @@ func TestNodeAddressIPFields(t *testing.T) { tstest.Shard(t) flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/7008") tstest.Parallel(t) - env := newTestEnv(t) - n1 := newTestNode(t, env) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitListening() @@ -551,8 +536,8 @@ func TestNodeAddressIPFields(t *testing.T) { func TestAddPingRequest(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - n1 := newTestNode(t, env) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) n1.StartDaemon() n1.AwaitListening() @@ -605,7 +590,7 @@ func TestC2NPingRequest(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) + env := NewTestEnv(t) gotPing := make(chan bool, 1) env.Control.HandleC2N = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -623,7 +608,7 @@ func TestC2NPingRequest(t *testing.T) { gotPing <- true }) - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) n1.StartDaemon() n1.AwaitListening() @@ -676,8 +661,8 @@ func TestC2NPingRequest(t *testing.T) { func TestNoControlConnWhenDown(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - n1 := newTestNode(t, env) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitResponding() @@ -715,8 +700,8 @@ func TestNoControlConnWhenDown(t *testing.T) { func TestOneNodeUpWindowsStyle(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - n1 := newTestNode(t, env) + env := NewTestEnv(t) + n1 := NewTestNode(t, env) n1.upFlagGOOS = "windows" d1 := n1.StartDaemonAsIPNGOOS("windows") @@ -735,9 +720,9 @@ func TestOneNodeUpWindowsStyle(t *testing.T) { func TestClientSideJailing(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) - registerNode := func() (*testNode, key.NodePublic) { - n := newTestNode(t, env) + env := NewTestEnv(t) + registerNode := func() (*TestNode, key.NodePublic) { + n := NewTestNode(t, env) n.StartDaemon() n.AwaitListening() n.MustUp() @@ -755,11 +740,11 @@ func TestClientSideJailing(t *testing.T) { defer ln.Close() port := uint16(ln.Addr().(*net.TCPAddr).Port) - lc1 := &tailscale.LocalClient{ + lc1 := &local.Client{ Socket: n1.sockFile, UseSocketOnly: true, } - lc2 := &tailscale.LocalClient{ + lc2 := &local.Client{ Socket: n2.sockFile, UseSocketOnly: true, } @@ -789,7 +774,7 @@ func TestClientSideJailing(t *testing.T) { }, } - testDial := func(t *testing.T, lc *tailscale.LocalClient, ip netip.Addr, port uint16, shouldFail bool) { + testDial := func(t *testing.T, lc *local.Client, ip netip.Addr, port uint16, shouldFail bool) { t.Helper() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -851,9 +836,9 @@ func TestNATPing(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) for _, v6 := range []bool{false, true} { - env := newTestEnv(t) - registerNode := func() (*testNode, key.NodePublic) { - n := newTestNode(t, env) + env := NewTestEnv(t) + registerNode := func() (*TestNode, key.NodePublic) { + n := NewTestNode(t, env) n.StartDaemon() n.AwaitListening() n.MustUp() @@ -978,11 +963,11 @@ func TestNATPing(t *testing.T) { func TestLogoutRemovesAllPeers(t *testing.T) { tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) + env := NewTestEnv(t) // Spin up some nodes. - nodes := make([]*testNode, 2) + nodes := make([]*TestNode, 2) for i := range nodes { - nodes[i] = newTestNode(t, env) + nodes[i] = NewTestNode(t, env) nodes[i].StartDaemon() nodes[i].AwaitResponding() nodes[i].MustUp() @@ -1036,9 +1021,9 @@ func TestAutoUpdateDefaults(t *testing.T) { } tstest.Shard(t) tstest.Parallel(t) - env := newTestEnv(t) + env := NewTestEnv(t) - checkDefault := func(n *testNode, want bool) error { + checkDefault := func(n *TestNode, want bool) error { enabled, ok := n.diskPrefs().AutoUpdate.Apply.Get() if !ok { return fmt.Errorf("auto-update for node is unset, should be set as %v", want) @@ -1049,7 +1034,7 @@ func TestAutoUpdateDefaults(t *testing.T) { return nil } - sendAndCheckDefault := func(t *testing.T, n *testNode, send, want bool) { + sendAndCheckDefault := func(t *testing.T, n *TestNode, send, want bool) { t.Helper() if !env.Control.AddRawMapResponse(n.MustStatus().Self.PublicKey, &tailcfg.MapResponse{ DefaultAutoUpdate: opt.NewBool(send), @@ -1065,11 +1050,11 @@ func TestAutoUpdateDefaults(t *testing.T) { tests := []struct { desc string - run func(t *testing.T, n *testNode) + run func(t *testing.T, n *TestNode) }{ { desc: "tailnet-default-false", - run: func(t *testing.T, n *testNode) { + run: func(t *testing.T, n *TestNode) { // First received default "false". sendAndCheckDefault(t, n, false, false) // Should not be changed even if sent "true" later. @@ -1083,7 +1068,7 @@ func TestAutoUpdateDefaults(t *testing.T) { }, { desc: "tailnet-default-true", - run: func(t *testing.T, n *testNode) { + run: func(t *testing.T, n *TestNode) { // First received default "true". sendAndCheckDefault(t, n, true, true) // Should not be changed even if sent "false" later. @@ -1097,7 +1082,7 @@ func TestAutoUpdateDefaults(t *testing.T) { }, { desc: "user-sets-first", - run: func(t *testing.T, n *testNode) { + run: func(t *testing.T, n *TestNode) { // User sets auto-update first, before receiving defaults. if out, err := n.TailscaleForOutput("set", "--auto-update=false").CombinedOutput(); err != nil { t.Fatalf("failed to disable auto-update on node: %v\noutput: %s", err, out) @@ -1110,7 +1095,7 @@ func TestAutoUpdateDefaults(t *testing.T) { } for _, tt := range tests { t.Run(tt.desc, func(t *testing.T) { - n := newTestNode(t, env) + n := NewTestNode(t, env) d := n.StartDaemon() defer d.MustCleanShutdown(t) @@ -1132,25 +1117,16 @@ func TestDNSOverTCPIntervalResolver(t *testing.T) { if os.Getuid() != 0 { t.Skip("skipping when not root") } - env := newTestEnv(t) + env := NewTestEnv(t) env.tunMode = true - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitResponding() n1.MustUp() - - wantIP4 := n1.AwaitIP4() n1.AwaitRunning() - status, err := n1.Status() - if err != nil { - t.Fatalf("failed to get node status: %v", err) - } - selfDNSName, err := dnsname.ToFQDN(status.Self.DNSName) - if err != nil { - t.Fatalf("error converting self dns name to fqdn: %v", err) - } + const dnsSymbolicFQDN = "magicdns.localhost-tailscale-daemon." cases := []struct { network string @@ -1166,9 +1142,9 @@ func TestDNSOverTCPIntervalResolver(t *testing.T) { }, } for _, c := range cases { - err = tstest.WaitFor(time.Second*5, func() error { + err := tstest.WaitFor(time.Second*5, func() error { m := new(dns.Msg) - m.SetQuestion(selfDNSName.WithTrailingDot(), dns.TypeA) + m.SetQuestion(dnsSymbolicFQDN, dns.TypeA) conn, err := net.DialTimeout(c.network, net.JoinHostPort(c.serviceAddr.String(), "53"), time.Second*1) if err != nil { return err @@ -1193,8 +1169,8 @@ func TestDNSOverTCPIntervalResolver(t *testing.T) { return fmt.Errorf("unexpected answer type: %s", resp.Answer[0]) } gotAddr = answer.A - if !bytes.Equal(gotAddr, wantIP4.AsSlice()) { - return fmt.Errorf("got (%s) != want (%s)", gotAddr, wantIP4) + if !bytes.Equal(gotAddr, tsaddr.TailscaleServiceIP().AsSlice()) { + return fmt.Errorf("got (%s) != want (%s)", gotAddr, tsaddr.TailscaleServiceIP()) } return nil }) @@ -1214,12 +1190,12 @@ func TestNetstackTCPLoopback(t *testing.T) { t.Skip("skipping when not root") } - env := newTestEnv(t) + env := NewTestEnv(t) env.tunMode = true loopbackPort := 5201 env.loopbackPort = &loopbackPort loopbackPortStr := strconv.Itoa(loopbackPort) - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitResponding() @@ -1356,11 +1332,11 @@ func TestNetstackUDPLoopback(t *testing.T) { t.Skip("skipping when not root") } - env := newTestEnv(t) + env := NewTestEnv(t) env.tunMode = true loopbackPort := 5201 env.loopbackPort = &loopbackPort - n1 := newTestNode(t, env) + n1 := NewTestNode(t, env) d1 := n1.StartDaemon() n1.AwaitResponding() @@ -1494,581 +1470,3 @@ func TestNetstackUDPLoopback(t *testing.T) { d1.MustCleanShutdown(t) } - -// testEnv contains the test environment (set of servers) used by one -// or more nodes. -type testEnv struct { - t testing.TB - tunMode bool - cli string - daemon string - loopbackPort *int - - LogCatcher *LogCatcher - LogCatcherServer *httptest.Server - - Control *testcontrol.Server - ControlServer *httptest.Server - - TrafficTrap *trafficTrap - TrafficTrapServer *httptest.Server -} - -// controlURL returns e.ControlServer.URL, panicking if it's the empty string, -// which it should never be in tests. -func (e *testEnv) controlURL() string { - s := e.ControlServer.URL - if s == "" { - panic("control server not set") - } - return s -} - -type testEnvOpt interface { - modifyTestEnv(*testEnv) -} - -type configureControl func(*testcontrol.Server) - -func (f configureControl) modifyTestEnv(te *testEnv) { - f(te.Control) -} - -// newTestEnv starts a bunch of services and returns a new test environment. -// newTestEnv arranges for the environment's resources to be cleaned up on exit. -func newTestEnv(t testing.TB, opts ...testEnvOpt) *testEnv { - if runtime.GOOS == "windows" { - t.Skip("not tested/working on Windows yet") - } - derpMap := RunDERPAndSTUN(t, logger.Discard, "127.0.0.1") - logc := new(LogCatcher) - control := &testcontrol.Server{ - DERPMap: derpMap, - } - control.HTTPTestServer = httptest.NewUnstartedServer(control) - trafficTrap := new(trafficTrap) - e := &testEnv{ - t: t, - cli: TailscaleBinary(t), - daemon: TailscaledBinary(t), - LogCatcher: logc, - LogCatcherServer: httptest.NewServer(logc), - Control: control, - ControlServer: control.HTTPTestServer, - TrafficTrap: trafficTrap, - TrafficTrapServer: httptest.NewServer(trafficTrap), - } - for _, o := range opts { - o.modifyTestEnv(e) - } - control.HTTPTestServer.Start() - t.Cleanup(func() { - // Shut down e. - if err := e.TrafficTrap.Err(); err != nil { - e.t.Errorf("traffic trap: %v", err) - e.t.Logf("logs: %s", e.LogCatcher.logsString()) - } - e.LogCatcherServer.Close() - e.TrafficTrapServer.Close() - e.ControlServer.Close() - }) - t.Logf("control URL: %v", e.controlURL()) - return e -} - -// testNode is a machine with a tailscale & tailscaled. -// Currently, the test is simplistic and user==node==machine. -// That may grow complexity later to test more. -type testNode struct { - env *testEnv - tailscaledParser *nodeOutputParser - - dir string // temp dir for sock & state - configFile string // or empty for none - sockFile string - stateFile string - upFlagGOOS string // if non-empty, sets TS_DEBUG_UP_FLAG_GOOS for cmd/tailscale CLI - - mu sync.Mutex - onLogLine []func([]byte) -} - -// newTestNode allocates a temp directory for a new test node. -// The node is not started automatically. -func newTestNode(t *testing.T, env *testEnv) *testNode { - dir := t.TempDir() - sockFile := filepath.Join(dir, "tailscale.sock") - if len(sockFile) >= 104 { - // Maximum length for a unix socket on darwin. Try something else. - sockFile = filepath.Join(os.TempDir(), rands.HexString(8)+".sock") - t.Cleanup(func() { os.Remove(sockFile) }) - } - n := &testNode{ - env: env, - dir: dir, - sockFile: sockFile, - stateFile: filepath.Join(dir, "tailscale.state"), - } - - // Look for a data race. Once we see the start marker, start logging the rest. - var sawRace bool - var sawPanic bool - n.addLogLineHook(func(line []byte) { - lineB := mem.B(line) - if mem.Contains(lineB, mem.S("WARNING: DATA RACE")) { - sawRace = true - } - if mem.HasPrefix(lineB, mem.S("panic: ")) { - sawPanic = true - } - if sawRace || sawPanic { - t.Logf("%s", line) - } - }) - - return n -} - -func (n *testNode) diskPrefs() *ipn.Prefs { - t := n.env.t - t.Helper() - if _, err := os.ReadFile(n.stateFile); err != nil { - t.Fatalf("reading prefs: %v", err) - } - fs, err := store.NewFileStore(nil, n.stateFile) - if err != nil { - t.Fatalf("reading prefs, NewFileStore: %v", err) - } - p, err := ipnlocal.ReadStartupPrefsForTest(t.Logf, fs) - if err != nil { - t.Fatalf("reading prefs, ReadDiskPrefsForTest: %v", err) - } - return p.AsStruct() -} - -// AwaitResponding waits for n's tailscaled to be up enough to be -// responding, but doesn't wait for any particular state. -func (n *testNode) AwaitResponding() { - t := n.env.t - t.Helper() - n.AwaitListening() - - st := n.MustStatus() - t.Logf("Status: %s", st.BackendState) - - if err := tstest.WaitFor(20*time.Second, func() error { - const sub = `Program starting: ` - if !n.env.LogCatcher.logsContains(mem.S(sub)) { - return fmt.Errorf("log catcher didn't see %#q; got %s", sub, n.env.LogCatcher.logsString()) - } - return nil - }); err != nil { - t.Fatal(err) - } -} - -// addLogLineHook registers a hook f to be called on each tailscaled -// log line output. -func (n *testNode) addLogLineHook(f func([]byte)) { - n.mu.Lock() - defer n.mu.Unlock() - n.onLogLine = append(n.onLogLine, f) -} - -// socks5AddrChan returns a channel that receives the address (e.g. "localhost:23874") -// of the node's SOCKS5 listener, once started. -func (n *testNode) socks5AddrChan() <-chan string { - ch := make(chan string, 1) - n.addLogLineHook(func(line []byte) { - const sub = "SOCKS5 listening on " - i := mem.Index(mem.B(line), mem.S(sub)) - if i == -1 { - return - } - addr := strings.TrimSpace(string(line)[i+len(sub):]) - select { - case ch <- addr: - default: - } - }) - return ch -} - -func (n *testNode) AwaitSocksAddr(ch <-chan string) string { - t := n.env.t - t.Helper() - timer := time.NewTimer(10 * time.Second) - defer timer.Stop() - select { - case v := <-ch: - return v - case <-timer.C: - t.Fatal("timeout waiting for node to log its SOCK5 listening address") - panic("unreachable") - } -} - -// nodeOutputParser parses stderr of tailscaled processes, calling the -// per-line callbacks previously registered via -// testNode.addLogLineHook. -type nodeOutputParser struct { - allBuf bytes.Buffer - pendLineBuf bytes.Buffer - n *testNode -} - -func (op *nodeOutputParser) Write(p []byte) (n int, err error) { - tn := op.n - tn.mu.Lock() - defer tn.mu.Unlock() - - op.allBuf.Write(p) - n, err = op.pendLineBuf.Write(p) - op.parseLinesLocked() - return -} - -func (op *nodeOutputParser) parseLinesLocked() { - n := op.n - buf := op.pendLineBuf.Bytes() - for len(buf) > 0 { - nl := bytes.IndexByte(buf, '\n') - if nl == -1 { - break - } - line := buf[:nl+1] - buf = buf[nl+1:] - - for _, f := range n.onLogLine { - f(line) - } - } - if len(buf) == 0 { - op.pendLineBuf.Reset() - } else { - io.CopyN(io.Discard, &op.pendLineBuf, int64(op.pendLineBuf.Len()-len(buf))) - } -} - -type Daemon struct { - Process *os.Process -} - -func (d *Daemon) MustCleanShutdown(t testing.TB) { - d.Process.Signal(os.Interrupt) - ps, err := d.Process.Wait() - if err != nil { - t.Fatalf("tailscaled Wait: %v", err) - } - if ps.ExitCode() != 0 { - t.Errorf("tailscaled ExitCode = %d; want 0", ps.ExitCode()) - } -} - -// StartDaemon starts the node's tailscaled, failing if it fails to start. -// StartDaemon ensures that the process will exit when the test completes. -func (n *testNode) StartDaemon() *Daemon { - return n.StartDaemonAsIPNGOOS(runtime.GOOS) -} - -func (n *testNode) StartDaemonAsIPNGOOS(ipnGOOS string) *Daemon { - t := n.env.t - cmd := exec.Command(n.env.daemon) - cmd.Args = append(cmd.Args, - "--state="+n.stateFile, - "--socket="+n.sockFile, - "--socks5-server=localhost:0", - ) - if *verboseTailscaled { - cmd.Args = append(cmd.Args, "-verbose=2") - } - if !n.env.tunMode { - cmd.Args = append(cmd.Args, - "--tun=userspace-networking", - ) - } - if n.configFile != "" { - cmd.Args = append(cmd.Args, "--config="+n.configFile) - } - cmd.Env = append(os.Environ(), - "TS_CONTROL_IS_PLAINTEXT_HTTP=1", - "TS_DEBUG_PERMIT_HTTP_C2N=1", - "TS_LOG_TARGET="+n.env.LogCatcherServer.URL, - "HTTP_PROXY="+n.env.TrafficTrapServer.URL, - "HTTPS_PROXY="+n.env.TrafficTrapServer.URL, - "TS_DEBUG_FAKE_GOOS="+ipnGOOS, - "TS_LOGS_DIR="+t.TempDir(), - "TS_NETCHECK_GENERATE_204_URL="+n.env.ControlServer.URL+"/generate_204", - "TS_ASSUME_NETWORK_UP_FOR_TEST=1", // don't pause control client in airplane mode (no wifi, etc) - "TS_PANIC_IF_HIT_MAIN_CONTROL=1", - "TS_DISABLE_PORTMAPPER=1", // shouldn't be needed; test is all localhost - "TS_DEBUG_LOG_RATE=all", - ) - if n.env.loopbackPort != nil { - cmd.Env = append(cmd.Env, "TS_DEBUG_NETSTACK_LOOPBACK_PORT="+strconv.Itoa(*n.env.loopbackPort)) - } - if version.IsRace() { - cmd.Env = append(cmd.Env, "GORACE=halt_on_error=1") - } - n.tailscaledParser = &nodeOutputParser{n: n} - cmd.Stderr = n.tailscaledParser - if *verboseTailscaled { - cmd.Stdout = os.Stdout - cmd.Stderr = io.MultiWriter(cmd.Stderr, os.Stderr) - } - if runtime.GOOS != "windows" { - pr, pw, err := os.Pipe() - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { pw.Close() }) - cmd.ExtraFiles = append(cmd.ExtraFiles, pr) - cmd.Env = append(cmd.Env, "TS_PARENT_DEATH_FD=3") - } - if err := cmd.Start(); err != nil { - t.Fatalf("starting tailscaled: %v", err) - } - t.Cleanup(func() { cmd.Process.Kill() }) - return &Daemon{ - Process: cmd.Process, - } -} - -func (n *testNode) MustUp(extraArgs ...string) { - t := n.env.t - t.Helper() - args := []string{ - "up", - "--login-server=" + n.env.controlURL(), - "--reset", - } - args = append(args, extraArgs...) - cmd := n.Tailscale(args...) - t.Logf("Running %v ...", cmd) - cmd.Stdout = nil // in case --verbose-tailscale was set - cmd.Stderr = nil // in case --verbose-tailscale was set - if b, err := cmd.CombinedOutput(); err != nil { - t.Fatalf("up: %v, %v", string(b), err) - } -} - -func (n *testNode) MustDown() { - t := n.env.t - t.Logf("Running down ...") - if err := n.Tailscale("down", "--accept-risk=all").Run(); err != nil { - t.Fatalf("down: %v", err) - } -} - -func (n *testNode) MustLogOut() { - t := n.env.t - t.Logf("Running logout ...") - if err := n.Tailscale("logout").Run(); err != nil { - t.Fatalf("logout: %v", err) - } -} - -func (n *testNode) Ping(otherNode *testNode) error { - t := n.env.t - ip := otherNode.AwaitIP4().String() - t.Logf("Running ping %v (from %v)...", ip, n.AwaitIP4()) - return n.Tailscale("ping", ip).Run() -} - -// AwaitListening waits for the tailscaled to be serving local clients -// over its localhost IPC mechanism. (Unix socket, etc) -func (n *testNode) AwaitListening() { - t := n.env.t - if err := tstest.WaitFor(20*time.Second, func() (err error) { - c, err := safesocket.ConnectContext(context.Background(), n.sockFile) - if err == nil { - c.Close() - } - return err - }); err != nil { - t.Fatal(err) - } -} - -func (n *testNode) AwaitIPs() []netip.Addr { - t := n.env.t - t.Helper() - var addrs []netip.Addr - if err := tstest.WaitFor(20*time.Second, func() error { - cmd := n.Tailscale("ip") - cmd.Stdout = nil // in case --verbose-tailscale was set - cmd.Stderr = nil // in case --verbose-tailscale was set - out, err := cmd.Output() - if err != nil { - return err - } - ips := string(out) - ipslice := strings.Fields(ips) - addrs = make([]netip.Addr, len(ipslice)) - - for i, ip := range ipslice { - netIP, err := netip.ParseAddr(ip) - if err != nil { - t.Fatal(err) - } - addrs[i] = netIP - } - return nil - }); err != nil { - t.Fatalf("awaiting an IP address: %v", err) - } - if len(addrs) == 0 { - t.Fatalf("returned IP address was blank") - } - return addrs -} - -// AwaitIP4 returns the IPv4 address of n. -func (n *testNode) AwaitIP4() netip.Addr { - t := n.env.t - t.Helper() - ips := n.AwaitIPs() - return ips[0] -} - -// AwaitIP6 returns the IPv6 address of n. -func (n *testNode) AwaitIP6() netip.Addr { - t := n.env.t - t.Helper() - ips := n.AwaitIPs() - return ips[1] -} - -// AwaitRunning waits for n to reach the IPN state "Running". -func (n *testNode) AwaitRunning() { - n.AwaitBackendState("Running") -} - -func (n *testNode) AwaitBackendState(state string) { - t := n.env.t - t.Helper() - if err := tstest.WaitFor(20*time.Second, func() error { - st, err := n.Status() - if err != nil { - return err - } - if st.BackendState != state { - return fmt.Errorf("in state %q; want %q", st.BackendState, state) - } - return nil - }); err != nil { - t.Fatalf("failure/timeout waiting for transition to Running status: %v", err) - } -} - -// AwaitNeedsLogin waits for n to reach the IPN state "NeedsLogin". -func (n *testNode) AwaitNeedsLogin() { - t := n.env.t - t.Helper() - if err := tstest.WaitFor(20*time.Second, func() error { - st, err := n.Status() - if err != nil { - return err - } - if st.BackendState != "NeedsLogin" { - return fmt.Errorf("in state %q", st.BackendState) - } - return nil - }); err != nil { - t.Fatalf("failure/timeout waiting for transition to NeedsLogin status: %v", err) - } -} - -func (n *testNode) TailscaleForOutput(arg ...string) *exec.Cmd { - cmd := n.Tailscale(arg...) - cmd.Stdout = nil - cmd.Stderr = nil - return cmd -} - -// Tailscale returns a command that runs the tailscale CLI with the provided arguments. -// It does not start the process. -func (n *testNode) Tailscale(arg ...string) *exec.Cmd { - cmd := exec.Command(n.env.cli) - cmd.Args = append(cmd.Args, "--socket="+n.sockFile) - cmd.Args = append(cmd.Args, arg...) - cmd.Dir = n.dir - cmd.Env = append(os.Environ(), - "TS_DEBUG_UP_FLAG_GOOS="+n.upFlagGOOS, - "TS_LOGS_DIR="+n.env.t.TempDir(), - ) - if *verboseTailscale { - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - } - return cmd -} - -func (n *testNode) Status() (*ipnstate.Status, error) { - cmd := n.Tailscale("status", "--json") - cmd.Stdout = nil // in case --verbose-tailscale was set - cmd.Stderr = nil // in case --verbose-tailscale was set - out, err := cmd.CombinedOutput() - if err != nil { - return nil, fmt.Errorf("running tailscale status: %v, %s", err, out) - } - st := new(ipnstate.Status) - if err := json.Unmarshal(out, st); err != nil { - return nil, fmt.Errorf("decoding tailscale status JSON: %w", err) - } - return st, nil -} - -func (n *testNode) MustStatus() *ipnstate.Status { - tb := n.env.t - tb.Helper() - st, err := n.Status() - if err != nil { - tb.Fatal(err) - } - return st -} - -// trafficTrap is an HTTP proxy handler to note whether any -// HTTP traffic tries to leave localhost from tailscaled. We don't -// expect any, so any request triggers a failure. -type trafficTrap struct { - atomicErr syncs.AtomicValue[error] -} - -func (tt *trafficTrap) Err() error { - return tt.atomicErr.Load() -} - -func (tt *trafficTrap) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var got bytes.Buffer - r.Write(&got) - err := fmt.Errorf("unexpected HTTP request via proxy: %s", got.Bytes()) - mainError.Store(err) - if tt.Err() == nil { - // Best effort at remembering the first request. - tt.atomicErr.Store(err) - } - log.Printf("Error: %v", err) - w.WriteHeader(403) -} - -type authURLParserWriter struct { - buf bytes.Buffer - fn func(urlStr string) error -} - -var authURLRx = regexp.MustCompile(`(https?://\S+/auth/\S+)`) - -func (w *authURLParserWriter) Write(p []byte) (n int, err error) { - n, err = w.buf.Write(p) - m := authURLRx.FindSubmatch(w.buf.Bytes()) - if m != nil { - urlStr := string(m[1]) - w.buf.Reset() // so it's not matched again - if err := w.fn(urlStr); err != nil { - return 0, err - } - } - return n, err -} diff --git a/tstest/integration/nat/nat_test.go b/tstest/integration/nat/nat_test.go index 5355155882cd7..15f1269858947 100644 --- a/tstest/integration/nat/nat_test.go +++ b/tstest/integration/nat/nat_test.go @@ -32,6 +32,7 @@ import ( ) var ( + runVMTests = flag.Bool("run-vm-tests", false, "run tests that require a VM") logTailscaled = flag.Bool("log-tailscaled", false, "log tailscaled output") pcapFile = flag.String("pcap", "", "write pcap to file") ) @@ -59,8 +60,25 @@ func newNatTest(tb testing.TB) *natTest { base: filepath.Join(modRoot, "gokrazy/natlabapp.qcow2"), } + if !*runVMTests { + tb.Skip("skipping heavy test; set --run-vm-tests to run") + } + if _, err := os.Stat(nt.base); err != nil { - tb.Skipf("skipping test; base image %q not found", nt.base) + if !os.IsNotExist(err) { + tb.Fatal(err) + } + tb.Logf("building VM image...") + cmd := exec.Command("make", "natlab") + cmd.Dir = filepath.Join(modRoot, "gokrazy") + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + if err := cmd.Run(); err != nil { + tb.Fatalf("Error running 'make natlab' in gokrazy directory") + } + if _, err := os.Stat(nt.base); err != nil { + tb.Skipf("still can't find VM image: %v", err) + } } nt.kernel, err = findKernelPath(filepath.Join(modRoot, "gokrazy/natlabapp/builddir/github.com/tailscale/gokrazy-kernel/go.mod")) @@ -218,6 +236,22 @@ func hard(c *vnet.Config) *vnet.Node { fmt.Sprintf("10.0.%d.1/24", n), vnet.HardNAT)) } +func hardNoDERPOrEndoints(c *vnet.Config) *vnet.Node { + n := c.NumNodes() + 1 + return c.AddNode(c.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("10.0.%d.1/24", n), vnet.HardNAT), + vnet.TailscaledEnv{ + Key: "TS_DEBUG_STRIP_ENDPOINTS", + Value: "1", + }, + vnet.TailscaledEnv{ + Key: "TS_DEBUG_STRIP_HOME_DERP", + Value: "1", + }, + ) +} + func hardPMP(c *vnet.Config) *vnet.Node { n := c.NumNodes() + 1 return c.AddNode(c.AddNetwork( @@ -492,6 +526,26 @@ func TestEasyEasy(t *testing.T) { nt.want(routeDirect) } +// Issue tailscale/corp#26438: use learned DERP route as send path of last +// resort +// +// See (*magicsock.Conn).fallbackDERPRegionForPeer and its comment for +// background. +// +// This sets up a test with two nodes that must use DERP to communicate but the +// target of the ping (the second node) additionally is not getting DERP or +// Endpoint updates from the control plane. (Or rather, it's getting them but is +// configured to scrub them right when they come off the network before being +// processed) This then tests whether node2, upon receiving a packet, will be +// able to reply to node1 since it knows neither node1's endpoints nor its home +// DERP. The only reply route it can use is that fact that it just received a +// packet over a particular DERP from that peer. +func TestFallbackDERPRegionForPeer(t *testing.T) { + nt := newNatTest(t) + nt.runTest(hard, hardNoDERPOrEndoints) + nt.want(routeDERP) +} + func TestSingleJustIPv6(t *testing.T) { nt := newNatTest(t) nt.runTest(just6) diff --git a/tstest/integration/tailscaled_deps_test_darwin.go b/tstest/integration/tailscaled_deps_test_darwin.go index 6676ee22cbd1c..321ba25668c1f 100644 --- a/tstest/integration/tailscaled_deps_test_darwin.go +++ b/tstest/integration/tailscaled_deps_test_darwin.go @@ -11,12 +11,13 @@ import ( // transitive deps when we run "go install tailscaled" in a child // process and can cache a prior success when a dependency changes. _ "tailscale.com/chirp" - _ "tailscale.com/client/tailscale" + _ "tailscale.com/client/local" _ "tailscale.com/cmd/tailscaled/childproc" _ "tailscale.com/control/controlclient" _ "tailscale.com/derp/derphttp" _ "tailscale.com/drive/driveimpl" _ "tailscale.com/envknob" + _ "tailscale.com/feature/condregister" _ "tailscale.com/health" _ "tailscale.com/hostinfo" _ "tailscale.com/ipn" @@ -47,6 +48,7 @@ import ( _ "tailscale.com/types/logger" _ "tailscale.com/types/logid" _ "tailscale.com/util/clientmetric" + _ "tailscale.com/util/eventbus" _ "tailscale.com/util/multierr" _ "tailscale.com/util/osshare" _ "tailscale.com/version" diff --git a/tstest/integration/tailscaled_deps_test_freebsd.go b/tstest/integration/tailscaled_deps_test_freebsd.go index 6676ee22cbd1c..321ba25668c1f 100644 --- a/tstest/integration/tailscaled_deps_test_freebsd.go +++ b/tstest/integration/tailscaled_deps_test_freebsd.go @@ -11,12 +11,13 @@ import ( // transitive deps when we run "go install tailscaled" in a child // process and can cache a prior success when a dependency changes. _ "tailscale.com/chirp" - _ "tailscale.com/client/tailscale" + _ "tailscale.com/client/local" _ "tailscale.com/cmd/tailscaled/childproc" _ "tailscale.com/control/controlclient" _ "tailscale.com/derp/derphttp" _ "tailscale.com/drive/driveimpl" _ "tailscale.com/envknob" + _ "tailscale.com/feature/condregister" _ "tailscale.com/health" _ "tailscale.com/hostinfo" _ "tailscale.com/ipn" @@ -47,6 +48,7 @@ import ( _ "tailscale.com/types/logger" _ "tailscale.com/types/logid" _ "tailscale.com/util/clientmetric" + _ "tailscale.com/util/eventbus" _ "tailscale.com/util/multierr" _ "tailscale.com/util/osshare" _ "tailscale.com/version" diff --git a/tstest/integration/tailscaled_deps_test_linux.go b/tstest/integration/tailscaled_deps_test_linux.go index 6676ee22cbd1c..321ba25668c1f 100644 --- a/tstest/integration/tailscaled_deps_test_linux.go +++ b/tstest/integration/tailscaled_deps_test_linux.go @@ -11,12 +11,13 @@ import ( // transitive deps when we run "go install tailscaled" in a child // process and can cache a prior success when a dependency changes. _ "tailscale.com/chirp" - _ "tailscale.com/client/tailscale" + _ "tailscale.com/client/local" _ "tailscale.com/cmd/tailscaled/childproc" _ "tailscale.com/control/controlclient" _ "tailscale.com/derp/derphttp" _ "tailscale.com/drive/driveimpl" _ "tailscale.com/envknob" + _ "tailscale.com/feature/condregister" _ "tailscale.com/health" _ "tailscale.com/hostinfo" _ "tailscale.com/ipn" @@ -47,6 +48,7 @@ import ( _ "tailscale.com/types/logger" _ "tailscale.com/types/logid" _ "tailscale.com/util/clientmetric" + _ "tailscale.com/util/eventbus" _ "tailscale.com/util/multierr" _ "tailscale.com/util/osshare" _ "tailscale.com/version" diff --git a/tstest/integration/tailscaled_deps_test_openbsd.go b/tstest/integration/tailscaled_deps_test_openbsd.go index 6676ee22cbd1c..321ba25668c1f 100644 --- a/tstest/integration/tailscaled_deps_test_openbsd.go +++ b/tstest/integration/tailscaled_deps_test_openbsd.go @@ -11,12 +11,13 @@ import ( // transitive deps when we run "go install tailscaled" in a child // process and can cache a prior success when a dependency changes. _ "tailscale.com/chirp" - _ "tailscale.com/client/tailscale" + _ "tailscale.com/client/local" _ "tailscale.com/cmd/tailscaled/childproc" _ "tailscale.com/control/controlclient" _ "tailscale.com/derp/derphttp" _ "tailscale.com/drive/driveimpl" _ "tailscale.com/envknob" + _ "tailscale.com/feature/condregister" _ "tailscale.com/health" _ "tailscale.com/hostinfo" _ "tailscale.com/ipn" @@ -47,6 +48,7 @@ import ( _ "tailscale.com/types/logger" _ "tailscale.com/types/logid" _ "tailscale.com/util/clientmetric" + _ "tailscale.com/util/eventbus" _ "tailscale.com/util/multierr" _ "tailscale.com/util/osshare" _ "tailscale.com/version" diff --git a/tstest/integration/tailscaled_deps_test_windows.go b/tstest/integration/tailscaled_deps_test_windows.go index bbf46d8c21938..b5919b9628760 100644 --- a/tstest/integration/tailscaled_deps_test_windows.go +++ b/tstest/integration/tailscaled_deps_test_windows.go @@ -18,16 +18,20 @@ import ( _ "golang.org/x/sys/windows/svc/mgr" _ "golang.zx2c4.com/wintun" _ "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" - _ "tailscale.com/client/tailscale" + _ "tailscale.com/client/local" _ "tailscale.com/cmd/tailscaled/childproc" + _ "tailscale.com/cmd/tailscaled/tailscaledhooks" _ "tailscale.com/control/controlclient" _ "tailscale.com/derp/derphttp" _ "tailscale.com/drive/driveimpl" _ "tailscale.com/envknob" + _ "tailscale.com/feature/condregister" _ "tailscale.com/health" _ "tailscale.com/hostinfo" _ "tailscale.com/ipn" + _ "tailscale.com/ipn/auditlog" _ "tailscale.com/ipn/conffile" + _ "tailscale.com/ipn/desktop" _ "tailscale.com/ipn/ipnlocal" _ "tailscale.com/ipn/ipnserver" _ "tailscale.com/ipn/store" @@ -54,11 +58,13 @@ import ( _ "tailscale.com/types/logger" _ "tailscale.com/types/logid" _ "tailscale.com/util/clientmetric" + _ "tailscale.com/util/eventbus" _ "tailscale.com/util/multierr" _ "tailscale.com/util/osdiag" _ "tailscale.com/util/osshare" _ "tailscale.com/util/syspolicy" _ "tailscale.com/util/winutil" + _ "tailscale.com/util/winutil/gp" _ "tailscale.com/version" _ "tailscale.com/version/distro" _ "tailscale.com/wf" diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index bbcf277d171e1..71205f897aad8 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -6,6 +6,7 @@ package testcontrol import ( "bytes" + "cmp" "context" "encoding/binary" "encoding/json" @@ -26,7 +27,7 @@ import ( "time" "golang.org/x/net/http2" - "tailscale.com/control/controlhttp" + "tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/net/netaddr" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" @@ -54,6 +55,10 @@ type Server struct { MagicDNSDomain string HandleC2N http.Handler // if non-nil, used for /some-c2n-path/ in tests + // AllNodesSameUser, if true, makes all created nodes + // belong to the same user. + AllNodesSameUser bool + // ExplicitBaseURL or HTTPTestServer must be set. ExplicitBaseURL string // e.g. "http://127.0.0.1:1234" with no trailing URL HTTPTestServer *httptest.Server // if non-nil, used to get BaseURL @@ -95,9 +100,9 @@ type Server struct { logins map[key.NodePublic]*tailcfg.Login updates map[tailcfg.NodeID]chan updateType authPath map[string]*AuthPath - nodeKeyAuthed map[key.NodePublic]bool // key => true once authenticated - msgToSend map[key.NodePublic]any // value is *tailcfg.PingRequest or entire *tailcfg.MapResponse - allExpired bool // All nodes will be told their node key is expired. + nodeKeyAuthed set.Set[key.NodePublic] + msgToSend map[key.NodePublic]any // value is *tailcfg.PingRequest or entire *tailcfg.MapResponse + allExpired bool // All nodes will be told their node key is expired. } // BaseURL returns the server's base URL, without trailing slash. @@ -288,7 +293,7 @@ func (s *Server) serveNoiseUpgrade(w http.ResponseWriter, r *http.Request) { s.mu.Lock() noisePrivate := s.noisePrivKey s.mu.Unlock() - cc, err := controlhttp.AcceptHTTP(ctx, w, r, noisePrivate, nil) + cc, err := controlhttpserver.AcceptHTTP(ctx, w, r, noisePrivate, nil) if err != nil { log.Printf("AcceptHTTP: %v", err) return @@ -476,13 +481,22 @@ func (s *Server) AddFakeNode() { // TODO: send updates to other (non-fake?) nodes } -func (s *Server) AllUsers() (users []*tailcfg.User) { +func (s *Server) allUserProfiles() (res []tailcfg.UserProfile) { s.mu.Lock() defer s.mu.Unlock() - for _, u := range s.users { - users = append(users, u.Clone()) + for k, u := range s.users { + up := tailcfg.UserProfile{ + ID: u.ID, + DisplayName: u.DisplayName, + } + if login, ok := s.logins[k]; ok { + up.LoginName = login.LoginName + up.ProfilePicURL = cmp.Or(up.ProfilePicURL, login.ProfilePicURL) + up.DisplayName = cmp.Or(up.DisplayName, login.DisplayName) + } + res = append(res, up) } - return users + return res } func (s *Server) AllNodes() (nodes []*tailcfg.Node) { @@ -512,6 +526,10 @@ func (s *Server) getUser(nodeKey key.NodePublic) (*tailcfg.User, *tailcfg.Login) return u, s.logins[nodeKey] } id := tailcfg.UserID(len(s.users) + 1) + if s.AllNodesSameUser { + id = 123 + } + s.logf("Created user %v for node %s", id, nodeKey) loginName := fmt.Sprintf("user-%d@%s", id, domain) displayName := fmt.Sprintf("User %d", id) login := &tailcfg.Login{ @@ -523,9 +541,7 @@ func (s *Server) getUser(nodeKey key.NodePublic) (*tailcfg.User, *tailcfg.Login) } user := &tailcfg.User{ ID: id, - LoginName: loginName, DisplayName: displayName, - Logins: []tailcfg.LoginID{login.ID}, } s.users[nodeKey] = user s.logins[nodeKey] = login @@ -574,10 +590,8 @@ func (s *Server) CompleteAuth(authPathOrURL string) bool { if ap.nodeKey.IsZero() { panic("zero AuthPath.NodeKey") } - if s.nodeKeyAuthed == nil { - s.nodeKeyAuthed = map[key.NodePublic]bool{} - } - s.nodeKeyAuthed[ap.nodeKey] = true + s.nodeKeyAuthed.Make() + s.nodeKeyAuthed.Add(ap.nodeKey) ap.CompleteSuccessfully() return true } @@ -637,36 +651,40 @@ func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key. if s.nodes == nil { s.nodes = map[key.NodePublic]*tailcfg.Node{} } - + _, ok := s.nodes[nk] machineAuthorized := true // TODO: add Server.RequireMachineAuth + if !ok { - v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(tailcfg.NodeID(user.ID)>>8), uint8(tailcfg.NodeID(user.ID))), 32) - v6Prefix := netip.PrefixFrom(tsaddr.Tailscale4To6(v4Prefix.Addr()), 128) - - allowedIPs := []netip.Prefix{ - v4Prefix, - v6Prefix, - } + nodeID := len(s.nodes) + 1 + v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(nodeID>>8), uint8(nodeID)), 32) + v6Prefix := netip.PrefixFrom(tsaddr.Tailscale4To6(v4Prefix.Addr()), 128) - s.nodes[nk] = &tailcfg.Node{ - ID: tailcfg.NodeID(user.ID), - StableID: tailcfg.StableNodeID(fmt.Sprintf("TESTCTRL%08x", int(user.ID))), - User: user.ID, - Machine: mkey, - Key: req.NodeKey, - MachineAuthorized: machineAuthorized, - Addresses: allowedIPs, - AllowedIPs: allowedIPs, - Hostinfo: req.Hostinfo.View(), - Name: req.Hostinfo.Hostname, - Capabilities: []tailcfg.NodeCapability{ - tailcfg.CapabilityHTTPS, - tailcfg.NodeAttrFunnel, - tailcfg.CapabilityFunnelPorts + "?ports=8080,443", - }, + allowedIPs := []netip.Prefix{ + v4Prefix, + v6Prefix, + } + node := &tailcfg.Node{ + ID: tailcfg.NodeID(nodeID), + StableID: tailcfg.StableNodeID(fmt.Sprintf("TESTCTRL%08x", int(nodeID))), + User: user.ID, + Machine: mkey, + Key: req.NodeKey, + MachineAuthorized: machineAuthorized, + Addresses: allowedIPs, + AllowedIPs: allowedIPs, + Hostinfo: req.Hostinfo.View(), + Name: req.Hostinfo.Hostname, + Capabilities: []tailcfg.NodeCapability{ + tailcfg.CapabilityHTTPS, + tailcfg.NodeAttrFunnel, + tailcfg.CapabilityFileSharing, + tailcfg.CapabilityFunnelPorts + "?ports=8080,443", + }, + } + s.nodes[nk] = node } requireAuth := s.RequireAuth - if requireAuth && s.nodeKeyAuthed[nk] { + if requireAuth && s.nodeKeyAuthed.Contains(nk) { requireAuth = false } allExpired := s.allExpired @@ -797,7 +815,7 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi node.Hostinfo = req.Hostinfo.View() if ni := node.Hostinfo.NetInfo(); ni.Valid() { if ni.PreferredDERP() != 0 { - node.DERP = fmt.Sprintf("127.3.3.40:%d", ni.PreferredDERP()) + node.HomeDERP = ni.PreferredDERP() } } } @@ -831,15 +849,17 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi w.WriteHeader(200) for { - if resBytes, ok := s.takeRawMapMessage(req.NodeKey); ok { - if err := s.sendMapMsg(w, mkey, compress, resBytes); err != nil { - s.logf("sendMapMsg of raw message: %v", err) - return - } - if streaming { + // Only send raw map responses to the streaming poll, to avoid a + // non-streaming map request beating the streaming poll in a race and + // potentially dropping the map response. + if streaming { + if resBytes, ok := s.takeRawMapMessage(req.NodeKey); ok { + if err := s.sendMapMsg(w, compress, resBytes); err != nil { + s.logf("sendMapMsg of raw message: %v", err) + return + } continue } - return } if s.canGenerateAutomaticMapResponseFor(req.NodeKey) { @@ -864,7 +884,7 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi s.logf("json.Marshal: %v", err) return } - if err := s.sendMapMsg(w, mkey, compress, resBytes); err != nil { + if err := s.sendMapMsg(w, compress, resBytes); err != nil { return } } @@ -895,7 +915,7 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi } break keepAliveLoop case <-keepAliveTimerCh: - if err := s.sendMapMsg(w, mkey, compress, keepAliveMsg); err != nil { + if err := s.sendMapMsg(w, compress, keepAliveMsg); err != nil { return } } @@ -941,13 +961,12 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, node.CapMap = nodeCapMap node.Capabilities = append(node.Capabilities, tailcfg.NodeAttrDisableUPnP) - user, _ := s.getUser(nk) t := time.Date(2020, 8, 3, 0, 0, 0, 1, time.UTC) dns := s.DNSConfig if dns != nil && s.MagicDNSDomain != "" { dns = dns.Clone() dns.CertDomains = []string{ - fmt.Sprintf(node.Hostinfo.Hostname() + "." + s.MagicDNSDomain), + node.Hostinfo.Hostname() + "." + s.MagicDNSDomain, } } @@ -1001,15 +1020,9 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, sort.Slice(res.Peers, func(i, j int) bool { return res.Peers[i].ID < res.Peers[j].ID }) - for _, u := range s.AllUsers() { - res.UserProfiles = append(res.UserProfiles, tailcfg.UserProfile{ - ID: u.ID, - LoginName: u.LoginName, - DisplayName: u.DisplayName, - }) - } + res.UserProfiles = s.allUserProfiles() - v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(tailcfg.NodeID(user.ID)>>8), uint8(tailcfg.NodeID(user.ID))), 32) + v4Prefix := netip.PrefixFrom(netaddr.IPv4(100, 64, uint8(node.ID>>8), uint8(node.ID)), 32) v6Prefix := netip.PrefixFrom(tsaddr.Tailscale4To6(v4Prefix.Addr()), 128) res.Node.Addresses = []netip.Prefix{ @@ -1060,7 +1073,7 @@ func (s *Server) takeRawMapMessage(nk key.NodePublic) (mapResJSON []byte, ok boo return mapResJSON, true } -func (s *Server) sendMapMsg(w http.ResponseWriter, mkey key.MachinePublic, compress bool, msg any) error { +func (s *Server) sendMapMsg(w http.ResponseWriter, compress bool, msg any) error { resBytes, err := s.encode(compress, msg) if err != nil { return err diff --git a/tstest/integration/vms/harness_test.go b/tstest/integration/vms/harness_test.go index 1e080414d72e7..256227d6c64cc 100644 --- a/tstest/integration/vms/harness_test.go +++ b/tstest/integration/vms/harness_test.go @@ -134,11 +134,12 @@ func newHarness(t *testing.T) *Harness { loginServer := fmt.Sprintf("http://%s", ln.Addr()) t.Logf("loginServer: %s", loginServer) + binaries := integration.GetBinaries(t) h := &Harness{ pubKey: string(pubkey), - binaryDir: integration.BinaryDir(t), - cli: integration.TailscaleBinary(t), - daemon: integration.TailscaledBinary(t), + binaryDir: binaries.Dir, + cli: binaries.Tailscale.Path, + daemon: binaries.Tailscaled.Path, signer: signer, loginServerURL: loginServer, cs: cs, diff --git a/tstest/integration/vms/vms_test.go b/tstest/integration/vms/vms_test.go index 6d73a3f78d27e..f71f2bdbf2069 100644 --- a/tstest/integration/vms/vms_test.go +++ b/tstest/integration/vms/vms_test.go @@ -28,7 +28,6 @@ import ( "golang.org/x/crypto/ssh" "golang.org/x/sync/semaphore" "tailscale.com/tstest" - "tailscale.com/tstest/integration" "tailscale.com/types/logger" ) @@ -51,13 +50,6 @@ var ( }() ) -func TestMain(m *testing.M) { - flag.Parse() - v := m.Run() - integration.CleanupBinaries() - os.Exit(v) -} - func TestDownloadImages(t *testing.T) { if !*runVMTests { t.Skip("not running integration tests (need --run-vm-tests)") diff --git a/tstest/iosdeps/iosdeps_test.go b/tstest/iosdeps/iosdeps_test.go index ab69f1c2b0649..b533724eb4b3d 100644 --- a/tstest/iosdeps/iosdeps_test.go +++ b/tstest/iosdeps/iosdeps_test.go @@ -24,6 +24,7 @@ func TestDeps(t *testing.T) { "github.com/google/uuid": "see tailscale/tailscale#13760", "tailscale.com/clientupdate/distsign": "downloads via AppStore, not distsign", "github.com/tailscale/hujson": "no config file support on iOS", + "tailscale.com/feature/capture": "no debug packet capture on iOS", }, }.Check(t) } diff --git a/tstest/log.go b/tstest/log.go index cb67c609adcdf..d081c819d8ce2 100644 --- a/tstest/log.go +++ b/tstest/log.go @@ -13,6 +13,7 @@ import ( "go4.org/mem" "tailscale.com/types/logger" + "tailscale.com/util/testenv" ) type testLogWriter struct { @@ -149,7 +150,7 @@ func (ml *MemLogger) String() string { // WhileTestRunningLogger returns a logger.Logf that logs to t.Logf until the // test finishes, at which point it no longer logs anything. -func WhileTestRunningLogger(t testing.TB) logger.Logf { +func WhileTestRunningLogger(t testenv.TB) logger.Logf { var ( mu sync.RWMutex done bool diff --git a/tstest/mts/mts.go b/tstest/mts/mts.go new file mode 100644 index 0000000000000..c10d69d8daca4 --- /dev/null +++ b/tstest/mts/mts.go @@ -0,0 +1,599 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux || darwin + +// The mts ("Multiple Tailscale") command runs multiple tailscaled instances for +// development, managing their directories and sockets, and lets you easily direct +// tailscale CLI commands to them. +package main + +import ( + "bufio" + "context" + "encoding/json" + "flag" + "fmt" + "io" + "log" + "maps" + "net" + "net/http" + "net/url" + "os" + "os/exec" + "path/filepath" + "regexp" + "slices" + "strings" + "sync" + "syscall" + "time" + + "tailscale.com/client/local" + "tailscale.com/types/bools" + "tailscale.com/types/lazy" + "tailscale.com/util/mak" +) + +func usage(args ...any) { + var format string + if len(args) > 0 { + format, args = args[0].(string), args[1:] + } + if format != "" { + format = strings.TrimSpace(format) + "\n\n" + fmt.Fprintf(os.Stderr, format, args...) + } + io.WriteString(os.Stderr, strings.TrimSpace(` +usage: + + mts server # manage tailscaled instances + mts server run # run the mts server (parent process of all tailscaled) + mts server list # list all tailscaled and their state + mts server list # show details of named instance + mts server add # add+start new named tailscaled + mts server start # start a previously added tailscaled + mts server stop # stop & remove a named tailscaled + mts server rm # stop & remove a named tailscaled + mts server logs [-f] # get/follow tailscaled logs + + mts [tailscale CLI args] # run Tailscale CLI against a named instance + e.g. + mts gmail1 up + mts github2 status --json + `)+"\n") + os.Exit(1) +} + +func main() { + // Don't use flag.Parse here; we mostly just delegate through + // to the Tailscale CLI. + + if len(os.Args) < 2 { + usage() + } + firstArg, args := os.Args[1], os.Args[2:] + if firstArg == "server" || firstArg == "s" { + if err := runMTSServer(args); err != nil { + log.Fatal(err) + } + } else { + var c Client + inst := firstArg + c.RunCommand(inst, args) + } +} + +func runMTSServer(args []string) error { + if len(args) == 0 { + usage() + } + cmd, args := args[0], args[1:] + if cmd == "run" { + var s Server + return s.Run() + } + + // Commands other than "run" all use the HTTP client to + // hit the mts server over its unix socket. + var c Client + + switch cmd { + default: + usage("unknown mts server subcommand %q", cmd) + case "list", "ls": + list, err := c.List() + if err != nil { + return err + } + if len(args) == 0 { + names := slices.Sorted(maps.Keys(list.Instances)) + for _, name := range names { + running := list.Instances[name].Running + fmt.Printf("%10s %s\n", bools.IfElse(running, "RUNNING", "stopped"), name) + } + } else { + for _, name := range args { + inst, ok := list.Instances[name] + if !ok { + return fmt.Errorf("no instance named %q", name) + } + je := json.NewEncoder(os.Stdout) + je.SetIndent("", " ") + if err := je.Encode(inst); err != nil { + return err + } + } + } + + case "rm": + if len(args) == 0 { + return fmt.Errorf("missing instance name(s) to remove") + } + log.SetFlags(0) + for _, name := range args { + ok, err := c.Remove(name) + if err != nil { + return err + } + if ok { + log.Printf("%s deleted.", name) + } else { + log.Printf("%s didn't exist.", name) + } + } + case "stop": + if len(args) == 0 { + return fmt.Errorf("missing instance name(s) to stop") + } + log.SetFlags(0) + for _, name := range args { + ok, err := c.Stop(name) + if err != nil { + return err + } + if ok { + log.Printf("%s stopped.", name) + } else { + log.Printf("%s didn't exist.", name) + } + } + case "start", "restart": + list, err := c.List() + if err != nil { + return err + } + shouldStop := cmd == "restart" + for _, arg := range args { + is, ok := list.Instances[arg] + if !ok { + return fmt.Errorf("no instance named %q", arg) + } + if is.Running { + if shouldStop { + if _, err := c.Stop(arg); err != nil { + return fmt.Errorf("stopping %q: %w", arg, err) + } + } else { + log.SetFlags(0) + log.Printf("%s already running.", arg) + continue + } + } + // Creating an existing one starts it up. + if err := c.Create(arg); err != nil { + return fmt.Errorf("starting %q: %w", arg, err) + } + } + case "add": + if len(args) == 0 { + return fmt.Errorf("missing instance name(s) to add") + } + for _, name := range args { + if err := c.Create(name); err != nil { + return fmt.Errorf("creating %q: %w", name, err) + } + } + case "logs": + fs := flag.NewFlagSet("logs", flag.ExitOnError) + fs.Usage = func() { usage() } + follow := fs.Bool("f", false, "follow logs") + fs.Parse(args) + log.Printf("Parsed; following=%v, args=%q", *follow, fs.Args()) + if fs.NArg() != 1 { + usage() + } + cmd := bools.IfElse(*follow, "tail", "cat") + args := []string{cmd} + if *follow { + args = append(args, "-f") + } + path, err := exec.LookPath(cmd) + if err != nil { + return fmt.Errorf("looking up %q: %w", cmd, err) + } + args = append(args, instLogsFile(fs.Arg(0))) + log.Fatal(syscall.Exec(path, args, os.Environ())) + } + return nil +} + +type Client struct { +} + +func (c *Client) client() *http.Client { + return &http.Client{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return net.Dial("unix", mtsSock()) + }, + }, + } +} + +func getJSON[T any](res *http.Response, err error) (T, error) { + var ret T + if err != nil { + return ret, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + body, _ := io.ReadAll(res.Body) + return ret, fmt.Errorf("unexpected status: %v: %s", res.Status, body) + } + if err := json.NewDecoder(res.Body).Decode(&ret); err != nil { + return ret, err + } + return ret, nil +} + +func (c *Client) List() (listResponse, error) { + return getJSON[listResponse](c.client().Get("http://mts/list")) +} + +func (c *Client) Remove(name string) (found bool, err error) { + return getJSON[bool](c.client().PostForm("http://mts/rm", url.Values{ + "name": []string{name}, + })) +} + +func (c *Client) Stop(name string) (found bool, err error) { + return getJSON[bool](c.client().PostForm("http://mts/stop", url.Values{ + "name": []string{name}, + })) +} + +func (c *Client) Create(name string) error { + req, err := http.NewRequest("POST", "http://mts/create/"+name, nil) + if err != nil { + return err + } + resp, err := c.client().Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("unexpected status: %v: %s", resp.Status, body) + } + return nil +} + +func (c *Client) RunCommand(name string, args []string) { + sock := instSock(name) + lc := &local.Client{ + Socket: sock, + UseSocketOnly: true, + } + probeCtx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond) + defer cancel() + if _, err := lc.StatusWithoutPeers(probeCtx); err != nil { + log.Fatalf("instance %q not running? start with 'mts server start %q'; got error: %v", name, name, err) + } + args = append([]string{"run", "tailscale.com/cmd/tailscale", "--socket=" + sock}, args...) + cmd := exec.Command("go", args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Stdin = os.Stdin + err := cmd.Run() + if err == nil { + os.Exit(0) + } + if exitErr, ok := err.(*exec.ExitError); ok { + os.Exit(exitErr.ExitCode()) + } + panic(err) +} + +type Server struct { + lazyTailscaled lazy.GValue[string] + + mu sync.Mutex + cmds map[string]*exec.Cmd // running tailscaled instances +} + +func (s *Server) tailscaled() string { + v, err := s.lazyTailscaled.GetErr(func() (string, error) { + out, err := exec.Command("go", "list", "-f", "{{.Target}}", "tailscale.com/cmd/tailscaled").CombinedOutput() + if err != nil { + return "", err + } + return strings.TrimSpace(string(out)), nil + }) + if err != nil { + panic(err) + } + return v +} + +func (s *Server) Run() error { + if err := os.MkdirAll(mtsRoot(), 0700); err != nil { + return err + } + sock := mtsSock() + os.Remove(sock) + log.Printf("Multi-Tailscaled Server running; listening on %q ...", sock) + ln, err := net.Listen("unix", sock) + if err != nil { + return err + } + return http.Serve(ln, s) +} + +var validNameRx = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) + +func validInstanceName(name string) bool { + return validNameRx.MatchString(name) +} + +func (s *Server) InstanceRunning(name string) bool { + s.mu.Lock() + defer s.mu.Unlock() + _, ok := s.cmds[name] + return ok +} + +func (s *Server) Stop(name string) { + s.mu.Lock() + defer s.mu.Unlock() + if cmd, ok := s.cmds[name]; ok { + if err := cmd.Process.Kill(); err != nil { + log.Printf("error killing %q: %v", name, err) + } + delete(s.cmds, name) + } +} + +func (s *Server) RunInstance(name string) error { + s.mu.Lock() + defer s.mu.Unlock() + + if _, ok := s.cmds[name]; ok { + return fmt.Errorf("instance %q already running", name) + } + + if !validInstanceName(name) { + return fmt.Errorf("invalid instance name %q", name) + } + dir := filepath.Join(mtsRoot(), name) + if err := os.MkdirAll(dir, 0700); err != nil { + return err + } + + env := os.Environ() + env = append(env, "TS_DEBUG_LOG_RATE=all") + if ef, err := os.Open(instEnvFile(name)); err == nil { + defer ef.Close() + sc := bufio.NewScanner(ef) + for sc.Scan() { + t := strings.TrimSpace(sc.Text()) + if strings.HasPrefix(t, "#") || !strings.Contains(t, "=") { + continue + } + env = append(env, t) + } + } else if os.IsNotExist(err) { + // Write an example one. + os.WriteFile(instEnvFile(name), fmt.Appendf(nil, "# Example mts env.txt file; uncomment/add stuff you want for %q\n\n#TS_DEBUG_MAP=1\n#TS_DEBUG_REGISTER=1\n#TS_NO_LOGS_NO_SUPPORT=1\n", name), 0600) + } + + extraArgs := []string{"--verbose=1"} + if af, err := os.Open(instArgsFile(name)); err == nil { + extraArgs = nil // clear default args + defer af.Close() + sc := bufio.NewScanner(af) + for sc.Scan() { + t := strings.TrimSpace(sc.Text()) + if strings.HasPrefix(t, "#") || t == "" { + continue + } + extraArgs = append(extraArgs, t) + } + } else if os.IsNotExist(err) { + // Write an example one. + os.WriteFile(instArgsFile(name), fmt.Appendf(nil, "# Example mts args.txt file for instance %q.\n# One line per extra arg to tailscaled; no magic string quoting\n\n--verbose=1\n#--socks5-server=127.0.0.1:5000\n", name), 0600) + } + + log.Printf("Running Tailscale daemon %q in %q", name, dir) + + args := []string{ + "--tun=userspace-networking", + "--statedir=" + filepath.Join(dir), + "--socket=" + filepath.Join(dir, "tailscaled.sock"), + } + args = append(args, extraArgs...) + + cmd := exec.Command(s.tailscaled(), args...) + cmd.Dir = dir + cmd.Env = env + + out, err := cmd.StdoutPipe() + if err != nil { + return err + } + cmd.Stderr = cmd.Stdout + + logs := instLogsFile(name) + logFile, err := os.OpenFile(logs, os.O_CREATE|os.O_WRONLY|os.O_APPEND|os.O_TRUNC, 0644) + if err != nil { + return fmt.Errorf("opening logs file: %w", err) + } + + go func() { + bs := bufio.NewScanner(out) + for bs.Scan() { + // TODO(bradfitz): record in memory too, serve via HTTP + line := strings.TrimSpace(bs.Text()) + fmt.Fprintf(logFile, "%s\n", line) + fmt.Printf("tailscaled[%s]: %s\n", name, line) + } + }() + + if err := cmd.Start(); err != nil { + return err + } + go func() { + err := cmd.Wait() + logFile.Close() + log.Printf("Tailscale daemon %q exited: %v", name, err) + s.mu.Lock() + defer s.mu.Unlock() + delete(s.cmds, name) + }() + + mak.Set(&s.cmds, name, cmd) + return nil +} + +type listResponse struct { + // Instances maps instance name to its details. + Instances map[string]listResponseInstance `json:"instances"` +} + +type listResponseInstance struct { + Name string `json:"name"` + Dir string `json:"dir"` + Sock string `json:"sock"` + Running bool `json:"running"` + Env string `json:"env"` + Args string `json:"args"` + Logs string `json:"logs"` +} + +func writeJSON(w http.ResponseWriter, v any) { + w.Header().Set("Content-Type", "application/json") + e := json.NewEncoder(w) + e.SetIndent("", " ") + e.Encode(v) +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/list" { + var res listResponse + for _, name := range s.InstanceNames() { + mak.Set(&res.Instances, name, listResponseInstance{ + Name: name, + Dir: instDir(name), + Sock: instSock(name), + Running: s.InstanceRunning(name), + Env: instEnvFile(name), + Args: instArgsFile(name), + Logs: instLogsFile(name), + }) + } + writeJSON(w, res) + return + } + if r.URL.Path == "/rm" || r.URL.Path == "/stop" { + shouldRemove := r.URL.Path == "/rm" + if r.Method != "POST" { + http.Error(w, "POST required", http.StatusMethodNotAllowed) + return + } + target := r.FormValue("name") + var ok bool + for _, name := range s.InstanceNames() { + if name != target { + continue + } + ok = true + s.Stop(name) + if shouldRemove { + if err := os.RemoveAll(instDir(name)); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + break + } + writeJSON(w, ok) + return + } + if inst, ok := strings.CutPrefix(r.URL.Path, "/create/"); ok { + if !s.InstanceRunning(inst) { + if err := s.RunInstance(inst); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + fmt.Fprintf(w, "OK\n") + return + } + if r.URL.Path == "/" { + fmt.Fprintf(w, "This is mts, the multi-tailscaled server.\n") + return + } + http.NotFound(w, r) +} + +func (s *Server) InstanceNames() []string { + var ret []string + des, err := os.ReadDir(mtsRoot()) + if err != nil { + if os.IsNotExist(err) { + return nil + } + panic(err) + } + for _, de := range des { + if !de.IsDir() { + continue + } + ret = append(ret, de.Name()) + } + return ret +} + +func mtsRoot() string { + dir, err := os.UserConfigDir() + if err != nil { + panic(err) + } + return filepath.Join(dir, "multi-tailscale-dev") +} + +func instDir(name string) string { + return filepath.Join(mtsRoot(), name) +} + +func instSock(name string) string { + return filepath.Join(instDir(name), "tailscaled.sock") +} + +func instEnvFile(name string) string { + return filepath.Join(mtsRoot(), name, "env.txt") +} + +func instArgsFile(name string) string { + return filepath.Join(mtsRoot(), name, "args.txt") +} + +func instLogsFile(name string) string { + return filepath.Join(mtsRoot(), name, "logs.txt") +} + +func mtsSock() string { + return filepath.Join(mtsRoot(), "mts.sock") +} diff --git a/tstest/natlab/vnet/conf.go b/tstest/natlab/vnet/conf.go index cf71a66743e1c..07b181540838c 100644 --- a/tstest/natlab/vnet/conf.go +++ b/tstest/natlab/vnet/conf.go @@ -10,6 +10,7 @@ import ( "net/netip" "os" "slices" + "time" "github.com/google/gopacket/layers" "github.com/google/gopacket/pcapgo" @@ -120,6 +121,8 @@ func (c *Config) AddNode(opts ...any) *Node { n.err = fmt.Errorf("unknown NodeOption %q", o) } } + case MAC: + n.mac = o default: if n.err == nil { n.err = fmt.Errorf("unknown AddNode option type %T", o) @@ -279,10 +282,28 @@ type Network struct { svcs set.Set[NetworkService] + latency time.Duration // latency applied to interface writes + lossRate float64 // chance of packet loss (0.0 to 1.0) + // ... err error // carried error } +// SetLatency sets the simulated network latency for this network. +func (n *Network) SetLatency(d time.Duration) { + n.latency = d +} + +// SetPacketLoss sets the packet loss rate for this network 0.0 (no loss) to 1.0 (total loss). +func (n *Network) SetPacketLoss(rate float64) { + if rate < 0 { + rate = 0 + } else if rate > 1 { + rate = 1 + } + n.lossRate = rate +} + // SetBlackholedIPv4 sets whether the network should blackhole all IPv4 traffic // out to the Internet. (DHCP etc continues to work on the LAN.) func (n *Network) SetBlackholedIPv4(v bool) { @@ -361,6 +382,8 @@ func (s *Server) initFromConfig(c *Config) error { wanIP4: conf.wanIP4, lanIP4: conf.lanIP4, breakWAN4: conf.breakWAN4, + latency: conf.latency, + lossRate: conf.lossRate, nodesByIP4: map[netip.Addr]*node{}, nodesByMAC: map[MAC]*node{}, logf: logger.WithPrefix(s.logf, fmt.Sprintf("[net-%v] ", conf.mac)), diff --git a/tstest/natlab/vnet/conf_test.go b/tstest/natlab/vnet/conf_test.go index 15d3c69ef52d9..6566ac8cf4610 100644 --- a/tstest/natlab/vnet/conf_test.go +++ b/tstest/natlab/vnet/conf_test.go @@ -3,7 +3,10 @@ package vnet -import "testing" +import ( + "testing" + "time" +) func TestConfig(t *testing.T) { tests := []struct { @@ -18,6 +21,16 @@ func TestConfig(t *testing.T) { c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", HardNAT)) }, }, + { + name: "latency-and-loss", + setup: func(c *Config) { + n1 := c.AddNetwork("2.1.1.1", "192.168.1.1/24", EasyNAT, NATPMP) + n1.SetLatency(time.Second) + n1.SetPacketLoss(0.1) + c.AddNode(n1) + c.AddNode(c.AddNetwork("2.2.2.2", "10.2.0.1/16", HardNAT)) + }, + }, { name: "indirect", setup: func(c *Config) { diff --git a/tstest/natlab/vnet/vip.go b/tstest/natlab/vnet/vip.go index c75f17cee5393..190c9e75f1a62 100644 --- a/tstest/natlab/vnet/vip.go +++ b/tstest/natlab/vnet/vip.go @@ -17,7 +17,7 @@ var ( fakeControl = newVIP("control.tailscale", 3) fakeDERP1 = newVIP("derp1.tailscale", "33.4.0.1") // 3340=DERP; 1=derp 1 fakeDERP2 = newVIP("derp2.tailscale", "33.4.0.2") // 3340=DERP; 2=derp 2 - fakeLogCatcher = newVIP("log.tailscale.io", 4) + fakeLogCatcher = newVIP("log.tailscale.com", 4) fakeSyslog = newVIP("syslog.tailscale", 9) ) diff --git a/tstest/natlab/vnet/vnet.go b/tstest/natlab/vnet/vnet.go index e7991b3e6ef5d..1fa170d87df50 100644 --- a/tstest/natlab/vnet/vnet.go +++ b/tstest/natlab/vnet/vnet.go @@ -50,10 +50,11 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/waiter" - "tailscale.com/client/tailscale" + "tailscale.com/client/local" "tailscale.com/derp" "tailscale.com/derp/derphttp" "tailscale.com/net/netutil" + "tailscale.com/net/netx" "tailscale.com/net/stun" "tailscale.com/syncs" "tailscale.com/tailcfg" @@ -88,6 +89,9 @@ func (s *Server) PopulateDERPMapIPs() error { if n.IPv4 != "" { s.derpIPs.Add(netip.MustParseAddr(n.IPv4)) } + if n.IPv6 != "" { + s.derpIPs.Add(netip.MustParseAddr(n.IPv6)) + } } } return nil @@ -394,7 +398,7 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) { } } -// serveLogCatchConn serves a TCP connection to "log.tailscale.io", speaking the +// serveLogCatchConn serves a TCP connection to "log.tailscale.com", speaking the // logtail/logcatcher protocol. // // We terminate TLS with an arbitrary cert; the client is configured to not @@ -515,6 +519,8 @@ type network struct { wanIP4 netip.Addr // router's LAN IPv4, if any lanIP4 netip.Prefix // router's LAN IP + CIDR (e.g. 192.168.2.1/24) breakWAN4 bool // break WAN IPv4 connectivity + latency time.Duration // latency applied to interface writes + lossRate float64 // probability of dropping a packet (0.0 to 1.0) nodesByIP4 map[netip.Addr]*node // by LAN IPv4 nodesByMAC map[MAC]*node logf func(format string, args ...any) @@ -644,7 +650,7 @@ type Server struct { mu sync.Mutex agentConnWaiter map[*node]chan<- struct{} // signaled after added to set agentConns set.Set[*agentConn] // not keyed by node; should be small/cheap enough to scan all - agentDialer map[*node]DialFunc + agentDialer map[*node]netx.DialFunc } func (s *Server) logf(format string, args ...any) { @@ -659,8 +665,6 @@ func (s *Server) SetLoggerForTest(logf func(format string, args ...any)) { s.optLogf = logf } -type DialFunc func(ctx context.Context, network, address string) (net.Conn, error) - var derpMap = &tailcfg.DERPMap{ Regions: map[int]*tailcfg.DERPRegion{ 1: { @@ -977,7 +981,7 @@ func (n *network) writeEth(res []byte) bool { for mac, nw := range n.writers.All() { if mac != srcMAC { num++ - nw.write(res) + n.conditionedWrite(nw, res) } } return num > 0 @@ -987,7 +991,7 @@ func (n *network) writeEth(res []byte) bool { return false } if nw, ok := n.writers.Load(dstMAC); ok { - nw.write(res) + n.conditionedWrite(nw, res) return true } @@ -1000,6 +1004,23 @@ func (n *network) writeEth(res []byte) bool { return false } +func (n *network) conditionedWrite(nw networkWriter, packet []byte) { + if n.lossRate > 0 && rand.Float64() < n.lossRate { + // packet lost + return + } + if n.latency > 0 { + // copy the packet as there's no guarantee packet is owned long enough. + // TODO(raggi): this could be optimized substantially if necessary, + // a pool of buffers and a cheaper delay mechanism are both obvious improvements. + var pkt = make([]byte, len(packet)) + copy(pkt, packet) + time.AfterFunc(n.latency, func() { nw.write(pkt) }) + } else { + nw.write(packet) + } +} + var ( macAllNodes = MAC{0: 0x33, 1: 0x33, 5: 0x01} macAllRouters = MAC{0: 0x33, 1: 0x33, 5: 0x02} @@ -2104,11 +2125,11 @@ func (s *Server) takeAgentConnOne(n *node) (_ *agentConn, ok bool) { } type NodeAgentClient struct { - *tailscale.LocalClient + *local.Client HTTPClient *http.Client } -func (s *Server) NodeAgentDialer(n *Node) DialFunc { +func (s *Server) NodeAgentDialer(n *Node) netx.DialFunc { s.mu.Lock() defer s.mu.Unlock() @@ -2129,7 +2150,7 @@ func (s *Server) NodeAgentDialer(n *Node) DialFunc { func (s *Server) NodeAgentClient(n *Node) *NodeAgentClient { d := s.NodeAgentDialer(n) return &NodeAgentClient{ - LocalClient: &tailscale.LocalClient{ + Client: &local.Client{ UseSocketOnly: true, OmitAuth: true, Dial: d, diff --git a/tstest/nettest/nettest.go b/tstest/nettest/nettest.go index 47c8857a57ce3..c78677dd45c59 100644 --- a/tstest/nettest/nettest.go +++ b/tstest/nettest/nettest.go @@ -6,11 +6,22 @@ package nettest import ( + "context" + "flag" + "net" + "net/http" + "net/http/httptest" + "sync" "testing" + "tailscale.com/net/memnet" "tailscale.com/net/netmon" + "tailscale.com/net/netx" + "tailscale.com/util/testenv" ) +var useMemNet = flag.Bool("use-test-memnet", false, "prefer using in-memory network for tests") + // SkipIfNoNetwork skips the test if it looks like there's no network // access. func SkipIfNoNetwork(t testing.TB) { @@ -19,3 +30,89 @@ func SkipIfNoNetwork(t testing.TB) { t.Skip("skipping; test requires network but no interface is up") } } + +// PreferMemNetwork reports whether the --use-test-memnet flag is set. +func PreferMemNetwork() bool { + return *useMemNet +} + +// GetNetwork returns the appropriate Network implementation based on +// whether the --use-test-memnet flag is set. +// +// Each call generates a new network. +func GetNetwork(tb testing.TB) netx.Network { + var n netx.Network + if PreferMemNetwork() { + n = &memnet.Network{} + } else { + n = netx.RealNetwork() + } + + detectLeaks := PreferMemNetwork() || !testenv.InParallelTest(tb) + if detectLeaks { + tb.Cleanup(func() { + // TODO: leak detection, making sure no connections + // remain at the end of the test. For real network, + // snapshot conns in pid table before & after. + }) + } + return n +} + +// NewHTTPServer starts and returns a new [httptest.Server]. +// The caller should call Close when finished, to shut it down. +func NewHTTPServer(net netx.Network, handler http.Handler) *httptest.Server { + ts := NewUnstartedHTTPServer(net, handler) + ts.Start() + return ts +} + +// NewUnstartedHTTPServer returns a new [httptest.Server] but doesn't start it. +// +// After changing its configuration, the caller should call Start or +// StartTLS. +// +// The caller should call Close when finished, to shut it down. +func NewUnstartedHTTPServer(nw netx.Network, handler http.Handler) *httptest.Server { + s := &httptest.Server{ + Config: &http.Server{Handler: handler}, + } + ln := nw.NewLocalTCPListener() + s.Listener = &listenerOnAddrOnce{ + Listener: ln, + fn: func() { + c := s.Client() + if c == nil { + // This httptest.Server.Start initialization order has been true + // for over 10 years. Let's keep counting on it. + panic("httptest.Server: Client not initialized before Addr called") + } + if c.Transport == nil { + c.Transport = &http.Transport{} + } + tr := c.Transport.(*http.Transport) + if tr.Dial != nil || tr.DialContext != nil { + panic("unexpected non-nil Dial or DialContext in httptest.Server.Client.Transport") + } + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + return nw.Dial(ctx, network, addr) + } + }, + } + return s +} + +// listenerOnAddrOnce is a net.Listener that wraps another net.Listener +// and calls a function the first time its Addr is called. +type listenerOnAddrOnce struct { + net.Listener + once sync.Once + fn func() +} + +func (ln *listenerOnAddrOnce) Addr() net.Addr { + ln.once.Do(func() { + ln.fn() + }) + return ln.Listener.Addr() +} diff --git a/tstest/resource.go b/tstest/resource.go index b094c7911014f..f50bb3330e846 100644 --- a/tstest/resource.go +++ b/tstest/resource.go @@ -7,10 +7,10 @@ import ( "bytes" "runtime" "runtime/pprof" + "slices" + "strings" "testing" "time" - - "github.com/google/go-cmp/cmp" ) // ResourceCheck takes a snapshot of the current goroutines and registers a @@ -44,7 +44,20 @@ func ResourceCheck(tb testing.TB) { if endN <= startN { return } - tb.Logf("goroutine diff:\n%v\n", cmp.Diff(startStacks, endStacks)) + + // Parse and print goroutines. + start := parseGoroutines(startStacks) + end := parseGoroutines(endStacks) + if testing.Verbose() { + tb.Logf("goroutines start:\n%s", printGoroutines(start)) + tb.Logf("goroutines end:\n%s", printGoroutines(end)) + } + + // Print goroutine diff, omitting tstest.ResourceCheck goroutines. + self := func(g goroutine) bool { return bytes.Contains(g.stack, []byte("\ttailscale.com/tstest.goroutines+")) } + start.goroutines = slices.DeleteFunc(start.goroutines, self) + end.goroutines = slices.DeleteFunc(end.goroutines, self) + tb.Logf("goroutine diff (-start +end):\n%s", diffGoroutines(start, end)) // tb.Failed() above won't report on panics, so we shouldn't call Fatal // here or we risk suppressing reporting of the panic. @@ -58,3 +71,208 @@ func goroutines() (int, []byte) { p.WriteTo(b, 1) return p.Count(), b.Bytes() } + +// parseGoroutines takes pprof/goroutines?debug=1 -formatted output sorted by +// count, and splits it into a separate list of goroutines with count and stack +// separated. +// +// Example input: +// +// goroutine profile: total 408 +// 48 @ 0x47bc0e 0x136c6b9 0x136c69e 0x136c7ab 0x1379809 0x13797fa 0x483da1 +// # 0x136c6b8 gvisor.dev/gvisor/pkg/sync.Gopark+0x78 gvisor.dev/gvisor@v0.0.0-20250205023644-9414b50a5633/pkg/sync/runtime_unsafe.go:33 +// # 0x136c69d gvisor.dev/gvisor/pkg/sleep.(*Sleeper).nextWaker+0x5d gvisor.dev/gvisor@v0.0.0-20250205023644-9414b50a5633/pkg/sleep/sleep_unsafe.go:210 +// # 0x136c7aa gvisor.dev/gvisor/pkg/sleep.(*Sleeper).fetch+0x2a gvisor.dev/gvisor@v0.0.0-20250205023644-9414b50a5633/pkg/sleep/sleep_unsafe.go:257 +// # 0x1379808 gvisor.dev/gvisor/pkg/sleep.(*Sleeper).Fetch+0xa8 gvisor.dev/gvisor@v0.0.0-20250205023644-9414b50a5633/pkg/sleep/sleep_unsafe.go:280 +// # 0x13797f9 gvisor.dev/gvisor/pkg/tcpip/transport/tcp.(*processor).start+0x99 gvisor.dev/gvisor@v0.0.0-20250205023644-9414b50a5633/pkg/tcpip/transport/tcp/dispatcher.go:291 +// +// 48 @ 0x47bc0e 0x413705 0x4132b2 0x10fc905 0x483da1 +// # 0x10fc904 github.com/tailscale/wireguard-go/device.(*Device).RoutineDecryption+0x184 github.com/tailscale/wireguard-go@v0.0.0-20250107165329-0b8b35511f19/device/receive.go:245 +// +// 48 @ 0x47bc0e 0x413705 0x4132b2 0x10fcd2a 0x483da1 +// # 0x10fcd29 github.com/tailscale/wireguard-go/device.(*Device).RoutineHandshake+0x169 github.com/tailscale/wireguard-go@v0.0.0-20250107165329-0b8b35511f19/device/receive.go:279 +// +// 48 @ 0x47bc0e 0x413705 0x4132b2 0x1100ba7 0x483da1 +// # 0x1100ba6 github.com/tailscale/wireguard-go/device.(*Device).RoutineEncryption+0x186 github.com/tailscale/wireguard-go@v0.0.0-20250107165329-0b8b35511f19/device/send.go:451 +// +// 26 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +// # 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +// +// 13 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +// # 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +// +// 7 @ 0x47bc0e 0x413705 0x4132b2 0x10fda4d 0x483da1 +// # 0x10fda4c github.com/tailscale/wireguard-go/device.(*Peer).RoutineSequentialReceiver+0x16c github.com/tailscale/wireguard-go@v0.0.0-20250107165329-0b8b35511f19/device/receive.go:443 +func parseGoroutines(g []byte) goroutineDump { + head, tail, ok := bytes.Cut(g, []byte("\n")) + if !ok { + return goroutineDump{head: head} + } + + raw := bytes.Split(tail, []byte("\n\n")) + parsed := make([]goroutine, 0, len(raw)) + for _, s := range raw { + count, rem, ok := bytes.Cut(s, []byte(" @ ")) + if !ok { + continue + } + header, stack, _ := bytes.Cut(rem, []byte("\n")) + sort := slices.Clone(header) + reverseWords(sort) + parsed = append(parsed, goroutine{count, header, stack, sort}) + } + + return goroutineDump{head, parsed} +} + +type goroutineDump struct { + head []byte + goroutines []goroutine +} + +// goroutine is a parsed stack trace in pprof goroutine output, e.g. +// "10 @ 0x100 0x001\n# 0x100 test() test.go\n# 0x001 main() test.go". +type goroutine struct { + count []byte // e.g. "10" + header []byte // e.g. "0x100 0x001" + stack []byte // e.g. "# 0x100 test() test.go\n# 0x001 main() test.go" + + // sort is the same pointers as in header, but in reverse order so that we + // can place related goroutines near each other by sorting on this field. + // E.g. "0x001 0x100". + sort []byte +} + +func (g goroutine) Compare(h goroutine) int { + return bytes.Compare(g.sort, h.sort) +} + +// reverseWords repositions the words in b such that they are reversed. +// Words are separated by spaces. New lines are not considered. +// https://sketch.dev/sk/a4ef +func reverseWords(b []byte) { + if len(b) == 0 { + return + } + + // First, reverse the entire slice. + reverse(b) + + // Then reverse each word individually. + start := 0 + for i := 0; i <= len(b); i++ { + if i == len(b) || b[i] == ' ' { + reverse(b[start:i]) + start = i + 1 + } + } +} + +// reverse reverses bytes in place +func reverse(b []byte) { + for i, j := 0, len(b)-1; i < j; i, j = i+1, j-1 { + b[i], b[j] = b[j], b[i] + } +} + +// printGoroutines returns a text representation of h, gs equivalent to the +// pprof ?debug=1 input parsed by parseGoroutines, except the goroutines are +// sorted in an order easier for diffing. +func printGoroutines(g goroutineDump) []byte { + var b bytes.Buffer + b.Write(g.head) + + slices.SortFunc(g.goroutines, goroutine.Compare) + for _, g := range g.goroutines { + b.WriteString("\n\n") + b.Write(g.count) + b.WriteString(" @ ") + b.Write(g.header) + b.WriteString("\n") + if len(g.stack) > 0 { + b.Write(g.stack) + } + } + + return b.Bytes() +} + +// diffGoroutines returns a diff between goroutines of gx and gy. +// Goroutines present in gx and absent from gy are prefixed with "-". +// Goroutines absent from gx and present in gy are prefixed with "+". +// Goroutines present in both but with different counts only show a prefix on the count line. +func diffGoroutines(x, y goroutineDump) string { + hx, hy := x.head, y.head + gx, gy := x.goroutines, y.goroutines + var b strings.Builder + if !bytes.Equal(hx, hy) { + b.WriteString("- ") + b.Write(hx) + b.WriteString("\n+ ") + b.Write(hy) + b.WriteString("\n") + } + + slices.SortFunc(gx, goroutine.Compare) + slices.SortFunc(gy, goroutine.Compare) + + writeHeader := func(prefix string, g goroutine) { + b.WriteString(prefix) + b.Write(g.count) + b.WriteString(" @ ") + b.Write(g.header) + b.WriteString("\n") + } + writeStack := func(prefix string, g goroutine) { + s := g.stack + for { + var h []byte + h, s, _ = bytes.Cut(s, []byte("\n")) + if len(h) == 0 && len(s) == 0 { + break + } + b.WriteString(prefix) + b.Write(h) + b.WriteString("\n") + } + } + + i, j := 0, 0 + for { + var d int + switch { + case i < len(gx) && j < len(gy): + d = gx[i].Compare(gy[j]) + case i < len(gx): + d = -1 + case j < len(gy): + d = 1 + default: + return b.String() + } + + switch d { + case -1: + b.WriteString("\n") + writeHeader("- ", gx[i]) + writeStack("- ", gx[i]) + i++ + + case +1: + b.WriteString("\n") + writeHeader("+ ", gy[j]) + writeStack("+ ", gy[j]) + j++ + + case 0: + if !bytes.Equal(gx[i].count, gy[j].count) { + b.WriteString("\n") + writeHeader("- ", gx[i]) + writeHeader("+ ", gy[j]) + writeStack(" ", gy[j]) + } + i++ + j++ + } + } +} diff --git a/tstest/resource_test.go b/tstest/resource_test.go new file mode 100644 index 0000000000000..7199ac5d11cbf --- /dev/null +++ b/tstest/resource_test.go @@ -0,0 +1,256 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package tstest + +import ( + "strings" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestPrintGoroutines(t *testing.T) { + tests := []struct { + name string + in string + want string + }{ + { + name: "empty", + in: "goroutine profile: total 0\n", + want: "goroutine profile: total 0", + }, + { + name: "single goroutine", + in: `goroutine profile: total 1 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + want: `goroutine profile: total 1 + +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + }, + { + name: "multiple goroutines sorted", + in: `goroutine profile: total 14 +7 @ 0x47bc0e 0x413705 0x4132b2 0x10fda4d 0x483da1 +# 0x10fda4c github.com/user/pkg.RoutineA+0x16c pkg/a.go:443 + +7 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +# 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +`, + want: `goroutine profile: total 14 + +7 @ 0x47bc0e 0x413705 0x4132b2 0x10fda4d 0x483da1 +# 0x10fda4c github.com/user/pkg.RoutineA+0x16c pkg/a.go:443 + +7 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +# 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := string(printGoroutines(parseGoroutines([]byte(tt.in)))) + if got != tt.want { + t.Errorf("printGoroutines() = %q, want %q, diff:\n%s", got, tt.want, cmp.Diff(tt.want, got)) + } + }) + } +} + +func TestDiffPprofGoroutines(t *testing.T) { + tests := []struct { + name string + x, y string + want string + }{ + { + name: "no difference", + x: `goroutine profile: total 1 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261`, + y: `goroutine profile: total 1 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + want: "", + }, + { + name: "different counts", + x: `goroutine profile: total 1 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + y: `goroutine profile: total 2 +2 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + want: `- goroutine profile: total 1 ++ goroutine profile: total 2 + +- 1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 ++ 2 @ 0x47bc0e 0x458e57 0x847587 0x483da1 + # 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + }, + { + name: "new goroutine", + x: `goroutine profile: total 1 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + y: `goroutine profile: total 2 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 + +1 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +# 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +`, + want: `- goroutine profile: total 1 ++ goroutine profile: total 2 + ++ 1 @ 0x47bc0e 0x458e57 0x754927 0x483da1 ++ # 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +`, + }, + { + name: "removed goroutine", + x: `goroutine profile: total 2 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 + +1 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +# 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +`, + y: `goroutine profile: total 1 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + want: `- goroutine profile: total 2 ++ goroutine profile: total 1 + +- 1 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +- # 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +`, + }, + { + name: "removed many goroutine", + x: `goroutine profile: total 2 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 + +1 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +# 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +`, + y: `goroutine profile: total 0`, + want: `- goroutine profile: total 2 ++ goroutine profile: total 0 + +- 1 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +- # 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 + +- 1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +- # 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + }, + { + name: "invalid input x", + x: "invalid", + y: "goroutine profile: total 0\n", + want: "- invalid\n+ goroutine profile: total 0\n", + }, + { + name: "invalid input y", + x: "goroutine profile: total 0\n", + y: "invalid", + want: "- goroutine profile: total 0\n+ invalid\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := diffGoroutines( + parseGoroutines([]byte(tt.x)), + parseGoroutines([]byte(tt.y)), + ) + if got != tt.want { + t.Errorf("diffPprofGoroutines() diff:\ngot:\n%s\nwant:\n%s\ndiff (-want +got):\n%s", got, tt.want, cmp.Diff(tt.want, got)) + } + }) + } +} + +func TestParseGoroutines(t *testing.T) { + tests := []struct { + name string + in string + wantHeader string + wantCount int + }{ + { + name: "empty profile", + in: "goroutine profile: total 0\n", + wantHeader: "goroutine profile: total 0", + wantCount: 0, + }, + { + name: "single goroutine", + in: `goroutine profile: total 1 +1 @ 0x47bc0e 0x458e57 0x847587 0x483da1 +# 0x847586 database/sql.(*DB).connectionOpener+0x86 database/sql/sql.go:1261 +`, + wantHeader: "goroutine profile: total 1", + wantCount: 1, + }, + { + name: "multiple goroutines", + in: `goroutine profile: total 14 +7 @ 0x47bc0e 0x413705 0x4132b2 0x10fda4d 0x483da1 +# 0x10fda4c github.com/user/pkg.RoutineA+0x16c pkg/a.go:443 + +7 @ 0x47bc0e 0x458e57 0x754927 0x483da1 +# 0x754926 net/http.(*persistConn).writeLoop+0xe6 net/http/transport.go:2596 +`, + wantHeader: "goroutine profile: total 14", + wantCount: 2, + }, + { + name: "invalid format", + in: "invalid", + wantHeader: "invalid", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := parseGoroutines([]byte(tt.in)) + + if got := string(g.head); got != tt.wantHeader { + t.Errorf("parseGoroutines() header = %q, want %q", got, tt.wantHeader) + } + if got := len(g.goroutines); got != tt.wantCount { + t.Errorf("parseGoroutines() goroutine count = %d, want %d", got, tt.wantCount) + } + + // Verify that the sort field is correctly reversed + for _, g := range g.goroutines { + original := strings.Fields(string(g.header)) + sorted := strings.Fields(string(g.sort)) + if len(original) != len(sorted) { + t.Errorf("sort field has different number of words: got %d, want %d", len(sorted), len(original)) + continue + } + for i := 0; i < len(original); i++ { + if original[i] != sorted[len(sorted)-1-i] { + t.Errorf("sort field word mismatch at position %d: got %q, want %q", i, sorted[len(sorted)-1-i], original[i]) + } + } + } + }) + } +} diff --git a/tstest/tailmac/Swift/Common/Config.swift b/tstest/tailmac/Swift/Common/Config.swift index 01d5069b0049d..18b68ae9b9d14 100644 --- a/tstest/tailmac/Swift/Common/Config.swift +++ b/tstest/tailmac/Swift/Common/Config.swift @@ -14,6 +14,7 @@ class Config: Codable { var mac = "52:cc:cc:cc:cc:01" var ethermac = "52:cc:cc:cc:ce:01" var port: UInt32 = 51009 + var sharedDir: String? // The virtual machines ID. Also double as the directory name under which // we will store configuration, block device, etc. diff --git a/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift b/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift index 00f999a158c19..c0961c883fdbb 100644 --- a/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift +++ b/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift @@ -141,5 +141,18 @@ struct TailMacConfigHelper { func createKeyboardConfiguration() -> VZKeyboardConfiguration { return VZMacKeyboardConfiguration() } + + func createDirectoryShareConfiguration(tag: String) -> VZDirectorySharingDeviceConfiguration? { + guard let dir = config.sharedDir else { return nil } + + let sharedDir = VZSharedDirectory(url: URL(fileURLWithPath: dir), readOnly: false) + let share = VZSingleDirectoryShare(directory: sharedDir) + + // Create the VZVirtioFileSystemDeviceConfiguration and assign it a unique tag. + let sharingConfiguration = VZVirtioFileSystemDeviceConfiguration(tag: tag) + sharingConfiguration.share = share + + return sharingConfiguration + } } diff --git a/tstest/tailmac/Swift/Host/HostCli.swift b/tstest/tailmac/Swift/Host/HostCli.swift index 1318a09fa546e..c31478cc39d45 100644 --- a/tstest/tailmac/Swift/Host/HostCli.swift +++ b/tstest/tailmac/Swift/Host/HostCli.swift @@ -19,10 +19,12 @@ var config: Config = Config() extension HostCli { struct Run: ParsableCommand { @Option var id: String + @Option var share: String? mutating func run() { - print("Running vm with identifier \(id)") config = Config(id) + config.sharedDir = share + print("Running vm with identifier \(id) and sharedDir \(share ?? "")") _ = NSApplicationMain(CommandLine.argc, CommandLine.unsafeArgv) } } diff --git a/tstest/tailmac/Swift/Host/VMController.swift b/tstest/tailmac/Swift/Host/VMController.swift index 8774894c1157a..fe4a3828b18fe 100644 --- a/tstest/tailmac/Swift/Host/VMController.swift +++ b/tstest/tailmac/Swift/Host/VMController.swift @@ -95,6 +95,13 @@ class VMController: NSObject, VZVirtualMachineDelegate { virtualMachineConfiguration.keyboards = [helper.createKeyboardConfiguration()] virtualMachineConfiguration.socketDevices = [helper.createSocketDeviceConfiguration()] + if let dir = config.sharedDir, let shareConfig = helper.createDirectoryShareConfiguration(tag: "vmshare") { + print("Sharing \(dir) as vmshare. Use: mount_virtiofs vmshare in the guest to mount.") + virtualMachineConfiguration.directorySharingDevices = [shareConfig] + } else { + print("No shared directory created. \(config.sharedDir ?? "none") was requested.") + } + try! virtualMachineConfiguration.validate() try! virtualMachineConfiguration.validateSaveRestoreSupport() diff --git a/tstest/tailmac/Swift/TailMac/TailMac.swift b/tstest/tailmac/Swift/TailMac/TailMac.swift index 56f651696e12c..84aa5e498a008 100644 --- a/tstest/tailmac/Swift/TailMac/TailMac.swift +++ b/tstest/tailmac/Swift/TailMac/TailMac.swift @@ -95,12 +95,16 @@ extension Tailmac { extension Tailmac { struct Run: ParsableCommand { @Option(help: "The vm identifier") var id: String + @Option(help: "Optional share directory") var share: String? @Flag(help: "Tail the TailMac log output instead of returning immediatly") var tail mutating func run() { let process = Process() let stdOutPipe = Pipe() - let appPath = "./Host.app/Contents/MacOS/Host" + + let executablePath = CommandLine.arguments[0] + let executableDirectory = (executablePath as NSString).deletingLastPathComponent + let appPath = executableDirectory + "/Host.app/Contents/MacOS/Host" process.executableURL = URL( fileURLWithPath: appPath, @@ -109,10 +113,15 @@ extension Tailmac { ) if !FileManager.default.fileExists(atPath: appPath) { - fatalError("Could not find Host.app. This must be co-located with the tailmac utility") + fatalError("Could not find Host.app at \(appPath). This must be co-located with the tailmac utility") } - process.arguments = ["run", "--id", id] + var args = ["run", "--id", id] + if let share { + args.append("--share") + args.append(share) + } + process.arguments = args do { process.standardOutput = stdOutPipe @@ -121,26 +130,18 @@ extension Tailmac { fatalError("Unable to launch the vm process") } - // This doesn't print until we exit which is not ideal, but at least we - // get the output if tail != 0 { + // (jonathan)TODO: How do we get the process output in real time? + // The child process only seems to flush to stdout on completion let outHandle = stdOutPipe.fileHandleForReading - - let queue = OperationQueue() - NotificationCenter.default.addObserver( - forName: NSNotification.Name.NSFileHandleDataAvailable, - object: outHandle, queue: queue) - { - notification -> Void in - let data = outHandle.availableData + outHandle.readabilityHandler = { handle in + let data = handle.availableData if data.count > 0 { if let str = String(data: data, encoding: String.Encoding.utf8) { print(str) } } - outHandle.waitForDataInBackgroundAndNotify() } - outHandle.waitForDataInBackgroundAndNotify() process.waitUntilExit() } } diff --git a/tstime/tstime.go b/tstime/tstime.go index 1c006355f8726..6e5b7f9f47146 100644 --- a/tstime/tstime.go +++ b/tstime/tstime.go @@ -6,6 +6,7 @@ package tstime import ( "context" + "encoding" "strconv" "strings" "time" @@ -183,3 +184,40 @@ func (StdClock) AfterFunc(d time.Duration, f func()) TimerController { func (StdClock) Since(t time.Time) time.Duration { return time.Since(t) } + +// GoDuration is a [time.Duration] but JSON serializes with [time.Duration.String]. +// +// Note that this format is specific to Go and non-standard, +// but excels in being most humanly readable compared to alternatives. +// The wider industry still lacks consensus for the representation +// of a time duration in humanly-readable text. +// See https://go.dev/issue/71631 for more discussion. +// +// Regardless of how the industry evolves into the future, +// this type explicitly uses the Go format. +type GoDuration struct{ time.Duration } + +var ( + _ encoding.TextAppender = (*GoDuration)(nil) + _ encoding.TextMarshaler = (*GoDuration)(nil) + _ encoding.TextUnmarshaler = (*GoDuration)(nil) +) + +func (d GoDuration) AppendText(b []byte) ([]byte, error) { + // The String method is inlineable (see https://go.dev/cl/520602), + // so this may not allocate since the string does not escape. + return append(b, d.String()...), nil +} + +func (d GoDuration) MarshalText() ([]byte, error) { + return []byte(d.String()), nil +} + +func (d *GoDuration) UnmarshalText(b []byte) error { + d2, err := time.ParseDuration(string(b)) + if err != nil { + return err + } + d.Duration = d2 + return nil +} diff --git a/tstime/tstime_test.go b/tstime/tstime_test.go index 3ffeaf0fff1b8..556ad4e8bb1d0 100644 --- a/tstime/tstime_test.go +++ b/tstime/tstime_test.go @@ -4,8 +4,11 @@ package tstime import ( + "encoding/json" "testing" "time" + + "tailscale.com/util/must" ) func TestParseDuration(t *testing.T) { @@ -34,3 +37,17 @@ func TestParseDuration(t *testing.T) { } } } + +func TestGoDuration(t *testing.T) { + wantDur := GoDuration{time.Hour + time.Minute + time.Second + time.Millisecond + time.Microsecond + time.Nanosecond} + gotJSON := string(must.Get(json.Marshal(wantDur))) + wantJSON := `"1h1m1.001001001s"` + if gotJSON != wantJSON { + t.Errorf("json.Marshal(%v) = %s, want %s", wantDur, gotJSON, wantJSON) + } + var gotDur GoDuration + must.Do(json.Unmarshal([]byte(wantJSON), &gotDur)) + if gotDur != wantDur { + t.Errorf("json.Unmarshal(%s) = %v, want %v", wantJSON, gotDur, wantDur) + } +} diff --git a/tsweb/debug.go b/tsweb/debug.go index 6db3f25cf06d5..4c0fabaff4aea 100644 --- a/tsweb/debug.go +++ b/tsweb/debug.go @@ -9,12 +9,11 @@ import ( "html" "io" "net/http" - "net/http/pprof" "net/url" "os" "runtime" - "tailscale.com/tsweb/promvarz" + "tailscale.com/feature" "tailscale.com/tsweb/varz" "tailscale.com/version" ) @@ -34,8 +33,14 @@ type DebugHandler struct { kvs []func(io.Writer) // output one
  • ...
  • each, see KV() urls []string // one
  • ...
  • block with link each sections []func(io.Writer, *http.Request) // invoked in registration order prior to outputting + title string // title displayed on index page } +// PrometheusHandler is an optional hook to enable native Prometheus +// support in the debug handler. It is disabled by default. Import the +// tailscale.com/tsweb/promvarz package to enable this feature. +var PrometheusHandler feature.Hook[func(*DebugHandler)] + // Debugger returns the DebugHandler registered on mux at /debug/, // creating it if necessary. func Debugger(mux *http.ServeMux) *DebugHandler { @@ -44,25 +49,21 @@ func Debugger(mux *http.ServeMux) *DebugHandler { return d } ret := &DebugHandler{ - mux: mux, + mux: mux, + title: fmt.Sprintf("%s debug", version.CmdName()), } mux.Handle("/debug/", ret) ret.KVFunc("Uptime", func() any { return varz.Uptime() }) ret.KV("Version", version.Long()) ret.Handle("vars", "Metrics (Go)", expvar.Handler()) - ret.Handle("varz", "Metrics (Prometheus)", http.HandlerFunc(promvarz.Handler)) - ret.Handle("pprof/", "pprof (index)", http.HandlerFunc(pprof.Index)) - // the CPU profile handler is special because it responds - // streamily, unlike every other pprof handler. This means it's - // not made available through pprof.Index the way all the other - // pprof types are, you have to register the CPU profile handler - // separately. Use HandleSilent for that to not pollute the human - // debug list with a link that produces streaming line noise if - // you click it. - ret.HandleSilent("pprof/profile", http.HandlerFunc(pprof.Profile)) - ret.URL("/debug/pprof/goroutine?debug=1", "Goroutines (collapsed)") - ret.URL("/debug/pprof/goroutine?debug=2", "Goroutines (full)") + if PrometheusHandler.IsSet() { + PrometheusHandler.Get()(ret) + } else { + ret.Handle("varz", "Metrics (Prometheus)", http.HandlerFunc(varz.Handler)) + } + + addProfilingHandlers(ret) ret.Handle("gc", "force GC", http.HandlerFunc(gcHandler)) hostname, err := os.Hostname() if err == nil { @@ -85,7 +86,7 @@ func (d *DebugHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { AddBrowserHeaders(w) f := func(format string, args ...any) { fmt.Fprintf(w, format, args...) } - f("

    %s debug

      ", version.CmdName()) + f("

      %s

        ", html.EscapeString(d.title)) for _, kv := range d.kvs { kv(w) } @@ -103,14 +104,20 @@ func (d *DebugHandler) handle(slug string, handler http.Handler) string { return href } -// Handle registers handler at /debug/ and creates a descriptive -// entry in /debug/ for it. +// Handle registers handler at /debug/ and adds a link to it +// on /debug/ with the provided description. func (d *DebugHandler) Handle(slug, desc string, handler http.Handler) { href := d.handle(slug, handler) d.URL(href, desc) } -// HandleSilent registers handler at /debug/. It does not create +// Handle registers handler at /debug/ and adds a link to it +// on /debug/ with the provided description. +func (d *DebugHandler) HandleFunc(slug, desc string, handler http.HandlerFunc) { + d.Handle(slug, desc, handler) +} + +// HandleSilent registers handler at /debug/. It does not add // a descriptive entry in /debug/ for it. This should be used // sparingly, for things that need to be registered but would pollute // the list of debug links. @@ -118,6 +125,14 @@ func (d *DebugHandler) HandleSilent(slug string, handler http.Handler) { d.handle(slug, handler) } +// HandleSilent registers handler at /debug/. It does not add +// a descriptive entry in /debug/ for it. This should be used +// sparingly, for things that need to be registered but would pollute +// the list of debug links. +func (d *DebugHandler) HandleSilentFunc(slug string, handler http.HandlerFunc) { + d.HandleSilent(slug, handler) +} + // KV adds a key/value list item to /debug/. func (d *DebugHandler) KV(k string, v any) { val := html.EscapeString(fmt.Sprintf("%v", v)) @@ -149,6 +164,11 @@ func (d *DebugHandler) Section(f func(w io.Writer, r *http.Request)) { d.sections = append(d.sections, f) } +// Title sets the title at the top of the debug page. +func (d *DebugHandler) Title(title string) { + d.title = title +} + func gcHandler(w http.ResponseWriter, r *http.Request) { w.Write([]byte("running GC...\n")) if f, ok := w.(http.Flusher); ok { diff --git a/tsweb/pprof_default.go b/tsweb/pprof_default.go new file mode 100644 index 0000000000000..7d22a61619855 --- /dev/null +++ b/tsweb/pprof_default.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !js && !wasm + +package tsweb + +import ( + "net/http" + "net/http/pprof" +) + +func addProfilingHandlers(d *DebugHandler) { + // pprof.Index serves everything that runtime/pprof.Lookup finds: + // goroutine, threadcreate, heap, allocs, block, mutex + d.Handle("pprof/", "pprof (index)", http.HandlerFunc(pprof.Index)) + // But register the other ones from net/http/pprof directly: + d.HandleSilent("pprof/cmdline", http.HandlerFunc(pprof.Cmdline)) + d.HandleSilent("pprof/profile", http.HandlerFunc(pprof.Profile)) + d.HandleSilent("pprof/symbol", http.HandlerFunc(pprof.Symbol)) + d.HandleSilent("pprof/trace", http.HandlerFunc(pprof.Trace)) + d.URL("/debug/pprof/goroutine?debug=1", "Goroutines (collapsed)") + d.URL("/debug/pprof/goroutine?debug=2", "Goroutines (full)") +} diff --git a/tsweb/pprof_js.go b/tsweb/pprof_js.go new file mode 100644 index 0000000000000..1212b37e86f5a --- /dev/null +++ b/tsweb/pprof_js.go @@ -0,0 +1,10 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build js && wasm + +package tsweb + +func addProfilingHandlers(d *DebugHandler) { + // No pprof in js builds, pprof doesn't work and bloats the build. +} diff --git a/tsweb/promvarz/promvarz.go b/tsweb/promvarz/promvarz.go index d0e1e52baeadb..1d978c7677328 100644 --- a/tsweb/promvarz/promvarz.go +++ b/tsweb/promvarz/promvarz.go @@ -11,12 +11,21 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/common/expfmt" + "tailscale.com/tsweb" "tailscale.com/tsweb/varz" ) -// Handler returns Prometheus metrics exported by our expvar converter +func init() { + tsweb.PrometheusHandler.Set(registerVarz) +} + +func registerVarz(debug *tsweb.DebugHandler) { + debug.Handle("varz", "Metrics (Prometheus)", http.HandlerFunc(handler)) +} + +// handler returns Prometheus metrics exported by our expvar converter // and the official Prometheus client. -func Handler(w http.ResponseWriter, r *http.Request) { +func handler(w http.ResponseWriter, r *http.Request) { if err := gatherNativePrometheusMetrics(w); err != nil { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte(err.Error())) diff --git a/tsweb/promvarz/promvarz_test.go b/tsweb/promvarz/promvarz_test.go index a3f4e66f11a42..9f91b5d12380e 100644 --- a/tsweb/promvarz/promvarz_test.go +++ b/tsweb/promvarz/promvarz_test.go @@ -23,7 +23,7 @@ func TestHandler(t *testing.T) { testVar1.Set(42) testVar2.Set(4242) - svr := httptest.NewServer(http.HandlerFunc(Handler)) + svr := httptest.NewServer(http.HandlerFunc(handler)) defer svr.Close() want := ` diff --git a/tsweb/request_id.go b/tsweb/request_id.go index 8516b8f72161e..46e52385240ca 100644 --- a/tsweb/request_id.go +++ b/tsweb/request_id.go @@ -6,9 +6,10 @@ package tsweb import ( "context" "net/http" + "time" "tailscale.com/util/ctxkey" - "tailscale.com/util/fastuuid" + "tailscale.com/util/rands" ) // RequestID is an opaque identifier for a HTTP request, used to correlate @@ -41,10 +42,12 @@ const RequestIDHeader = "X-Tailscale-Request-Id" // GenerateRequestID generates a new request ID with the current format. func GenerateRequestID() RequestID { - // REQ-1 indicates the version of the RequestID pattern. It is - // currently arbitrary but allows for forward compatible - // transitions if needed. - return RequestID("REQ-1" + fastuuid.NewUUID().String()) + // Return a string of the form "REQ-<...>" + // Previously we returned "REQ-1". + // Now we return "REQ-2" version, where the "2" doubles as the year 2YYY + // in a leading date. + now := time.Now().UTC() + return RequestID("REQ-" + now.Format("20060102150405") + rands.HexString(16)) } // SetRequestID is an HTTP middleware that injects a RequestID in the diff --git a/tsweb/tsweb.go b/tsweb/tsweb.go index 9ddb3fad5d710..119fed2e61012 100644 --- a/tsweb/tsweb.go +++ b/tsweb/tsweb.go @@ -15,7 +15,6 @@ import ( "io" "net" "net/http" - _ "net/http/pprof" "net/netip" "net/url" "os" diff --git a/tsweb/tsweb_test.go b/tsweb/tsweb_test.go index 13840c01225e3..d4c9721e97215 100644 --- a/tsweb/tsweb_test.go +++ b/tsweb/tsweb_test.go @@ -1307,6 +1307,28 @@ func TestBucket(t *testing.T) { } } +func TestGenerateRequestID(t *testing.T) { + t0 := time.Now() + got := GenerateRequestID() + t.Logf("Got: %q", got) + if !strings.HasPrefix(string(got), "REQ-2") { + t.Errorf("expect REQ-2 prefix; got %q", got) + } + const wantLen = len("REQ-2024112022140896f8ead3d3f3be27") + if len(got) != wantLen { + t.Fatalf("len = %d; want %d", len(got), wantLen) + } + d := got[len("REQ-"):][:14] + timeBack, err := time.Parse("20060102150405", string(d)) + if err != nil { + t.Fatalf("parsing time back: %v", err) + } + elapsed := timeBack.Sub(t0) + if elapsed > 3*time.Second { // allow for slow github actions runners :) + t.Fatalf("time back was %v; want within 3s", elapsed) + } +} + func ExampleMiddlewareStack() { // setHeader returns a middleware that sets header k = vs. setHeader := func(k string, vs ...string) Middleware { diff --git a/tsweb/varz/varz.go b/tsweb/varz/varz.go index 561b2487710e3..c6d66fbe2beda 100644 --- a/tsweb/varz/varz.go +++ b/tsweb/varz/varz.go @@ -5,6 +5,7 @@ package varz import ( + "bufio" "cmp" "expvar" "fmt" @@ -13,20 +14,29 @@ import ( "reflect" "runtime" "sort" + "strconv" "strings" "sync" "time" "unicode" "unicode/utf8" + "golang.org/x/exp/constraints" "tailscale.com/metrics" + "tailscale.com/types/logger" "tailscale.com/version" ) +// StaticStringVar returns a new expvar.Var that always returns s. +func StaticStringVar(s string) expvar.Var { + var v any = s // box s into an interface just once + return expvar.Func(func() any { return v }) +} + func init() { expvar.Publish("process_start_unix_time", expvar.Func(func() any { return timeStart.Unix() })) - expvar.Publish("version", expvar.Func(func() any { return version.Long() })) - expvar.Publish("go_version", expvar.Func(func() any { return runtime.Version() })) + expvar.Publish("version", StaticStringVar(version.Long())) + expvar.Publish("go_version", StaticStringVar(runtime.Version())) expvar.Publish("counter_uptime_sec", expvar.Func(func() any { return int64(Uptime().Seconds()) })) expvar.Publish("gauge_goroutines", expvar.Func(func() any { return runtime.NumGoroutine() })) } @@ -310,21 +320,52 @@ type PrometheusMetricsReflectRooter interface { var expvarDo = expvar.Do // pulled out for tests -func writeMemstats(w io.Writer, ms *runtime.MemStats) { - out := func(name, typ string, v uint64, help string) { - if help != "" { - fmt.Fprintf(w, "# HELP memstats_%s %s\n", name, help) - } - fmt.Fprintf(w, "# TYPE memstats_%s %s\nmemstats_%s %v\n", name, typ, name, v) +func writeMemstat[V constraints.Integer | constraints.Float](bw *bufio.Writer, typ, name string, v V, help string) { + if help != "" { + bw.WriteString("# HELP memstats_") + bw.WriteString(name) + bw.WriteString(" ") + bw.WriteString(help) + bw.WriteByte('\n') } - g := func(name string, v uint64, help string) { out(name, "gauge", v, help) } - c := func(name string, v uint64, help string) { out(name, "counter", v, help) } - g("heap_alloc", ms.HeapAlloc, "current bytes of allocated heap objects (up/down smoothly)") - c("total_alloc", ms.TotalAlloc, "cumulative bytes allocated for heap objects") - g("sys", ms.Sys, "total bytes of memory obtained from the OS") - c("mallocs", ms.Mallocs, "cumulative count of heap objects allocated") - c("frees", ms.Frees, "cumulative count of heap objects freed") - c("num_gc", uint64(ms.NumGC), "number of completed GC cycles") + bw.WriteString("# TYPE memstats_") + bw.WriteString(name) + bw.WriteString(" ") + bw.WriteString(typ) + bw.WriteByte('\n') + bw.WriteString("memstats_") + bw.WriteString(name) + bw.WriteByte(' ') + rt := reflect.TypeOf(v) + switch { + case rt == reflect.TypeFor[int]() || + rt == reflect.TypeFor[uint]() || + rt == reflect.TypeFor[int8]() || + rt == reflect.TypeFor[uint8]() || + rt == reflect.TypeFor[int16]() || + rt == reflect.TypeFor[uint16]() || + rt == reflect.TypeFor[int32]() || + rt == reflect.TypeFor[uint32]() || + rt == reflect.TypeFor[int64]() || + rt == reflect.TypeFor[uint64]() || + rt == reflect.TypeFor[uintptr](): + bw.Write(strconv.AppendInt(bw.AvailableBuffer(), int64(v), 10)) + case rt == reflect.TypeFor[float32]() || rt == reflect.TypeFor[float64](): + bw.Write(strconv.AppendFloat(bw.AvailableBuffer(), float64(v), 'f', -1, 64)) + } + bw.WriteByte('\n') +} + +func writeMemstats(w io.Writer, ms *runtime.MemStats) { + fmt.Fprintf(w, "%v", logger.ArgWriter(func(bw *bufio.Writer) { + writeMemstat(bw, "gauge", "heap_alloc", ms.HeapAlloc, "current bytes of allocated heap objects (up/down smoothly)") + writeMemstat(bw, "counter", "total_alloc", ms.TotalAlloc, "cumulative bytes allocated for heap objects") + writeMemstat(bw, "gauge", "sys", ms.Sys, "total bytes of memory obtained from the OS") + writeMemstat(bw, "counter", "mallocs", ms.Mallocs, "cumulative count of heap objects allocated") + writeMemstat(bw, "counter", "frees", ms.Frees, "cumulative count of heap objects freed") + writeMemstat(bw, "counter", "num_gc", ms.NumGC, "number of completed GC cycles") + writeMemstat(bw, "gauge", "gc_cpu_fraction", ms.GCCPUFraction, "fraction of CPU time used by GC") + })) } // sortedStructField is metadata about a struct field used both for sorting once diff --git a/tsweb/varz/varz_test.go b/tsweb/varz/varz_test.go index 7e094b0e72608..f7a9d880199e2 100644 --- a/tsweb/varz/varz_test.go +++ b/tsweb/varz/varz_test.go @@ -4,14 +4,17 @@ package varz import ( + "bytes" "expvar" "net/http/httptest" "reflect" + "runtime" "strings" "testing" "tailscale.com/metrics" "tailscale.com/tstest" + "tailscale.com/util/racebuild" "tailscale.com/version" ) @@ -418,3 +421,75 @@ func TestVarzHandlerSorting(t *testing.T) { } } } + +func TestWriteMemestats(t *testing.T) { + memstats := &runtime.MemStats{ + Alloc: 1, + TotalAlloc: 2, + Sys: 3, + Lookups: 4, + Mallocs: 5, + Frees: 6, + HeapAlloc: 7, + HeapSys: 8, + HeapIdle: 9, + HeapInuse: 10, + HeapReleased: 11, + HeapObjects: 12, + StackInuse: 13, + StackSys: 14, + MSpanInuse: 15, + MSpanSys: 16, + MCacheInuse: 17, + MCacheSys: 18, + BuckHashSys: 19, + GCSys: 20, + OtherSys: 21, + NextGC: 22, + LastGC: 23, + PauseTotalNs: 24, + // PauseNs: [256]int64{}, + NumGC: 26, + NumForcedGC: 27, + GCCPUFraction: 0.28, + } + + var buf bytes.Buffer + writeMemstats(&buf, memstats) + lines := strings.Split(buf.String(), "\n") + + checkFor := func(name, typ, value string) { + var foundType, foundValue bool + for _, line := range lines { + if line == "memstats_"+name+" "+value { + foundValue = true + } + if line == "# TYPE memstats_"+name+" "+typ { + foundType = true + } + if foundValue && foundType { + return + } + } + t.Errorf("memstats_%s foundType=%v foundValue=%v", name, foundType, foundValue) + } + + t.Logf("memstats:\n %s", buf.String()) + + checkFor("heap_alloc", "gauge", "7") + checkFor("total_alloc", "counter", "2") + checkFor("sys", "gauge", "3") + checkFor("mallocs", "counter", "5") + checkFor("frees", "counter", "6") + checkFor("num_gc", "counter", "26") + checkFor("gc_cpu_fraction", "gauge", "0.28") + + if !racebuild.On { + if allocs := testing.AllocsPerRun(1000, func() { + buf.Reset() + writeMemstats(&buf, memstats) + }); allocs != 1 { + t.Errorf("allocs = %v; want max %v", allocs, 1) + } + } +} diff --git a/types/bools/bools.go b/types/bools/bools.go new file mode 100644 index 0000000000000..e64068746ed9e --- /dev/null +++ b/types/bools/bools.go @@ -0,0 +1,37 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package bools contains the [Int], [Compare], and [IfElse] functions. +package bools + +// Int returns 1 for true and 0 for false. +func Int(v bool) int { + if v { + return 1 + } else { + return 0 + } +} + +// Compare compares two boolean values as if false is ordered before true. +func Compare[T ~bool](x, y T) int { + switch { + case x == false && y == true: + return -1 + case x == true && y == false: + return +1 + default: + return 0 + } +} + +// IfElse is a ternary operator that returns trueVal if condExpr is true +// otherwise it returns falseVal. +// IfElse(c, a, b) is roughly equivalent to (c ? a : b) in languages like C. +func IfElse[T any](condExpr bool, trueVal T, falseVal T) T { + if condExpr { + return trueVal + } else { + return falseVal + } +} diff --git a/types/bools/compare_test.go b/types/bools/bools_test.go similarity index 56% rename from types/bools/compare_test.go rename to types/bools/bools_test.go index 280294621e719..67faf3bcc92d8 100644 --- a/types/bools/compare_test.go +++ b/types/bools/bools_test.go @@ -5,6 +5,15 @@ package bools import "testing" +func TestInt(t *testing.T) { + if got := Int(true); got != 1 { + t.Errorf("Int(true) = %v, want 1", got) + } + if got := Int(false); got != 0 { + t.Errorf("Int(false) = %v, want 0", got) + } +} + func TestCompare(t *testing.T) { if got := Compare(false, false); got != 0 { t.Errorf("Compare(false, false) = %v, want 0", got) @@ -19,3 +28,12 @@ func TestCompare(t *testing.T) { t.Errorf("Compare(true, true) = %v, want 0", got) } } + +func TestIfElse(t *testing.T) { + if got := IfElse(true, 0, 1); got != 0 { + t.Errorf("IfElse(true, 0, 1) = %v, want 0", got) + } + if got := IfElse(false, 0, 1); got != 1 { + t.Errorf("IfElse(false, 0, 1) = %v, want 1", got) + } +} diff --git a/types/bools/compare.go b/types/bools/compare.go deleted file mode 100644 index ac433b240755a..0000000000000 --- a/types/bools/compare.go +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package bools contains the bools.Compare function. -package bools - -// Compare compares two boolean values as if false is ordered before true. -func Compare[T ~bool](x, y T) int { - switch { - case x == false && y == true: - return -1 - case x == true && y == false: - return +1 - default: - return 0 - } -} diff --git a/types/dnstype/dnstype_view.go b/types/dnstype/dnstype_view.go index c0e2b28ffb9b4..c77ff9a406106 100644 --- a/types/dnstype/dnstype_view.go +++ b/types/dnstype/dnstype_view.go @@ -15,7 +15,7 @@ import ( //go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type=Resolver -// View returns a readonly view of Resolver. +// View returns a read-only view of Resolver. func (p *Resolver) View() ResolverView { return ResolverView{Đļ: p} } @@ -31,7 +31,7 @@ type ResolverView struct { Đļ *Resolver } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v ResolverView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with diff --git a/types/iox/io.go b/types/iox/io.go new file mode 100644 index 0000000000000..a5ca1be43f737 --- /dev/null +++ b/types/iox/io.go @@ -0,0 +1,23 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package iox provides types to implement [io] functionality. +package iox + +// TODO(https://go.dev/issue/21670): Deprecate or remove this functionality +// once the Go language supports implementing an 1-method interface directly +// using a function value of a matching signature. + +// ReaderFunc implements [io.Reader] using the underlying function value. +type ReaderFunc func([]byte) (int, error) + +func (f ReaderFunc) Read(b []byte) (int, error) { + return f(b) +} + +// WriterFunc implements [io.Writer] using the underlying function value. +type WriterFunc func([]byte) (int, error) + +func (f WriterFunc) Write(b []byte) (int, error) { + return f(b) +} diff --git a/types/iox/io_test.go b/types/iox/io_test.go new file mode 100644 index 0000000000000..9fba39605d28d --- /dev/null +++ b/types/iox/io_test.go @@ -0,0 +1,39 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package iox + +import ( + "bytes" + "io" + "testing" + "testing/iotest" + + "tailscale.com/util/must" +) + +func TestCopy(t *testing.T) { + const testdata = "the quick brown fox jumped over the lazy dog" + src := testdata + bb := new(bytes.Buffer) + if got := must.Get(io.Copy(bb, ReaderFunc(func(b []byte) (n int, err error) { + n = copy(b[:min(len(b), 7)], src) + src = src[n:] + if len(src) == 0 { + err = io.EOF + } + return n, err + }))); int(got) != len(testdata) { + t.Errorf("copy = %d, want %d", got, len(testdata)) + } + var dst []byte + if got := must.Get(io.Copy(WriterFunc(func(b []byte) (n int, err error) { + dst = append(dst, b...) + return len(b), nil + }), iotest.OneByteReader(bb))); int(got) != len(testdata) { + t.Errorf("copy = %d, want %d", got, len(testdata)) + } + if string(dst) != testdata { + t.Errorf("copy = %q, want %q", dst, testdata) + } +} diff --git a/types/jsonx/json.go b/types/jsonx/json.go new file mode 100644 index 0000000000000..3f01ea358df30 --- /dev/null +++ b/types/jsonx/json.go @@ -0,0 +1,171 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package jsonx contains helper types and functionality to use with +// [github.com/go-json-experiment/json], which is positioned to be +// merged into the Go standard library as [encoding/json/v2]. +// +// See https://go.dev/issues/71497 +package jsonx + +import ( + "errors" + "fmt" + "reflect" + + "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" +) + +var ( + errUnknownTypeName = errors.New("unknown type name") + errNonSingularValue = errors.New("dynamic value must only have exactly one member") +) + +// MakeInterfaceCoders constructs a pair of marshal and unmarshal functions +// to serialize a Go interface type T. A bijective mapping for the set +// of concrete types that implement T is provided, +// where the key is a stable type name to use in the JSON representation, +// while the value is any value of a concrete type that implements T. +// By convention, only the zero value of concrete types is passed. +// +// The JSON representation for a dynamic value is a JSON object +// with a single member, where the member name is the type name, +// and the value is the JSON representation for the Go value. +// For example, the JSON serialization for a concrete type named Foo +// would be {"Foo": ...}, where ... is the JSON representation +// of the concrete value of the Foo type. +// +// Example instantiation: +// +// // Interface is a union type implemented by [FooType] and [BarType]. +// type Interface interface { ... } +// +// var interfaceCoders = MakeInterfaceCoders(map[string]Interface{ +// "FooType": FooType{}, +// "BarType": (*BarType)(nil), +// }) +// +// The pair of Marshal and Unmarshal functions can be used with the [json] +// package with either type-specified or caller-specified serialization. +// The result of this constructor is usually stored into a global variable. +// +// Example usage with type-specified serialization: +// +// // InterfaceWrapper is a concrete type that wraps [Interface]. +// // It extends [Interface] to implement +// // [json.MarshalerTo] and [json.UnmarshalerFrom]. +// type InterfaceWrapper struct{ Interface } +// +// func (w InterfaceWrapper) MarshalJSONTo(enc *jsontext.Encoder) error { +// return interfaceCoders.Marshal(enc, &w.Interface) +// } +// +// func (w *InterfaceWrapper) UnmarshalJSONFrom(dec *jsontext.Decoder) error { +// return interfaceCoders.Unmarshal(dec, &w.Interface) +// } +// +// Example usage with caller-specified serialization: +// +// var opts json.Options = json.JoinOptions( +// json.WithMarshalers(json.MarshalToFunc(interfaceCoders.Marshal)), +// json.WithUnmarshalers(json.UnmarshalFromFunc(interfaceCoders.Unmarshal)), +// ) +// +// var v Interface +// ... := json.Marshal(v, opts) +// ... := json.Unmarshal(&v, opts) +// +// The function panics if T is not a named interface kind, +// or if valuesByName contains distinct entries with the same concrete type. +func MakeInterfaceCoders[T any](valuesByName map[string]T) (c struct { + Marshal func(*jsontext.Encoder, *T) error + Unmarshal func(*jsontext.Decoder, *T) error +}) { + // Verify that T is a named interface. + switch t := reflect.TypeFor[T](); { + case t.Kind() != reflect.Interface: + panic(fmt.Sprintf("%v must be an interface kind", t)) + case t.Name() == "": + panic(fmt.Sprintf("%v must be a named type", t)) + } + + // Construct a bijective mapping of names to types. + typesByName := make(map[string]reflect.Type) + namesByType := make(map[reflect.Type]string) + for name, value := range valuesByName { + t := reflect.TypeOf(value) + if t == nil { + panic(fmt.Sprintf("nil value for %s", name)) + } + if name2, ok := namesByType[t]; ok { + panic(fmt.Sprintf("type %v cannot have multiple names %s and %v", t, name, name2)) + } + typesByName[name] = t + namesByType[t] = name + } + + // Construct the marshal and unmarshal functions. + c.Marshal = func(enc *jsontext.Encoder, val *T) error { + t := reflect.TypeOf(*val) + if t == nil { + return enc.WriteToken(jsontext.Null) + } + name := namesByType[t] + if name == "" { + return fmt.Errorf("Go type %v: %w", t, errUnknownTypeName) + } + + if err := enc.WriteToken(jsontext.BeginObject); err != nil { + return err + } + if err := enc.WriteToken(jsontext.String(name)); err != nil { + return err + } + if err := json.MarshalEncode(enc, *val); err != nil { + return err + } + if err := enc.WriteToken(jsontext.EndObject); err != nil { + return err + } + return nil + } + c.Unmarshal = func(dec *jsontext.Decoder, val *T) error { + switch tok, err := dec.ReadToken(); { + case err != nil: + return err + case tok.Kind() == 'n': + var zero T + *val = zero // store nil interface value for JSON null + return nil + case tok.Kind() != '{': + return &json.SemanticError{JSONKind: tok.Kind(), GoType: reflect.TypeFor[T]()} + } + var v reflect.Value + switch tok, err := dec.ReadToken(); { + case err != nil: + return err + case tok.Kind() != '"': + return errNonSingularValue + default: + t := typesByName[tok.String()] + if t == nil { + return errUnknownTypeName + } + v = reflect.New(t) + } + if err := json.UnmarshalDecode(dec, v.Interface()); err != nil { + return err + } + *val = v.Elem().Interface().(T) + switch tok, err := dec.ReadToken(); { + case err != nil: + return err + case tok.Kind() != '}': + return errNonSingularValue + } + return nil + } + + return c +} diff --git a/types/jsonx/json_test.go b/types/jsonx/json_test.go new file mode 100644 index 0000000000000..0f2a646c40d6d --- /dev/null +++ b/types/jsonx/json_test.go @@ -0,0 +1,140 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package jsonx + +import ( + "errors" + "testing" + + "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" + "github.com/google/go-cmp/cmp" + "tailscale.com/types/ptr" +) + +type Interface interface { + implementsInterface() +} + +type Foo string + +func (Foo) implementsInterface() {} + +type Bar int + +func (Bar) implementsInterface() {} + +type Baz struct{ Fizz, Buzz string } + +func (*Baz) implementsInterface() {} + +var interfaceCoders = MakeInterfaceCoders(map[string]Interface{ + "Foo": Foo(""), + "Bar": (*Bar)(nil), + "Baz": (*Baz)(nil), +}) + +type InterfaceWrapper struct{ Interface } + +func (w InterfaceWrapper) MarshalJSONTo(enc *jsontext.Encoder) error { + return interfaceCoders.Marshal(enc, &w.Interface) +} + +func (w *InterfaceWrapper) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + return interfaceCoders.Unmarshal(dec, &w.Interface) +} + +func TestInterfaceCoders(t *testing.T) { + var opts json.Options = json.JoinOptions( + json.WithMarshalers(json.MarshalToFunc(interfaceCoders.Marshal)), + json.WithUnmarshalers(json.UnmarshalFromFunc(interfaceCoders.Unmarshal)), + ) + + errSkipMarshal := errors.New("skip marshal") + makeFiller := func() InterfaceWrapper { + return InterfaceWrapper{&Baz{"fizz", "buzz"}} + } + + for _, tt := range []struct { + label string + wantVal InterfaceWrapper + wantJSON string + wantMarshalError error + wantUnmarshalError error + }{{ + label: "Null", + wantVal: InterfaceWrapper{}, + wantJSON: `null`, + }, { + label: "Foo", + wantVal: InterfaceWrapper{Foo("hello")}, + wantJSON: `{"Foo":"hello"}`, + }, { + label: "BarPointer", + wantVal: InterfaceWrapper{ptr.To(Bar(5))}, + wantJSON: `{"Bar":5}`, + }, { + label: "BarValue", + wantVal: InterfaceWrapper{Bar(5)}, + // NOTE: We could handle BarValue just like BarPointer, + // but round-trip marshal/unmarshal would not be identical. + wantMarshalError: errUnknownTypeName, + }, { + label: "Baz", + wantVal: InterfaceWrapper{&Baz{"alpha", "omega"}}, + wantJSON: `{"Baz":{"Fizz":"alpha","Buzz":"omega"}}`, + }, { + label: "Unknown", + wantVal: makeFiller(), + wantJSON: `{"Unknown":[1,2,3]}`, + wantMarshalError: errSkipMarshal, + wantUnmarshalError: errUnknownTypeName, + }, { + label: "Empty", + wantVal: makeFiller(), + wantJSON: `{}`, + wantMarshalError: errSkipMarshal, + wantUnmarshalError: errNonSingularValue, + }, { + label: "Duplicate", + wantVal: InterfaceWrapper{Foo("hello")}, // first entry wins + wantJSON: `{"Foo":"hello","Bar":5}`, + wantMarshalError: errSkipMarshal, + wantUnmarshalError: errNonSingularValue, + }} { + t.Run(tt.label, func(t *testing.T) { + if tt.wantMarshalError != errSkipMarshal { + switch gotJSON, err := json.Marshal(&tt.wantVal); { + case !errors.Is(err, tt.wantMarshalError): + t.Fatalf("json.Marshal(%v) error = %v, want %v", tt.wantVal, err, tt.wantMarshalError) + case string(gotJSON) != tt.wantJSON: + t.Fatalf("json.Marshal(%v) = %s, want %s", tt.wantVal, gotJSON, tt.wantJSON) + } + switch gotJSON, err := json.Marshal(&tt.wantVal.Interface, opts); { + case !errors.Is(err, tt.wantMarshalError): + t.Fatalf("json.Marshal(%v) error = %v, want %v", tt.wantVal, err, tt.wantMarshalError) + case string(gotJSON) != tt.wantJSON: + t.Fatalf("json.Marshal(%v) = %s, want %s", tt.wantVal, gotJSON, tt.wantJSON) + } + } + + if tt.wantJSON != "" { + gotVal := makeFiller() + if err := json.Unmarshal([]byte(tt.wantJSON), &gotVal); !errors.Is(err, tt.wantUnmarshalError) { + t.Fatalf("json.Unmarshal(%v) error = %v, want %v", tt.wantJSON, err, tt.wantUnmarshalError) + } + if d := cmp.Diff(gotVal, tt.wantVal); d != "" { + t.Fatalf("json.Unmarshal(%v):\n%s", tt.wantJSON, d) + } + gotVal = makeFiller() + if err := json.Unmarshal([]byte(tt.wantJSON), &gotVal.Interface, opts); !errors.Is(err, tt.wantUnmarshalError) { + t.Fatalf("json.Unmarshal(%v) error = %v, want %v", tt.wantJSON, err, tt.wantUnmarshalError) + } + if d := cmp.Diff(gotVal, tt.wantVal); d != "" { + t.Fatalf("json.Unmarshal(%v):\n%s", tt.wantJSON, d) + } + } + }) + } +} diff --git a/types/lazy/deferred.go b/types/lazy/deferred.go index 964553cef6524..973082914c48c 100644 --- a/types/lazy/deferred.go +++ b/types/lazy/deferred.go @@ -22,7 +22,14 @@ type DeferredInit struct { // until the owner's [DeferredInit.Do] method is called // for the first time. // -// DeferredFuncs is safe for concurrent use. +// DeferredFuncs is safe for concurrent use. The execution +// order of functions deferred by different goroutines is +// unspecified and must not be relied upon. +// However, functions deferred by the same goroutine are +// executed in the same relative order they were deferred. +// Warning: this is the opposite of the behavior of Go's +// defer statement, which executes deferred functions in +// reverse order. type DeferredFuncs struct { m sync.Mutex funcs []func() error diff --git a/types/lazy/deferred_test.go b/types/lazy/deferred_test.go index 9de16c67a6067..98cacbfce7088 100644 --- a/types/lazy/deferred_test.go +++ b/types/lazy/deferred_test.go @@ -205,16 +205,38 @@ func TestDeferredErr(t *testing.T) { } } +// TestDeferAfterDo checks all of the following: +// - Deferring a function before [DeferredInit.Do] is called should always succeed. +// - All successfully deferred functions are executed by the time [DeferredInit.Do] completes. +// - No functions can be deferred after [DeferredInit.Do] is called, meaning: +// - [DeferredInit.Defer] should return false. +// - The deferred function should not be executed. +// +// This test is intentionally racy as it attempts to defer functions from multiple goroutines +// and then calls [DeferredInit.Do] without waiting for them to finish. Waiting would alter +// the observable behavior and render the test pointless. func TestDeferAfterDo(t *testing.T) { var di DeferredInit var deferred, called atomic.Int32 + // deferOnce defers a test function once and fails the test + // if [DeferredInit.Defer] returns true after [DeferredInit.Do] + // has already been called and any deferred functions have been executed. + // It's called concurrently by multiple goroutines. deferOnce := func() bool { + // canDefer is whether it's acceptable for Defer to return true. + // (but not it necessarily must return true) + // If its func has run before, it's definitely not okay for it to + // accept more Defer funcs. + canDefer := called.Load() == 0 ok := di.Defer(func() error { called.Add(1) return nil }) if ok { + if !canDefer { + t.Error("An init function was deferred after DeferredInit.Do() was already called") + } deferred.Add(1) } return ok @@ -242,19 +264,17 @@ func TestDeferAfterDo(t *testing.T) { if err := di.Do(); err != nil { t.Fatalf("DeferredInit.Do() failed: %v", err) } - wantDeferred, wantCalled := deferred.Load(), called.Load() + // The number of called funcs should remain unchanged after [DeferredInit.Do] returns. + wantCalled := called.Load() if deferOnce() { t.Error("An init func was deferred after DeferredInit.Do() returned") } // Wait for the goroutines deferring init funcs to exit. - // No funcs should be deferred after DeferredInit.Do() has returned, - // so the deferred and called counters should remain unchanged. + // No funcs should be called after DeferredInit.Do() has returned, + // and the number of called funcs should be equal to the number of deferred funcs. wg.Wait() - if gotDeferred := deferred.Load(); gotDeferred != wantDeferred { - t.Errorf("An init func was deferred after DeferredInit.Do() returned. Got %d, want %d", gotDeferred, wantDeferred) - } if gotCalled := called.Load(); gotCalled != wantCalled { t.Errorf("An init func was called after DeferredInit.Do() returned. Got %d, want %d", gotCalled, wantCalled) } diff --git a/types/lazy/lazy.go b/types/lazy/lazy.go index 43325512d9cb0..f5d7be4940a11 100644 --- a/types/lazy/lazy.go +++ b/types/lazy/lazy.go @@ -120,44 +120,9 @@ func (z *SyncValue[T]) PeekErr() (v T, err error, ok bool) { return zero, nil, false } -// SyncFunc wraps a function to make it lazy. -// -// The returned function calls fill the first time it's called, and returns -// fill's result on every subsequent call. -// -// The returned function is safe for concurrent use. -func SyncFunc[T any](fill func() T) func() T { - var ( - once sync.Once - v T - ) - return func() T { - once.Do(func() { v = fill() }) - return v - } -} - -// SyncFuncErr wraps a function to make it lazy. -// -// The returned function calls fill the first time it's called, and returns -// fill's results on every subsequent call. -// -// The returned function is safe for concurrent use. -func SyncFuncErr[T any](fill func() (T, error)) func() (T, error) { - var ( - once sync.Once - v T - err error - ) - return func() (T, error) { - once.Do(func() { v, err = fill() }) - return v, err - } -} - -// TB is a subset of testing.TB that we use to set up test helpers. +// testing_TB is a subset of testing.TB that we use to set up test helpers. // It's defined here to avoid pulling in the testing package. -type TB interface { +type testing_TB interface { Helper() Cleanup(func()) } @@ -167,7 +132,9 @@ type TB interface { // subtests complete. // It is not safe for concurrent use and must not be called concurrently with // any SyncValue methods, including another call to itself. -func (z *SyncValue[T]) SetForTest(tb TB, val T, err error) { +// +// The provided tb should be a [*testing.T] or [*testing.B]. +func (z *SyncValue[T]) SetForTest(tb testing_TB, val T, err error) { tb.Helper() oldErr, oldVal := z.err.Load(), z.v diff --git a/types/lazy/sync_test.go b/types/lazy/sync_test.go index 5578eee0cfed9..4d1278253955b 100644 --- a/types/lazy/sync_test.go +++ b/types/lazy/sync_test.go @@ -354,46 +354,3 @@ func TestSyncValueSetForTest(t *testing.T) { }) } } - -func TestSyncFunc(t *testing.T) { - f := SyncFunc(fortyTwo) - - n := int(testing.AllocsPerRun(1000, func() { - got := f() - if got != 42 { - t.Fatalf("got %v; want 42", got) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} - -func TestSyncFuncErr(t *testing.T) { - f := SyncFuncErr(func() (int, error) { - return 42, nil - }) - n := int(testing.AllocsPerRun(1000, func() { - got, err := f() - if got != 42 || err != nil { - t.Fatalf("got %v, %v; want 42, nil", got, err) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } - - wantErr := errors.New("test error") - f = SyncFuncErr(func() (int, error) { - return 0, wantErr - }) - n = int(testing.AllocsPerRun(1000, func() { - got, err := f() - if got != 0 || err != wantErr { - t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) - } - })) - if n != 0 { - t.Errorf("allocs = %v; want 0", n) - } -} diff --git a/types/logger/logger.go b/types/logger/logger.go index 11596b357cabb..6c4edf6336005 100644 --- a/types/logger/logger.go +++ b/types/logger/logger.go @@ -14,6 +14,7 @@ import ( "fmt" "io" "log" + "runtime" "strings" "sync" "time" @@ -23,6 +24,7 @@ import ( "go4.org/mem" "tailscale.com/envknob" "tailscale.com/util/ctxkey" + "tailscale.com/util/testenv" ) // Logf is the basic Tailscale logger type: a printf-like func. @@ -162,6 +164,10 @@ func RateLimitedFnWithClock(logf Logf, f time.Duration, burst int, maxCache int, if envknob.String("TS_DEBUG_LOG_RATE") == "all" { return logf } + if runtime.GOOS == "plan9" { + // To ease bring-up. + return logf + } var ( mu sync.Mutex msgLim = make(map[string]*limitData) // keyed by logf format @@ -317,6 +323,7 @@ func (fn ArgWriter) Format(f fmt.State, _ rune) { bw.Reset(f) fn(bw) bw.Flush() + bw.Reset(io.Discard) argBufioPool.Put(bw) } @@ -379,16 +386,10 @@ func (a asJSONResult) Format(s fmt.State, verb rune) { s.Write(v) } -// TBLogger is the testing.TB subset needed by TestLogger. -type TBLogger interface { - Helper() - Logf(format string, args ...any) -} - // TestLogger returns a logger that logs to tb.Logf // with a prefix to make it easier to distinguish spam // from explicit test failures. -func TestLogger(tb TBLogger) Logf { +func TestLogger(tb testenv.TB) Logf { return func(format string, args ...any) { tb.Helper() tb.Logf(" ... "+format, args...) diff --git a/types/mapx/ordered.go b/types/mapx/ordered.go new file mode 100644 index 0000000000000..1991f039d7726 --- /dev/null +++ b/types/mapx/ordered.go @@ -0,0 +1,111 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package mapx contains extra map types and functions. +package mapx + +import ( + "iter" + "slices" +) + +// OrderedMap is a map that maintains the order of its keys. +// +// It is meant for maps that only grow or that are small; +// is it not optimized for deleting keys. +// +// The zero value is ready to use. +// +// Locking-wise, it has the same rules as a regular Go map: +// concurrent reads are safe, but not writes. +type OrderedMap[K comparable, V any] struct { + // m is the underlying map. + m map[K]V + + // keys is the order of keys in the map. + keys []K +} + +func (m *OrderedMap[K, V]) init() { + if m.m == nil { + m.m = make(map[K]V) + } +} + +// Set sets the value for the given key in the map. +// +// If the key already exists, it updates the value and keeps the order. +func (m *OrderedMap[K, V]) Set(key K, value V) { + m.init() + len0 := len(m.keys) + m.m[key] = value + if len(m.m) > len0 { + // New key (not an update) + m.keys = append(m.keys, key) + } +} + +// Get returns the value for the given key in the map. +// If the key does not exist, it returns the zero value for V. +func (m *OrderedMap[K, V]) Get(key K) V { + return m.m[key] +} + +// GetOk returns the value for the given key in the map +// and whether it was present in the map. +func (m *OrderedMap[K, V]) GetOk(key K) (_ V, ok bool) { + v, ok := m.m[key] + return v, ok +} + +// Contains reports whether the map contains the given key. +func (m *OrderedMap[K, V]) Contains(key K) bool { + _, ok := m.m[key] + return ok +} + +// Delete removes the key from the map. +// +// The cost is O(n) in the number of keys in the map. +func (m *OrderedMap[K, V]) Delete(key K) { + len0 := len(m.m) + delete(m.m, key) + if len(m.m) == len0 { + // Wasn't present; no need to adjust keys. + return + } + was := m.keys + m.keys = m.keys[:0] + for _, k := range was { + if k != key { + m.keys = append(m.keys, k) + } + } +} + +// All yields all the keys and values, in the order they were inserted. +func (m *OrderedMap[K, V]) All() iter.Seq2[K, V] { + return func(yield func(K, V) bool) { + for _, k := range m.keys { + if !yield(k, m.m[k]) { + return + } + } + } +} + +// Keys yields the map keys, in the order they were inserted. +func (m *OrderedMap[K, V]) Keys() iter.Seq[K] { + return slices.Values(m.keys) +} + +// Values yields the map values, in the order they were inserted. +func (m *OrderedMap[K, V]) Values() iter.Seq[V] { + return func(yield func(V) bool) { + for _, k := range m.keys { + if !yield(m.m[k]) { + return + } + } + } +} diff --git a/types/mapx/ordered_test.go b/types/mapx/ordered_test.go new file mode 100644 index 0000000000000..7dcb7e40558c3 --- /dev/null +++ b/types/mapx/ordered_test.go @@ -0,0 +1,56 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package mapx + +import ( + "fmt" + "slices" + "testing" +) + +func TestOrderedMap(t *testing.T) { + // Test the OrderedMap type and its methods. + var m OrderedMap[string, int] + m.Set("d", 4) + m.Set("a", 1) + m.Set("b", 1) + m.Set("b", 2) + m.Set("c", 3) + m.Delete("d") + m.Delete("e") + + want := map[string]int{ + "a": 1, + "b": 2, + "c": 3, + "d": 0, + } + for k, v := range want { + if m.Get(k) != v { + t.Errorf("Get(%q) = %d, want %d", k, m.Get(k), v) + continue + } + got, ok := m.GetOk(k) + if got != v { + t.Errorf("GetOk(%q) = %d, want %d", k, got, v) + } + if ok != m.Contains(k) { + t.Errorf("GetOk and Contains don't agree for %q", k) + } + } + + if got, want := slices.Collect(m.Keys()), []string{"a", "b", "c"}; !slices.Equal(got, want) { + t.Errorf("Keys() = %q, want %q", got, want) + } + if got, want := slices.Collect(m.Values()), []int{1, 2, 3}; !slices.Equal(got, want) { + t.Errorf("Values() = %v, want %v", got, want) + } + var allGot []string + for k, v := range m.All() { + allGot = append(allGot, fmt.Sprintf("%s:%d", k, v)) + } + if got, want := allGot, []string{"a:1", "b:2", "c:3"}; !slices.Equal(got, want) { + t.Errorf("All() = %q, want %q", got, want) + } +} diff --git a/types/netmap/netmap.go b/types/netmap/netmap.go index 5e06229221e1f..c6250c49ce9c9 100644 --- a/types/netmap/netmap.go +++ b/types/netmap/netmap.go @@ -76,10 +76,9 @@ type NetworkMap struct { // If this is empty, then data-plane audit logging is disabled. DomainAuditLogID string - UserProfiles map[tailcfg.UserID]tailcfg.UserProfile - - // MaxKeyDuration describes the MaxKeyDuration setting for the tailnet. - MaxKeyDuration time.Duration + // UserProfiles contains the profile information of UserIDs referenced + // in SelfNode and Peers. + UserProfiles map[tailcfg.UserID]tailcfg.UserProfileView } // User returns nm.SelfNode.User if nm.SelfNode is non-nil, otherwise it returns @@ -101,6 +100,62 @@ func (nm *NetworkMap) GetAddresses() views.Slice[netip.Prefix] { return nm.SelfNode.Addresses() } +// GetVIPServiceIPMap returns a map of service names to the slice of +// VIP addresses that correspond to the service. The service names are +// with the prefix "svc:". +// +// TODO(tailscale/corp##25997): cache the result of decoding the capmap so that +// we don't have to decode it multiple times after each netmap update. +func (nm *NetworkMap) GetVIPServiceIPMap() tailcfg.ServiceIPMappings { + if nm == nil { + return nil + } + if !nm.SelfNode.Valid() { + return nil + } + + ipMaps, err := tailcfg.UnmarshalNodeCapViewJSON[tailcfg.ServiceIPMappings](nm.SelfNode.CapMap(), tailcfg.NodeAttrServiceHost) + if len(ipMaps) != 1 || err != nil { + return nil + } + + return ipMaps[0] +} + +// GetIPVIPServiceMap returns a map of VIP addresses to the service +// names that has the VIP address. The service names are with the +// prefix "svc:". +func (nm *NetworkMap) GetIPVIPServiceMap() IPServiceMappings { + var res IPServiceMappings + if nm == nil { + return res + } + + if !nm.SelfNode.Valid() { + return res + } + + serviceIPMap := nm.GetVIPServiceIPMap() + if serviceIPMap == nil { + return res + } + res = make(IPServiceMappings) + for svc, addrs := range serviceIPMap { + for _, addr := range addrs { + res[addr] = svc + } + } + return res +} + +// SelfNodeOrZero returns the self node, or a zero value if nm is nil. +func (nm *NetworkMap) SelfNodeOrZero() tailcfg.NodeView { + if nm == nil { + return tailcfg.NodeView{} + } + return nm.SelfNode +} + // AnyPeersAdvertiseRoutes reports whether any peer is advertising non-exit node routes. func (nm *NetworkMap) AnyPeersAdvertiseRoutes() bool { for _, p := range nm.Peers { @@ -197,21 +252,11 @@ func (nm *NetworkMap) DomainName() string { return nm.Domain } -// SelfCapabilities returns SelfNode.Capabilities if nm and nm.SelfNode are -// non-nil. This is a method so we can use it in envknob/logknob without a -// circular dependency. -func (nm *NetworkMap) SelfCapabilities() views.Slice[tailcfg.NodeCapability] { - var zero views.Slice[tailcfg.NodeCapability] - if nm == nil || !nm.SelfNode.Valid() { - return zero - } - out := nm.SelfNode.Capabilities().AsSlice() - nm.SelfNode.CapMap().Range(func(k tailcfg.NodeCapability, _ views.Slice[tailcfg.RawMessage]) (cont bool) { - out = append(out, k) - return true - }) - - return views.SliceOf(out) +// HasSelfCapability reports whether nm.SelfNode contains capability c. +// +// It exists to satisify an unused (as of 2025-01-04) interface in the logknob package. +func (nm *NetworkMap) HasSelfCapability(c tailcfg.NodeCapability) bool { + return nm.AllCaps.Contains(c) } func (nm *NetworkMap) String() string { @@ -251,7 +296,12 @@ func (nm *NetworkMap) PeerWithStableID(pid tailcfg.StableNodeID) (_ tailcfg.Node func (nm *NetworkMap) printConciseHeader(buf *strings.Builder) { fmt.Fprintf(buf, "netmap: self: %v auth=%v", nm.NodeKey.ShortString(), nm.GetMachineStatus()) - login := nm.UserProfiles[nm.User()].LoginName + + var login string + up, ok := nm.UserProfiles[nm.User()] + if ok { + login = up.LoginName() + } if login == "" { if nm.User().IsZero() { login = "?" @@ -279,15 +329,14 @@ func (a *NetworkMap) equalConciseHeader(b *NetworkMap) bool { // in nodeConciseEqual in sync. func printPeerConcise(buf *strings.Builder, p tailcfg.NodeView) { aip := make([]string, p.AllowedIPs().Len()) - for i := range aip { - a := p.AllowedIPs().At(i) - s := strings.TrimSuffix(fmt.Sprint(a), "/32") + for i, a := range p.AllowedIPs().All() { + s := strings.TrimSuffix(a.String(), "/32") aip[i] = s } - ep := make([]string, p.Endpoints().Len()) - for i := range ep { - e := p.Endpoints().At(i).String() + epStrs := make([]string, p.Endpoints().Len()) + for i, ep := range p.Endpoints().All() { + e := ep.String() // Align vertically on the ':' between IP and port colon := strings.IndexByte(e, ':') spaces := 0 @@ -295,14 +344,11 @@ func printPeerConcise(buf *strings.Builder, p tailcfg.NodeView) { spaces++ colon-- } - ep[i] = fmt.Sprintf("%21v", e+strings.Repeat(" ", spaces)) + epStrs[i] = fmt.Sprintf("%21v", e+strings.Repeat(" ", spaces)) } - derp := p.DERP() - const derpPrefix = "127.3.3.40:" - if strings.HasPrefix(derp, derpPrefix) { - derp = "D" + derp[len(derpPrefix):] - } + derp := fmt.Sprintf("D%d", p.HomeDERP()) + var discoShort string if !p.DiscoKey().IsZero() { discoShort = p.DiscoKey().ShortString() + " " @@ -316,13 +362,13 @@ func printPeerConcise(buf *strings.Builder, p tailcfg.NodeView) { discoShort, derp, strings.Join(aip, " "), - strings.Join(ep, " ")) + strings.Join(epStrs, " ")) } // nodeConciseEqual reports whether a and b are equal for the fields accessed by printPeerConcise. func nodeConciseEqual(a, b tailcfg.NodeView) bool { return a.Key() == b.Key() && - a.DERP() == b.DERP() && + a.HomeDERP() == b.HomeDERP() && a.DiscoKey() == b.DiscoKey() && views.SliceEqual(a.AllowedIPs(), b.AllowedIPs()) && views.SliceEqual(a.Endpoints(), b.Endpoints()) @@ -391,3 +437,19 @@ const ( _ WGConfigFlags = 1 << iota AllowSubnetRoutes ) + +// IPServiceMappings maps IP addresses to service names. This is the inverse of +// [tailcfg.ServiceIPMappings], and is used to inform track which service a VIP +// is associated with. This is set to b.ipVIPServiceMap every time the netmap is +// updated. This is used to reduce the cost for looking up the service name for +// the dst IP address in the netStack packet processing workflow. +// +// This is of the form: +// +// { +// "100.65.32.1": "svc:samba", +// "fd7a:115c:a1e0::1234": "svc:samba", +// "100.102.42.3": "svc:web", +// "fd7a:115c:a1e0::abcd": "svc:web", +// } +type IPServiceMappings map[netip.Addr]tailcfg.ServiceName diff --git a/types/netmap/netmap_test.go b/types/netmap/netmap_test.go index e7e2d19575c44..40f504741bfea 100644 --- a/types/netmap/netmap_test.go +++ b/types/netmap/netmap_test.go @@ -63,12 +63,12 @@ func TestNetworkMapConcise(t *testing.T) { Peers: nodeViews([]*tailcfg.Node{ { Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, { Key: testNodeKey(3), - DERP: "127.3.3.40:4", + HomeDERP: 4, Endpoints: eps("10.2.0.100:12", "10.1.0.100:12345"), }, }), @@ -102,7 +102,7 @@ func TestConciseDiffFrom(t *testing.T) { Peers: nodeViews([]*tailcfg.Node{ { Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, }), @@ -112,7 +112,7 @@ func TestConciseDiffFrom(t *testing.T) { Peers: nodeViews([]*tailcfg.Node{ { Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, }), @@ -126,7 +126,7 @@ func TestConciseDiffFrom(t *testing.T) { Peers: nodeViews([]*tailcfg.Node{ { Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, }), @@ -136,7 +136,7 @@ func TestConciseDiffFrom(t *testing.T) { Peers: nodeViews([]*tailcfg.Node{ { Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, }), @@ -151,7 +151,7 @@ func TestConciseDiffFrom(t *testing.T) { { ID: 2, Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, }), @@ -162,19 +162,19 @@ func TestConciseDiffFrom(t *testing.T) { { ID: 1, Key: testNodeKey(1), - DERP: "127.3.3.40:1", + HomeDERP: 1, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, { ID: 2, Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, { ID: 3, Key: testNodeKey(3), - DERP: "127.3.3.40:3", + HomeDERP: 3, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, }), @@ -189,19 +189,19 @@ func TestConciseDiffFrom(t *testing.T) { { ID: 1, Key: testNodeKey(1), - DERP: "127.3.3.40:1", + HomeDERP: 1, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, { ID: 2, Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, { ID: 3, Key: testNodeKey(3), - DERP: "127.3.3.40:3", + HomeDERP: 3, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, }), @@ -212,7 +212,7 @@ func TestConciseDiffFrom(t *testing.T) { { ID: 2, Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "192.168.0.100:12354"), }, }), @@ -227,7 +227,7 @@ func TestConciseDiffFrom(t *testing.T) { { ID: 2, Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "1.1.1.1:1"), }, }), @@ -238,7 +238,7 @@ func TestConciseDiffFrom(t *testing.T) { { ID: 2, Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:12", "1.1.1.1:2"), }, }), @@ -253,7 +253,7 @@ func TestConciseDiffFrom(t *testing.T) { { ID: 2, Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:41641", "1.1.1.1:41641"), DiscoKey: testDiscoKey("f00f00f00f"), AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 102, 103, 104), 32)}, @@ -266,7 +266,7 @@ func TestConciseDiffFrom(t *testing.T) { { ID: 2, Key: testNodeKey(2), - DERP: "127.3.3.40:2", + HomeDERP: 2, Endpoints: eps("192.168.0.100:41641", "1.1.1.1:41641"), DiscoKey: testDiscoKey("ba4ba4ba4b"), AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 102, 103, 104), 32)}, diff --git a/types/netmap/nodemut.go b/types/netmap/nodemut.go index 46fbaefc640e0..e31c731becbf1 100644 --- a/types/netmap/nodemut.go +++ b/types/netmap/nodemut.go @@ -5,7 +5,6 @@ package netmap import ( "cmp" - "fmt" "net/netip" "reflect" "slices" @@ -35,7 +34,7 @@ type NodeMutationDERPHome struct { } func (m NodeMutationDERPHome) Apply(n *tailcfg.Node) { - n.DERP = fmt.Sprintf("127.3.3.40:%v", m.DERPRegion) + n.HomeDERP = m.DERPRegion } // NodeMutation is a NodeMutation that says a node's endpoints have changed. @@ -177,6 +176,5 @@ func mapResponseContainsNonPatchFields(res *tailcfg.MapResponse) bool { // function is called, so it should never be set anyway. But for // completedness, and for tests, check it too: res.PeersChanged != nil || - res.DefaultAutoUpdate != "" || - res.MaxKeyDuration > 0 + res.DefaultAutoUpdate != "" } diff --git a/types/opt/value.go b/types/opt/value.go index 54fab7a538270..c71c53e511aca 100644 --- a/types/opt/value.go +++ b/types/opt/value.go @@ -36,7 +36,7 @@ func ValueOf[T any](v T) Value[T] { } // String implements [fmt.Stringer]. -func (o *Value[T]) String() string { +func (o Value[T]) String() string { if !o.set { return fmt.Sprintf("(empty[%T])", o.value) } @@ -100,31 +100,31 @@ func (o Value[T]) Equal(v Value[T]) bool { return false } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (o Value[T]) MarshalJSONV2(enc *jsontext.Encoder, opts jsonv2.Options) error { +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (o Value[T]) MarshalJSONTo(enc *jsontext.Encoder) error { if !o.set { return enc.WriteToken(jsontext.Null) } - return jsonv2.MarshalEncode(enc, &o.value, opts) + return jsonv2.MarshalEncode(enc, &o.value) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (o *Value[T]) UnmarshalJSONV2(dec *jsontext.Decoder, opts jsonv2.Options) error { +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (o *Value[T]) UnmarshalJSONFrom(dec *jsontext.Decoder) error { if dec.PeekKind() == 'n' { *o = Value[T]{} _, err := dec.ReadToken() // read null return err } o.set = true - return jsonv2.UnmarshalDecode(dec, &o.value, opts) + return jsonv2.UnmarshalDecode(dec, &o.value) } // MarshalJSON implements [json.Marshaler]. func (o Value[T]) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(o) // uses MarshalJSONV2 + return jsonv2.Marshal(o) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (o *Value[T]) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, o) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, o) // uses UnmarshalJSONFrom } diff --git a/types/opt/value_test.go b/types/opt/value_test.go index 93d935e27581f..890f9a5795cb3 100644 --- a/types/opt/value_test.go +++ b/types/opt/value_test.go @@ -9,6 +9,13 @@ import ( "testing" jsonv2 "github.com/go-json-experiment/json" + "tailscale.com/types/bools" + "tailscale.com/util/must" +) + +var ( + _ jsonv2.MarshalerTo = (*Value[bool])(nil) + _ jsonv2.UnmarshalerFrom = (*Value[bool])(nil) ) type testStruct struct { @@ -87,7 +94,14 @@ func TestValue(t *testing.T) { False: ValueOf(false), ExplicitUnset: Value[bool]{}, }, - want: `{"True":true,"False":false,"Unset":null,"ExplicitUnset":null}`, + want: bools.IfElse( + // Detect whether v1 "encoding/json" supports `omitzero` or not. + // TODO(Go1.24): Remove this after `omitzero` is supported. + string(must.Get(json.Marshal(struct { + X int `json:",omitzero"` + }{}))) == `{}`, + `{"True":true,"False":false}`, // omitzero supported + `{"True":true,"False":false,"Unset":null,"ExplicitUnset":null}`), // omitzero not supported wantBack: struct { True Value[bool] `json:",omitzero"` False Value[bool] `json:",omitzero"` diff --git a/types/persist/persist.go b/types/persist/persist.go index 8b555abd42c1e..d888a6afb6af5 100644 --- a/types/persist/persist.go +++ b/types/persist/persist.go @@ -21,17 +21,6 @@ import ( type Persist struct { _ structs.Incomparable - // LegacyFrontendPrivateMachineKey is here temporarily - // (starting 2020-09-28) during migration of Windows users' - // machine keys from frontend storage to the backend. On the - // first LocalBackend.Start call, the backend will initialize - // the real (backend-owned) machine key from the frontend's - // provided value (if non-zero), picking a new random one if - // needed. This field should be considered read-only from GUI - // frontends. The real value should not be written back in - // this field, lest the frontend persist it to disk. - LegacyFrontendPrivateMachineKey key.MachinePrivate `json:"PrivateMachineKey"` - PrivateNodeKey key.NodePrivate OldPrivateNodeKey key.NodePrivate // needed to request key rotation UserProfile tailcfg.UserProfile @@ -95,8 +84,7 @@ func (p *Persist) Equals(p2 *Persist) bool { return false } - return p.LegacyFrontendPrivateMachineKey.Equal(p2.LegacyFrontendPrivateMachineKey) && - p.PrivateNodeKey.Equal(p2.PrivateNodeKey) && + return p.PrivateNodeKey.Equal(p2.PrivateNodeKey) && p.OldPrivateNodeKey.Equal(p2.OldPrivateNodeKey) && p.UserProfile.Equal(&p2.UserProfile) && p.NetworkLockKey.Equal(p2.NetworkLockKey) && @@ -106,18 +94,14 @@ func (p *Persist) Equals(p2 *Persist) bool { func (p *Persist) Pretty() string { var ( - mk key.MachinePublic ok, nk key.NodePublic ) - if !p.LegacyFrontendPrivateMachineKey.IsZero() { - mk = p.LegacyFrontendPrivateMachineKey.Public() - } if !p.OldPrivateNodeKey.IsZero() { ok = p.OldPrivateNodeKey.Public() } if !p.PrivateNodeKey.IsZero() { nk = p.PublicNodeKey() } - return fmt.Sprintf("Persist{lm=%v, o=%v, n=%v u=%#v}", - mk.ShortString(), ok.ShortString(), nk.ShortString(), p.UserProfile.LoginName) + return fmt.Sprintf("Persist{o=%v, n=%v u=%#v}", + ok.ShortString(), nk.ShortString(), p.UserProfile.LoginName) } diff --git a/types/persist/persist_clone.go b/types/persist/persist_clone.go index 95dd65ac18e67..680419ff2f30b 100644 --- a/types/persist/persist_clone.go +++ b/types/persist/persist_clone.go @@ -25,12 +25,11 @@ func (src *Persist) Clone() *Persist { // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _PersistCloneNeedsRegeneration = Persist(struct { - _ structs.Incomparable - LegacyFrontendPrivateMachineKey key.MachinePrivate - PrivateNodeKey key.NodePrivate - OldPrivateNodeKey key.NodePrivate - UserProfile tailcfg.UserProfile - NetworkLockKey key.NLPrivate - NodeID tailcfg.StableNodeID - DisallowedTKAStateIDs []string + _ structs.Incomparable + PrivateNodeKey key.NodePrivate + OldPrivateNodeKey key.NodePrivate + UserProfile tailcfg.UserProfile + NetworkLockKey key.NLPrivate + NodeID tailcfg.StableNodeID + DisallowedTKAStateIDs []string }{}) diff --git a/types/persist/persist_test.go b/types/persist/persist_test.go index 6b159573d4302..dbf2a6d8c7662 100644 --- a/types/persist/persist_test.go +++ b/types/persist/persist_test.go @@ -21,13 +21,12 @@ func fieldsOf(t reflect.Type) (fields []string) { } func TestPersistEqual(t *testing.T) { - persistHandles := []string{"LegacyFrontendPrivateMachineKey", "PrivateNodeKey", "OldPrivateNodeKey", "UserProfile", "NetworkLockKey", "NodeID", "DisallowedTKAStateIDs"} + persistHandles := []string{"PrivateNodeKey", "OldPrivateNodeKey", "UserProfile", "NetworkLockKey", "NodeID", "DisallowedTKAStateIDs"} if have := fieldsOf(reflect.TypeFor[Persist]()); !reflect.DeepEqual(have, persistHandles) { t.Errorf("Persist.Equal check might be out of sync\nfields: %q\nhandled: %q\n", have, persistHandles) } - m1 := key.NewMachine() k1 := key.NewNode() nl1 := key.NewNLPrivate() tests := []struct { @@ -39,17 +38,6 @@ func TestPersistEqual(t *testing.T) { {&Persist{}, nil, false}, {&Persist{}, &Persist{}, true}, - { - &Persist{LegacyFrontendPrivateMachineKey: m1}, - &Persist{LegacyFrontendPrivateMachineKey: key.NewMachine()}, - false, - }, - { - &Persist{LegacyFrontendPrivateMachineKey: m1}, - &Persist{LegacyFrontendPrivateMachineKey: m1}, - true, - }, - { &Persist{PrivateNodeKey: k1}, &Persist{PrivateNodeKey: key.NewNode()}, diff --git a/types/persist/persist_view.go b/types/persist/persist_view.go index 1d479b3bf10e7..55eb40c51ac47 100644 --- a/types/persist/persist_view.go +++ b/types/persist/persist_view.go @@ -17,7 +17,7 @@ import ( //go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=Persist -// View returns a readonly view of Persist. +// View returns a read-only view of Persist. func (p *Persist) View() PersistView { return PersistView{Đļ: p} } @@ -33,7 +33,7 @@ type PersistView struct { Đļ *Persist } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v PersistView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -62,9 +62,6 @@ func (v *PersistView) UnmarshalJSON(b []byte) error { return nil } -func (v PersistView) LegacyFrontendPrivateMachineKey() key.MachinePrivate { - return v.Đļ.LegacyFrontendPrivateMachineKey -} func (v PersistView) PrivateNodeKey() key.NodePrivate { return v.Đļ.PrivateNodeKey } func (v PersistView) OldPrivateNodeKey() key.NodePrivate { return v.Đļ.OldPrivateNodeKey } func (v PersistView) UserProfile() tailcfg.UserProfile { return v.Đļ.UserProfile } @@ -76,12 +73,11 @@ func (v PersistView) DisallowedTKAStateIDs() views.Slice[string] { // A compilation failure here means this code must be regenerated, with the command at the top of this file. var _PersistViewNeedsRegeneration = Persist(struct { - _ structs.Incomparable - LegacyFrontendPrivateMachineKey key.MachinePrivate - PrivateNodeKey key.NodePrivate - OldPrivateNodeKey key.NodePrivate - UserProfile tailcfg.UserProfile - NetworkLockKey key.NLPrivate - NodeID tailcfg.StableNodeID - DisallowedTKAStateIDs []string + _ structs.Incomparable + PrivateNodeKey key.NodePrivate + OldPrivateNodeKey key.NodePrivate + UserProfile tailcfg.UserProfile + NetworkLockKey key.NLPrivate + NodeID tailcfg.StableNodeID + DisallowedTKAStateIDs []string }{}) diff --git a/types/prefs/item.go b/types/prefs/item.go index 1032041471a75..717a0c76cf291 100644 --- a/types/prefs/item.go +++ b/types/prefs/item.go @@ -152,15 +152,15 @@ func (iv ItemView[T, V]) Equal(iv2 ItemView[T, V]) bool { return iv.Đļ.Equal(*iv2.Đļ) } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (iv ItemView[T, V]) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { - return iv.Đļ.MarshalJSONV2(out, opts) +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (iv ItemView[T, V]) MarshalJSONTo(out *jsontext.Encoder) error { + return iv.Đļ.MarshalJSONTo(out) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (iv *ItemView[T, V]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (iv *ItemView[T, V]) UnmarshalJSONFrom(in *jsontext.Decoder) error { var x Item[T] - if err := x.UnmarshalJSONV2(in, opts); err != nil { + if err := x.UnmarshalJSONFrom(in); err != nil { return err } iv.Đļ = &x @@ -169,10 +169,10 @@ func (iv *ItemView[T, V]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Opti // MarshalJSON implements [json.Marshaler]. func (iv ItemView[T, V]) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(iv) // uses MarshalJSONV2 + return jsonv2.Marshal(iv) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (iv *ItemView[T, V]) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, iv) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, iv) // uses UnmarshalJSONFrom } diff --git a/types/prefs/list.go b/types/prefs/list.go index 9830e79de86cb..7db473887d195 100644 --- a/types/prefs/list.go +++ b/types/prefs/list.go @@ -157,15 +157,20 @@ func (lv ListView[T]) Equal(lv2 ListView[T]) bool { return lv.Đļ.Equal(*lv2.Đļ) } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (lv ListView[T]) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { - return lv.Đļ.MarshalJSONV2(out, opts) +var ( + _ jsonv2.MarshalerTo = (*ListView[bool])(nil) + _ jsonv2.UnmarshalerFrom = (*ListView[bool])(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (lv ListView[T]) MarshalJSONTo(out *jsontext.Encoder) error { + return lv.Đļ.MarshalJSONTo(out) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (lv *ListView[T]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (lv *ListView[T]) UnmarshalJSONFrom(in *jsontext.Decoder) error { var x List[T] - if err := x.UnmarshalJSONV2(in, opts); err != nil { + if err := x.UnmarshalJSONFrom(in); err != nil { return err } lv.Đļ = &x @@ -174,10 +179,10 @@ func (lv *ListView[T]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options // MarshalJSON implements [json.Marshaler]. func (lv ListView[T]) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(lv) // uses MarshalJSONV2 + return jsonv2.Marshal(lv) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (lv *ListView[T]) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, lv) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, lv) // uses UnmarshalJSONFrom } diff --git a/types/prefs/map.go b/types/prefs/map.go index 2bd32bfbdec75..4b64690ed1351 100644 --- a/types/prefs/map.go +++ b/types/prefs/map.go @@ -133,15 +133,15 @@ func (mv MapView[K, V]) Equal(mv2 MapView[K, V]) bool { return mv.Đļ.Equal(*mv2.Đļ) } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (mv MapView[K, V]) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { - return mv.Đļ.MarshalJSONV2(out, opts) +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (mv MapView[K, V]) MarshalJSONTo(out *jsontext.Encoder) error { + return mv.Đļ.MarshalJSONTo(out) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (mv *MapView[K, V]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (mv *MapView[K, V]) UnmarshalJSONFrom(in *jsontext.Decoder) error { var x Map[K, V] - if err := x.UnmarshalJSONV2(in, opts); err != nil { + if err := x.UnmarshalJSONFrom(in); err != nil { return err } mv.Đļ = &x @@ -150,10 +150,10 @@ func (mv *MapView[K, V]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Optio // MarshalJSON implements [json.Marshaler]. func (mv MapView[K, V]) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(mv) // uses MarshalJSONV2 + return jsonv2.Marshal(mv) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (mv *MapView[K, V]) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, mv) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, mv) // uses UnmarshalJSONFrom } diff --git a/types/prefs/prefs.go b/types/prefs/prefs.go index 3bbd237fe5efe..a6caf12838b79 100644 --- a/types/prefs/prefs.go +++ b/types/prefs/prefs.go @@ -29,8 +29,8 @@ import ( var ( // ErrManaged is the error returned when attempting to modify a managed preference. ErrManaged = errors.New("cannot modify a managed preference") - // ErrReadOnly is the error returned when attempting to modify a readonly preference. - ErrReadOnly = errors.New("cannot modify a readonly preference") + // ErrReadOnly is the error returned when attempting to modify a read-only preference. + ErrReadOnly = errors.New("cannot modify a read-only preference") ) // metadata holds type-agnostic preference metadata. @@ -158,22 +158,27 @@ func (p *preference[T]) SetReadOnly(readonly bool) { p.s.Metadata.ReadOnly = readonly } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (p preference[T]) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { - return jsonv2.MarshalEncode(out, &p.s, opts) +var ( + _ jsonv2.MarshalerTo = (*preference[struct{}])(nil) + _ jsonv2.UnmarshalerFrom = (*preference[struct{}])(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (p preference[T]) MarshalJSONTo(out *jsontext.Encoder) error { + return jsonv2.MarshalEncode(out, &p.s) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (p *preference[T]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { - return jsonv2.UnmarshalDecode(in, &p.s, opts) +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (p *preference[T]) UnmarshalJSONFrom(in *jsontext.Decoder) error { + return jsonv2.UnmarshalDecode(in, &p.s) } // MarshalJSON implements [json.Marshaler]. func (p preference[T]) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(p) // uses MarshalJSONV2 + return jsonv2.Marshal(p) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (p *preference[T]) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, p) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, p) // uses UnmarshalJSONFrom } diff --git a/types/prefs/prefs_example/prefs_example_view.go b/types/prefs/prefs_example/prefs_example_view.go index 0256bd7e6d25b..9aaac6e9c3ed6 100644 --- a/types/prefs/prefs_example/prefs_example_view.go +++ b/types/prefs/prefs_example/prefs_example_view.go @@ -20,7 +20,7 @@ import ( //go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=Prefs,AutoUpdatePrefs,AppConnectorPrefs -// View returns a readonly view of Prefs. +// View returns a read-only view of Prefs. func (p *Prefs) View() PrefsView { return PrefsView{Đļ: p} } @@ -36,7 +36,7 @@ type PrefsView struct { Đļ *Prefs } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v PrefsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -132,7 +132,7 @@ var _PrefsViewNeedsRegeneration = Prefs(struct { Persist *persist.Persist }{}) -// View returns a readonly view of AutoUpdatePrefs. +// View returns a read-only view of AutoUpdatePrefs. func (p *AutoUpdatePrefs) View() AutoUpdatePrefsView { return AutoUpdatePrefsView{Đļ: p} } @@ -148,7 +148,7 @@ type AutoUpdatePrefsView struct { Đļ *AutoUpdatePrefs } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v AutoUpdatePrefsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -186,7 +186,7 @@ var _AutoUpdatePrefsViewNeedsRegeneration = AutoUpdatePrefs(struct { Apply prefs.Item[opt.Bool] }{}) -// View returns a readonly view of AppConnectorPrefs. +// View returns a read-only view of AppConnectorPrefs. func (p *AppConnectorPrefs) View() AppConnectorPrefsView { return AppConnectorPrefsView{Đļ: p} } @@ -202,7 +202,7 @@ type AppConnectorPrefsView struct { Đļ *AppConnectorPrefs } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v AppConnectorPrefsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with diff --git a/types/prefs/prefs_example/prefs_types.go b/types/prefs/prefs_example/prefs_types.go index 49f0d8c3c4b57..c35f1f62fde3d 100644 --- a/types/prefs/prefs_example/prefs_types.go +++ b/types/prefs/prefs_example/prefs_types.go @@ -48,10 +48,10 @@ import ( // the `omitzero` JSON tag option. This option is not supported by the // [encoding/json] package as of 2024-08-21; see golang/go#45669. // It is recommended that a prefs type implements both -// [jsonv2.MarshalerV2]/[jsonv2.UnmarshalerV2] and [json.Marshaler]/[json.Unmarshaler] +// [jsonv2.MarshalerTo]/[jsonv2.UnmarshalerFrom] and [json.Marshaler]/[json.Unmarshaler] // to ensure consistent and more performant marshaling, regardless of the JSON package // used at the call sites; the standard marshalers can be implemented via [jsonv2]. -// See [Prefs.MarshalJSONV2], [Prefs.UnmarshalJSONV2], [Prefs.MarshalJSON], +// See [Prefs.MarshalJSONTo], [Prefs.UnmarshalJSONFrom], [Prefs.MarshalJSON], // and [Prefs.UnmarshalJSON] for an example implementation. type Prefs struct { ControlURL prefs.Item[string] `json:",omitzero"` @@ -128,34 +128,39 @@ type AppConnectorPrefs struct { Advertise prefs.Item[bool] `json:",omitzero"` } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. +var ( + _ jsonv2.MarshalerTo = (*Prefs)(nil) + _ jsonv2.UnmarshalerFrom = (*Prefs)(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. // It is implemented as a performance improvement and to enable omission of // unconfigured preferences from the JSON output. See the [Prefs] doc for details. -func (p Prefs) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { +func (p Prefs) MarshalJSONTo(out *jsontext.Encoder) error { // The prefs type shadows the Prefs's method set, // causing [jsonv2] to use the default marshaler and avoiding // infinite recursion. type prefs Prefs - return jsonv2.MarshalEncode(out, (*prefs)(&p), opts) + return jsonv2.MarshalEncode(out, (*prefs)(&p)) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (p *Prefs) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (p *Prefs) UnmarshalJSONFrom(in *jsontext.Decoder) error { // The prefs type shadows the Prefs's method set, // causing [jsonv2] to use the default unmarshaler and avoiding // infinite recursion. type prefs Prefs - return jsonv2.UnmarshalDecode(in, (*prefs)(p), opts) + return jsonv2.UnmarshalDecode(in, (*prefs)(p)) } // MarshalJSON implements [json.Marshaler]. func (p Prefs) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(p) // uses MarshalJSONV2 + return jsonv2.Marshal(p) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (p *Prefs) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, p) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, p) // uses UnmarshalJSONFrom } type marshalAsTrueInJSON struct{} diff --git a/types/prefs/prefs_test.go b/types/prefs/prefs_test.go index ea4729366bc23..d6af745bf83b8 100644 --- a/types/prefs/prefs_test.go +++ b/types/prefs/prefs_test.go @@ -19,6 +19,20 @@ import ( //go:generate go run tailscale.com/cmd/viewer --tags=test --type=TestPrefs,TestBundle,TestValueStruct,TestGenericStruct,TestPrefsGroup +var ( + _ jsonv2.MarshalerTo = (*ItemView[*TestBundle, TestBundleView])(nil) + _ jsonv2.UnmarshalerFrom = (*ItemView[*TestBundle, TestBundleView])(nil) + + _ jsonv2.MarshalerTo = (*MapView[string, string])(nil) + _ jsonv2.UnmarshalerFrom = (*MapView[string, string])(nil) + + _ jsonv2.MarshalerTo = (*StructListView[*TestBundle, TestBundleView])(nil) + _ jsonv2.UnmarshalerFrom = (*StructListView[*TestBundle, TestBundleView])(nil) + + _ jsonv2.MarshalerTo = (*StructMapView[string, *TestBundle, TestBundleView])(nil) + _ jsonv2.UnmarshalerFrom = (*StructMapView[string, *TestBundle, TestBundleView])(nil) +) + type TestPrefs struct { Int32Item Item[int32] `json:",omitzero"` UInt64Item Item[uint64] `json:",omitzero"` @@ -53,32 +67,37 @@ type TestPrefs struct { Group TestPrefsGroup `json:",omitzero"` } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (p TestPrefs) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { +var ( + _ jsonv2.MarshalerTo = (*TestPrefs)(nil) + _ jsonv2.UnmarshalerFrom = (*TestPrefs)(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (p TestPrefs) MarshalJSONTo(out *jsontext.Encoder) error { // The testPrefs type shadows the TestPrefs's method set, // causing jsonv2 to use the default marshaler and avoiding // infinite recursion. type testPrefs TestPrefs - return jsonv2.MarshalEncode(out, (*testPrefs)(&p), opts) + return jsonv2.MarshalEncode(out, (*testPrefs)(&p)) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (p *TestPrefs) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (p *TestPrefs) UnmarshalJSONFrom(in *jsontext.Decoder) error { // The testPrefs type shadows the TestPrefs's method set, // causing jsonv2 to use the default unmarshaler and avoiding // infinite recursion. type testPrefs TestPrefs - return jsonv2.UnmarshalDecode(in, (*testPrefs)(p), opts) + return jsonv2.UnmarshalDecode(in, (*testPrefs)(p)) } // MarshalJSON implements [json.Marshaler]. func (p TestPrefs) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(p) // uses MarshalJSONV2 + return jsonv2.Marshal(p) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (p *TestPrefs) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, p) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, p) // uses UnmarshalJSONFrom } // TestBundle is an example structure type that, diff --git a/types/prefs/prefs_view_test.go b/types/prefs/prefs_view_test.go index d76eebb43e9ef..f6cfc918d02c0 100644 --- a/types/prefs/prefs_view_test.go +++ b/types/prefs/prefs_view_test.go @@ -13,7 +13,7 @@ import ( //go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=TestPrefs,TestBundle,TestValueStruct,TestGenericStruct,TestPrefsGroup -tags=test -// View returns a readonly view of TestPrefs. +// View returns a read-only view of TestPrefs. func (p *TestPrefs) View() TestPrefsView { return TestPrefsView{Đļ: p} } @@ -29,7 +29,7 @@ type TestPrefsView struct { Đļ *TestPrefs } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v TestPrefsView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -117,7 +117,7 @@ var _TestPrefsViewNeedsRegeneration = TestPrefs(struct { Group TestPrefsGroup }{}) -// View returns a readonly view of TestBundle. +// View returns a read-only view of TestBundle. func (p *TestBundle) View() TestBundleView { return TestBundleView{Đļ: p} } @@ -133,7 +133,7 @@ type TestBundleView struct { Đļ *TestBundle } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v TestBundleView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -162,15 +162,8 @@ func (v *TestBundleView) UnmarshalJSON(b []byte) error { return nil } -func (v TestBundleView) Name() string { return v.Đļ.Name } -func (v TestBundleView) Nested() *TestValueStruct { - if v.Đļ.Nested == nil { - return nil - } - x := *v.Đļ.Nested - return &x -} - +func (v TestBundleView) Name() string { return v.Đļ.Name } +func (v TestBundleView) Nested() TestValueStructView { return v.Đļ.Nested.View() } func (v TestBundleView) Equal(v2 TestBundleView) bool { return v.Đļ.Equal(v2.Đļ) } // A compilation failure here means this code must be regenerated, with the command at the top of this file. @@ -179,7 +172,7 @@ var _TestBundleViewNeedsRegeneration = TestBundle(struct { Nested *TestValueStruct }{}) -// View returns a readonly view of TestValueStruct. +// View returns a read-only view of TestValueStruct. func (p *TestValueStruct) View() TestValueStructView { return TestValueStructView{Đļ: p} } @@ -195,7 +188,7 @@ type TestValueStructView struct { Đļ *TestValueStruct } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v TestValueStructView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -232,7 +225,7 @@ var _TestValueStructViewNeedsRegeneration = TestValueStruct(struct { Value int }{}) -// View returns a readonly view of TestGenericStruct. +// View returns a read-only view of TestGenericStruct. func (p *TestGenericStruct[T]) View() TestGenericStructView[T] { return TestGenericStructView[T]{Đļ: p} } @@ -248,7 +241,7 @@ type TestGenericStructView[T ImmutableType] struct { Đļ *TestGenericStruct[T] } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v TestGenericStructView[T]) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with @@ -287,7 +280,7 @@ func _TestGenericStructViewNeedsRegeneration[T ImmutableType](TestGenericStruct[ }{}) } -// View returns a readonly view of TestPrefsGroup. +// View returns a read-only view of TestPrefsGroup. func (p *TestPrefsGroup) View() TestPrefsGroupView { return TestPrefsGroupView{Đļ: p} } @@ -303,7 +296,7 @@ type TestPrefsGroupView struct { Đļ *TestPrefsGroup } -// Valid reports whether underlying value is non-nil. +// Valid reports whether v's underlying value is non-nil. func (v TestPrefsGroupView) Valid() bool { return v.Đļ != nil } // AsStruct returns a clone of the underlying value which aliases no memory with diff --git a/types/prefs/struct_list.go b/types/prefs/struct_list.go index 872cb232655e3..65f11011af8fb 100644 --- a/types/prefs/struct_list.go +++ b/types/prefs/struct_list.go @@ -169,15 +169,15 @@ func (lv StructListView[T, V]) Equal(lv2 StructListView[T, V]) bool { return lv.Đļ.Equal(*lv2.Đļ) } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (lv StructListView[T, V]) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { - return lv.Đļ.MarshalJSONV2(out, opts) +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (lv StructListView[T, V]) MarshalJSONTo(out *jsontext.Encoder) error { + return lv.Đļ.MarshalJSONTo(out) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (lv *StructListView[T, V]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (lv *StructListView[T, V]) UnmarshalJSONFrom(in *jsontext.Decoder) error { var x StructList[T] - if err := x.UnmarshalJSONV2(in, opts); err != nil { + if err := x.UnmarshalJSONFrom(in); err != nil { return err } lv.Đļ = &x @@ -186,10 +186,10 @@ func (lv *StructListView[T, V]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv // MarshalJSON implements [json.Marshaler]. func (lv StructListView[T, V]) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(lv) // uses MarshalJSONV2 + return jsonv2.Marshal(lv) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (lv *StructListView[T, V]) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, lv) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, lv) // uses UnmarshalJSONFrom } diff --git a/types/prefs/struct_map.go b/types/prefs/struct_map.go index 2003eebe323fa..a081f7c7468e2 100644 --- a/types/prefs/struct_map.go +++ b/types/prefs/struct_map.go @@ -83,7 +83,7 @@ type StructMapView[K MapKeyType, T views.ViewCloner[T, V], V views.StructView[T] Đļ *StructMap[K, T] } -// StructMapViewOf returns a readonly view of m. +// StructMapViewOf returns a read-only view of m. // It is used by [tailscale.com/cmd/viewer]. func StructMapViewOf[K MapKeyType, T views.ViewCloner[T, V], V views.StructView[T]](m *StructMap[K, T]) StructMapView[K, T, V] { return StructMapView[K, T, V]{m} @@ -149,15 +149,15 @@ func (mv StructMapView[K, T, V]) Equal(mv2 StructMapView[K, T, V]) bool { return mv.Đļ.Equal(*mv2.Đļ) } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (mv StructMapView[K, T, V]) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { - return mv.Đļ.MarshalJSONV2(out, opts) +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (mv StructMapView[K, T, V]) MarshalJSONTo(out *jsontext.Encoder) error { + return mv.Đļ.MarshalJSONTo(out) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (mv *StructMapView[K, T, V]) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (mv *StructMapView[K, T, V]) UnmarshalJSONFrom(in *jsontext.Decoder) error { var x StructMap[K, T] - if err := x.UnmarshalJSONV2(in, opts); err != nil { + if err := x.UnmarshalJSONFrom(in); err != nil { return err } mv.Đļ = &x @@ -166,10 +166,10 @@ func (mv *StructMapView[K, T, V]) UnmarshalJSONV2(in *jsontext.Decoder, opts jso // MarshalJSON implements [json.Marshaler]. func (mv StructMapView[K, T, V]) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(mv) // uses MarshalJSONV2 + return jsonv2.Marshal(mv) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (mv *StructMapView[K, T, V]) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, mv) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, mv) // uses UnmarshalJSONFrom } diff --git a/types/result/result.go b/types/result/result.go new file mode 100644 index 0000000000000..6bd1c2ea62004 --- /dev/null +++ b/types/result/result.go @@ -0,0 +1,49 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package result contains the Of result type, which is +// either a value or an error. +package result + +// Of is either a T value or an error. +// +// Think of it like Rust or Swift's result types. +// It's named "Of" because the fully qualified name +// for callers reads result.Of[T]. +type Of[T any] struct { + v T // valid if Err is nil; invalid if Err is non-nil + err error +} + +// Value returns a new result with value v, +// without an error. +func Value[T any](v T) Of[T] { + return Of[T]{v: v} +} + +// Error returns a new result with error err. +// If err is nil, the returned result is equivalent +// to calling Value with T's zero value. +func Error[T any](err error) Of[T] { + return Of[T]{err: err} +} + +// MustValue returns r's result value. +// It panics if r.Err returns non-nil. +func (r Of[T]) MustValue() T { + if r.err != nil { + panic(r.err) + } + return r.v +} + +// Value returns r's result value and error. +func (r Of[T]) Value() (T, error) { + return r.v, r.err +} + +// Err returns r's error, if any. +// When r.Err returns nil, it's safe to call r.MustValue without it panicking. +func (r Of[T]) Err() error { + return r.err +} diff --git a/types/views/views.go b/types/views/views.go index 19aa69d4a8edb..3911f111258a8 100644 --- a/types/views/views.go +++ b/types/views/views.go @@ -16,6 +16,7 @@ import ( "slices" "go4.org/mem" + "tailscale.com/types/ptr" ) func unmarshalSliceFromJSON[T any](b []byte, x *[]T) error { @@ -329,6 +330,12 @@ func SliceEqual[T comparable](a, b Slice[T]) bool { return slices.Equal(a.Đļ, b.Đļ) } +// shortOOOLen (short Out-of-Order length) is the slice length at or +// under which we attempt to compare two slices quadratically rather +// than allocating memory for a map in SliceEqualAnyOrder and +// SliceEqualAnyOrderFunc. +const shortOOOLen = 5 + // SliceEqualAnyOrder reports whether a and b contain the same elements, regardless of order. // The underlying slices for a and b can be nil. func SliceEqualAnyOrder[T comparable](a, b Slice[T]) bool { @@ -346,13 +353,63 @@ func SliceEqualAnyOrder[T comparable](a, b Slice[T]) bool { return true } - // count the occurrences of remaining values and compare - valueCount := make(map[T]int) - for i, n := diffStart, a.Len(); i < n; i++ { - valueCount[a.At(i)]++ - valueCount[b.At(i)]-- + a, b = a.SliceFrom(diffStart), b.SliceFrom(diffStart) + cmp := func(v T) T { return v } + + // For a small number of items, avoid the allocation of a map and just + // do the quadratic thing. + if a.Len() <= shortOOOLen { + return unorderedSliceEqualAnyOrderSmall(a, b, cmp) + } + return unorderedSliceEqualAnyOrder(a, b, cmp) +} + +// SliceEqualAnyOrderFunc reports whether a and b contain the same elements, +// regardless of order. The underlying slices for a and b can be nil. +// +// The provided function should return a comparable value for each element. +func SliceEqualAnyOrderFunc[T any, V comparable](a, b Slice[T], cmp func(T) V) bool { + if a.Len() != b.Len() { + return false + } + + var diffStart int // beginning index where a and b differ + for n := a.Len(); diffStart < n; diffStart++ { + av := cmp(a.At(diffStart)) + bv := cmp(b.At(diffStart)) + if av != bv { + break + } } - for _, count := range valueCount { + if diffStart == a.Len() { + return true + } + + a, b = a.SliceFrom(diffStart), b.SliceFrom(diffStart) + // For a small number of items, avoid the allocation of a map and just + // do the quadratic thing. + if a.Len() <= shortOOOLen { + return unorderedSliceEqualAnyOrderSmall(a, b, cmp) + } + return unorderedSliceEqualAnyOrder(a, b, cmp) +} + +// unorderedSliceEqualAnyOrder reports whether a and b contain the same elements +// using a map. The cmp function maps from a T slice element to a comparable +// value. +func unorderedSliceEqualAnyOrder[T any, V comparable](a, b Slice[T], cmp func(T) V) bool { + if a.Len() != b.Len() { + panic("internal error") + } + if a.Len() == 0 { + return true + } + m := make(map[V]int) + for i := range a.Len() { + m[cmp(a.At(i))]++ + m[cmp(b.At(i))]-- + } + for _, count := range m { if count != 0 { return false } @@ -360,6 +417,60 @@ func SliceEqualAnyOrder[T comparable](a, b Slice[T]) bool { return true } +// unorderedSliceEqualAnyOrderSmall reports whether a and b (which must be the +// same length, and shortOOOLen or shorter) contain the same elements (using cmp +// to map from T to a comparable value) in some order. +// +// This is the quadratic-time implementation for small slices that doesn't +// allocate. +func unorderedSliceEqualAnyOrderSmall[T any, V comparable](a, b Slice[T], cmp func(T) V) bool { + if a.Len() != b.Len() || a.Len() > shortOOOLen { + panic("internal error") + } + + // These track which elements in a and b have been matched, so + // that we don't treat arrays with differing number of + // duplicate elements as equal (e.g. [1, 1, 2] and [1, 2, 2]). + var aMatched, bMatched [shortOOOLen]bool + + // Compare each element in a to each element in b + for i := range a.Len() { + av := cmp(a.At(i)) + found := false + for j := range a.Len() { + // Skip elements in b that have already been + // used to match an item in a. + if bMatched[j] { + continue + } + + bv := cmp(b.At(j)) + if av == bv { + // Mark these elements as already + // matched, so that a future loop + // iteration (of a duplicate element) + // doesn't match it again. + aMatched[i] = true + bMatched[j] = true + found = true + break + } + } + if !found { + return false + } + } + + // Verify all elements were matched exactly once. + for i := range a.Len() { + if !aMatched[i] || !bMatched[i] { + return false + } + } + + return true +} + // MapSlice is a view over a map whose values are slices. type MapSlice[K comparable, V any] struct { // Đļ is the underlying mutable value, named with a hard-to-type @@ -415,16 +526,6 @@ func (m *MapSlice[K, V]) UnmarshalJSON(b []byte) error { return json.Unmarshal(b, &m.Đļ) } -// Range calls f for every k,v pair in the underlying map. -// It stops iteration immediately if f returns false. -func (m MapSlice[K, V]) Range(f MapRangeFn[K, Slice[V]]) { - for k, v := range m.Đļ { - if !f(k, SliceOf(v)) { - return - } - } -} - // AsMap returns a shallow-clone of the underlying map. // // If V is a pointer type, it is the caller's responsibility to make sure the @@ -523,20 +624,51 @@ func (m Map[K, V]) AsMap() map[K]V { return maps.Clone(m.Đļ) } -// MapRangeFn is the func called from a Map.Range call. -// Implementations should return false to stop range. -type MapRangeFn[K comparable, V any] func(k K, v V) (cont bool) +// NOTE: the type constraints for MapViewsEqual and MapViewsEqualFunc are based +// on those for maps.Equal and maps.EqualFunc. -// Range calls f for every k,v pair in the underlying map. -// It stops iteration immediately if f returns false. -func (m Map[K, V]) Range(f MapRangeFn[K, V]) { - for k, v := range m.Đļ { - if !f(k, v) { - return +// MapViewsEqual returns whether the two given [Map]s are equal. Both K and V +// must be comparable; if V is non-comparable, use [MapViewsEqualFunc] instead. +func MapViewsEqual[K, V comparable](a, b Map[K, V]) bool { + if a.Len() != b.Len() || a.IsNil() != b.IsNil() { + return false + } + if a.IsNil() { + return true // both nil; can exit early + } + + for k, v := range a.All() { + bv, ok := b.GetOk(k) + if !ok || v != bv { + return false + } + } + return true +} + +// MapViewsEqualFunc returns whether the two given [Map]s are equal, using the +// given function to compare two values. +func MapViewsEqualFunc[K comparable, V1, V2 any](a Map[K, V1], b Map[K, V2], eq func(V1, V2) bool) bool { + if a.Len() != b.Len() || a.IsNil() != b.IsNil() { + return false + } + if a.IsNil() { + return true // both nil; can exit early + } + + for k, v := range a.All() { + bv, ok := b.GetOk(k) + if !ok || !eq(v, bv) { + return false } } + return true } +// MapRangeFn is the func called from a Map.Range call. +// Implementations should return false to stop range. +type MapRangeFn[K comparable, V any] func(k K, v V) (cont bool) + // All returns an iterator iterating over the keys // and values of m. func (m Map[K, V]) All() iter.Seq2[K, V] { @@ -600,16 +732,6 @@ func (m MapFn[K, T, V]) GetOk(k K) (V, bool) { return m.wrapv(v), ok } -// Range calls f for every k,v pair in the underlying map. -// It stops iteration immediately if f returns false. -func (m MapFn[K, T, V]) Range(f MapRangeFn[K, V]) { - for k, v := range m.Đļ { - if !f(k, m.wrapv(v)) { - return - } - } -} - // All returns an iterator iterating over the keys and value views of m. func (m MapFn[K, T, V]) All() iter.Seq2[K, V] { return func(yield func(K, V) bool) { @@ -621,6 +743,85 @@ func (m MapFn[K, T, V]) All() iter.Seq2[K, V] { } } +// ValuePointer provides a read-only view of a pointer to a value type, +// such as a primitive type or an immutable struct. Its Value and ValueOk +// methods return a stack-allocated shallow copy of the underlying value. +// It is the caller's responsibility to ensure that T +// is free from memory aliasing/mutation concerns. +type ValuePointer[T any] struct { + // Đļ is the underlying value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + Đļ *T +} + +// Valid reports whether the underlying pointer is non-nil. +func (p ValuePointer[T]) Valid() bool { + return p.Đļ != nil +} + +// Get returns a shallow copy of the value if the underlying pointer is non-nil. +// Otherwise, it returns a zero value. +func (p ValuePointer[T]) Get() T { + v, _ := p.GetOk() + return v +} + +// GetOk returns a shallow copy of the underlying value and true if the underlying +// pointer is non-nil. Otherwise, it returns a zero value and false. +func (p ValuePointer[T]) GetOk() (value T, ok bool) { + if p.Đļ == nil { + return value, false // value holds a zero value + } + return *p.Đļ, true +} + +// GetOr returns a shallow copy of the underlying value if it is non-nil. +// Otherwise, it returns the provided default value. +func (p ValuePointer[T]) GetOr(def T) T { + if p.Đļ == nil { + return def + } + return *p.Đļ +} + +// Clone returns a shallow copy of the underlying value. +func (p ValuePointer[T]) Clone() *T { + if p.Đļ == nil { + return nil + } + return ptr.To(*p.Đļ) +} + +// String implements [fmt.Stringer]. +func (p ValuePointer[T]) String() string { + if p.Đļ == nil { + return "nil" + } + return fmt.Sprint(p.Đļ) +} + +// ValuePointerOf returns an immutable view of a pointer to an immutable value. +// It is the caller's responsibility to ensure that T +// is free from memory aliasing/mutation concerns. +func ValuePointerOf[T any](v *T) ValuePointer[T] { + return ValuePointer[T]{v} +} + +// MarshalJSON implements [json.Marshaler]. +func (p ValuePointer[T]) MarshalJSON() ([]byte, error) { + return json.Marshal(p.Đļ) +} + +// UnmarshalJSON implements [json.Unmarshaler]. +func (p *ValuePointer[T]) UnmarshalJSON(b []byte) error { + if p.Đļ != nil { + return errors.New("already initialized") + } + return json.Unmarshal(b, &p.Đļ) +} + // ContainsPointers reports whether T contains any pointers, // either explicitly or implicitly. // It has special handling for some types that contain pointers diff --git a/types/views/views_test.go b/types/views/views_test.go index 8a1ff3fddfc9e..2205cbc03ab74 100644 --- a/types/views/views_test.go +++ b/types/views/views_test.go @@ -15,6 +15,7 @@ import ( "unsafe" qt "github.com/frankban/quicktest" + "tailscale.com/types/structs" ) type viewStruct struct { @@ -152,6 +153,161 @@ func TestViewUtils(t *testing.T) { qt.Equals, true) } +func TestSliceEqualAnyOrderFunc(t *testing.T) { + type nc struct { + _ structs.Incomparable + v string + } + + // ncFrom returns a Slice[nc] from a slice of []string + ncFrom := func(s ...string) Slice[nc] { + var out []nc + for _, v := range s { + out = append(out, nc{v: v}) + } + return SliceOf(out) + } + + // cmp returns a comparable value for a nc + cmp := func(a nc) string { return a.v } + + v := ncFrom("foo", "bar") + c := qt.New(t) + + // Simple case of slice equal to itself. + c.Check(SliceEqualAnyOrderFunc(v, v, cmp), qt.Equals, true) + + // Different order. + c.Check(SliceEqualAnyOrderFunc(v, ncFrom("bar", "foo"), cmp), qt.Equals, true) + + // Different values, same length + c.Check(SliceEqualAnyOrderFunc(v, ncFrom("foo", "baz"), cmp), qt.Equals, false) + + // Different values, different length + c.Check(SliceEqualAnyOrderFunc(v, ncFrom("foo"), cmp), qt.Equals, false) + + // Nothing shared + c.Check(SliceEqualAnyOrderFunc(v, ncFrom("baz", "qux"), cmp), qt.Equals, false) + + // Long slice that matches + longSlice := ncFrom("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") + longSame := ncFrom("b", "a", "c", "d", "e", "f", "g", "h", "i", "j") // first 2 elems swapped + c.Check(SliceEqualAnyOrderFunc(longSlice, longSame, cmp), qt.Equals, true) + + // Long difference; past the quadratic limit + longDiff := ncFrom("b", "a", "c", "d", "e", "f", "g", "h", "i", "k") // differs at end + c.Check(SliceEqualAnyOrderFunc(longSlice, longDiff, cmp), qt.Equals, false) + + // The short slice optimization had a bug where it wouldn't handle + // duplicate elements; test various cases here driven by code coverage. + shortTestCases := []struct { + name string + s1, s2 Slice[nc] + want bool + }{ + { + name: "duplicates_same_length", + s1: ncFrom("a", "a", "b"), + s2: ncFrom("a", "b", "b"), + want: false, + }, + { + name: "duplicates_different_matched", + s1: ncFrom("x", "y", "a", "a", "b"), + s2: ncFrom("x", "y", "b", "a", "a"), + want: true, + }, + { + name: "item_in_a_not_b", + s1: ncFrom("x", "y", "a", "b", "c"), + s2: ncFrom("x", "y", "b", "c", "q"), + want: false, + }, + } + for _, tc := range shortTestCases { + t.Run("short_"+tc.name, func(t *testing.T) { + c.Check(SliceEqualAnyOrderFunc(tc.s1, tc.s2, cmp), qt.Equals, tc.want) + }) + } +} + +func TestSliceEqualAnyOrderAllocs(t *testing.T) { + ss := func(s ...string) Slice[string] { return SliceOf(s) } + cmp := func(s string) string { return s } + + t.Run("no-allocs-short-unordered", func(t *testing.T) { + // No allocations for short comparisons + short1 := ss("a", "b", "c") + short2 := ss("c", "b", "a") + if n := testing.AllocsPerRun(1000, func() { + if !SliceEqualAnyOrder(short1, short2) { + t.Fatal("not equal") + } + if !SliceEqualAnyOrderFunc(short1, short2, cmp) { + t.Fatal("not equal") + } + }); n > 0 { + t.Fatalf("allocs = %v; want 0", n) + } + }) + + t.Run("no-allocs-long-match", func(t *testing.T) { + long1 := ss("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") + long2 := ss("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") + + if n := testing.AllocsPerRun(1000, func() { + if !SliceEqualAnyOrder(long1, long2) { + t.Fatal("not equal") + } + if !SliceEqualAnyOrderFunc(long1, long2, cmp) { + t.Fatal("not equal") + } + }); n > 0 { + t.Fatalf("allocs = %v; want 0", n) + } + }) + + t.Run("allocs-long-unordered", func(t *testing.T) { + // We do unfortunately allocate for long comparisons. + long1 := ss("a", "b", "c", "d", "e", "f", "g", "h", "i", "j") + long2 := ss("c", "b", "a", "e", "d", "f", "g", "h", "i", "j") + + if n := testing.AllocsPerRun(1000, func() { + if !SliceEqualAnyOrder(long1, long2) { + t.Fatal("not equal") + } + if !SliceEqualAnyOrderFunc(long1, long2, cmp) { + t.Fatal("not equal") + } + }); n == 0 { + t.Fatalf("unexpectedly didn't allocate") + } + }) +} + +func BenchmarkSliceEqualAnyOrder(b *testing.B) { + b.Run("short", func(b *testing.B) { + b.ReportAllocs() + s1 := SliceOf([]string{"foo", "bar"}) + s2 := SliceOf([]string{"bar", "foo"}) + for range b.N { + if !SliceEqualAnyOrder(s1, s2) { + b.Fatal() + } + } + }) + b.Run("long", func(b *testing.B) { + b.ReportAllocs() + s1 := SliceOf([]string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"}) + s2 := SliceOf([]string{"c", "b", "a", "e", "d", "f", "g", "h", "i", "j"}) + for range b.N { + if !SliceEqualAnyOrder(s1, s2) { + b.Fatal() + } + } + }) +} + func TestSliceEqual(t *testing.T) { a := SliceOf([]string{"foo", "bar"}) b := SliceOf([]string{"foo", "bar"}) @@ -501,3 +657,87 @@ func TestMapFnIter(t *testing.T) { t.Errorf("got %q; want %q", got, want) } } + +func TestMapViewsEqual(t *testing.T) { + testCases := []struct { + name string + a, b map[string]string + want bool + }{ + { + name: "both_nil", + a: nil, + b: nil, + want: true, + }, + { + name: "both_empty", + a: map[string]string{}, + b: map[string]string{}, + want: true, + }, + { + name: "one_nil", + a: nil, + b: map[string]string{"a": "1"}, + want: false, + }, + { + name: "different_length", + a: map[string]string{"a": "1"}, + b: map[string]string{"a": "1", "b": "2"}, + want: false, + }, + { + name: "different_values", + a: map[string]string{"a": "1"}, + b: map[string]string{"a": "2"}, + want: false, + }, + { + name: "different_keys", + a: map[string]string{"a": "1"}, + b: map[string]string{"b": "1"}, + want: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := MapViewsEqual(MapOf(tc.a), MapOf(tc.b)) + if got != tc.want { + t.Errorf("MapViewsEqual: got=%v, want %v", got, tc.want) + } + + got = MapViewsEqualFunc(MapOf(tc.a), MapOf(tc.b), func(a, b string) bool { + return a == b + }) + if got != tc.want { + t.Errorf("MapViewsEqualFunc: got=%v, want %v", got, tc.want) + } + }) + } +} + +func TestMapViewsEqualFunc(t *testing.T) { + // Test that we can compare maps with two different non-comparable + // values using a custom comparison function. + type customStruct1 struct { + _ structs.Incomparable + Field1 string + } + type customStruct2 struct { + _ structs.Incomparable + Field2 string + } + + a := map[string]customStruct1{"a": {Field1: "1"}} + b := map[string]customStruct2{"a": {Field2: "1"}} + + got := MapViewsEqualFunc(MapOf(a), MapOf(b), func(a customStruct1, b customStruct2) bool { + return a.Field1 == b.Field2 + }) + if !got { + t.Errorf("MapViewsEqualFunc: got=%v, want true", got) + } +} diff --git a/util/clientmetric/clientmetric.go b/util/clientmetric/clientmetric.go index 584a24f73dca8..5c11160194fdc 100644 --- a/util/clientmetric/clientmetric.go +++ b/util/clientmetric/clientmetric.go @@ -270,7 +270,7 @@ func (c *AggregateCounter) UnregisterAll() { // a sum of expvar variables registered with it. func NewAggregateCounter(name string) *AggregateCounter { c := &AggregateCounter{counters: set.Set[*expvar.Int]{}} - NewGaugeFunc(name, c.Value) + NewCounterFunc(name, c.Value) return c } diff --git a/util/codegen/codegen.go b/util/codegen/codegen.go index d998d925d9143..1b3af10e03ee1 100644 --- a/util/codegen/codegen.go +++ b/util/codegen/codegen.go @@ -97,6 +97,11 @@ func (it *ImportTracker) Import(pkg string) { } } +// Has reports whether the specified package has been imported. +func (it *ImportTracker) Has(pkg string) bool { + return it.packages[pkg] +} + func (it *ImportTracker) qualifier(pkg *types.Package) string { if it.thisPkg == pkg { return "" @@ -272,11 +277,16 @@ func IsInvalid(t types.Type) bool { // It has special handling for some types that contain pointers // that we know are free from memory aliasing/mutation concerns. func ContainsPointers(typ types.Type) bool { - switch typ.String() { + s := typ.String() + switch s { case "time.Time": - // time.Time contains a pointer that does not need copying + // time.Time contains a pointer that does not need cloning. return false - case "inet.af/netip.Addr", "net/netip.Addr", "net/netip.Prefix", "net/netip.AddrPort": + case "inet.af/netip.Addr": + return false + } + if strings.HasPrefix(s, "unique.Handle[") { + // unique.Handle contains a pointer that does not need cloning. return false } switch ft := typ.Underlying().(type) { diff --git a/util/codegen/codegen_test.go b/util/codegen/codegen_test.go index 28ddaed2bac36..74715eecae6ef 100644 --- a/util/codegen/codegen_test.go +++ b/util/codegen/codegen_test.go @@ -10,6 +10,8 @@ import ( "strings" "sync" "testing" + "time" + "unique" "unsafe" "golang.org/x/exp/constraints" @@ -84,6 +86,16 @@ type PointerUnionParam[T netip.Prefix | BasicType | IntPtr] struct { V T } +type StructWithUniqueHandle struct{ _ unique.Handle[[32]byte] } + +type StructWithTime struct{ _ time.Time } + +type StructWithNetipTypes struct { + _ netip.Addr + _ netip.AddrPort + _ netip.Prefix +} + type Interface interface { Method() } @@ -161,6 +173,18 @@ func TestGenericContainsPointers(t *testing.T) { typ: "PointerUnionParam", wantPointer: true, }, + { + typ: "StructWithUniqueHandle", + wantPointer: false, + }, + { + typ: "StructWithTime", + wantPointer: false, + }, + { + typ: "StructWithNetipTypes", + wantPointer: false, + }, } for _, tt := range tests { diff --git a/util/cstruct/cstruct.go b/util/cstruct/cstruct.go index 464dc5dc3cadf..4d1d0a98b8032 100644 --- a/util/cstruct/cstruct.go +++ b/util/cstruct/cstruct.go @@ -6,10 +6,9 @@ package cstruct import ( + "encoding/binary" "errors" "io" - - "github.com/josharian/native" ) // Size of a pointer-typed value, in bits @@ -120,7 +119,7 @@ func (d *Decoder) Uint16() uint16 { d.err = err return 0 } - return native.Endian.Uint16(d.dbuf[0:2]) + return binary.NativeEndian.Uint16(d.dbuf[0:2]) } // Uint32 returns a uint32 decoded from the buffer. @@ -133,7 +132,7 @@ func (d *Decoder) Uint32() uint32 { d.err = err return 0 } - return native.Endian.Uint32(d.dbuf[0:4]) + return binary.NativeEndian.Uint32(d.dbuf[0:4]) } // Uint64 returns a uint64 decoded from the buffer. @@ -146,7 +145,7 @@ func (d *Decoder) Uint64() uint64 { d.err = err return 0 } - return native.Endian.Uint64(d.dbuf[0:8]) + return binary.NativeEndian.Uint64(d.dbuf[0:8]) } // Uintptr returns a uintptr decoded from the buffer. diff --git a/util/deephash/deephash_test.go b/util/deephash/deephash_test.go index d5584def33937..413893ff967d2 100644 --- a/util/deephash/deephash_test.go +++ b/util/deephash/deephash_test.go @@ -23,18 +23,11 @@ import ( "go4.org/mem" "go4.org/netipx" "tailscale.com/tailcfg" - "tailscale.com/types/dnstype" - "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/ptr" - "tailscale.com/types/views" "tailscale.com/util/deephash/testtype" - "tailscale.com/util/dnsname" "tailscale.com/util/hashx" "tailscale.com/version" - "tailscale.com/wgengine/filter" - "tailscale.com/wgengine/router" - "tailscale.com/wgengine/wgcfg" ) type appendBytes []byte @@ -197,21 +190,6 @@ func TestHash(t *testing.T) { } } -func TestDeepHash(t *testing.T) { - // v contains the types of values we care about for our current callers. - // Mostly we're just testing that we don't panic on handled types. - v := getVal() - hash1 := Hash(v) - t.Logf("hash: %v", hash1) - for range 20 { - v := getVal() - hash2 := Hash(v) - if hash1 != hash2 { - t.Error("second hash didn't match") - } - } -} - // Tests that we actually hash map elements. Whoops. func TestIssue4868(t *testing.T) { m1 := map[int]string{1: "foo"} @@ -255,110 +233,6 @@ func TestQuick(t *testing.T) { } } -type tailscaleTypes struct { - WGConfig *wgcfg.Config - RouterConfig *router.Config - MapFQDNAddrs map[dnsname.FQDN][]netip.Addr - MapFQDNAddrPorts map[dnsname.FQDN][]netip.AddrPort - MapDiscoPublics map[key.DiscoPublic]bool - MapResponse *tailcfg.MapResponse - FilterMatch filter.Match -} - -func getVal() *tailscaleTypes { - return &tailscaleTypes{ - &wgcfg.Config{ - Name: "foo", - Addresses: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{3: 3}).Unmap(), 5)}, - Peers: []wgcfg.Peer{ - { - PublicKey: key.NodePublic{}, - }, - }, - }, - &router.Config{ - Routes: []netip.Prefix{ - netip.MustParsePrefix("1.2.3.0/24"), - netip.MustParsePrefix("1234::/64"), - }, - }, - map[dnsname.FQDN][]netip.Addr{ - dnsname.FQDN("a."): {netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("4.3.2.1")}, - dnsname.FQDN("b."): {netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("9.9.9.9")}, - dnsname.FQDN("c."): {netip.MustParseAddr("6.6.6.6"), netip.MustParseAddr("7.7.7.7")}, - dnsname.FQDN("d."): {netip.MustParseAddr("6.7.6.6"), netip.MustParseAddr("7.7.7.8")}, - dnsname.FQDN("e."): {netip.MustParseAddr("6.8.6.6"), netip.MustParseAddr("7.7.7.9")}, - dnsname.FQDN("f."): {netip.MustParseAddr("6.9.6.6"), netip.MustParseAddr("7.7.7.0")}, - }, - map[dnsname.FQDN][]netip.AddrPort{ - dnsname.FQDN("a."): {netip.MustParseAddrPort("1.2.3.4:11"), netip.MustParseAddrPort("4.3.2.1:22")}, - dnsname.FQDN("b."): {netip.MustParseAddrPort("8.8.8.8:11"), netip.MustParseAddrPort("9.9.9.9:22")}, - dnsname.FQDN("c."): {netip.MustParseAddrPort("8.8.8.8:12"), netip.MustParseAddrPort("9.9.9.9:23")}, - dnsname.FQDN("d."): {netip.MustParseAddrPort("8.8.8.8:13"), netip.MustParseAddrPort("9.9.9.9:24")}, - dnsname.FQDN("e."): {netip.MustParseAddrPort("8.8.8.8:14"), netip.MustParseAddrPort("9.9.9.9:25")}, - }, - map[key.DiscoPublic]bool{ - key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 31: 0})): true, - key.DiscoPublicFromRaw32(mem.B([]byte{1: 2, 31: 0})): false, - key.DiscoPublicFromRaw32(mem.B([]byte{1: 3, 31: 0})): true, - key.DiscoPublicFromRaw32(mem.B([]byte{1: 4, 31: 0})): false, - }, - &tailcfg.MapResponse{ - DERPMap: &tailcfg.DERPMap{ - Regions: map[int]*tailcfg.DERPRegion{ - 1: { - RegionID: 1, - RegionCode: "foo", - Nodes: []*tailcfg.DERPNode{ - { - Name: "n1", - RegionID: 1, - HostName: "foo.com", - }, - { - Name: "n2", - RegionID: 1, - HostName: "bar.com", - }, - }, - }, - }, - }, - DNSConfig: &tailcfg.DNSConfig{ - Resolvers: []*dnstype.Resolver{ - {Addr: "10.0.0.1"}, - }, - }, - PacketFilter: []tailcfg.FilterRule{ - { - SrcIPs: []string{"1.2.3.4"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "1.2.3.4/32", - Ports: tailcfg.PortRange{First: 1, Last: 2}, - }, - }, - }, - }, - Peers: []*tailcfg.Node{ - { - ID: 1, - }, - { - ID: 2, - }, - }, - UserProfiles: []tailcfg.UserProfile{ - {ID: 1, LoginName: "foo@bar.com"}, - {ID: 2, LoginName: "bar@foo.com"}, - }, - }, - filter.Match{ - IPProto: views.SliceOf([]ipproto.Proto{1, 2, 3}), - }, - } -} - type IntThenByte struct { _ int _ byte @@ -758,14 +632,6 @@ func TestInterfaceCycle(t *testing.T) { var sink Sum -func BenchmarkHash(b *testing.B) { - b.ReportAllocs() - v := getVal() - for range b.N { - sink = Hash(v) - } -} - // filterRules is a packet filter that has both everything populated (in its // first element) and also a few entries that are the typical shape for regular // packet filters as sent to clients. @@ -1072,16 +938,6 @@ func FuzzAddr(f *testing.F) { }) } -func TestAppendTo(t *testing.T) { - v := getVal() - h := Hash(v) - sum := h.AppendTo(nil) - - if s := h.String(); s != string(sum) { - t.Errorf("hash sum mismatch; h.String()=%q h.AppendTo()=%q", s, string(sum)) - } -} - func TestFilterFields(t *testing.T) { type T struct { A int @@ -1126,15 +982,3 @@ func TestFilterFields(t *testing.T) { } } } - -func BenchmarkAppendTo(b *testing.B) { - b.ReportAllocs() - v := getVal() - h := Hash(v) - - hashBuf := make([]byte, 0, 100) - b.ResetTimer() - for range b.N { - hashBuf = h.AppendTo(hashBuf[:0]) - } -} diff --git a/util/deephash/tailscale_types_test.go b/util/deephash/tailscale_types_test.go new file mode 100644 index 0000000000000..d760253990048 --- /dev/null +++ b/util/deephash/tailscale_types_test.go @@ -0,0 +1,177 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This file contains tests and benchmarks that use types from other packages +// in the Tailscale codebase. Unlike other deephash tests, these are in the _test +// package to avoid circular dependencies. + +package deephash_test + +import ( + "net/netip" + "testing" + + "go4.org/mem" + "tailscale.com/tailcfg" + "tailscale.com/types/dnstype" + "tailscale.com/types/ipproto" + "tailscale.com/types/key" + "tailscale.com/types/views" + "tailscale.com/util/dnsname" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/router" + "tailscale.com/wgengine/wgcfg" + + . "tailscale.com/util/deephash" +) + +var sink Sum + +func BenchmarkHash(b *testing.B) { + b.ReportAllocs() + v := getVal() + for range b.N { + sink = Hash(v) + } +} + +func BenchmarkAppendTo(b *testing.B) { + b.ReportAllocs() + v := getVal() + h := Hash(v) + + hashBuf := make([]byte, 0, 100) + b.ResetTimer() + for range b.N { + hashBuf = h.AppendTo(hashBuf[:0]) + } +} + +func TestDeepHash(t *testing.T) { + // v contains the types of values we care about for our current callers. + // Mostly we're just testing that we don't panic on handled types. + v := getVal() + hash1 := Hash(v) + t.Logf("hash: %v", hash1) + for range 20 { + v := getVal() + hash2 := Hash(v) + if hash1 != hash2 { + t.Error("second hash didn't match") + } + } +} + +func TestAppendTo(t *testing.T) { + v := getVal() + h := Hash(v) + sum := h.AppendTo(nil) + + if s := h.String(); s != string(sum) { + t.Errorf("hash sum mismatch; h.String()=%q h.AppendTo()=%q", s, string(sum)) + } +} + +type tailscaleTypes struct { + WGConfig *wgcfg.Config + RouterConfig *router.Config + MapFQDNAddrs map[dnsname.FQDN][]netip.Addr + MapFQDNAddrPorts map[dnsname.FQDN][]netip.AddrPort + MapDiscoPublics map[key.DiscoPublic]bool + MapResponse *tailcfg.MapResponse + FilterMatch filter.Match +} + +func getVal() *tailscaleTypes { + return &tailscaleTypes{ + &wgcfg.Config{ + Name: "foo", + Addresses: []netip.Prefix{netip.PrefixFrom(netip.AddrFrom16([16]byte{3: 3}).Unmap(), 5)}, + Peers: []wgcfg.Peer{ + { + PublicKey: key.NodePublic{}, + }, + }, + }, + &router.Config{ + Routes: []netip.Prefix{ + netip.MustParsePrefix("1.2.3.0/24"), + netip.MustParsePrefix("1234::/64"), + }, + }, + map[dnsname.FQDN][]netip.Addr{ + dnsname.FQDN("a."): {netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("4.3.2.1")}, + dnsname.FQDN("b."): {netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("9.9.9.9")}, + dnsname.FQDN("c."): {netip.MustParseAddr("6.6.6.6"), netip.MustParseAddr("7.7.7.7")}, + dnsname.FQDN("d."): {netip.MustParseAddr("6.7.6.6"), netip.MustParseAddr("7.7.7.8")}, + dnsname.FQDN("e."): {netip.MustParseAddr("6.8.6.6"), netip.MustParseAddr("7.7.7.9")}, + dnsname.FQDN("f."): {netip.MustParseAddr("6.9.6.6"), netip.MustParseAddr("7.7.7.0")}, + }, + map[dnsname.FQDN][]netip.AddrPort{ + dnsname.FQDN("a."): {netip.MustParseAddrPort("1.2.3.4:11"), netip.MustParseAddrPort("4.3.2.1:22")}, + dnsname.FQDN("b."): {netip.MustParseAddrPort("8.8.8.8:11"), netip.MustParseAddrPort("9.9.9.9:22")}, + dnsname.FQDN("c."): {netip.MustParseAddrPort("8.8.8.8:12"), netip.MustParseAddrPort("9.9.9.9:23")}, + dnsname.FQDN("d."): {netip.MustParseAddrPort("8.8.8.8:13"), netip.MustParseAddrPort("9.9.9.9:24")}, + dnsname.FQDN("e."): {netip.MustParseAddrPort("8.8.8.8:14"), netip.MustParseAddrPort("9.9.9.9:25")}, + }, + map[key.DiscoPublic]bool{ + key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 31: 0})): true, + key.DiscoPublicFromRaw32(mem.B([]byte{1: 2, 31: 0})): false, + key.DiscoPublicFromRaw32(mem.B([]byte{1: 3, 31: 0})): true, + key.DiscoPublicFromRaw32(mem.B([]byte{1: 4, 31: 0})): false, + }, + &tailcfg.MapResponse{ + DERPMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + RegionID: 1, + RegionCode: "foo", + Nodes: []*tailcfg.DERPNode{ + { + Name: "n1", + RegionID: 1, + HostName: "foo.com", + }, + { + Name: "n2", + RegionID: 1, + HostName: "bar.com", + }, + }, + }, + }, + }, + DNSConfig: &tailcfg.DNSConfig{ + Resolvers: []*dnstype.Resolver{ + {Addr: "10.0.0.1"}, + }, + }, + PacketFilter: []tailcfg.FilterRule{ + { + SrcIPs: []string{"1.2.3.4"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "1.2.3.4/32", + Ports: tailcfg.PortRange{First: 1, Last: 2}, + }, + }, + }, + }, + Peers: []*tailcfg.Node{ + { + ID: 1, + }, + { + ID: 2, + }, + }, + UserProfiles: []tailcfg.UserProfile{ + {ID: 1, LoginName: "foo@bar.com"}, + {ID: 2, LoginName: "bar@foo.com"}, + }, + }, + filter.Match{ + IPProto: views.SliceOf([]ipproto.Proto{1, 2, 3}), + }, + } +} diff --git a/util/dnsname/dnsname.go b/util/dnsname/dnsname.go index dde0baaeda2f2..6404a9af1cc2f 100644 --- a/util/dnsname/dnsname.go +++ b/util/dnsname/dnsname.go @@ -5,9 +5,9 @@ package dnsname import ( - "errors" - "fmt" "strings" + + "tailscale.com/util/vizerror" ) const ( @@ -36,7 +36,7 @@ func ToFQDN(s string) (FQDN, error) { totalLen += 1 // account for missing dot } if totalLen > maxNameLength { - return "", fmt.Errorf("%q is too long to be a DNS name", s) + return "", vizerror.Errorf("%q is too long to be a DNS name", s) } st := 0 @@ -54,7 +54,7 @@ func ToFQDN(s string) (FQDN, error) { // // See https://github.com/tailscale/tailscale/issues/2024 for more. if len(label) == 0 || len(label) > maxLabelLength { - return "", fmt.Errorf("%q is not a valid DNS label", label) + return "", vizerror.Errorf("%q is not a valid DNS label", label) } st = i + 1 } @@ -94,26 +94,27 @@ func (f FQDN) Contains(other FQDN) bool { return strings.HasSuffix(other.WithTrailingDot(), cmp) } -// ValidLabel reports whether label is a valid DNS label. +// ValidLabel reports whether label is a valid DNS label. All errors are +// [vizerror.Error]. func ValidLabel(label string) error { if len(label) == 0 { - return errors.New("empty DNS label") + return vizerror.New("empty DNS label") } if len(label) > maxLabelLength { - return fmt.Errorf("%q is too long, max length is %d bytes", label, maxLabelLength) + return vizerror.Errorf("%q is too long, max length is %d bytes", label, maxLabelLength) } if !isalphanum(label[0]) { - return fmt.Errorf("%q is not a valid DNS label: must start with a letter or number", label) + return vizerror.Errorf("%q is not a valid DNS label: must start with a letter or number", label) } if !isalphanum(label[len(label)-1]) { - return fmt.Errorf("%q is not a valid DNS label: must end with a letter or number", label) + return vizerror.Errorf("%q is not a valid DNS label: must end with a letter or number", label) } if len(label) < 2 { return nil } for i := 1; i < len(label)-1; i++ { if !isdnschar(label[i]) { - return fmt.Errorf("%q is not a valid DNS label: contains invalid character %q", label, label[i]) + return vizerror.Errorf("%q is not a valid DNS label: contains invalid character %q", label, label[i]) } } return nil diff --git a/util/eventbus/assets/event.html b/util/eventbus/assets/event.html new file mode 100644 index 0000000000000..8e016f583a250 --- /dev/null +++ b/util/eventbus/assets/event.html @@ -0,0 +1,6 @@ +
      • +
        + {{.Count}}: {{.Type}} from {{.Event.From.Name}}, {{len .Event.To}} recipients + {{.Event.Event}} +
        +
      • diff --git a/util/eventbus/assets/htmx-websocket.min.js.gz b/util/eventbus/assets/htmx-websocket.min.js.gz new file mode 100644 index 0000000000000..4ed53be492425 Binary files /dev/null and b/util/eventbus/assets/htmx-websocket.min.js.gz differ diff --git a/util/eventbus/assets/htmx.min.js.gz b/util/eventbus/assets/htmx.min.js.gz new file mode 100644 index 0000000000000..b75fea8d146df Binary files /dev/null and b/util/eventbus/assets/htmx.min.js.gz differ diff --git a/util/eventbus/assets/main.html b/util/eventbus/assets/main.html new file mode 100644 index 0000000000000..51d6b22addc6a --- /dev/null +++ b/util/eventbus/assets/main.html @@ -0,0 +1,97 @@ + + + + + + + + +

        Event bus

        + +
        +

        General

        + {{with $.PublishQueue}} + {{len .}} pending + {{end}} + + +
        + +
        +

        Clients

        + + + + + + + + + + + {{range .Clients}} + + + + + + + {{end}} +
        NamePublishingSubscribingPending
        {{.Name}} +
          + {{range .Publish}} +
        • {{.}}
        • + {{end}} +
        +
        +
          + {{range .Subscribe}} +
        • {{.}}
        • + {{end}} +
        +
        + {{len ($.SubscribeQueue .Client)}} +
        +
        + +
        +

        Types

        + + {{range .Types}} + +
        +

        {{.Name}}

        +

        Definition

        + {{prettyPrintStruct .}} + +

        Published by:

        + {{if len (.Publish)}} +
          + {{range .Publish}} +
        • {{.Name}}
        • + {{end}} +
        + {{else}} +
          +
        • No publishers.
        • +
        + {{end}} + +

        Received by:

        + {{if len (.Subscribe)}} +
          + {{range .Subscribe}} +
        • {{.Name}}
        • + {{end}} +
        + {{else}} +
          +
        • No subscribers.
        • +
        + {{end}} +
        + {{end}} + +
        + + diff --git a/util/eventbus/assets/monitor.html b/util/eventbus/assets/monitor.html new file mode 100644 index 0000000000000..1af5bdce63455 --- /dev/null +++ b/util/eventbus/assets/monitor.html @@ -0,0 +1,5 @@ +
        +
          +
        + +
        diff --git a/util/eventbus/assets/style.css b/util/eventbus/assets/style.css new file mode 100644 index 0000000000000..690bd4f176270 --- /dev/null +++ b/util/eventbus/assets/style.css @@ -0,0 +1,90 @@ +/* CSS reset, thanks Josh Comeau: https://www.joshwcomeau.com/css/custom-css-reset/ */ +*, *::before, *::after { box-sizing: border-box; } +* { margin: 0; } +input, button, textarea, select { font: inherit; } +p, h1, h2, h3, h4, h5, h6 { overflow-wrap: break-word; } +p { text-wrap: pretty; } +h1, h2, h3, h4, h5, h6 { text-wrap: balance; } +#root, #__next { isolation: isolate; } +body { + line-height: 1.5; + -webkit-font-smoothing: antialiased; +} +img, picture, video, canvas, svg { + display: block; + max-width: 100%; +} + +/* Local styling begins */ + +body { + padding: 12px; +} + +div { + width: 100%; +} + +section { + display: flex; + flex-direction: column; + flex-gap: 6px; + align-items: flex-start; + padding: 12px 0; +} + +section > * { + margin-left: 24px; +} + +section > h2, section > h3 { + margin-left: 0; + padding-bottom: 6px; + padding-top: 12px; +} + +details { + padding-bottom: 12px; +} + +table { + table-layout: fixed; + width: calc(100% - 48px); + border-collapse: collapse; + border: 1px solid black; +} + +th, td { + padding: 12px; + border: 1px solid black; +} + +td.list { + vertical-align: top; +} + +ul { + list-style: none; +} + +td ul { + margin: 0; + padding: 0; +} + +code { + padding: 12px; + white-space: pre; +} + +#monitor { + width: calc(100% - 48px); + resize: vertical; + padding: 12px; + overflow: scroll; + height: 15lh; + border: 1px inset; + min-height: 1em; + display: flex; + flex-direction: column-reverse; +} diff --git a/util/eventbus/bench_test.go b/util/eventbus/bench_test.go new file mode 100644 index 0000000000000..25f5b80020880 --- /dev/null +++ b/util/eventbus/bench_test.go @@ -0,0 +1,125 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus_test + +import ( + "math/rand/v2" + "testing" + + "tailscale.com/util/eventbus" +) + +func BenchmarkBasicThroughput(b *testing.B) { + bus := eventbus.New() + pcli := bus.Client(b.Name() + "-pub") + scli := bus.Client(b.Name() + "-sub") + + type emptyEvent [0]byte + + // One publisher and a corresponding subscriber shoveling events as fast as + // they can through the plumbing. + pub := eventbus.Publish[emptyEvent](pcli) + sub := eventbus.Subscribe[emptyEvent](scli) + + go func() { + for { + select { + case <-sub.Events(): + continue + case <-sub.Done(): + return + } + } + }() + + for b.Loop() { + pub.Publish(emptyEvent{}) + } + bus.Close() +} + +func BenchmarkSubsThroughput(b *testing.B) { + bus := eventbus.New() + pcli := bus.Client(b.Name() + "-pub") + scli1 := bus.Client(b.Name() + "-sub1") + scli2 := bus.Client(b.Name() + "-sub2") + + type emptyEvent [0]byte + + // One publisher and two subscribers shoveling events as fast as they can + // through the plumbing. + pub := eventbus.Publish[emptyEvent](pcli) + sub1 := eventbus.Subscribe[emptyEvent](scli1) + sub2 := eventbus.Subscribe[emptyEvent](scli2) + + for _, sub := range []*eventbus.Subscriber[emptyEvent]{sub1, sub2} { + go func() { + for { + select { + case <-sub.Events(): + continue + case <-sub.Done(): + return + } + } + }() + } + + for b.Loop() { + pub.Publish(emptyEvent{}) + } + bus.Close() +} + +func BenchmarkMultiThroughput(b *testing.B) { + bus := eventbus.New() + cli := bus.Client(b.Name()) + + type eventA struct{} + type eventB struct{} + + // Two disjoint event streams routed through the global order. + apub := eventbus.Publish[eventA](cli) + asub := eventbus.Subscribe[eventA](cli) + bpub := eventbus.Publish[eventB](cli) + bsub := eventbus.Subscribe[eventB](cli) + + go func() { + for { + select { + case <-asub.Events(): + continue + case <-asub.Done(): + return + } + } + }() + go func() { + for { + select { + case <-bsub.Events(): + continue + case <-bsub.Done(): + return + } + } + }() + + var rng uint64 + var bits int + for b.Loop() { + if bits == 0 { + rng = rand.Uint64() + bits = 64 + } + if rng&1 == 0 { + apub.Publish(eventA{}) + } else { + bpub.Publish(eventB{}) + } + rng >>= 1 + bits-- + } + bus.Close() +} diff --git a/util/eventbus/bus.go b/util/eventbus/bus.go new file mode 100644 index 0000000000000..45d12da2f3736 --- /dev/null +++ b/util/eventbus/bus.go @@ -0,0 +1,309 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus + +import ( + "context" + "reflect" + "slices" + "sync" + + "tailscale.com/util/set" +) + +type PublishedEvent struct { + Event any + From *Client +} + +type RoutedEvent struct { + Event any + From *Client + To []*Client +} + +// Bus is an event bus that distributes published events to interested +// subscribers. +type Bus struct { + router *worker + write chan PublishedEvent + snapshot chan chan []PublishedEvent + routeDebug hook[RoutedEvent] + + topicsMu sync.Mutex + topics map[reflect.Type][]*subscribeState + + // Used for introspection/debugging only, not in the normal event + // publishing path. + clientsMu sync.Mutex + clients set.Set[*Client] +} + +// New returns a new bus. Use [PublisherOf] to make event publishers, +// and [Bus.Queue] and [Subscribe] to make event subscribers. +func New() *Bus { + ret := &Bus{ + write: make(chan PublishedEvent), + snapshot: make(chan chan []PublishedEvent), + topics: map[reflect.Type][]*subscribeState{}, + clients: set.Set[*Client]{}, + } + ret.router = runWorker(ret.pump) + return ret +} + +// Client returns a new client with no subscriptions. Use [Subscribe] +// to receive events, and [Publish] to emit events. +// +// The client's name is used only for debugging, to tell humans what +// piece of code a publisher/subscriber belongs to. Aim for something +// short but unique, for example "kernel-route-monitor" or "taildrop", +// not "watcher". +func (b *Bus) Client(name string) *Client { + ret := &Client{ + name: name, + bus: b, + pub: set.Set[publisher]{}, + } + b.clientsMu.Lock() + defer b.clientsMu.Unlock() + b.clients.Add(ret) + return ret +} + +// Debugger returns the debugging facility for the bus. +func (b *Bus) Debugger() *Debugger { + return &Debugger{b} +} + +// Close closes the bus. Implicitly closes all clients, publishers and +// subscribers attached to the bus. +// +// Close blocks until the bus is fully shut down. The bus is +// permanently unusable after closing. +func (b *Bus) Close() { + b.router.StopAndWait() + + b.clientsMu.Lock() + defer b.clientsMu.Unlock() + for c := range b.clients { + c.Close() + } + b.clients = nil +} + +func (b *Bus) pump(ctx context.Context) { + var vals queue[PublishedEvent] + acceptCh := func() chan PublishedEvent { + if vals.Full() { + return nil + } + return b.write + } + for { + // Drain all pending events. Note that while we're draining + // events into subscriber queues, we continue to + // opportunistically accept more incoming events, if we have + // queue space for it. + for !vals.Empty() { + val := vals.Peek() + dests := b.dest(reflect.ValueOf(val.Event).Type()) + + if b.routeDebug.active() { + clients := make([]*Client, len(dests)) + for i := range len(dests) { + clients[i] = dests[i].client + } + b.routeDebug.run(RoutedEvent{ + Event: val.Event, + From: val.From, + To: clients, + }) + } + + for _, d := range dests { + evt := DeliveredEvent{ + Event: val.Event, + From: val.From, + To: d.client, + } + deliverOne: + for { + select { + case d.write <- evt: + break deliverOne + case <-d.closed(): + // Queue closed, don't block but continue + // delivering to others. + break deliverOne + case in := <-acceptCh(): + vals.Add(in) + in.From.publishDebug.run(in) + case <-ctx.Done(): + return + case ch := <-b.snapshot: + ch <- vals.Snapshot() + } + } + } + vals.Drop() + } + + // Inbound queue empty, wait for at least 1 work item before + // resuming. + for vals.Empty() { + select { + case <-ctx.Done(): + return + case in := <-b.write: + vals.Add(in) + in.From.publishDebug.run(in) + case ch := <-b.snapshot: + ch <- nil + } + } + } +} + +func (b *Bus) dest(t reflect.Type) []*subscribeState { + b.topicsMu.Lock() + defer b.topicsMu.Unlock() + return b.topics[t] +} + +func (b *Bus) shouldPublish(t reflect.Type) bool { + if b.routeDebug.active() { + return true + } + + b.topicsMu.Lock() + defer b.topicsMu.Unlock() + return len(b.topics[t]) > 0 +} + +func (b *Bus) listClients() []*Client { + b.clientsMu.Lock() + defer b.clientsMu.Unlock() + return b.clients.Slice() +} + +func (b *Bus) snapshotPublishQueue() []PublishedEvent { + resp := make(chan []PublishedEvent) + select { + case b.snapshot <- resp: + return <-resp + case <-b.router.Done(): + return nil + } +} + +func (b *Bus) subscribe(t reflect.Type, q *subscribeState) (cancel func()) { + b.topicsMu.Lock() + defer b.topicsMu.Unlock() + b.topics[t] = append(b.topics[t], q) + return func() { + b.unsubscribe(t, q) + } +} + +func (b *Bus) unsubscribe(t reflect.Type, q *subscribeState) { + b.topicsMu.Lock() + defer b.topicsMu.Unlock() + // Topic slices are accessed by pump without holding a lock, so we + // have to replace the entire slice when unsubscribing. + // Unsubscribing should be infrequent enough that this won't + // matter. + i := slices.Index(b.topics[t], q) + if i < 0 { + return + } + b.topics[t] = slices.Delete(slices.Clone(b.topics[t]), i, i+1) +} + +// A worker runs a worker goroutine and helps coordinate its shutdown. +type worker struct { + ctx context.Context + stop context.CancelFunc + stopped chan struct{} +} + +// runWorker creates a worker goroutine running fn. The context passed +// to fn is canceled by [worker.Stop]. +func runWorker(fn func(context.Context)) *worker { + ctx, stop := context.WithCancel(context.Background()) + ret := &worker{ + ctx: ctx, + stop: stop, + stopped: make(chan struct{}), + } + go ret.run(fn) + return ret +} + +func (w *worker) run(fn func(context.Context)) { + defer close(w.stopped) + fn(w.ctx) +} + +// Stop signals the worker goroutine to shut down. +func (w *worker) Stop() { w.stop() } + +// Done returns a channel that is closed when the worker goroutine +// exits. +func (w *worker) Done() <-chan struct{} { return w.stopped } + +// Wait waits until the worker goroutine has exited. +func (w *worker) Wait() { <-w.stopped } + +// StopAndWait signals the worker goroutine to shut down, then waits +// for it to exit. +func (w *worker) StopAndWait() { + w.stop() + <-w.stopped +} + +// stopFlag is a value that can be watched for a notification. The +// zero value is ready for use. +// +// The flag is notified by running [stopFlag.Stop]. Stop can be called +// multiple times. Upon the first call to Stop, [stopFlag.Done] is +// closed, all pending [stopFlag.Wait] calls return, and future Wait +// calls return immediately. +// +// A stopFlag can only notify once, and is intended for use as a +// one-way shutdown signal that's lighter than a cancellable +// context.Context. +type stopFlag struct { + // guards the lazy construction of stopped, and the value of + // alreadyStopped. + mu sync.Mutex + stopped chan struct{} + alreadyStopped bool +} + +func (s *stopFlag) Stop() { + s.mu.Lock() + defer s.mu.Unlock() + if s.alreadyStopped { + return + } + s.alreadyStopped = true + if s.stopped == nil { + s.stopped = make(chan struct{}) + } + close(s.stopped) +} + +func (s *stopFlag) Done() <-chan struct{} { + s.mu.Lock() + defer s.mu.Unlock() + if s.stopped == nil { + s.stopped = make(chan struct{}) + } + return s.stopped +} + +func (s *stopFlag) Wait() { + <-s.Done() +} diff --git a/util/eventbus/bus_test.go b/util/eventbus/bus_test.go new file mode 100644 index 0000000000000..e159b6a12608a --- /dev/null +++ b/util/eventbus/bus_test.go @@ -0,0 +1,203 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus_test + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/creachadair/taskgroup" + "github.com/google/go-cmp/cmp" + "tailscale.com/util/eventbus" +) + +type EventA struct { + Counter int +} + +type EventB struct { + Counter int +} + +func TestBus(t *testing.T) { + b := eventbus.New() + defer b.Close() + + c := b.Client("TestSub") + defer c.Close() + s := eventbus.Subscribe[EventA](c) + + go func() { + p := b.Client("TestPub") + defer p.Close() + pa := eventbus.Publish[EventA](p) + defer pa.Close() + pb := eventbus.Publish[EventB](p) + defer pb.Close() + pa.Publish(EventA{1}) + pb.Publish(EventB{2}) + pa.Publish(EventA{3}) + }() + + want := expectEvents(t, EventA{1}, EventA{3}) + for !want.Empty() { + select { + case got := <-s.Events(): + want.Got(got) + case <-s.Done(): + t.Fatalf("queue closed unexpectedly") + case <-time.After(time.Second): + t.Fatalf("timed out waiting for event") + } + } +} + +func TestBusMultipleConsumers(t *testing.T) { + b := eventbus.New() + defer b.Close() + + c1 := b.Client("TestSubA") + defer c1.Close() + s1 := eventbus.Subscribe[EventA](c1) + + c2 := b.Client("TestSubB") + defer c2.Close() + s2A := eventbus.Subscribe[EventA](c2) + s2B := eventbus.Subscribe[EventB](c2) + + go func() { + p := b.Client("TestPub") + defer p.Close() + pa := eventbus.Publish[EventA](p) + defer pa.Close() + pb := eventbus.Publish[EventB](p) + defer pb.Close() + pa.Publish(EventA{1}) + pb.Publish(EventB{2}) + pa.Publish(EventA{3}) + }() + + wantA := expectEvents(t, EventA{1}, EventA{3}) + wantB := expectEvents(t, EventA{1}, EventB{2}, EventA{3}) + for !wantA.Empty() || !wantB.Empty() { + select { + case got := <-s1.Events(): + wantA.Got(got) + case got := <-s2A.Events(): + wantB.Got(got) + case got := <-s2B.Events(): + wantB.Got(got) + case <-s1.Done(): + t.Fatalf("queue closed unexpectedly") + case <-s2A.Done(): + t.Fatalf("queue closed unexpectedly") + case <-s2B.Done(): + t.Fatalf("queue closed unexpectedly") + case <-time.After(time.Second): + t.Fatalf("timed out waiting for event") + } + } +} + +func TestSpam(t *testing.T) { + b := eventbus.New() + defer b.Close() + + const ( + publishers = 100 + eventsPerPublisher = 20 + wantEvents = publishers * eventsPerPublisher + subscribers = 100 + ) + + var g taskgroup.Group + + received := make([][]EventA, subscribers) + for i := range subscribers { + c := b.Client(fmt.Sprintf("Subscriber%d", i)) + defer c.Close() + s := eventbus.Subscribe[EventA](c) + g.Go(func() error { + for range wantEvents { + select { + case evt := <-s.Events(): + received[i] = append(received[i], evt) + case <-s.Done(): + t.Errorf("queue done before expected number of events received") + return errors.New("queue prematurely closed") + case <-time.After(5 * time.Second): + t.Errorf("timed out waiting for expected bus event after %d events", len(received[i])) + return errors.New("timeout") + } + } + return nil + }) + } + + published := make([][]EventA, publishers) + for i := range publishers { + g.Run(func() { + c := b.Client(fmt.Sprintf("Publisher%d", i)) + p := eventbus.Publish[EventA](c) + for j := range eventsPerPublisher { + evt := EventA{i*eventsPerPublisher + j} + p.Publish(evt) + published[i] = append(published[i], evt) + } + }) + } + + if err := g.Wait(); err != nil { + t.Fatal(err) + } + var last []EventA + for i, got := range received { + if len(got) != wantEvents { + // Receiving goroutine already reported an error, we just need + // to fail early within the main test goroutine. + t.FailNow() + } + if last == nil { + continue + } + if diff := cmp.Diff(got, last); diff != "" { + t.Errorf("Subscriber %d did not see the same events as %d (-got+want):\n%s", i, i-1, diff) + } + last = got + } + for i, sent := range published { + if got := len(sent); got != eventsPerPublisher { + t.Fatalf("Publisher %d sent %d events, want %d", i, got, eventsPerPublisher) + } + } + + // TODO: check that the published sequences are proper + // subsequences of the received slices. +} + +type queueChecker struct { + t *testing.T + want []any +} + +func expectEvents(t *testing.T, want ...any) *queueChecker { + return &queueChecker{t, want} +} + +func (q *queueChecker) Got(v any) { + q.t.Helper() + if q.Empty() { + q.t.Fatalf("queue got unexpected %v", v) + } + if v != q.want[0] { + q.t.Fatalf("queue got %#v, want %#v", v, q.want[0]) + } + q.want = q.want[1:] +} + +func (q *queueChecker) Empty() bool { + return len(q.want) == 0 +} diff --git a/util/eventbus/client.go b/util/eventbus/client.go new file mode 100644 index 0000000000000..a7a88c0a158c7 --- /dev/null +++ b/util/eventbus/client.go @@ -0,0 +1,127 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus + +import ( + "reflect" + "sync" + + "tailscale.com/util/set" +) + +// A Client can publish and subscribe to events on its attached +// bus. See [Publish] to publish events, and [Subscribe] to receive +// events. +// +// Subscribers that share the same client receive events one at a +// time, in the order they were published. +type Client struct { + name string + bus *Bus + publishDebug hook[PublishedEvent] + + mu sync.Mutex + pub set.Set[publisher] + sub *subscribeState // Lazily created on first subscribe +} + +func (c *Client) Name() string { return c.name } + +// Close closes the client. Implicitly closes all publishers and +// subscribers obtained from this client. +func (c *Client) Close() { + var ( + pub set.Set[publisher] + sub *subscribeState + ) + + c.mu.Lock() + pub, c.pub = c.pub, nil + sub, c.sub = c.sub, nil + c.mu.Unlock() + + if sub != nil { + sub.close() + } + for p := range pub { + p.Close() + } +} + +func (c *Client) snapshotSubscribeQueue() []DeliveredEvent { + return c.peekSubscribeState().snapshotQueue() +} + +func (c *Client) peekSubscribeState() *subscribeState { + c.mu.Lock() + defer c.mu.Unlock() + return c.sub +} + +func (c *Client) publishTypes() []reflect.Type { + c.mu.Lock() + defer c.mu.Unlock() + ret := make([]reflect.Type, 0, len(c.pub)) + for pub := range c.pub { + ret = append(ret, pub.publishType()) + } + return ret +} + +func (c *Client) subscribeTypes() []reflect.Type { + return c.peekSubscribeState().subscribeTypes() +} + +func (c *Client) subscribeState() *subscribeState { + c.mu.Lock() + defer c.mu.Unlock() + if c.sub == nil { + c.sub = newSubscribeState(c) + } + return c.sub +} + +func (c *Client) addPublisher(pub publisher) { + c.mu.Lock() + defer c.mu.Unlock() + c.pub.Add(pub) +} + +func (c *Client) deletePublisher(pub publisher) { + c.mu.Lock() + defer c.mu.Unlock() + c.pub.Delete(pub) +} + +func (c *Client) addSubscriber(t reflect.Type, s *subscribeState) { + c.bus.subscribe(t, s) +} + +func (c *Client) deleteSubscriber(t reflect.Type, s *subscribeState) { + c.bus.unsubscribe(t, s) +} + +func (c *Client) publish() chan<- PublishedEvent { + return c.bus.write +} + +func (c *Client) shouldPublish(t reflect.Type) bool { + return c.publishDebug.active() || c.bus.shouldPublish(t) +} + +// Subscribe requests delivery of events of type T through the given +// Queue. Panics if the queue already has a subscriber for T. +func Subscribe[T any](c *Client) *Subscriber[T] { + return newSubscriber[T](c.subscribeState()) +} + +// Publisher returns a publisher for event type T using the given +// client. +func Publish[T any](c *Client) *Publisher[T] { + ret := newPublisher[T](c) + c.mu.Lock() + defer c.mu.Unlock() + c.pub.Add(ret) + return ret +} diff --git a/util/eventbus/debug-demo/main.go b/util/eventbus/debug-demo/main.go new file mode 100644 index 0000000000000..a6d232d882944 --- /dev/null +++ b/util/eventbus/debug-demo/main.go @@ -0,0 +1,103 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// debug-demo is a program that serves a bus's debug interface over +// HTTP, then generates some fake traffic from a handful of +// clients. It is an aid to development, to have something to present +// on the debug interfaces while writing them. +package main + +import ( + "log" + "math/rand/v2" + "net/http" + "net/netip" + "time" + + "tailscale.com/tsweb" + "tailscale.com/types/key" + "tailscale.com/util/eventbus" +) + +func main() { + b := eventbus.New() + c := b.Client("RouteMonitor") + go testPub[RouteAdded](c, 5*time.Second) + go testPub[RouteRemoved](c, 5*time.Second) + c = b.Client("ControlClient") + go testPub[PeerAdded](c, 3*time.Second) + go testPub[PeerRemoved](c, 6*time.Second) + c = b.Client("Portmapper") + go testPub[PortmapAcquired](c, 10*time.Second) + go testPub[PortmapLost](c, 15*time.Second) + go testSub[RouteAdded](c) + c = b.Client("WireguardConfig") + go testSub[PeerAdded](c) + go testSub[PeerRemoved](c) + c = b.Client("Magicsock") + go testPub[PeerPathChanged](c, 5*time.Second) + go testSub[RouteAdded](c) + go testSub[RouteRemoved](c) + go testSub[PortmapAcquired](c) + go testSub[PortmapLost](c) + + m := http.NewServeMux() + d := tsweb.Debugger(m) + b.Debugger().RegisterHTTP(d) + + m.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/debug/bus", http.StatusFound) + }) + log.Printf("Serving debug interface at http://localhost:8185/debug/bus") + http.ListenAndServe(":8185", m) +} + +func testPub[T any](c *eventbus.Client, every time.Duration) { + p := eventbus.Publish[T](c) + for { + jitter := time.Duration(rand.N(2000)) * time.Millisecond + time.Sleep(jitter) + var zero T + log.Printf("%s publish: %T", c.Name(), zero) + p.Publish(zero) + time.Sleep(every) + } +} + +func testSub[T any](c *eventbus.Client) { + s := eventbus.Subscribe[T](c) + for v := range s.Events() { + log.Printf("%s received: %T", c.Name(), v) + } +} + +type RouteAdded struct { + Prefix netip.Prefix + Via netip.Addr + Priority int +} +type RouteRemoved struct { + Prefix netip.Addr +} + +type PeerAdded struct { + ID int + Key key.NodePublic +} +type PeerRemoved struct { + ID int + Key key.NodePublic +} + +type PortmapAcquired struct { + Endpoint netip.Addr +} +type PortmapLost struct { + Endpoint netip.Addr +} + +type PeerPathChanged struct { + ID int + EndpointID int + Quality int +} diff --git a/util/eventbus/debug.go b/util/eventbus/debug.go new file mode 100644 index 0000000000000..832d72ac07dda --- /dev/null +++ b/util/eventbus/debug.go @@ -0,0 +1,188 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus + +import ( + "cmp" + "fmt" + "reflect" + "slices" + "sync" + "sync/atomic" + + "tailscale.com/tsweb" +) + +// A Debugger offers access to a bus's privileged introspection and +// debugging facilities. +// +// The debugger's functionality is intended for humans and their tools +// to examine and troubleshoot bus clients, and should not be used in +// normal codepaths. +// +// In particular, the debugger provides access to information that is +// deliberately withheld from bus clients to encourage more robust and +// maintainable code - for example, the sender of an event, or the +// event streams of other clients. Please don't use the debugger to +// circumvent these restrictions for purposes other than debugging. +type Debugger struct { + bus *Bus +} + +// Clients returns a list of all clients attached to the bus. +func (d *Debugger) Clients() []*Client { + ret := d.bus.listClients() + slices.SortFunc(ret, func(a, b *Client) int { + return cmp.Compare(a.Name(), b.Name()) + }) + return ret +} + +// PublishQueue returns the contents of the publish queue. +// +// The publish queue contains events that have been accepted by the +// bus from Publish() calls, but have not yet been routed to relevant +// subscribers. +// +// This queue is expected to be almost empty in normal operation. A +// full publish queue indicates that a slow subscriber downstream is +// causing backpressure and stalling the bus. +func (d *Debugger) PublishQueue() []PublishedEvent { + return d.bus.snapshotPublishQueue() +} + +// checkClient verifies that client is attached to the same bus as the +// Debugger, and panics if not. +func (d *Debugger) checkClient(client *Client) { + if client.bus != d.bus { + panic(fmt.Errorf("SubscribeQueue given client belonging to wrong bus")) + } +} + +// SubscribeQueue returns the contents of the given client's subscribe +// queue. +// +// The subscribe queue contains events that are to be delivered to the +// client, but haven't yet been handed off to the relevant +// [Subscriber]. +// +// This queue is expected to be almost empty in normal operation. A +// full subscribe queue indicates that the client is accepting events +// too slowly, and may be causing the rest of the bus to stall. +func (d *Debugger) SubscribeQueue(client *Client) []DeliveredEvent { + d.checkClient(client) + return client.snapshotSubscribeQueue() +} + +// WatchBus streams information about all events passing through the +// bus. +// +// Monitored events are delivered in the bus's global publication +// order (see "Concurrency properties" in the package docs). +// +// The caller must consume monitoring events promptly to avoid +// stalling the bus (see "Expected subscriber behavior" in the package +// docs). +func (d *Debugger) WatchBus() *Subscriber[RoutedEvent] { + return newMonitor(d.bus.routeDebug.add) +} + +// WatchPublish streams information about all events published by the +// given client. +// +// Monitored events are delivered in the bus's global publication +// order (see "Concurrency properties" in the package docs). +// +// The caller must consume monitoring events promptly to avoid +// stalling the bus (see "Expected subscriber behavior" in the package +// docs). +func (d *Debugger) WatchPublish(client *Client) *Subscriber[PublishedEvent] { + d.checkClient(client) + return newMonitor(client.publishDebug.add) +} + +// WatchSubscribe streams information about all events received by the +// given client. +// +// Monitored events are delivered in the bus's global publication +// order (see "Concurrency properties" in the package docs). +// +// The caller must consume monitoring events promptly to avoid +// stalling the bus (see "Expected subscriber behavior" in the package +// docs). +func (d *Debugger) WatchSubscribe(client *Client) *Subscriber[DeliveredEvent] { + d.checkClient(client) + return newMonitor(client.subscribeState().debug.add) +} + +// PublishTypes returns the list of types being published by client. +// +// The returned types are those for which the client has obtained a +// [Publisher]. The client may not have ever sent the type in +// question. +func (d *Debugger) PublishTypes(client *Client) []reflect.Type { + d.checkClient(client) + return client.publishTypes() +} + +// SubscribeTypes returns the list of types being subscribed to by +// client. +// +// The returned types are those for which the client has obtained a +// [Subscriber]. The client may not have ever received the type in +// question, and here may not be any publishers of the type. +func (d *Debugger) SubscribeTypes(client *Client) []reflect.Type { + d.checkClient(client) + return client.subscribeTypes() +} + +func (d *Debugger) RegisterHTTP(td *tsweb.DebugHandler) { registerHTTPDebugger(d, td) } + +// A hook collects hook functions that can be run as a group. +type hook[T any] struct { + sync.Mutex + fns []hookFn[T] +} + +var hookID atomic.Uint64 + +// add registers fn to be called when the hook is run. Returns an +// unregistration function that removes fn from the hook when called. +func (h *hook[T]) add(fn func(T)) (remove func()) { + id := hookID.Add(1) + h.Lock() + defer h.Unlock() + h.fns = append(h.fns, hookFn[T]{id, fn}) + return func() { h.remove(id) } +} + +// remove removes the function with the given ID from the hook. +func (h *hook[T]) remove(id uint64) { + h.Lock() + defer h.Unlock() + h.fns = slices.DeleteFunc(h.fns, func(f hookFn[T]) bool { return f.ID == id }) +} + +// active reports whether any functions are registered with the +// hook. This can be used to skip expensive work when the hook is +// inactive. +func (h *hook[T]) active() bool { + h.Lock() + defer h.Unlock() + return len(h.fns) > 0 +} + +// run calls all registered functions with the value v. +func (h *hook[T]) run(v T) { + h.Lock() + defer h.Unlock() + for _, fn := range h.fns { + fn.Fn(v) + } +} + +type hookFn[T any] struct { + ID uint64 + Fn func(T) +} diff --git a/util/eventbus/debughttp.go b/util/eventbus/debughttp.go new file mode 100644 index 0000000000000..a94eaa9cf7ba2 --- /dev/null +++ b/util/eventbus/debughttp.go @@ -0,0 +1,240 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ios && !android + +package eventbus + +import ( + "bytes" + "cmp" + "embed" + "fmt" + "html/template" + "io" + "io/fs" + "log" + "net/http" + "path/filepath" + "reflect" + "slices" + "strings" + "sync" + + "github.com/coder/websocket" + "tailscale.com/tsweb" +) + +type httpDebugger struct { + *Debugger +} + +func registerHTTPDebugger(d *Debugger, td *tsweb.DebugHandler) { + dh := httpDebugger{d} + td.Handle("bus", "Event bus", dh) + td.HandleSilent("bus/monitor", http.HandlerFunc(dh.serveMonitor)) + td.HandleSilent("bus/style.css", serveStatic("style.css")) + td.HandleSilent("bus/htmx.min.js", serveStatic("htmx.min.js.gz")) + td.HandleSilent("bus/htmx-websocket.min.js", serveStatic("htmx-websocket.min.js.gz")) +} + +//go:embed assets/*.html +var templatesSrc embed.FS + +var templates = sync.OnceValue(func() *template.Template { + d, err := fs.Sub(templatesSrc, "assets") + if err != nil { + panic(fmt.Errorf("getting eventbus debughttp templates subdir: %w", err)) + } + ret := template.New("").Funcs(map[string]any{ + "prettyPrintStruct": prettyPrintStruct, + }) + return template.Must(ret.ParseFS(d, "*")) +}) + +//go:generate go run fetch-htmx.go + +//go:embed assets/*.css assets/*.min.js.gz +var static embed.FS + +func serveStatic(name string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.HasSuffix(name, ".css"): + w.Header().Set("Content-Type", "text/css") + case strings.HasSuffix(name, ".min.js.gz"): + w.Header().Set("Content-Type", "text/javascript") + w.Header().Set("Content-Encoding", "gzip") + case strings.HasSuffix(name, ".js"): + w.Header().Set("Content-Type", "text/javascript") + default: + http.Error(w, "not found", http.StatusNotFound) + return + } + + f, err := static.Open(filepath.Join("assets", name)) + if err != nil { + http.Error(w, fmt.Sprintf("opening asset: %v", err), http.StatusInternalServerError) + return + } + defer f.Close() + if _, err := io.Copy(w, f); err != nil { + http.Error(w, fmt.Sprintf("serving asset: %v", err), http.StatusInternalServerError) + return + } + }) +} + +func render(w http.ResponseWriter, name string, data any) { + err := templates().ExecuteTemplate(w, name+".html", data) + if err != nil { + err := fmt.Errorf("rendering template: %v", err) + log.Print(err) + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func (h httpDebugger) ServeHTTP(w http.ResponseWriter, r *http.Request) { + type clientInfo struct { + *Client + Publish []reflect.Type + Subscribe []reflect.Type + } + type typeInfo struct { + reflect.Type + Publish []*Client + Subscribe []*Client + } + type info struct { + *Debugger + Clients map[string]*clientInfo + Types map[string]*typeInfo + } + + data := info{ + Debugger: h.Debugger, + Clients: map[string]*clientInfo{}, + Types: map[string]*typeInfo{}, + } + + getTypeInfo := func(t reflect.Type) *typeInfo { + if data.Types[t.Name()] == nil { + data.Types[t.Name()] = &typeInfo{ + Type: t, + } + } + return data.Types[t.Name()] + } + + for _, c := range h.Clients() { + ci := &clientInfo{ + Client: c, + Publish: h.PublishTypes(c), + Subscribe: h.SubscribeTypes(c), + } + slices.SortFunc(ci.Publish, func(a, b reflect.Type) int { return cmp.Compare(a.Name(), b.Name()) }) + slices.SortFunc(ci.Subscribe, func(a, b reflect.Type) int { return cmp.Compare(a.Name(), b.Name()) }) + data.Clients[c.Name()] = ci + + for _, t := range ci.Publish { + ti := getTypeInfo(t) + ti.Publish = append(ti.Publish, c) + } + for _, t := range ci.Subscribe { + ti := getTypeInfo(t) + ti.Subscribe = append(ti.Subscribe, c) + } + } + + render(w, "main", data) +} + +func (h httpDebugger) serveMonitor(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Upgrade") == "websocket" { + h.serveMonitorStream(w, r) + return + } + + render(w, "monitor", nil) +} + +func (h httpDebugger) serveMonitorStream(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + if err != nil { + return + } + defer conn.CloseNow() + wsCtx := conn.CloseRead(r.Context()) + + mon := h.WatchBus() + defer mon.Close() + + i := 0 + for { + select { + case <-r.Context().Done(): + return + case <-wsCtx.Done(): + return + case <-mon.Done(): + return + case event := <-mon.Events(): + msg, err := conn.Writer(r.Context(), websocket.MessageText) + if err != nil { + return + } + data := map[string]any{ + "Count": i, + "Type": reflect.TypeOf(event.Event), + "Event": event, + } + i++ + if err := templates().ExecuteTemplate(msg, "event.html", data); err != nil { + log.Println(err) + return + } + if err := msg.Close(); err != nil { + return + } + } + } +} + +func prettyPrintStruct(t reflect.Type) string { + if t.Kind() != reflect.Struct { + return t.String() + } + var rec func(io.Writer, int, reflect.Type) + rec = func(out io.Writer, indent int, t reflect.Type) { + ind := strings.Repeat(" ", indent) + fmt.Fprintf(out, "%s", t.String()) + fs := collectFields(t) + if len(fs) > 0 { + io.WriteString(out, " {\n") + for _, f := range fs { + fmt.Fprintf(out, "%s %s ", ind, f.Name) + if f.Type.Kind() == reflect.Struct { + rec(out, indent+1, f.Type) + } else { + fmt.Fprint(out, f.Type) + } + io.WriteString(out, "\n") + } + fmt.Fprintf(out, "%s}", ind) + } + } + + var ret bytes.Buffer + rec(&ret, 0, t) + return ret.String() +} + +func collectFields(t reflect.Type) (ret []reflect.StructField) { + for _, f := range reflect.VisibleFields(t) { + if !f.IsExported() { + continue + } + ret = append(ret, f) + } + return ret +} diff --git a/util/eventbus/debughttp_off.go b/util/eventbus/debughttp_off.go new file mode 100644 index 0000000000000..85330579c8329 --- /dev/null +++ b/util/eventbus/debughttp_off.go @@ -0,0 +1,20 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ios || android + +package eventbus + +import "tailscale.com/tsweb" + +func registerHTTPDebugger(d *Debugger, td *tsweb.DebugHandler) { + // The event bus debugging UI uses html/template, which uses + // reflection for method lookups. This forces the compiler to + // retain a lot more code and information to make dynamic method + // dispatch work, which is unacceptable bloat for the iOS build. + // We also disable it on Android while we're at it, as nobody + // is debugging Tailscale internals on Android. + // + // TODO: https://github.com/tailscale/tailscale/issues/15297 to + // bring the debug UI back to iOS somehow. +} diff --git a/util/eventbus/doc.go b/util/eventbus/doc.go new file mode 100644 index 0000000000000..964a686eae109 --- /dev/null +++ b/util/eventbus/doc.go @@ -0,0 +1,92 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package eventbus provides an in-process event bus. +// +// An event bus connects publishers of typed events with subscribers +// interested in those events. Typically, there is one global event +// bus per process. +// +// # Usage +// +// To send or receive events, first use [Bus.Client] to register with +// the bus. Clients should register with a human-readable name that +// identifies the code using the client, to aid in debugging. +// +// To publish events, use [Publish] on a Client to get a typed +// publisher for your event type, then call [Publisher.Publish] as +// needed. If your event is expensive to construct, you can optionally +// use [Publisher.ShouldPublish] to skip the work if nobody is +// listening for the event. +// +// To receive events, use [Subscribe] to get a typed subscriber for +// each event type you're interested in. Receive the events themselves +// by selecting over all your [Subscriber.Events] channels, as well as +// [Subscriber.Done] for shutdown notifications. +// +// # Concurrency properties +// +// The bus serializes all published events across all publishers, and +// preserves that ordering when delivering to subscribers that are +// attached to the same Client. In more detail: +// +// - An event is published to the bus at some instant between the +// start and end of the call to [Publisher.Publish]. +// - Two events cannot be published at the same instant, and so are +// totally ordered by their publication time. Given two events E1 +// and E2, either E1 happens before E2, or E2 happens before E1. +// - Clients dispatch events to their Subscribers in publication +// order: if E1 happens before E2, the client always delivers E1 +// before E2. +// - Clients do not synchronize subscriptions with each other: given +// clients C1 and C2, both subscribed to events E1 and E2, C1 may +// deliver both E1 and E2 before C2 delivers E1. +// +// Less formally: there is one true timeline of all published events. +// If you make a Client and subscribe to events, you will receive +// events one at a time, in the same order as the one true +// timeline. You will "skip over" events you didn't subscribe to, but +// your view of the world always moves forward in time, never +// backwards, and you will observe events in the same order as +// everyone else. +// +// However, you cannot assume that what your client see as "now" is +// the same as what other clients. They may be further behind you in +// working through the timeline, or running ahead of you. This means +// you should be careful about reaching out to another component +// directly after receiving an event, as its view of the world may not +// yet (or ever) be exactly consistent with yours. +// +// To make your code more testable and understandable, you should try +// to structure it following the actor model: you have some local +// state over which you have authority, but your only way to interact +// with state elsewhere in the program is to receive and process +// events coming from elsewhere, or to emit events of your own. +// +// # Expected subscriber behavior +// +// Subscribers are expected to promptly receive their events on +// [Subscriber.Events]. The bus has a small, fixed amount of internal +// buffering, meaning that a slow subscriber will eventually cause +// backpressure and block publication of all further events. +// +// In general, you should receive from your subscriber(s) in a loop, +// and only do fast state updates within that loop. Any heavier work +// should be offloaded to another goroutine. +// +// Causing publishers to block from backpressure is considered a bug +// in the slow subscriber causing the backpressure, and should be +// addressed there. Publishers should assume that Publish will not +// block for extended periods of time, and should not make exceptional +// effort to behave gracefully if they do get blocked. +// +// These blocking semantics are provisional and subject to +// change. Please speak up if this causes development pain, so that we +// can adapt the semantics to better suit our needs. +// +// # Debugging facilities +// +// The [Debugger], obtained through [Bus.Debugger], provides +// introspection facilities to monitor events flowing through the bus, +// and inspect publisher and subscriber state. +package eventbus diff --git a/util/eventbus/fetch-htmx.go b/util/eventbus/fetch-htmx.go new file mode 100644 index 0000000000000..f80d5025727fd --- /dev/null +++ b/util/eventbus/fetch-htmx.go @@ -0,0 +1,93 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build ignore + +// Program fetch-htmx fetches and installs local copies of the HTMX +// library and its dependencies, used by the debug UI. It is meant to +// be run via go generate. +package main + +import ( + "compress/gzip" + "crypto/sha512" + "encoding/base64" + "fmt" + "io" + "log" + "net/http" + "os" +) + +func main() { + // Hash from https://htmx.org/docs/#installing + htmx, err := fetchHashed("https://unpkg.com/htmx.org@2.0.4", "HGfztofotfshcF7+8n44JQL2oJmowVChPTg48S+jvZoztPfvwD79OC/LTtG6dMp+") + if err != nil { + log.Fatalf("fetching htmx: %v", err) + } + + // Hash SHOULD be from https://htmx.org/extensions/ws/ , but the + // hash is currently incorrect, see + // https://github.com/bigskysoftware/htmx-extensions/issues/153 + // + // Until that bug is resolved, hash was obtained by rebuilding the + // extension from git source, and verifying that the hash matches + // what unpkg is serving. + ws, err := fetchHashed("https://unpkg.com/htmx-ext-ws@2.0.2", "932iIqjARv+Gy0+r6RTGrfCkCKS5MsF539Iqf6Vt8L4YmbnnWI2DSFoMD90bvXd0") + if err != nil { + log.Fatalf("fetching htmx-websockets: %v", err) + } + + if err := writeGz("assets/htmx.min.js.gz", htmx); err != nil { + log.Fatalf("writing htmx.min.js.gz: %v", err) + } + if err := writeGz("assets/htmx-websocket.min.js.gz", ws); err != nil { + log.Fatalf("writing htmx-websocket.min.js.gz: %v", err) + } +} + +func writeGz(path string, bs []byte) error { + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + + g, err := gzip.NewWriterLevel(f, gzip.BestCompression) + if err != nil { + return err + } + + if _, err := g.Write(bs); err != nil { + return err + } + + if err := g.Flush(); err != nil { + return err + } + if err := f.Close(); err != nil { + return err + } + return nil +} + +func fetchHashed(url, wantHash string) ([]byte, error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("fetching %q returned error status: %s", url, resp.Status) + } + ret, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("reading file from %q: %v", url, err) + } + h := sha512.Sum384(ret) + got := base64.StdEncoding.EncodeToString(h[:]) + if got != wantHash { + return nil, fmt.Errorf("wrong hash for %q: got %q, want %q", url, got, wantHash) + } + return ret, nil +} diff --git a/util/eventbus/publish.go b/util/eventbus/publish.go new file mode 100644 index 0000000000000..9897114b64973 --- /dev/null +++ b/util/eventbus/publish.go @@ -0,0 +1,74 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus + +import ( + "reflect" +) + +// publisher is a uniformly typed wrapper around Publisher[T], so that +// debugging facilities can look at active publishers. +type publisher interface { + publishType() reflect.Type + Close() +} + +// A Publisher publishes typed events on a bus. +type Publisher[T any] struct { + client *Client + stop stopFlag +} + +func newPublisher[T any](c *Client) *Publisher[T] { + ret := &Publisher[T]{ + client: c, + } + c.addPublisher(ret) + return ret +} + +// Close closes the publisher. +// +// Calls to Publish after Close silently do nothing. +func (p *Publisher[T]) Close() { + // Just unblocks any active calls to Publish, no other + // synchronization needed. + p.stop.Stop() + p.client.deletePublisher(p) +} + +func (p *Publisher[T]) publishType() reflect.Type { + return reflect.TypeFor[T]() +} + +// Publish publishes event v on the bus. +func (p *Publisher[T]) Publish(v T) { + // Check for just a stopped publisher or bus before trying to + // write, so that once closed Publish consistently does nothing. + select { + case <-p.stop.Done(): + return + default: + } + + evt := PublishedEvent{ + Event: v, + From: p.client, + } + + select { + case p.client.publish() <- evt: + case <-p.stop.Done(): + } +} + +// ShouldPublish reports whether anyone is subscribed to the events +// that this publisher emits. +// +// ShouldPublish can be used to skip expensive event construction if +// nobody seems to care. Publishers must not assume that someone will +// definitely receive an event if ShouldPublish returns true. +func (p *Publisher[T]) ShouldPublish() bool { + return p.client.shouldPublish(reflect.TypeFor[T]()) +} diff --git a/util/eventbus/queue.go b/util/eventbus/queue.go new file mode 100644 index 0000000000000..a62bf3c62d1d4 --- /dev/null +++ b/util/eventbus/queue.go @@ -0,0 +1,85 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus + +import ( + "slices" +) + +const maxQueuedItems = 16 + +// queue is an ordered queue of length up to maxQueuedItems. +type queue[T any] struct { + vals []T + start int +} + +// canAppend reports whether a value can be appended to q.vals without +// shifting values around. +func (q *queue[T]) canAppend() bool { + return cap(q.vals) < maxQueuedItems || len(q.vals) < cap(q.vals) +} + +func (q *queue[T]) Full() bool { + return q.start == 0 && !q.canAppend() +} + +func (q *queue[T]) Empty() bool { + return q.start == len(q.vals) +} + +func (q *queue[T]) Len() int { + return len(q.vals) - q.start +} + +// Add adds v to the end of the queue. Blocks until append can be +// done. +func (q *queue[T]) Add(v T) { + if !q.canAppend() { + if q.start == 0 { + panic("Add on a full queue") + } + + // Slide remaining values back to the start of the array. + n := copy(q.vals, q.vals[q.start:]) + toClear := len(q.vals) - n + clear(q.vals[len(q.vals)-toClear:]) + q.vals = q.vals[:n] + q.start = 0 + } + + q.vals = append(q.vals, v) +} + +// Peek returns the first value in the queue, without removing it from +// the queue, or nil if the queue is empty. +func (q *queue[T]) Peek() T { + if q.Empty() { + var zero T + return zero + } + + return q.vals[q.start] +} + +// Drop discards the first value in the queue, if any. +func (q *queue[T]) Drop() { + if q.Empty() { + return + } + + var zero T + q.vals[q.start] = zero + q.start++ + if q.Empty() { + // Reset cursor to start of array, it's free to do. + q.start = 0 + q.vals = q.vals[:0] + } +} + +// Snapshot returns a copy of the queue's contents. +func (q *queue[T]) Snapshot() []T { + return slices.Clone(q.vals[q.start:]) +} diff --git a/util/eventbus/subscribe.go b/util/eventbus/subscribe.go new file mode 100644 index 0000000000000..ba17e85484655 --- /dev/null +++ b/util/eventbus/subscribe.go @@ -0,0 +1,254 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package eventbus + +import ( + "context" + "fmt" + "reflect" + "sync" +) + +type DeliveredEvent struct { + Event any + From *Client + To *Client +} + +// subscriber is a uniformly typed wrapper around Subscriber[T], so +// that debugging facilities can look at active subscribers. +type subscriber interface { + subscribeType() reflect.Type + // dispatch is a function that dispatches the head value in vals to + // a subscriber, while also handling stop and incoming queue write + // events. + // + // dispatch exists because of the strongly typed Subscriber[T] + // wrapper around subscriptions: within the bus events are boxed in an + // 'any', and need to be unpacked to their full type before delivery + // to the subscriber. This involves writing to a strongly-typed + // channel, so subscribeState cannot handle that dispatch by itself - + // but if that strongly typed send blocks, we also need to keep + // processing other potential sources of wakeups, which is how we end + // up at this awkward type signature and sharing of internal state + // through dispatch. + dispatch(ctx context.Context, vals *queue[DeliveredEvent], acceptCh func() chan DeliveredEvent, snapshot chan chan []DeliveredEvent) bool + Close() +} + +// subscribeState handles dispatching of events received from a Bus. +type subscribeState struct { + client *Client + + dispatcher *worker + write chan DeliveredEvent + snapshot chan chan []DeliveredEvent + debug hook[DeliveredEvent] + + outputsMu sync.Mutex + outputs map[reflect.Type]subscriber +} + +func newSubscribeState(c *Client) *subscribeState { + ret := &subscribeState{ + client: c, + write: make(chan DeliveredEvent), + snapshot: make(chan chan []DeliveredEvent), + outputs: map[reflect.Type]subscriber{}, + } + ret.dispatcher = runWorker(ret.pump) + return ret +} + +func (q *subscribeState) pump(ctx context.Context) { + var vals queue[DeliveredEvent] + acceptCh := func() chan DeliveredEvent { + if vals.Full() { + return nil + } + return q.write + } + for { + if !vals.Empty() { + val := vals.Peek() + sub := q.subscriberFor(val.Event) + if sub == nil { + // Raced with unsubscribe. + vals.Drop() + continue + } + if !sub.dispatch(ctx, &vals, acceptCh, q.snapshot) { + return + } + + if q.debug.active() { + q.debug.run(DeliveredEvent{ + Event: val.Event, + From: val.From, + To: q.client, + }) + } + } else { + // Keep the cases in this select in sync with + // Subscriber.dispatch below. The only different should be + // that this select doesn't deliver queued values to + // anyone, and unconditionally accepts new values. + select { + case val := <-q.write: + vals.Add(val) + case <-ctx.Done(): + return + case ch := <-q.snapshot: + ch <- vals.Snapshot() + } + } + } +} + +func (s *subscribeState) snapshotQueue() []DeliveredEvent { + if s == nil { + return nil + } + + resp := make(chan []DeliveredEvent) + select { + case s.snapshot <- resp: + return <-resp + case <-s.dispatcher.Done(): + return nil + } +} + +func (s *subscribeState) subscribeTypes() []reflect.Type { + if s == nil { + return nil + } + + s.outputsMu.Lock() + defer s.outputsMu.Unlock() + ret := make([]reflect.Type, 0, len(s.outputs)) + for t := range s.outputs { + ret = append(ret, t) + } + return ret +} + +func (s *subscribeState) addSubscriber(t reflect.Type, sub subscriber) { + s.outputsMu.Lock() + defer s.outputsMu.Unlock() + if s.outputs[t] != nil { + panic(fmt.Errorf("double subscription for event %s", t)) + } + s.outputs[t] = sub + s.client.addSubscriber(t, s) +} + +func (s *subscribeState) deleteSubscriber(t reflect.Type) { + s.outputsMu.Lock() + defer s.outputsMu.Unlock() + delete(s.outputs, t) + s.client.deleteSubscriber(t, s) +} + +func (q *subscribeState) subscriberFor(val any) subscriber { + q.outputsMu.Lock() + defer q.outputsMu.Unlock() + return q.outputs[reflect.TypeOf(val)] +} + +// Close closes the subscribeState. Implicitly closes all Subscribers +// linked to this state, and any pending events are discarded. +func (s *subscribeState) close() { + s.dispatcher.StopAndWait() + + var subs map[reflect.Type]subscriber + s.outputsMu.Lock() + subs, s.outputs = s.outputs, nil + s.outputsMu.Unlock() + for _, sub := range subs { + sub.Close() + } +} + +func (s *subscribeState) closed() <-chan struct{} { + return s.dispatcher.Done() +} + +// A Subscriber delivers one type of event from a [Client]. +type Subscriber[T any] struct { + stop stopFlag + read chan T + unregister func() +} + +func newSubscriber[T any](r *subscribeState) *Subscriber[T] { + t := reflect.TypeFor[T]() + + ret := &Subscriber[T]{ + read: make(chan T), + unregister: func() { r.deleteSubscriber(t) }, + } + r.addSubscriber(t, ret) + + return ret +} + +func newMonitor[T any](attach func(fn func(T)) (cancel func())) *Subscriber[T] { + ret := &Subscriber[T]{ + read: make(chan T, 100), // arbitrary, large + } + ret.unregister = attach(ret.monitor) + return ret +} + +func (s *Subscriber[T]) subscribeType() reflect.Type { + return reflect.TypeFor[T]() +} + +func (s *Subscriber[T]) monitor(debugEvent T) { + select { + case s.read <- debugEvent: + case <-s.stop.Done(): + } +} + +func (s *Subscriber[T]) dispatch(ctx context.Context, vals *queue[DeliveredEvent], acceptCh func() chan DeliveredEvent, snapshot chan chan []DeliveredEvent) bool { + t := vals.Peek().Event.(T) + for { + // Keep the cases in this select in sync with subscribeState.pump + // above. The only different should be that this select + // delivers a value on s.read. + select { + case s.read <- t: + vals.Drop() + return true + case val := <-acceptCh(): + vals.Add(val) + case <-ctx.Done(): + return false + case ch := <-snapshot: + ch <- vals.Snapshot() + } + } +} + +// Events returns a channel on which the subscriber's events are +// delivered. +func (s *Subscriber[T]) Events() <-chan T { + return s.read +} + +// Done returns a channel that is closed when the subscriber is +// closed. +func (s *Subscriber[T]) Done() <-chan struct{} { + return s.stop.Done() +} + +// Close closes the Subscriber, indicating the caller no longer wishes +// to receive this event type. After Close, receives on +// [Subscriber.Events] block for ever. +func (s *Subscriber[T]) Close() { + s.stop.Stop() // unblock receivers + s.unregister() +} diff --git a/util/fastuuid/fastuuid.go b/util/fastuuid/fastuuid.go deleted file mode 100644 index 4b115ea4e4974..0000000000000 --- a/util/fastuuid/fastuuid.go +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package fastuuid implements a UUID construction using an in process CSPRNG. -package fastuuid - -import ( - crand "crypto/rand" - "encoding/binary" - "io" - "math/rand/v2" - "sync" - - "github.com/google/uuid" -) - -// NewUUID returns a new UUID using a pool of generators, good for highly -// concurrent use. -func NewUUID() uuid.UUID { - g := pool.Get().(*generator) - defer pool.Put(g) - return g.newUUID() -} - -var pool = sync.Pool{ - New: func() any { - return newGenerator() - }, -} - -type generator struct { - rng rand.ChaCha8 -} - -func seed() [32]byte { - var r [32]byte - if _, err := io.ReadFull(crand.Reader, r[:]); err != nil { - panic(err) - } - return r -} - -func newGenerator() *generator { - return &generator{ - rng: *rand.NewChaCha8(seed()), - } -} - -func (g *generator) newUUID() uuid.UUID { - var u uuid.UUID - binary.NativeEndian.PutUint64(u[:8], g.rng.Uint64()) - binary.NativeEndian.PutUint64(u[8:], g.rng.Uint64()) - u[6] = (u[6] & 0x0f) | 0x40 // Version 4 - u[8] = (u[8] & 0x3f) | 0x80 // Variant 10 - return u -} diff --git a/util/fastuuid/fastuuid_test.go b/util/fastuuid/fastuuid_test.go deleted file mode 100644 index f0d9939043850..0000000000000 --- a/util/fastuuid/fastuuid_test.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package fastuuid - -import ( - "testing" - - "github.com/google/uuid" -) - -func TestNewUUID(t *testing.T) { - g := pool.Get().(*generator) - defer pool.Put(g) - u := g.newUUID() - if u[6] != (u[6]&0x0f)|0x40 { - t.Errorf("version bits are incorrect") - } - if u[8] != (u[8]&0x3f)|0x80 { - t.Errorf("variant bits are incorrect") - } -} - -func BenchmarkBasic(b *testing.B) { - b.Run("NewUUID", func(b *testing.B) { - for range b.N { - NewUUID() - } - }) - - b.Run("uuid.New-unpooled", func(b *testing.B) { - uuid.DisableRandPool() - for range b.N { - uuid.New() - } - }) - - b.Run("uuid.New-pooled", func(b *testing.B) { - uuid.EnableRandPool() - for range b.N { - uuid.New() - } - }) -} - -func BenchmarkParallel(b *testing.B) { - b.Run("NewUUID", func(b *testing.B) { - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - NewUUID() - } - }) - }) - - b.Run("uuid.New-unpooled", func(b *testing.B) { - uuid.DisableRandPool() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - uuid.New() - } - }) - }) - - b.Run("uuid.New-pooled", func(b *testing.B) { - uuid.EnableRandPool() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - uuid.New() - } - }) - }) -} diff --git a/util/goroutines/goroutines.go b/util/goroutines/goroutines.go index 9758b07586613..d40cbecb10876 100644 --- a/util/goroutines/goroutines.go +++ b/util/goroutines/goroutines.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -// The goroutines package contains utilities for getting active goroutines. +// The goroutines package contains utilities for tracking and getting active goroutines. package goroutines import ( diff --git a/util/goroutines/tracker.go b/util/goroutines/tracker.go new file mode 100644 index 0000000000000..044843d33d155 --- /dev/null +++ b/util/goroutines/tracker.go @@ -0,0 +1,66 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package goroutines + +import ( + "sync" + "sync/atomic" + + "tailscale.com/util/set" +) + +// Tracker tracks a set of goroutines. +type Tracker struct { + started atomic.Int64 // counter + running atomic.Int64 // gauge + + mu sync.Mutex + onDone set.HandleSet[func()] +} + +func (t *Tracker) Go(f func()) { + t.started.Add(1) + t.running.Add(1) + go t.goAndDecr(f) +} + +func (t *Tracker) goAndDecr(f func()) { + defer t.decr() + f() +} + +func (t *Tracker) decr() { + t.running.Add(-1) + + t.mu.Lock() + defer t.mu.Unlock() + for _, f := range t.onDone { + go f() + } +} + +// AddDoneCallback adds a callback to be called in a new goroutine +// whenever a goroutine managed by t (excluding ones from this method) +// finishes. It returns a function to remove the callback. +func (t *Tracker) AddDoneCallback(f func()) (remove func()) { + t.mu.Lock() + defer t.mu.Unlock() + if t.onDone == nil { + t.onDone = set.HandleSet[func()]{} + } + h := t.onDone.Add(f) + return func() { + t.mu.Lock() + defer t.mu.Unlock() + delete(t.onDone, h) + } +} + +func (t *Tracker) RunningGoroutines() int64 { + return t.running.Load() +} + +func (t *Tracker) StartedGoroutines() int64 { + return t.started.Load() +} diff --git a/util/lineiter/lineiter.go b/util/lineiter/lineiter.go new file mode 100644 index 0000000000000..5cb1eeef3ee1d --- /dev/null +++ b/util/lineiter/lineiter.go @@ -0,0 +1,72 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package lineiter iterates over lines in things. +package lineiter + +import ( + "bufio" + "bytes" + "io" + "iter" + "os" + + "tailscale.com/types/result" +) + +// File returns an iterator that reads lines from the named file. +// +// The returned substrings don't include the trailing newline. +// Lines may be empty. +func File(name string) iter.Seq[result.Of[[]byte]] { + f, err := os.Open(name) + return reader(f, f, err) +} + +// Bytes returns an iterator over the lines in bs. +// The returned substrings don't include the trailing newline. +// Lines may be empty. +func Bytes(bs []byte) iter.Seq[[]byte] { + return func(yield func([]byte) bool) { + for len(bs) > 0 { + i := bytes.IndexByte(bs, '\n') + if i < 0 { + yield(bs) + return + } + if !yield(bs[:i]) { + return + } + bs = bs[i+1:] + } + } +} + +// Reader returns an iterator over the lines in r. +// +// The returned substrings don't include the trailing newline. +// Lines may be empty. +func Reader(r io.Reader) iter.Seq[result.Of[[]byte]] { + return reader(r, nil, nil) +} + +func reader(r io.Reader, c io.Closer, err error) iter.Seq[result.Of[[]byte]] { + return func(yield func(result.Of[[]byte]) bool) { + if err != nil { + yield(result.Error[[]byte](err)) + return + } + if c != nil { + defer c.Close() + } + bs := bufio.NewScanner(r) + for bs.Scan() { + if !yield(result.Value(bs.Bytes())) { + return + } + } + if err := bs.Err(); err != nil { + yield(result.Error[[]byte](err)) + } + } +} diff --git a/util/lineiter/lineiter_test.go b/util/lineiter/lineiter_test.go new file mode 100644 index 0000000000000..3373d5fe7b122 --- /dev/null +++ b/util/lineiter/lineiter_test.go @@ -0,0 +1,32 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lineiter + +import ( + "slices" + "strings" + "testing" +) + +func TestBytesLines(t *testing.T) { + var got []string + for line := range Bytes([]byte("foo\n\nbar\nbaz")) { + got = append(got, string(line)) + } + want := []string{"foo", "", "bar", "baz"} + if !slices.Equal(got, want) { + t.Errorf("got %q; want %q", got, want) + } +} + +func TestReader(t *testing.T) { + var got []string + for line := range Reader(strings.NewReader("foo\n\nbar\nbaz")) { + got = append(got, string(line.MustValue())) + } + want := []string{"foo", "", "bar", "baz"} + if !slices.Equal(got, want) { + t.Errorf("got %q; want %q", got, want) + } +} diff --git a/util/linuxfw/fake_netfilter.go b/util/linuxfw/fake_netfilter.go new file mode 100644 index 0000000000000..a998ed765fd63 --- /dev/null +++ b/util/linuxfw/fake_netfilter.go @@ -0,0 +1,95 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package linuxfw + +import ( + "net/netip" + + "tailscale.com/types/logger" +) + +// FakeNetfilterRunner is a fake netfilter runner for tests. +type FakeNetfilterRunner struct { + // services is a map that tracks the firewall rules added/deleted via + // EnsureDNATRuleForSvc/DeleteDNATRuleForSvc. + services map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + } +} + +// NewFakeNetfilterRunner creates a new FakeNetfilterRunner. +func NewFakeNetfilterRunner() *FakeNetfilterRunner { + return &FakeNetfilterRunner{ + services: make(map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }), + } +} + +func (f *FakeNetfilterRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + f.services[svcName] = struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr + }{origDst, dst} + return nil +} + +func (f *FakeNetfilterRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + delete(f.services, svcName) + return nil +} + +func (f *FakeNetfilterRunner) GetServiceState() map[string]struct { + TailscaleServiceIP netip.Addr + ClusterIP netip.Addr +} { + return f.services +} + +func (f *FakeNetfilterRunner) HasIPV6() bool { + return true +} + +func (f *FakeNetfilterRunner) HasIPV6Filter() bool { + return true +} + +func (f *FakeNetfilterRunner) HasIPV6NAT() bool { + return true +} + +func (f *FakeNetfilterRunner) AddBase(tunname string) error { return nil } +func (f *FakeNetfilterRunner) DelBase() error { return nil } +func (f *FakeNetfilterRunner) AddChains() error { return nil } +func (f *FakeNetfilterRunner) DelChains() error { return nil } +func (f *FakeNetfilterRunner) AddHooks() error { return nil } +func (f *FakeNetfilterRunner) DelHooks(logf logger.Logf) error { return nil } +func (f *FakeNetfilterRunner) AddSNATRule() error { return nil } +func (f *FakeNetfilterRunner) DelSNATRule() error { return nil } +func (f *FakeNetfilterRunner) AddStatefulRule(tunname string) error { return nil } +func (f *FakeNetfilterRunner) DelStatefulRule(tunname string) error { return nil } +func (f *FakeNetfilterRunner) AddLoopbackRule(addr netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) DelLoopbackRule(addr netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) AddDNATRule(origDst, dst netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error { + return nil +} +func (f *FakeNetfilterRunner) EnsureSNATForDst(src, dst netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) DNATNonTailscaleTraffic(tun string, dst netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error { return nil } +func (f *FakeNetfilterRunner) AddMagicsockPortRule(port uint16, network string) error { return nil } +func (f *FakeNetfilterRunner) DelMagicsockPortRule(port uint16, network string) error { return nil } +func (f *FakeNetfilterRunner) DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error { + return nil +} +func (f *FakeNetfilterRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pms []PortMap) error { + return nil +} +func (f *FakeNetfilterRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error { + return nil +} diff --git a/util/linuxfw/iptables_for_svcs.go b/util/linuxfw/iptables_for_svcs.go index 8e0f5d48d0d0a..2cd8716e4622b 100644 --- a/util/linuxfw/iptables_for_svcs.go +++ b/util/linuxfw/iptables_for_svcs.go @@ -13,6 +13,7 @@ import ( // This file contains functionality to insert portmapping rules for a 'service'. // These are currently only used by the Kubernetes operator proxies. // An iptables rule for such a service contains a comment with the service name. +// A 'service' corresponds to a VIPService as used by the Kubernetes operator. // EnsurePortMapRuleForSvc adds a prerouting rule that forwards traffic received // on match port and NOT on the provided interface to target IP and target port. @@ -24,10 +25,10 @@ func (i *iptablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip if err != nil { return fmt.Errorf("error checking if rule exists: %w", err) } - if !exists { - return table.Append("nat", "PREROUTING", args...) + if exists { + return nil } - return nil + return table.Append("nat", "PREROUTING", args...) } // DeleteMapRuleForSvc constructs a prerouting rule as would be created by @@ -40,10 +41,41 @@ func (i *iptablesRunner) DeletePortMapRuleForSvc(svc, excludeI string, targetIP if err != nil { return fmt.Errorf("error checking if rule exists: %w", err) } + if !exists { + return nil + } + return table.Delete("nat", "PREROUTING", args...) +} + +// EnsureDNATRuleForSvc adds a DNAT rule that forwards traffic from the +// VIPService IP address to a local address. This is used by the Kubernetes +// operator's network layer proxies to forward tailnet traffic for VIPServices +// to Kubernetes Services. +func (i *iptablesRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + table := i.getIPTByAddr(dst) + args := argsForIngressRule(svcName, origDst, dst) + exists, err := table.Exists("nat", "PREROUTING", args...) + if err != nil { + return fmt.Errorf("error checking if rule exists: %w", err) + } if exists { - return table.Delete("nat", "PREROUTING", args...) + return nil } - return nil + return table.Append("nat", "PREROUTING", args...) +} + +// DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc. +func (i *iptablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + table := i.getIPTByAddr(dst) + args := argsForIngressRule(svcName, origDst, dst) + exists, err := table.Exists("nat", "PREROUTING", args...) + if err != nil { + return fmt.Errorf("error checking if rule exists: %w", err) + } + if !exists { + return nil + } + return table.Delete("nat", "PREROUTING", args...) } // DeleteSvc constructs all possible rules that would have been created by @@ -72,8 +104,24 @@ func argsForPortMapRule(svc, excludeI string, targetIP netip.Addr, pm PortMap) [ } } +func argsForIngressRule(svcName string, origDst, targetIP netip.Addr) []string { + c := commentForIngressSvc(svcName, origDst, targetIP) + return []string{ + "--destination", origDst.String(), + "-m", "comment", "--comment", c, + "-j", "DNAT", + "--to-destination", targetIP.String(), + } +} + // commentForSvc generates a comment to be added to an iptables DNAT rule for a // service. This is for iptables debugging/readability purposes only. func commentForSvc(svc string, pm PortMap) string { return fmt.Sprintf("%s:%s:%d -> %s:%d", svc, pm.Protocol, pm.MatchPort, pm.Protocol, pm.TargetPort) } + +// commentForIngressSvc generates a comment to be added to an iptables DNAT rule for a +// service. This is for iptables debugging/readability purposes only. +func commentForIngressSvc(svc string, vip, clusterIP netip.Addr) string { + return fmt.Sprintf("svc: %s, %s -> %s", svc, vip.String(), clusterIP.String()) +} diff --git a/util/linuxfw/iptables_for_svcs_test.go b/util/linuxfw/iptables_for_svcs_test.go index 99b2f517f1eaf..c3c1b1f65d6fe 100644 --- a/util/linuxfw/iptables_for_svcs_test.go +++ b/util/linuxfw/iptables_for_svcs_test.go @@ -153,6 +153,135 @@ func Test_iptablesRunner_DeleteSvc(t *testing.T) { svcMustExist(t, "svc2", map[string][]string{v4Addr.String(): s2R1, v6Addr.String(): s2R2}, iptr) } +func Test_iptablesRunner_EnsureDNATRuleForSvc(t *testing.T) { + v4OrigDst := netip.MustParseAddr("10.0.0.1") + v4Target := netip.MustParseAddr("10.0.0.2") + v6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1") + v6Target := netip.MustParseAddr("fd7a:115c:a1e0::2") + v4Rule := argsForIngressRule("svc:test", v4OrigDst, v4Target) + + tests := []struct { + name string + svcName string + origDst netip.Addr + targetIP netip.Addr + precreateSvcRules [][]string + }{ + { + name: "dnat_for_ipv4", + svcName: "svc:test", + origDst: v4OrigDst, + targetIP: v4Target, + }, + { + name: "dnat_for_ipv6", + svcName: "svc:test-2", + origDst: v6OrigDst, + targetIP: v6Target, + }, + { + name: "add_existing_rule", + svcName: "svc:test", + origDst: v4OrigDst, + targetIP: v4Target, + precreateSvcRules: [][]string{v4Rule}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + iptr := NewFakeIPTablesRunner() + table := iptr.getIPTByAddr(tt.targetIP) + for _, ruleset := range tt.precreateSvcRules { + mustPrecreateDNATRule(t, ruleset, table) + } + if err := iptr.EnsureDNATRuleForSvc(tt.svcName, tt.origDst, tt.targetIP); err != nil { + t.Errorf("[unexpected error] iptablesRunner.EnsureDNATRuleForSvc() = %v", err) + } + args := argsForIngressRule(tt.svcName, tt.origDst, tt.targetIP) + exists, err := table.Exists("nat", "PREROUTING", args...) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if !exists { + t.Errorf("expected rule was not created") + } + }) + } +} + +func Test_iptablesRunner_DeleteDNATRuleForSvc(t *testing.T) { + v4OrigDst := netip.MustParseAddr("10.0.0.1") + v4Target := netip.MustParseAddr("10.0.0.2") + v6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1") + v6Target := netip.MustParseAddr("fd7a:115c:a1e0::2") + v4Rule := argsForIngressRule("svc:test", v4OrigDst, v4Target) + v6Rule := argsForIngressRule("svc:test", v6OrigDst, v6Target) + + tests := []struct { + name string + svcName string + origDst netip.Addr + targetIP netip.Addr + precreateSvcRules [][]string + }{ + { + name: "multiple_rules_ipv4_deleted", + svcName: "svc:test", + origDst: v4OrigDst, + targetIP: v4Target, + precreateSvcRules: [][]string{v4Rule, v6Rule}, + }, + { + name: "multiple_rules_ipv6_deleted", + svcName: "svc:test", + origDst: v6OrigDst, + targetIP: v6Target, + precreateSvcRules: [][]string{v4Rule, v6Rule}, + }, + { + name: "non-existent_rule_deleted", + svcName: "svc:test", + origDst: v4OrigDst, + targetIP: v4Target, + precreateSvcRules: [][]string{v6Rule}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + iptr := NewFakeIPTablesRunner() + table := iptr.getIPTByAddr(tt.targetIP) + for _, ruleset := range tt.precreateSvcRules { + mustPrecreateDNATRule(t, ruleset, table) + } + if err := iptr.DeleteDNATRuleForSvc(tt.svcName, tt.origDst, tt.targetIP); err != nil { + t.Errorf("iptablesRunner.DeleteDNATRuleForSvc() errored: %v ", err) + } + deletedRule := argsForIngressRule(tt.svcName, tt.origDst, tt.targetIP) + exists, err := table.Exists("nat", "PREROUTING", deletedRule...) + if err != nil { + t.Fatalf("error verifying that rule does not exist after deletion: %v", err) + } + if exists { + t.Errorf("DNAT rule exists after deletion") + } + }) + } +} + +func mustPrecreateDNATRule(t *testing.T, rules []string, table iptablesInterface) { + t.Helper() + exists, err := table.Exists("nat", "PREROUTING", rules...) + if err != nil { + t.Fatalf("error ensuring that nat PREROUTING table exists: %v", err) + } + if exists { + return + } + if err := table.Append("nat", "PREROUTING", rules...); err != nil { + t.Fatalf("error precreating DNAT rule: %v", err) + } +} + func svcMustExist(t *testing.T, svcName string, rules map[string][]string, iptr *iptablesRunner) { t.Helper() for dst, ruleset := range rules { diff --git a/util/linuxfw/nftables.go b/util/linuxfw/nftables.go index 056563071479f..e8b267b5e42ae 100644 --- a/util/linuxfw/nftables.go +++ b/util/linuxfw/nftables.go @@ -8,6 +8,7 @@ package linuxfw import ( "cmp" + "encoding/binary" "fmt" "sort" "strings" @@ -15,7 +16,6 @@ import ( "github.com/google/nftables" "github.com/google/nftables/expr" "github.com/google/nftables/xt" - "github.com/josharian/native" "golang.org/x/sys/unix" "tailscale.com/types/logger" ) @@ -235,8 +235,8 @@ func printMatchInfo(name string, info xt.InfoAny) string { break } - pkttype := int(native.Endian.Uint32(data[0:4])) - invert := int(native.Endian.Uint32(data[4:8])) + pkttype := int(binary.NativeEndian.Uint32(data[0:4])) + invert := int(binary.NativeEndian.Uint32(data[4:8])) var invertPrefix string if invert != 0 { invertPrefix = "!" diff --git a/util/linuxfw/nftables_for_svcs.go b/util/linuxfw/nftables_for_svcs.go index 130585b2229e3..474b980869691 100644 --- a/util/linuxfw/nftables_for_svcs.go +++ b/util/linuxfw/nftables_for_svcs.go @@ -119,6 +119,63 @@ func (n *nftablesRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm [ return n.conn.Flush() } +// EnsureDNATRuleForSvc adds a DNAT rule that forwards traffic from the +// VIPService IP address to a local address. This is used by the Kubernetes +// operator's network layer proxies to forward tailnet traffic for VIPServices +// to Kubernetes Services. +func (n *nftablesRunner) EnsureDNATRuleForSvc(svc string, origDst, dst netip.Addr) error { + t, ch, err := n.ensurePreroutingChain(origDst) + if err != nil { + return fmt.Errorf("error ensuring chain for %s: %w", svc, err) + } + meta := svcRuleMeta(svc, origDst, dst) + rule, err := n.findRuleByMetadata(t, ch, meta) + if err != nil { + return fmt.Errorf("error looking up rule: %w", err) + } + if rule != nil { + return nil + } + rule = dnatRuleForChain(t, ch, origDst, dst, meta) + n.conn.InsertRule(rule) + return n.conn.Flush() +} + +// DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc. +// We use the metadata attached to the rule to look it up. +func (n *nftablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + table, err := n.getNFTByAddr(origDst) + if err != nil { + return fmt.Errorf("error setting up nftables for IP family of %s: %w", origDst, err) + } + t, err := getTableIfExists(n.conn, table.Proto, "nat") + if err != nil { + return fmt.Errorf("error checking if nat table exists: %w", err) + } + if t == nil { + return nil + } + ch, err := getChainFromTable(n.conn, t, "PREROUTING") + if errors.Is(err, errorChainNotFound{tableName: "nat", chainName: "PREROUTING"}) { + return nil + } + if err != nil { + return fmt.Errorf("error checking if chain PREROUTING exists: %w", err) + } + meta := svcRuleMeta(svcName, origDst, dst) + rule, err := n.findRuleByMetadata(t, ch, meta) + if err != nil { + return fmt.Errorf("error checking if rule exists: %w", err) + } + if rule == nil { + return nil + } + if err := n.conn.DelRule(rule); err != nil { + return fmt.Errorf("error deleting rule: %w", err) + } + return n.conn.Flush() +} + func portMapRule(t *nftables.Table, ch *nftables.Chain, tun string, targetIP netip.Addr, matchPort, targetPort uint16, proto uint8, meta []byte) *nftables.Rule { var fam uint32 if targetIP.Is4() { @@ -243,3 +300,10 @@ func protoFromString(s string) (uint8, error) { return 0, fmt.Errorf("unrecognized protocol: %q", s) } } + +// svcRuleMeta generates metadata for a rule. +// This metadata can then be used to find the rule. +// https://github.com/google/nftables/issues/48 +func svcRuleMeta(svcName string, origDst, dst netip.Addr) []byte { + return []byte(fmt.Sprintf("svc:%s,VIP:%s,ClusterIP:%s", svcName, origDst.String(), dst.String())) +} diff --git a/util/linuxfw/nftables_for_svcs_test.go b/util/linuxfw/nftables_for_svcs_test.go index d2df6e4bdf2ef..73472ce20cbe5 100644 --- a/util/linuxfw/nftables_for_svcs_test.go +++ b/util/linuxfw/nftables_for_svcs_test.go @@ -14,8 +14,9 @@ import ( // This test creates a temporary network namespace for the nftables rules being // set up, so it needs to run in a privileged mode. Locally it needs to be run -// by root, else it will be silently skipped. In CI it runs in a privileged -// container. +// by root, else it will be silently skipped. +// sudo go test -v -run Test_nftablesRunner_EnsurePortMapRuleForSvc ./util/linuxfw/... +// In CI it runs in a privileged container. func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) { conn := newSysConn(t) runner := newFakeNftablesRunnerWithConn(t, conn, true) @@ -23,51 +24,215 @@ func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) { pmTCP := PortMap{MatchPort: 4003, TargetPort: 80, Protocol: "TCP"} pmTCP1 := PortMap{MatchPort: 4004, TargetPort: 443, Protocol: "TCP"} - // Create a rule for service 'foo' to forward TCP traffic to IPv4 endpoint - runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP) + // Create a rule for service 'svc:foo' to forward TCP traffic to IPv4 endpoint + runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP) svcChains(t, 1, conn) - chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4) - checkPortMapRule(t, "foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) + chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv4) + checkPortMapRule(t, "svc:foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) - // Create another rule for service 'foo' to forward TCP traffic to the + // Create another rule for service 'svc:foo' to forward TCP traffic to the // same IPv4 endpoint, but to a different port. - runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP1) + runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP1) svcChains(t, 1, conn) - chainRuleCount(t, "foo", 2, conn, nftables.TableFamilyIPv4) - checkPortMapRule(t, "foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4) + chainRuleCount(t, "svc:foo", 2, conn, nftables.TableFamilyIPv4) + checkPortMapRule(t, "svc:foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4) - // Create a rule for service 'foo' to forward TCP traffic to an IPv6 endpoint - runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv6, pmTCP) + // Create a rule for service 'svc:foo' to forward TCP traffic to an IPv6 endpoint + runner.EnsurePortMapRuleForSvc("svc:foo", "tailscale0", ipv6, pmTCP) svcChains(t, 2, conn) - chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv6) - checkPortMapRule(t, "foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) + chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv6) + checkPortMapRule(t, "svc:foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) - // Create a rule for service 'bar' to forward TCP traffic to IPv4 endpoint - runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv4, pmTCP) + // Create a rule for service 'svc:bar' to forward TCP traffic to IPv4 endpoint + runner.EnsurePortMapRuleForSvc("svc:bar", "tailscale0", ipv4, pmTCP) svcChains(t, 3, conn) - chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv4) - checkPortMapRule(t, "bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) + chainRuleCount(t, "svc:bar", 1, conn, nftables.TableFamilyIPv4) + checkPortMapRule(t, "svc:bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) - // Create a rule for service 'bar' to forward TCP traffic to an IPv6 endpoint - runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv6, pmTCP) + // Create a rule for service 'svc:bar' to forward TCP traffic to an IPv6 endpoint + runner.EnsurePortMapRuleForSvc("svc:bar", "tailscale0", ipv6, pmTCP) svcChains(t, 4, conn) - chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv6) - checkPortMapRule(t, "bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) + chainRuleCount(t, "svc:bar", 1, conn, nftables.TableFamilyIPv6) + checkPortMapRule(t, "svc:bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) - // Delete service bar - runner.DeleteSvc("bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP}) + // Delete service svc:bar + runner.DeleteSvc("svc:bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP}) svcChains(t, 2, conn) - // Delete a rule from service foo - runner.DeletePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP) + // Delete a rule from service svc:foo + runner.DeletePortMapRuleForSvc("svc:foo", "tailscale0", ipv4, pmTCP) svcChains(t, 2, conn) - chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4) + chainRuleCount(t, "svc:foo", 1, conn, nftables.TableFamilyIPv4) - // Delete service foo - runner.DeleteSvc("foo", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP, pmTCP1}) + // Delete service svc:foo + runner.DeleteSvc("svc:foo", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP, pmTCP1}) svcChains(t, 0, conn) } +func Test_nftablesRunner_EnsureDNATRuleForSvc(t *testing.T) { + conn := newSysConn(t) + runner := newFakeNftablesRunnerWithConn(t, conn, true) + + // Test IPv4 DNAT rule + ipv4OrigDst := netip.MustParseAddr("10.0.0.1") + ipv4Target := netip.MustParseAddr("10.0.0.2") + + // Create DNAT rule for service 'svc:foo' to forward IPv4 traffic + err := runner.EnsureDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target) + if err != nil { + t.Fatalf("error creating IPv4 DNAT rule: %v", err) + } + checkDNATRule(t, "svc:foo", ipv4OrigDst, ipv4Target, runner, nftables.TableFamilyIPv4) + + // Test IPv6 DNAT rule + ipv6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1") + ipv6Target := netip.MustParseAddr("fd7a:115c:a1e0::2") + + // Create DNAT rule for service 'svc:foo' to forward IPv6 traffic + err = runner.EnsureDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target) + if err != nil { + t.Fatalf("error creating IPv6 DNAT rule: %v", err) + } + checkDNATRule(t, "svc:foo", ipv6OrigDst, ipv6Target, runner, nftables.TableFamilyIPv6) + + // Test creating rule for another service + err = runner.EnsureDNATRuleForSvc("svc:bar", ipv4OrigDst, ipv4Target) + if err != nil { + t.Fatalf("error creating DNAT rule for service 'svc:bar': %v", err) + } + checkDNATRule(t, "svc:bar", ipv4OrigDst, ipv4Target, runner, nftables.TableFamilyIPv4) +} + +func Test_nftablesRunner_DeleteDNATRuleForSvc(t *testing.T) { + conn := newSysConn(t) + runner := newFakeNftablesRunnerWithConn(t, conn, true) + + // Test IPv4 DNAT rule deletion + ipv4OrigDst := netip.MustParseAddr("10.0.0.1") + ipv4Target := netip.MustParseAddr("10.0.0.2") + + // Create and then delete IPv4 DNAT rule + err := runner.EnsureDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target) + if err != nil { + t.Fatalf("error creating IPv4 DNAT rule: %v", err) + } + + // Verify rule exists before deletion + table, err := runner.getNFTByAddr(ipv4OrigDst) + if err != nil { + t.Fatalf("error getting table: %v", err) + } + nftTable, err := getTableIfExists(runner.conn, table.Proto, "nat") + if err != nil { + t.Fatalf("error getting nat table: %v", err) + } + ch, err := getChainFromTable(runner.conn, nftTable, "PREROUTING") + if err != nil { + t.Fatalf("error getting PREROUTING chain: %v", err) + } + meta := svcRuleMeta("svc:foo", ipv4OrigDst, ipv4Target) + rule, err := runner.findRuleByMetadata(nftTable, ch, meta) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if rule == nil { + t.Fatal("rule does not exist before deletion") + } + + err = runner.DeleteDNATRuleForSvc("svc:foo", ipv4OrigDst, ipv4Target) + if err != nil { + t.Fatalf("error deleting IPv4 DNAT rule: %v", err) + } + + // Verify rule is deleted + rule, err = runner.findRuleByMetadata(nftTable, ch, meta) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if rule != nil { + t.Fatal("rule still exists after deletion") + } + + // Test IPv6 DNAT rule deletion + ipv6OrigDst := netip.MustParseAddr("fd7a:115c:a1e0::1") + ipv6Target := netip.MustParseAddr("fd7a:115c:a1e0::2") + + // Create and then delete IPv6 DNAT rule + err = runner.EnsureDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target) + if err != nil { + t.Fatalf("error creating IPv6 DNAT rule: %v", err) + } + + // Verify rule exists before deletion + table, err = runner.getNFTByAddr(ipv6OrigDst) + if err != nil { + t.Fatalf("error getting table: %v", err) + } + nftTable, err = getTableIfExists(runner.conn, table.Proto, "nat") + if err != nil { + t.Fatalf("error getting nat table: %v", err) + } + ch, err = getChainFromTable(runner.conn, nftTable, "PREROUTING") + if err != nil { + t.Fatalf("error getting PREROUTING chain: %v", err) + } + meta = svcRuleMeta("svc:foo", ipv6OrigDst, ipv6Target) + rule, err = runner.findRuleByMetadata(nftTable, ch, meta) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if rule == nil { + t.Fatal("rule does not exist before deletion") + } + + err = runner.DeleteDNATRuleForSvc("svc:foo", ipv6OrigDst, ipv6Target) + if err != nil { + t.Fatalf("error deleting IPv6 DNAT rule: %v", err) + } + + // Verify rule is deleted + rule, err = runner.findRuleByMetadata(nftTable, ch, meta) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if rule != nil { + t.Fatal("rule still exists after deletion") + } +} + +// checkDNATRule verifies that a DNAT rule exists for the given service, original destination, and target IP. +func checkDNATRule(t *testing.T, svc string, origDst, targetIP netip.Addr, runner *nftablesRunner, fam nftables.TableFamily) { + t.Helper() + table, err := runner.getNFTByAddr(origDst) + if err != nil { + t.Fatalf("error getting table: %v", err) + } + nftTable, err := getTableIfExists(runner.conn, table.Proto, "nat") + if err != nil { + t.Fatalf("error getting nat table: %v", err) + } + if nftTable == nil { + t.Fatal("nat table not found") + } + + ch, err := getChainFromTable(runner.conn, nftTable, "PREROUTING") + if err != nil { + t.Fatalf("error getting PREROUTING chain: %v", err) + } + if ch == nil { + t.Fatal("PREROUTING chain not found") + } + + meta := svcRuleMeta(svc, origDst, targetIP) + rule, err := runner.findRuleByMetadata(nftTable, ch, meta) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if rule == nil { + t.Fatal("DNAT rule not found") + } +} + // svcChains verifies that the expected number of chains exist (for either IP // family) and that each of them is configured as NAT prerouting chain. func svcChains(t *testing.T, wantCount int, conn *nftables.Conn) { diff --git a/util/linuxfw/nftables_runner.go b/util/linuxfw/nftables_runner.go index 0f411521bb562..faa02f7c75956 100644 --- a/util/linuxfw/nftables_runner.go +++ b/util/linuxfw/nftables_runner.go @@ -107,6 +107,12 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error { if err != nil { return err } + rule := dnatRuleForChain(nat, preroutingCh, origDst, dst, nil) + n.conn.InsertRule(rule) + return n.conn.Flush() +} + +func dnatRuleForChain(t *nftables.Table, ch *nftables.Chain, origDst, dst netip.Addr, meta []byte) *nftables.Rule { var daddrOffset, fam, dadderLen uint32 if origDst.Is4() { daddrOffset = 16 @@ -117,9 +123,9 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error { dadderLen = 16 fam = unix.NFPROTO_IPV6 } - dnatRule := &nftables.Rule{ - Table: nat, - Chain: preroutingCh, + rule := &nftables.Rule{ + Table: t, + Chain: ch, Exprs: []expr.Any{ &expr.Payload{ DestRegister: 1, @@ -143,8 +149,10 @@ func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error { }, }, } - n.conn.InsertRule(dnatRule) - return n.conn.Flush() + if len(meta) > 0 { + rule.UserData = meta + } + return rule } // DNATWithLoadBalancer currently just forwards all traffic destined for origDst @@ -555,6 +563,8 @@ type NetfilterRunner interface { EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error + EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error + DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm []PortMap) error @@ -1710,57 +1720,45 @@ func (n *nftablesRunner) AddSNATRule() error { return nil } +func delMatchSubnetRouteMarkMasqRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) error { + + rule, err := createMatchSubnetRouteMarkRule(table, chain, Masq) + if err != nil { + return fmt.Errorf("create match subnet route mark rule: %w", err) + } + + SNATRule, err := findRule(conn, rule) + if err != nil { + return fmt.Errorf("find SNAT rule v4: %w", err) + } + + if SNATRule != nil { + _ = conn.DelRule(SNATRule) + } + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush del SNAT rule: %w", err) + } + + return nil +} + // DelSNATRule removes the netfilter rule to SNAT traffic destined for // local subnets. An error is returned if the rule does not exist. func (n *nftablesRunner) DelSNATRule() error { conn := n.conn - hexTSFwmarkMask := getTailscaleFwmarkMask() - hexTSSubnetRouteMark := getTailscaleSubnetRouteMark() - - exprs := []expr.Any{ - &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, - &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 4, - Mask: hexTSFwmarkMask, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: hexTSSubnetRouteMark, - }, - &expr.Counter{}, - &expr.Masq{}, - } - for _, table := range n.getTables() { chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) if err != nil { - return fmt.Errorf("get postrouting chain v4: %w", err) - } - - rule := &nftables.Rule{ - Table: table.Nat, - Chain: chain, - Exprs: exprs, + return fmt.Errorf("get postrouting chain: %w", err) } - - SNATRule, err := findRule(conn, rule) + err = delMatchSubnetRouteMarkMasqRule(conn, table.Nat, chain) if err != nil { - return fmt.Errorf("find SNAT rule v4: %w", err) - } - - if SNATRule != nil { - _ = conn.DelRule(SNATRule) + return err } } - if err := conn.Flush(); err != nil { - return fmt.Errorf("flush del SNAT rule: %w", err) - } - return nil } diff --git a/util/linuxfw/nftables_runner_test.go b/util/linuxfw/nftables_runner_test.go index 712a7b93955da..6fb180ed67ce6 100644 --- a/util/linuxfw/nftables_runner_test.go +++ b/util/linuxfw/nftables_runner_test.go @@ -12,6 +12,7 @@ import ( "net/netip" "os" "runtime" + "slices" "strings" "testing" @@ -24,21 +25,21 @@ import ( "tailscale.com/types/logger" ) +func toAnySlice[T any](s []T) []any { + out := make([]any, len(s)) + for i, v := range s { + out[i] = v + } + return out +} + // nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing // users to make sense of large byte literals more easily. func nfdump(b []byte) string { var buf bytes.Buffer - i := 0 - for ; i < len(b); i += 4 { - // TODO: show printable characters as ASCII - fmt.Fprintf(&buf, "%02x %02x %02x %02x\n", - b[i], - b[i+1], - b[i+2], - b[i+3]) - } - for ; i < len(b); i++ { - fmt.Fprintf(&buf, "%02x ", b[i]) + for c := range slices.Chunk(b, 4) { + format := strings.Repeat("%02x ", len(c)) + fmt.Fprintf(&buf, format+"\n", toAnySlice(c)...) } return buf.String() } @@ -75,7 +76,7 @@ func linediff(a, b string) string { return buf.String() } -func newTestConn(t *testing.T, want [][]byte) *nftables.Conn { +func newTestConn(t *testing.T, want [][]byte, reply [][]netlink.Message) *nftables.Conn { conn, err := nftables.New(nftables.WithTestDial( func(req []netlink.Message) ([]netlink.Message, error) { for idx, msg := range req { @@ -96,7 +97,13 @@ func newTestConn(t *testing.T, want [][]byte) *nftables.Conn { } want = want[1:] } - return req, nil + // no reply for batch end message + if len(want) == 0 { + return nil, nil + } + rep := reply[0] + reply = reply[1:] + return rep, nil })) if err != nil { t.Fatal(err) @@ -120,7 +127,7 @@ func TestInsertHookRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -160,7 +167,7 @@ func TestInsertLoopbackRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -196,7 +203,7 @@ func TestInsertLoopbackRuleV6(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) tableV6 := testConn.AddTable(&nftables.Table{ Family: protoV6, Name: "ts-filter-test", @@ -232,7 +239,7 @@ func TestAddReturnChromeOSVMRangeRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -264,7 +271,7 @@ func TestAddDropCGNATRangeRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -296,7 +303,7 @@ func TestAddSetSubnetRouteMarkRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -328,7 +335,7 @@ func TestAddDropOutgoingPacketFromCGNATRangeRuleWithTunname(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -360,7 +367,7 @@ func TestAddAcceptOutgoingPacketRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -392,7 +399,7 @@ func TestAddAcceptIncomingPacketRule(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", @@ -420,11 +427,11 @@ func TestAddMatchSubnetRouteMarkRuleMasq(t *testing.T) { // nft add chain ip ts-nat-test ts-postrouting-test { type nat hook postrouting priority 100; } []byte("\x02\x00\x00\x00\x10\x00\x01\x00\x74\x73\x2d\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x18\x00\x03\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x04\x08\x00\x02\x00\x00\x00\x00\x64\x08\x00\x07\x00\x6e\x61\x74\x00"), // nft add rule ip ts-nat-test ts-postrouting-test meta mark & 0x00ff0000 == 0x00040000 counter masquerade - []byte("\x02\x00\x00\x00\x10\x00\x01\x00\x74\x73\x2d\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x18\x00\x02\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\xf4\x00\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\xff\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x04\x00\x00\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x30\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x1c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x10\x00\x02\x80\x0c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01"), + []byte("\x02\x00\x00\x00\x10\x00\x01\x00\x74\x73\x2d\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x18\x00\x02\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\xd8\x00\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\xff\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x04\x00\x00\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x14\x00\x01\x80\x09\x00\x01\x00\x6d\x61\x73\x71\x00\x00\x00\x00\x04\x00\x02\x80"), // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-nat-test", @@ -436,7 +443,46 @@ func TestAddMatchSubnetRouteMarkRuleMasq(t *testing.T) { Hooknum: nftables.ChainHookPostrouting, Priority: nftables.ChainPriorityNATSource, }) - err := addMatchSubnetRouteMarkRule(testConn, table, chain, Accept) + err := addMatchSubnetRouteMarkRule(testConn, table, chain, Masq) + if err != nil { + t.Fatal(err) + } +} + +func TestDelMatchSubnetRouteMarkMasqRule(t *testing.T) { + proto := nftables.TableFamilyIPv4 + reply := [][]netlink.Message{ + nil, + {{Header: netlink.Header{Length: 0x128, Type: 0xa06, Flags: 0x802, Sequence: 0xa213d55d, PID: 0x11e79}, Data: []uint8{0x2, 0x0, 0x0, 0x8c, 0xd, 0x0, 0x1, 0x0, 0x6e, 0x61, 0x74, 0x2d, 0x74, 0x65, 0x73, 0x74, 0x0, 0x0, 0x0, 0x0, 0x18, 0x0, 0x2, 0x0, 0x74, 0x73, 0x2d, 0x70, 0x6f, 0x73, 0x74, 0x72, 0x6f, 0x75, 0x74, 0x69, 0x6e, 0x67, 0x2d, 0x74, 0x65, 0x73, 0x74, 0x0, 0xc, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x4, 0xe0, 0x0, 0x4, 0x0, 0x24, 0x0, 0x1, 0x0, 0x9, 0x0, 0x1, 0x0, 0x6d, 0x65, 0x74, 0x61, 0x0, 0x0, 0x0, 0x0, 0x14, 0x0, 0x2, 0x0, 0x8, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x3, 0x8, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x1, 0x4c, 0x0, 0x1, 0x0, 0xc, 0x0, 0x1, 0x0, 0x62, 0x69, 0x74, 0x77, 0x69, 0x73, 0x65, 0x0, 0x3c, 0x0, 0x2, 0x0, 0x8, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x1, 0x8, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x1, 0x8, 0x0, 0x3, 0x0, 0x0, 0x0, 0x0, 0x4, 0x8, 0x0, 0x6, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x4, 0x0, 0x8, 0x0, 0x1, 0x0, 0x0, 0xff, 0x0, 0x0, 0xc, 0x0, 0x5, 0x0, 0x8, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2c, 0x0, 0x1, 0x0, 0x8, 0x0, 0x1, 0x0, 0x63, 0x6d, 0x70, 0x0, 0x20, 0x0, 0x2, 0x0, 0x8, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x1, 0x8, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x3, 0x0, 0x8, 0x0, 0x1, 0x0, 0x0, 0x4, 0x0, 0x0, 0x2c, 0x0, 0x1, 0x0, 0xc, 0x0, 0x1, 0x0, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x65, 0x72, 0x0, 0x1c, 0x0, 0x2, 0x0, 0xc, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x0, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x14, 0x0, 0x1, 0x0, 0x9, 0x0, 0x1, 0x0, 0x6d, 0x61, 0x73, 0x71, 0x0, 0x0, 0x0, 0x0, 0x4, 0x0, 0x2, 0x0}}}, + {{Header: netlink.Header{Length: 0x14, Type: 0x3, Flags: 0x2, Sequence: 0x311fdccb, PID: 0x11e79}, Data: []uint8{0x0, 0x0, 0x0, 0x0}}}, + {{Header: netlink.Header{Length: 0x24, Type: 0x2, Flags: 0x100, Sequence: 0x311fdccb, PID: 0x11e79}, Data: []uint8{0x0, 0x0, 0x0, 0x0, 0x48, 0x0, 0x0, 0x0, 0x8, 0xa, 0x5, 0x0, 0xcb, 0xdc, 0x1f, 0x31, 0x79, 0x1e, 0x1, 0x0}}}, + } + want := [][]byte{ + // get rules in nat-test table ts-postrouting-test chain + []byte("\x02\x00\x00\x00\x0d\x00\x01\x00\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x00\x18\x00\x02\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00"), + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft delete rule ip nat-test ts-postrouting-test handle 4 + []byte("\x02\x00\x00\x00\x0d\x00\x01\x00\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x00\x18\x00\x02\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\x0c\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x04"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + + conn := newTestConn(t, want, reply) + + table := &nftables.Table{ + Family: proto, + Name: "nat-test", + } + chain := &nftables.Chain{ + Name: "ts-postrouting-test", + Table: table, + Type: nftables.ChainTypeNAT, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityNATSource, + } + + err := delMatchSubnetRouteMarkMasqRule(conn, table, chain) if err != nil { t.Fatal(err) } @@ -456,7 +502,7 @@ func TestAddMatchSubnetRouteMarkRuleAccept(t *testing.T) { // batch end []byte("\x00\x00\x00\x0a"), } - testConn := newTestConn(t, want) + testConn := newTestConn(t, want, nil) table := testConn.AddTable(&nftables.Table{ Family: proto, Name: "ts-filter-test", diff --git a/util/lru/lru_test.go b/util/lru/lru_test.go index fb538efbe7957..5500e5e0f309f 100644 --- a/util/lru/lru_test.go +++ b/util/lru/lru_test.go @@ -10,7 +10,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" - xmaps "golang.org/x/exp/maps" + "tailscale.com/util/slicesx" ) func TestLRU(t *testing.T) { @@ -75,7 +75,7 @@ func TestStressEvictions(t *testing.T) { for len(vm) < numKeys { vm[rand.Uint64()] = true } - vals := xmaps.Keys(vm) + vals := slicesx.MapKeys(vm) c := Cache[uint64, bool]{ MaxEntries: cacheSize, @@ -106,7 +106,7 @@ func TestStressBatchedEvictions(t *testing.T) { for len(vm) < numKeys { vm[rand.Uint64()] = true } - vals := xmaps.Keys(vm) + vals := slicesx.MapKeys(vm) c := Cache[uint64, bool]{} diff --git a/util/mak/mak.go b/util/mak/mak.go index b421fb0ed5a55..fbdb40b0afd21 100644 --- a/util/mak/mak.go +++ b/util/mak/mak.go @@ -5,11 +5,6 @@ // things, notably to maps, but also slices. package mak -import ( - "fmt" - "reflect" -) - // Set populates an entry in a map, making the map if necessary. // // That is, it assigns (*m)[k] = v, making *m if it was nil. @@ -20,35 +15,6 @@ func Set[K comparable, V any, T ~map[K]V](m *T, k K, v V) { (*m)[k] = v } -// NonNil takes a pointer to a Go data structure -// (currently only a slice or a map) and makes sure it's non-nil for -// JSON serialization. (In particular, JavaScript clients usually want -// the field to be defined after they decode the JSON.) -// -// Deprecated: use NonNilSliceForJSON or NonNilMapForJSON instead. -func NonNil(ptr any) { - if ptr == nil { - panic("nil interface") - } - rv := reflect.ValueOf(ptr) - if rv.Kind() != reflect.Ptr { - panic(fmt.Sprintf("kind %v, not Ptr", rv.Kind())) - } - if rv.Pointer() == 0 { - panic("nil pointer") - } - rv = rv.Elem() - if rv.Pointer() != 0 { - return - } - switch rv.Type().Kind() { - case reflect.Slice: - rv.Set(reflect.MakeSlice(rv.Type(), 0, 0)) - case reflect.Map: - rv.Set(reflect.MakeMap(rv.Type())) - } -} - // NonNilSliceForJSON makes sure that *slicePtr is non-nil so it will // won't be omitted from JSON serialization and possibly confuse JavaScript // clients expecting it to be present. diff --git a/util/mak/mak_test.go b/util/mak/mak_test.go index 4de499a9d5040..e47839a3c8fe9 100644 --- a/util/mak/mak_test.go +++ b/util/mak/mak_test.go @@ -40,35 +40,6 @@ func TestSet(t *testing.T) { }) } -func TestNonNil(t *testing.T) { - var s []string - NonNil(&s) - if len(s) != 0 { - t.Errorf("slice len = %d; want 0", len(s)) - } - if s == nil { - t.Error("slice still nil") - } - - s = append(s, "foo") - NonNil(&s) - if len(s) != 1 { - t.Errorf("len = %d; want 1", len(s)) - } - if s[0] != "foo" { - t.Errorf("value = %q; want foo", s) - } - - var m map[string]string - NonNil(&m) - if len(m) != 0 { - t.Errorf("map len = %d; want 0", len(s)) - } - if m == nil { - t.Error("map still nil") - } -} - func TestNonNilMapForJSON(t *testing.T) { type M map[string]int var m M diff --git a/util/osuser/group_ids.go b/util/osuser/group_ids.go index f25861dbb4519..7c2b5b090cbcc 100644 --- a/util/osuser/group_ids.go +++ b/util/osuser/group_ids.go @@ -19,6 +19,10 @@ import ( // an error. It will first try to use the 'id' command to get the group IDs, // and if that fails, it will fall back to the user.GroupIds method. func GetGroupIds(user *user.User) ([]string, error) { + if runtime.GOOS == "plan9" { + return nil, nil + } + if runtime.GOOS != "linux" { return user.GroupIds() } diff --git a/util/osuser/user.go b/util/osuser/user.go index 2c7f2e24b9b11..8b96194d716ce 100644 --- a/util/osuser/user.go +++ b/util/osuser/user.go @@ -54,9 +54,18 @@ func lookup(usernameOrUID string, std lookupStd, wantShell bool) (*user.User, st // Skip getent entirely on Non-Unix platforms that won't ever have it. // (Using HasPrefix for "wasip1", anticipating that WASI support will // move beyond "preview 1" some day.) - if runtime.GOOS == "windows" || runtime.GOOS == "js" || runtime.GOARCH == "wasm" { + if runtime.GOOS == "windows" || runtime.GOOS == "js" || runtime.GOARCH == "wasm" || runtime.GOOS == "plan9" { + var shell string + if wantShell && runtime.GOOS == "plan9" { + shell = "/bin/rc" + } + if runtime.GOOS == "plan9" { + if u, err := user.Current(); err == nil { + return u, shell, nil + } + } u, err := std(usernameOrUID) - return u, "", err + return u, shell, err } // No getent on Gokrazy. So hard-code the login shell. @@ -78,6 +87,16 @@ func lookup(usernameOrUID string, std lookupStd, wantShell bool) (*user.User, st return u, shell, nil } + if runtime.GOOS == "plan9" { + return &user.User{ + Uid: "0", + Gid: "0", + Username: "glenda", + Name: "Glenda", + HomeDir: "/", + }, "/bin/rc", nil + } + // Start with getent if caller wants to get the user shell. if wantShell { return userLookupGetent(usernameOrUID, std) diff --git a/util/pidowner/pidowner_linux.go b/util/pidowner/pidowner_linux.go index 2a5181f14e03c..a07f512427062 100644 --- a/util/pidowner/pidowner_linux.go +++ b/util/pidowner/pidowner_linux.go @@ -8,26 +8,26 @@ import ( "os" "strings" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" ) func ownerOfPID(pid int) (userID string, err error) { file := fmt.Sprintf("/proc/%d/status", pid) - err = lineread.File(file, func(line []byte) error { + for lr := range lineiter.File(file) { + line, err := lr.Value() + if err != nil { + if os.IsNotExist(err) { + return "", ErrProcessNotFound + } + return "", err + } if len(line) < 4 || string(line[:4]) != "Uid:" { - return nil + continue } f := strings.Fields(string(line)) if len(f) >= 2 { userID = f[1] // real userid } - return nil - }) - if os.IsNotExist(err) { - return "", ErrProcessNotFound - } - if err != nil { - return } if userID == "" { return "", fmt.Errorf("missing Uid line in %s", file) diff --git a/util/set/slice.go b/util/set/slice.go index 38551aee197ad..2fc65b82d1c6e 100644 --- a/util/set/slice.go +++ b/util/set/slice.go @@ -67,7 +67,7 @@ func (ss *Slice[T]) Add(vs ...T) { // AddSlice adds all elements in vs to the set. func (ss *Slice[T]) AddSlice(vs views.Slice[T]) { - for i := range vs.Len() { - ss.Add(vs.At(i)) + for _, v := range vs.All() { + ss.Add(v) } } diff --git a/util/singleflight/singleflight.go b/util/singleflight/singleflight.go index 9df47448b70ab..9c7dc54020097 100644 --- a/util/singleflight/singleflight.go +++ b/util/singleflight/singleflight.go @@ -36,7 +36,7 @@ var errGoexit = errors.New("runtime.Goexit was called") // A panicError is an arbitrary value recovered from a panic // with the stack trace during the execution of given function. type panicError struct { - value interface{} + value any stack []byte } @@ -45,7 +45,7 @@ func (p *panicError) Error() string { return fmt.Sprintf("%v\n\n%s", p.value, p.stack) } -func newPanicError(v interface{}) error { +func newPanicError(v any) error { stack := debug.Stack() // The first line of the stack trace is of the form "goroutine N [status]:" diff --git a/util/singleflight/singleflight_test.go b/util/singleflight/singleflight_test.go index 031922736fab6..e938c9dcdd174 100644 --- a/util/singleflight/singleflight_test.go +++ b/util/singleflight/singleflight_test.go @@ -25,7 +25,7 @@ import ( func TestDo(t *testing.T) { var g Group[string, any] - v, err, _ := g.Do("key", func() (interface{}, error) { + v, err, _ := g.Do("key", func() (any, error) { return "bar", nil }) if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want { @@ -39,7 +39,7 @@ func TestDo(t *testing.T) { func TestDoErr(t *testing.T) { var g Group[string, any] someErr := errors.New("Some error") - v, err, _ := g.Do("key", func() (interface{}, error) { + v, err, _ := g.Do("key", func() (any, error) { return nil, someErr }) if err != someErr { @@ -55,7 +55,7 @@ func TestDoDupSuppress(t *testing.T) { var wg1, wg2 sync.WaitGroup c := make(chan string, 1) var calls int32 - fn := func() (interface{}, error) { + fn := func() (any, error) { if atomic.AddInt32(&calls, 1) == 1 { // First invocation. wg1.Done() @@ -108,7 +108,7 @@ func TestForget(t *testing.T) { ) go func() { - g.Do("key", func() (i interface{}, e error) { + g.Do("key", func() (i any, e error) { close(firstStarted) <-unblockFirst close(firstFinished) @@ -119,7 +119,7 @@ func TestForget(t *testing.T) { g.Forget("key") unblockSecond := make(chan struct{}) - secondResult := g.DoChan("key", func() (i interface{}, e error) { + secondResult := g.DoChan("key", func() (i any, e error) { <-unblockSecond return 2, nil }) @@ -127,7 +127,7 @@ func TestForget(t *testing.T) { close(unblockFirst) <-firstFinished - thirdResult := g.DoChan("key", func() (i interface{}, e error) { + thirdResult := g.DoChan("key", func() (i any, e error) { return 3, nil }) @@ -141,7 +141,7 @@ func TestForget(t *testing.T) { func TestDoChan(t *testing.T) { var g Group[string, any] - ch := g.DoChan("key", func() (interface{}, error) { + ch := g.DoChan("key", func() (any, error) { return "bar", nil }) @@ -160,7 +160,7 @@ func TestDoChan(t *testing.T) { // See https://github.com/golang/go/issues/41133 func TestPanicDo(t *testing.T) { var g Group[string, any] - fn := func() (interface{}, error) { + fn := func() (any, error) { panic("invalid memory address or nil pointer dereference") } @@ -197,7 +197,7 @@ func TestPanicDo(t *testing.T) { func TestGoexitDo(t *testing.T) { var g Group[string, any] - fn := func() (interface{}, error) { + fn := func() (any, error) { runtime.Goexit() return nil, nil } @@ -238,7 +238,7 @@ func TestPanicDoChan(t *testing.T) { }() g := new(Group[string, any]) - ch := g.DoChan("", func() (interface{}, error) { + ch := g.DoChan("", func() (any, error) { panic("Panicking in DoChan") }) <-ch @@ -283,7 +283,7 @@ func TestPanicDoSharedByDoChan(t *testing.T) { defer func() { recover() }() - g.Do("", func() (interface{}, error) { + g.Do("", func() (any, error) { close(blocked) <-unblock panic("Panicking in Do") @@ -291,7 +291,7 @@ func TestPanicDoSharedByDoChan(t *testing.T) { }() <-blocked - ch := g.DoChan("", func() (interface{}, error) { + ch := g.DoChan("", func() (any, error) { panic("DoChan unexpectedly executed callback") }) close(unblock) diff --git a/util/slicesx/slicesx.go b/util/slicesx/slicesx.go index e0b820eb71e91..ff9d473759fb0 100644 --- a/util/slicesx/slicesx.go +++ b/util/slicesx/slicesx.go @@ -95,6 +95,17 @@ func Filter[S ~[]T, T any](dst, src S, fn func(T) bool) S { return dst } +// AppendNonzero appends all non-zero elements of src to dst. +func AppendNonzero[S ~[]T, T comparable](dst, src S) S { + var zero T + for _, v := range src { + if v != zero { + dst = append(dst, v) + } + } + return dst +} + // AppendMatching appends elements in ps to dst if f(x) is true. func AppendMatching[T any](dst, ps []T, f func(T) bool) []T { for _, p := range ps { @@ -148,3 +159,43 @@ func FirstEqual[T comparable](s []T, v T) bool { func LastEqual[T comparable](s []T, v T) bool { return len(s) > 0 && s[len(s)-1] == v } + +// MapKeys returns the values of the map m. +// +// The keys will be in an indeterminate order. +// +// It's equivalent to golang.org/x/exp/maps.Keys, which +// unfortunately has the package name "maps", shadowing +// the std "maps" package. This version exists for clarity +// when reading call sites. +// +// As opposed to slices.Collect(maps.Keys(m)), this allocates +// the returned slice once to exactly the right size, rather than +// appending larger backing arrays as it goes. +func MapKeys[M ~map[K]V, K comparable, V any](m M) []K { + r := make([]K, 0, len(m)) + for k := range m { + r = append(r, k) + } + return r +} + +// MapValues returns the values of the map m. +// +// The values will be in an indeterminate order. +// +// It's equivalent to golang.org/x/exp/maps.Values, which +// unfortunately has the package name "maps", shadowing +// the std "maps" package. This version exists for clarity +// when reading call sites. +// +// As opposed to slices.Collect(maps.Values(m)), this allocates +// the returned slice once to exactly the right size, rather than +// appending larger backing arrays as it goes. +func MapValues[M ~map[K]V, K comparable, V any](m M) []V { + r := make([]V, 0, len(m)) + for _, v := range m { + r = append(r, v) + } + return r +} diff --git a/util/slicesx/slicesx_test.go b/util/slicesx/slicesx_test.go index 597b22b8335fe..34644928465d8 100644 --- a/util/slicesx/slicesx_test.go +++ b/util/slicesx/slicesx_test.go @@ -137,6 +137,19 @@ func TestFilterNoAllocations(t *testing.T) { } } +func TestAppendNonzero(t *testing.T) { + v := []string{"one", "two", "", "four"} + got := AppendNonzero(nil, v) + want := []string{"one", "two", "four"} + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } + got = AppendNonzero(v[:0], v) + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } +} + func TestAppendMatching(t *testing.T) { v := []string{"one", "two", "three", "four"} got := AppendMatching(v[:0], v, func(s string) bool { return len(s) > 3 }) diff --git a/util/stringsx/stringsx.go b/util/stringsx/stringsx.go new file mode 100644 index 0000000000000..6c7a8d20d4221 --- /dev/null +++ b/util/stringsx/stringsx.go @@ -0,0 +1,52 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package stringsx provides additional string manipulation functions +// that aren't in the standard library's strings package or go4.org/mem. +package stringsx + +import ( + "unicode" + "unicode/utf8" +) + +// CompareFold returns -1, 0, or 1 depending on whether a < b, a == b, or a > b, +// like cmp.Compare, but case insensitively. +func CompareFold(a, b string) int { + // Track our position in both strings + ia, ib := 0, 0 + for ia < len(a) && ib < len(b) { + ra, wa := nextRuneLower(a[ia:]) + rb, wb := nextRuneLower(b[ib:]) + if ra < rb { + return -1 + } + if ra > rb { + return 1 + } + ia += wa + ib += wb + if wa == 0 || wb == 0 { + break + } + } + + // If we've reached here, one or both strings are exhausted + // The shorter string is "less than" if they match up to this point + switch { + case ia == len(a) && ib == len(b): + return 0 + case ia == len(a): + return -1 + default: + return 1 + } +} + +// nextRuneLower returns the next rune in the string, lowercased, along with its +// original (consumed) width in bytes. If the string is empty, it returns +// (utf8.RuneError, 0) +func nextRuneLower(s string) (r rune, width int) { + r, width = utf8.DecodeRuneInString(s) + return unicode.ToLower(r), width +} diff --git a/util/stringsx/stringsx_test.go b/util/stringsx/stringsx_test.go new file mode 100644 index 0000000000000..8575c0b278fca --- /dev/null +++ b/util/stringsx/stringsx_test.go @@ -0,0 +1,78 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package stringsx + +import ( + "cmp" + "strings" + "testing" +) + +func TestCompareFold(t *testing.T) { + tests := []struct { + a, b string + }{ + // Basic ASCII cases + {"", ""}, + {"a", "a"}, + {"a", "A"}, + {"A", "a"}, + {"a", "b"}, + {"b", "a"}, + {"abc", "ABC"}, + {"ABC", "abc"}, + {"abc", "abd"}, + {"abd", "abc"}, + + // Length differences + {"abc", "ab"}, + {"ab", "abc"}, + + // Unicode cases + {"ä¸–į•Œ", "ä¸–į•Œ"}, + {"Helloä¸–į•Œ", "helloä¸–į•Œ"}, + {"ä¸–į•ŒHello", "ä¸–į•Œhello"}, + {"ä¸–į•Œ", "ä¸–į•Œx"}, + {"ä¸–į•Œx", "ä¸–į•Œ"}, + + // Special case folding examples + {"ß", "ss"}, // German sharp s + {"īŦ", "fi"}, // fi ligature + {"ÎŖ", "΃"}, // Greek sigma + {"İ", "i\u0307"}, // Turkish dotted I + + // Mixed cases + {"HelloWorld", "helloworld"}, + {"HELLOWORLD", "helloworld"}, + {"helloworld", "HELLOWORLD"}, + {"HelloWorld", "helloworld"}, + {"helloworld", "HelloWorld"}, + + // Edge cases + {" ", " "}, + {"1", "1"}, + {"123", "123"}, + {"!@#", "!@#"}, + } + + wants := []int{} + for _, tt := range tests { + got := CompareFold(tt.a, tt.b) + want := cmp.Compare(strings.ToLower(tt.a), strings.ToLower(tt.b)) + if got != want { + t.Errorf("CompareFold(%q, %q) = %v, want %v", tt.a, tt.b, got, want) + } + wants = append(wants, want) + } + + if n := testing.AllocsPerRun(1000, func() { + for i, tt := range tests { + if CompareFold(tt.a, tt.b) != wants[i] { + panic("unexpected") + } + } + }); n > 0 { + t.Errorf("allocs = %v; want 0", int(n)) + } +} diff --git a/util/syspolicy/caching_handler.go b/util/syspolicy/caching_handler.go deleted file mode 100644 index 5192958bc45a5..0000000000000 --- a/util/syspolicy/caching_handler.go +++ /dev/null @@ -1,122 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import ( - "errors" - "sync" -) - -// CachingHandler is a handler that reads policies from an underlying handler the first time each key is requested -// and permanently caches the result unless there is an error. If there is an ErrNoSuchKey error, that result is cached, -// otherwise the actual error is returned and the next read for that key will retry using the handler. -type CachingHandler struct { - mu sync.Mutex - strings map[string]string - uint64s map[string]uint64 - bools map[string]bool - strArrs map[string][]string - notFound map[string]bool - handler Handler -} - -// NewCachingHandler creates a CachingHandler given a handler. -func NewCachingHandler(handler Handler) *CachingHandler { - return &CachingHandler{ - handler: handler, - strings: make(map[string]string), - uint64s: make(map[string]uint64), - bools: make(map[string]bool), - strArrs: make(map[string][]string), - notFound: make(map[string]bool), - } -} - -// ReadString reads the policy settings value string given the key. -// ReadString first reads from the handler's cache before resorting to using the handler. -func (ch *CachingHandler) ReadString(key string) (string, error) { - ch.mu.Lock() - defer ch.mu.Unlock() - if val, ok := ch.strings[key]; ok { - return val, nil - } - if notFound := ch.notFound[key]; notFound { - return "", ErrNoSuchKey - } - val, err := ch.handler.ReadString(key) - if errors.Is(err, ErrNoSuchKey) { - ch.notFound[key] = true - return "", err - } else if err != nil { - return "", err - } - ch.strings[key] = val - return val, nil -} - -// ReadUInt64 reads the policy settings uint64 value given the key. -// ReadUInt64 first reads from the handler's cache before resorting to using the handler. -func (ch *CachingHandler) ReadUInt64(key string) (uint64, error) { - ch.mu.Lock() - defer ch.mu.Unlock() - if val, ok := ch.uint64s[key]; ok { - return val, nil - } - if notFound := ch.notFound[key]; notFound { - return 0, ErrNoSuchKey - } - val, err := ch.handler.ReadUInt64(key) - if errors.Is(err, ErrNoSuchKey) { - ch.notFound[key] = true - return 0, err - } else if err != nil { - return 0, err - } - ch.uint64s[key] = val - return val, nil -} - -// ReadBoolean reads the policy settings boolean value given the key. -// ReadBoolean first reads from the handler's cache before resorting to using the handler. -func (ch *CachingHandler) ReadBoolean(key string) (bool, error) { - ch.mu.Lock() - defer ch.mu.Unlock() - if val, ok := ch.bools[key]; ok { - return val, nil - } - if notFound := ch.notFound[key]; notFound { - return false, ErrNoSuchKey - } - val, err := ch.handler.ReadBoolean(key) - if errors.Is(err, ErrNoSuchKey) { - ch.notFound[key] = true - return false, err - } else if err != nil { - return false, err - } - ch.bools[key] = val - return val, nil -} - -// ReadBoolean reads the policy settings boolean value given the key. -// ReadBoolean first reads from the handler's cache before resorting to using the handler. -func (ch *CachingHandler) ReadStringArray(key string) ([]string, error) { - ch.mu.Lock() - defer ch.mu.Unlock() - if val, ok := ch.strArrs[key]; ok { - return val, nil - } - if notFound := ch.notFound[key]; notFound { - return nil, ErrNoSuchKey - } - val, err := ch.handler.ReadStringArray(key) - if errors.Is(err, ErrNoSuchKey) { - ch.notFound[key] = true - return nil, err - } else if err != nil { - return nil, err - } - ch.strArrs[key] = val - return val, nil -} diff --git a/util/syspolicy/caching_handler_test.go b/util/syspolicy/caching_handler_test.go deleted file mode 100644 index 881f6ff83c0f8..0000000000000 --- a/util/syspolicy/caching_handler_test.go +++ /dev/null @@ -1,262 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import ( - "testing" -) - -func TestHandlerReadString(t *testing.T) { - tests := []struct { - name string - key string - handlerKey Key - handlerValue string - handlerError error - preserveHandler bool - wantValue string - wantErr error - strings map[string]string - expectedCalls int - }{ - { - name: "read existing cached values", - key: "test", - handlerKey: "do not read", - strings: map[string]string{"test": "foo"}, - wantValue: "foo", - expectedCalls: 0, - }, - { - name: "read existing values not cached", - key: "test", - handlerKey: "test", - handlerValue: "foo", - wantValue: "foo", - expectedCalls: 1, - }, - { - name: "error no such key", - key: "test", - handlerKey: "test", - handlerError: ErrNoSuchKey, - wantErr: ErrNoSuchKey, - expectedCalls: 1, - }, - { - name: "other error", - key: "test", - handlerKey: "test", - handlerError: someOtherError, - wantErr: someOtherError, - preserveHandler: true, - expectedCalls: 2, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testHandler := &testHandler{ - t: t, - key: tt.handlerKey, - s: tt.handlerValue, - err: tt.handlerError, - } - cache := NewCachingHandler(testHandler) - if tt.strings != nil { - cache.strings = tt.strings - } - got, err := cache.ReadString(tt.key) - if err != tt.wantErr { - t.Errorf("err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("got %v want %v", got, cache.strings[tt.key]) - } - if !tt.preserveHandler { - testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil - } - got, err = cache.ReadString(tt.key) - if err != tt.wantErr { - t.Errorf("repeat err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) - } - if testHandler.calls != tt.expectedCalls { - t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) - } - }) - } -} - -func TestHandlerReadUint64(t *testing.T) { - tests := []struct { - name string - key string - handlerKey Key - handlerValue uint64 - handlerError error - preserveHandler bool - wantValue uint64 - wantErr error - uint64s map[string]uint64 - expectedCalls int - }{ - { - name: "read existing cached values", - key: "test", - handlerKey: "do not read", - uint64s: map[string]uint64{"test": 1}, - wantValue: 1, - expectedCalls: 0, - }, - { - name: "read existing values not cached", - key: "test", - handlerKey: "test", - handlerValue: 1, - wantValue: 1, - expectedCalls: 1, - }, - { - name: "error no such key", - key: "test", - handlerKey: "test", - handlerError: ErrNoSuchKey, - wantErr: ErrNoSuchKey, - expectedCalls: 1, - }, - { - name: "other error", - key: "test", - handlerKey: "test", - handlerError: someOtherError, - wantErr: someOtherError, - preserveHandler: true, - expectedCalls: 2, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testHandler := &testHandler{ - t: t, - key: tt.handlerKey, - u64: tt.handlerValue, - err: tt.handlerError, - } - cache := NewCachingHandler(testHandler) - if tt.uint64s != nil { - cache.uint64s = tt.uint64s - } - got, err := cache.ReadUInt64(tt.key) - if err != tt.wantErr { - t.Errorf("err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("got %v want %v", got, cache.strings[tt.key]) - } - if !tt.preserveHandler { - testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil - } - got, err = cache.ReadUInt64(tt.key) - if err != tt.wantErr { - t.Errorf("repeat err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) - } - if testHandler.calls != tt.expectedCalls { - t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) - } - }) - } - -} - -func TestHandlerReadBool(t *testing.T) { - tests := []struct { - name string - key string - handlerKey Key - handlerValue bool - handlerError error - preserveHandler bool - wantValue bool - wantErr error - bools map[string]bool - expectedCalls int - }{ - { - name: "read existing cached values", - key: "test", - handlerKey: "do not read", - bools: map[string]bool{"test": true}, - wantValue: true, - expectedCalls: 0, - }, - { - name: "read existing values not cached", - key: "test", - handlerKey: "test", - handlerValue: true, - wantValue: true, - expectedCalls: 1, - }, - { - name: "error no such key", - key: "test", - handlerKey: "test", - handlerError: ErrNoSuchKey, - wantErr: ErrNoSuchKey, - expectedCalls: 1, - }, - { - name: "other error", - key: "test", - handlerKey: "test", - handlerError: someOtherError, - wantErr: someOtherError, - preserveHandler: true, - expectedCalls: 2, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - testHandler := &testHandler{ - t: t, - key: tt.handlerKey, - b: tt.handlerValue, - err: tt.handlerError, - } - cache := NewCachingHandler(testHandler) - if tt.bools != nil { - cache.bools = tt.bools - } - got, err := cache.ReadBoolean(tt.key) - if err != tt.wantErr { - t.Errorf("err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("got %v want %v", got, cache.strings[tt.key]) - } - if !tt.preserveHandler { - testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil - } - got, err = cache.ReadBoolean(tt.key) - if err != tt.wantErr { - t.Errorf("repeat err=%v want %v", err, tt.wantErr) - } - if got != tt.wantValue { - t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) - } - if testHandler.calls != tt.expectedCalls { - t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) - } - }) - } - -} diff --git a/util/syspolicy/handler.go b/util/syspolicy/handler.go index f1fad97709a3f..c4bfd9de92594 100644 --- a/util/syspolicy/handler.go +++ b/util/syspolicy/handler.go @@ -4,16 +4,17 @@ package syspolicy import ( - "errors" - "sync/atomic" + "tailscale.com/util/syspolicy/rsop" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" + "tailscale.com/util/testenv" ) -var ( - handlerUsed atomic.Bool - handler Handler = defaultHandler{} -) +// TODO(nickkhyl): delete this file once other repos are updated. // Handler reads system policies from OS-specific storage. +// +// Deprecated: implementing a [source.Store] should be preferred. type Handler interface { // ReadString reads the policy setting's string value for the given key. // It should return ErrNoSuchKey if the key does not have a value set. @@ -29,55 +30,84 @@ type Handler interface { ReadStringArray(key string) ([]string, error) } -// ErrNoSuchKey is returned by a Handler when the specified key does not have a -// value set. -var ErrNoSuchKey = errors.New("no such key") - -// defaultHandler is the catch all syspolicy type for anything that isn't windows or apple. -type defaultHandler struct{} +// RegisterHandler wraps and registers the specified handler as the device's +// policy [source.Store] for the program's lifetime. +// +// Deprecated: using [RegisterStore] should be preferred. +func RegisterHandler(h Handler) { + rsop.RegisterStore("DeviceHandler", setting.DeviceScope, WrapHandler(h)) +} -func (defaultHandler) ReadString(_ string) (string, error) { - return "", ErrNoSuchKey +// SetHandlerForTest wraps and sets the specified handler as the device's policy +// [source.Store] for the duration of tb. +// +// Deprecated: using [MustRegisterStoreForTest] should be preferred. +func SetHandlerForTest(tb testenv.TB, h Handler) { + RegisterWellKnownSettingsForTest(tb) + MustRegisterStoreForTest(tb, "DeviceHandler-TestOnly", setting.DefaultScope(), WrapHandler(h)) } -func (defaultHandler) ReadUInt64(_ string) (uint64, error) { - return 0, ErrNoSuchKey +var _ source.Store = (*handlerStore)(nil) + +// handlerStore is a [source.Store] that calls the underlying [Handler]. +// +// TODO(nickkhyl): remove it when the corp and android repos are updated. +type handlerStore struct { + h Handler } -func (defaultHandler) ReadBoolean(_ string) (bool, error) { - return false, ErrNoSuchKey +// WrapHandler returns a [source.Store] that wraps the specified [Handler]. +func WrapHandler(h Handler) source.Store { + return handlerStore{h} } -func (defaultHandler) ReadStringArray(_ string) ([]string, error) { - return nil, ErrNoSuchKey +// Lock implements [source.Lockable]. +func (s handlerStore) Lock() error { + if lockable, ok := s.h.(source.Lockable); ok { + return lockable.Lock() + } + return nil } -// markHandlerInUse is called before handler methods are called. -func markHandlerInUse() { - handlerUsed.Store(true) +// Unlock implements [source.Lockable]. +func (s handlerStore) Unlock() { + if lockable, ok := s.h.(source.Lockable); ok { + lockable.Unlock() + } } -// RegisterHandler initializes the policy handler and ensures registration will happen once. -func RegisterHandler(h Handler) { - // Technically this assignment is not concurrency safe, but in the - // event that there was any risk of a data race, we will panic due to - // the CompareAndSwap failing. - handler = h - if !handlerUsed.CompareAndSwap(false, true) { - panic("handler was already used before registration") +// RegisterChangeCallback implements [source.Changeable]. +func (s handlerStore) RegisterChangeCallback(callback func()) (unregister func(), err error) { + if changeable, ok := s.h.(source.Changeable); ok { + return changeable.RegisterChangeCallback(callback) } + return func() {}, nil } -// TB is a subset of testing.TB that we use to set up test helpers. -// It's defined here to avoid pulling in the testing package. -type TB interface { - Helper() - Cleanup(func()) +// ReadString implements [source.Store]. +func (s handlerStore) ReadString(key setting.Key) (string, error) { + return s.h.ReadString(string(key)) } -func SetHandlerForTest(tb TB, h Handler) { - tb.Helper() - oldHandler := handler - handler = h - tb.Cleanup(func() { handler = oldHandler }) +// ReadUInt64 implements [source.Store]. +func (s handlerStore) ReadUInt64(key setting.Key) (uint64, error) { + return s.h.ReadUInt64(string(key)) +} + +// ReadBoolean implements [source.Store]. +func (s handlerStore) ReadBoolean(key setting.Key) (bool, error) { + return s.h.ReadBoolean(string(key)) +} + +// ReadStringArray implements [source.Store]. +func (s handlerStore) ReadStringArray(key setting.Key) ([]string, error) { + return s.h.ReadStringArray(string(key)) +} + +// Done implements [source.Expirable]. +func (s handlerStore) Done() <-chan struct{} { + if expirable, ok := s.h.(source.Expirable); ok { + return expirable.Done() + } + return nil } diff --git a/util/syspolicy/handler_test.go b/util/syspolicy/handler_test.go deleted file mode 100644 index 39b18936f176d..0000000000000 --- a/util/syspolicy/handler_test.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import "testing" - -func TestDefaultHandlerReadValues(t *testing.T) { - var h defaultHandler - - got, err := h.ReadString(string(AdminConsoleVisibility)) - if got != "" || err != ErrNoSuchKey { - t.Fatalf("got %v err %v", got, err) - } - result, err := h.ReadUInt64(string(LogSCMInteractions)) - if result != 0 || err != ErrNoSuchKey { - t.Fatalf("got %v err %v", result, err) - } -} diff --git a/util/syspolicy/handler_windows.go b/util/syspolicy/handler_windows.go deleted file mode 100644 index 661853ead5d53..0000000000000 --- a/util/syspolicy/handler_windows.go +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -import ( - "errors" - "fmt" - - "tailscale.com/util/clientmetric" - "tailscale.com/util/winutil" -) - -var ( - windowsErrors = clientmetric.NewCounter("windows_syspolicy_errors") - windowsAny = clientmetric.NewGauge("windows_syspolicy_any") -) - -type windowsHandler struct{} - -func init() { - RegisterHandler(NewCachingHandler(windowsHandler{})) - - keyList := []struct { - isSet func(Key) bool - keys []Key - }{ - { - isSet: func(k Key) bool { - _, err := handler.ReadString(string(k)) - return err == nil - }, - keys: stringKeys, - }, - { - isSet: func(k Key) bool { - _, err := handler.ReadBoolean(string(k)) - return err == nil - }, - keys: boolKeys, - }, - { - isSet: func(k Key) bool { - _, err := handler.ReadUInt64(string(k)) - return err == nil - }, - keys: uint64Keys, - }, - } - - var anySet bool - for _, l := range keyList { - for _, k := range l.keys { - if !l.isSet(k) { - continue - } - clientmetric.NewGauge(fmt.Sprintf("windows_syspolicy_%s", k)).Set(1) - anySet = true - } - } - if anySet { - windowsAny.Set(1) - } -} - -func (windowsHandler) ReadString(key string) (string, error) { - s, err := winutil.GetPolicyString(key) - if errors.Is(err, winutil.ErrNoValue) { - err = ErrNoSuchKey - } else if err != nil { - windowsErrors.Add(1) - } - - return s, err -} - -func (windowsHandler) ReadUInt64(key string) (uint64, error) { - value, err := winutil.GetPolicyInteger(key) - if errors.Is(err, winutil.ErrNoValue) { - err = ErrNoSuchKey - } else if err != nil { - windowsErrors.Add(1) - } - return value, err -} - -func (windowsHandler) ReadBoolean(key string) (bool, error) { - value, err := winutil.GetPolicyInteger(key) - if errors.Is(err, winutil.ErrNoValue) { - err = ErrNoSuchKey - } else if err != nil { - windowsErrors.Add(1) - } - return value != 0, err -} - -func (windowsHandler) ReadStringArray(key string) ([]string, error) { - value, err := winutil.GetPolicyStringArray(key) - if errors.Is(err, winutil.ErrNoValue) { - err = ErrNoSuchKey - } else if err != nil { - windowsErrors.Add(1) - } - return value, err -} diff --git a/util/syspolicy/internal/internal.go b/util/syspolicy/internal/internal.go index 8f28896259abf..6ab147de6d096 100644 --- a/util/syspolicy/internal/internal.go +++ b/util/syspolicy/internal/internal.go @@ -10,6 +10,7 @@ import ( "github.com/go-json-experiment/json/jsontext" "tailscale.com/types/lazy" + "tailscale.com/util/testenv" "tailscale.com/version" ) @@ -25,22 +26,10 @@ func OS() string { return OSForTesting.Get(version.OS) } -// TB is a subset of testing.TB that we use to set up test helpers. -// It's defined here to avoid pulling in the testing package. -type TB interface { - Helper() - Cleanup(func()) - Logf(format string, args ...any) - Error(args ...any) - Errorf(format string, args ...any) - Fatal(args ...any) - Fatalf(format string, args ...any) -} - // EqualJSONForTest compares the JSON in j1 and j2 for semantic equality. // It returns "", "", true if j1 and j2 are equal. Otherwise, it returns // indented versions of j1 and j2 and false. -func EqualJSONForTest(tb TB, j1, j2 jsontext.Value) (s1, s2 string, equal bool) { +func EqualJSONForTest(tb testenv.TB, j1, j2 jsontext.Value) (s1, s2 string, equal bool) { tb.Helper() j1 = j1.Clone() j2 = j2.Clone() @@ -56,10 +45,10 @@ func EqualJSONForTest(tb TB, j1, j2 jsontext.Value) (s1, s2 string, equal bool) return "", "", true } // Otherwise, format the values for display and return false. - if err := j1.Indent("", "\t"); err != nil { + if err := j1.Indent(); err != nil { tb.Fatal(err) } - if err := j2.Indent("", "\t"); err != nil { + if err := j2.Indent(); err != nil { tb.Fatal(err) } return j1.String(), j2.String(), false diff --git a/util/syspolicy/internal/loggerx/logger.go b/util/syspolicy/internal/loggerx/logger.go index c29a5f0845cd6..d1f48cbb428fe 100644 --- a/util/syspolicy/internal/loggerx/logger.go +++ b/util/syspolicy/internal/loggerx/logger.go @@ -10,7 +10,7 @@ import ( "tailscale.com/types/lazy" "tailscale.com/types/logger" - "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/testenv" ) const ( @@ -58,7 +58,7 @@ func verbosef(format string, args ...any) { // SetForTest sets the specified printf and verbosef functions for the duration // of tb and its subtests. -func SetForTest(tb internal.TB, printf, verbosef logger.Logf) { +func SetForTest(tb testenv.TB, printf, verbosef logger.Logf) { lazyPrintf.SetForTest(tb, printf, nil) lazyVerbosef.SetForTest(tb, verbosef, nil) } diff --git a/util/syspolicy/internal/metrics/metrics.go b/util/syspolicy/internal/metrics/metrics.go index 2ea02278afc92..43f2a285a26ea 100644 --- a/util/syspolicy/internal/metrics/metrics.go +++ b/util/syspolicy/internal/metrics/metrics.go @@ -259,7 +259,7 @@ var addMetricTestHook, setMetricTestHook syncs.AtomicValue[metricFn] // SetHooksForTest sets the specified addMetric and setMetric functions // as the metric functions for the duration of tb and all its subtests. -func SetHooksForTest(tb internal.TB, addMetric, setMetric metricFn) { +func SetHooksForTest(tb testenv.TB, addMetric, setMetric metricFn) { oldAddMetric := addMetricTestHook.Swap(addMetric) oldSetMetric := setMetricTestHook.Swap(setMetric) tb.Cleanup(func() { @@ -284,12 +284,13 @@ func SetHooksForTest(tb internal.TB, addMetric, setMetric metricFn) { } func newSettingMetric(key setting.Key, scope setting.Scope, suffix string, typ clientmetric.Type) metric { - name := strings.ReplaceAll(string(key), setting.KeyPathSeparator, "_") + name := strings.ReplaceAll(string(key), string(setting.KeyPathSeparator), "_") + name = strings.ReplaceAll(name, ".", "_") // dots are not allowed in metric names return newMetric([]string{name, metricScopeName(scope), suffix}, typ) } func newMetric(nameParts []string, typ clientmetric.Type) metric { - name := strings.Join(slicesx.Filter([]string{internal.OS(), "syspolicy"}, nameParts, isNonEmpty), "_") + name := strings.Join(slicesx.AppendNonzero([]string{internal.OS(), "syspolicy"}, nameParts), "_") switch { case !ShouldReport(): return &funcMetric{name: name, typ: typ} @@ -304,8 +305,6 @@ func newMetric(nameParts []string, typ clientmetric.Type) metric { } } -func isNonEmpty(s string) bool { return s != "" } - func metricScopeName(scope setting.Scope) string { switch scope { case setting.DeviceSetting: diff --git a/util/syspolicy/internal/metrics/test_handler.go b/util/syspolicy/internal/metrics/test_handler.go index f9e4846092be3..36c3f2cad876a 100644 --- a/util/syspolicy/internal/metrics/test_handler.go +++ b/util/syspolicy/internal/metrics/test_handler.go @@ -9,6 +9,7 @@ import ( "tailscale.com/util/clientmetric" "tailscale.com/util/set" "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/testenv" ) // TestState represents a metric name and its expected value. @@ -19,13 +20,13 @@ type TestState struct { // TestHandler facilitates testing of the code that uses metrics. type TestHandler struct { - t internal.TB + t testenv.TB m map[string]int64 } // NewTestHandler returns a new TestHandler. -func NewTestHandler(t internal.TB) *TestHandler { +func NewTestHandler(t testenv.TB) *TestHandler { return &TestHandler{t, make(map[string]int64)} } diff --git a/util/syspolicy/policy_keys.go b/util/syspolicy/policy_keys.go index ec0556a942cc6..29b2dfd281c4a 100644 --- a/util/syspolicy/policy_keys.go +++ b/util/syspolicy/policy_keys.go @@ -3,15 +3,51 @@ package syspolicy -import "tailscale.com/util/syspolicy/setting" +import ( + "tailscale.com/types/lazy" + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/testenv" +) +// Key is a string that uniquely identifies a policy and must remain unchanged +// once established and documented for a given policy setting. It may contain +// alphanumeric characters and zero or more [KeyPathSeparator]s to group +// individual policy settings into categories. type Key = setting.Key +// The const block below lists known policy keys. +// When adding a key to this list, remember to add a corresponding +// [setting.Definition] to [implicitDefinitions] below. +// Otherwise, the [TestKnownKeysRegistered] test will fail as a reminder. + const ( // Keys with a string value ControlURL Key = "LoginURL" // default ""; if blank, ipn uses ipn.DefaultControlURL. LogTarget Key = "LogTarget" // default ""; if blank logging uses logtail.DefaultHost. Tailnet Key = "Tailnet" // default ""; if blank, no tailnet name is sent to the server. + + // AlwaysOn is a boolean key that controls whether Tailscale + // should always remain in a connected state, and the user should + // not be able to disconnect at their discretion. + // + // Warning: This policy setting is experimental and may change or be removed in the future. + // It may also not be fully supported by all Tailscale clients until it is out of experimental status. + // See tailscale/corp#26247, tailscale/corp#26248 and tailscale/corp#26249 for more information. + AlwaysOn Key = "AlwaysOn.Enabled" + + // AlwaysOnOverrideWithReason is a boolean key that alters the behavior + // of [AlwaysOn]. When true, the user is allowed to disconnect Tailscale + // by providing a reason. The reason is logged and sent to the control + // for auditing purposes. It has no effect when [AlwaysOn] is false. + AlwaysOnOverrideWithReason Key = "AlwaysOn.OverrideWithReason" + + // ReconnectAfter is a string value formatted for use with time.ParseDuration() + // that defines the duration after which the client should automatically reconnect + // to the Tailscale network following a user-initiated disconnect. + // An empty string or a zero duration disables automatic reconnection. + ReconnectAfter Key = "ReconnectAfter" + // ExitNodeID is the exit node's node id. default ""; if blank, no exit node is forced. // Exit node ID takes precedence over exit node IP. // To find the node ID, go to /api.md#device. @@ -27,6 +63,14 @@ const ( ExitNodeAllowLANAccess Key = "ExitNodeAllowLANAccess" EnableTailscaleDNS Key = "UseTailscaleDNSSettings" EnableTailscaleSubnets Key = "UseTailscaleSubnets" + + // EnableDNSRegistration is a string value that can be set to "always", "never" + // or "user-decides". It controls whether DNS registration and dynamic DNS + // updates are enabled for the Tailscale interface. For historical reasons + // and to maintain compatibility with existing setups, the default is "never". + // It is only used on Windows. + EnableDNSRegistration Key = "EnableDNSRegistration" + // CheckUpdates is the key to signal if the updater should periodically // check for updates. CheckUpdates Key = "CheckUpdates" @@ -63,6 +107,9 @@ const ( // SuggestedExitNodeVisibility controls the visibility of suggested exit nodes in the client GUI. // When this system policy is set to 'hide', an exit node suggestion won't be presented to the user as part of the exit nodes picker. SuggestedExitNodeVisibility Key = "SuggestedExitNode" + // OnboardingFlowVisibility controls the visibility of the onboarding flow in the client GUI. + // When this system policy is set to 'hide', the onboarding flow is never shown to the user. + OnboardingFlowVisibility Key = "OnboardingFlow" // Keys with a string value formatted for use with time.ParseDuration(). KeyExpirationNoticeTime Key = "KeyExpirationNotice" // default 24 hours @@ -106,7 +153,105 @@ const ( // Example: "CN=Tailscale Inc Test Root CA,OU=Tailscale Inc Test Certificate Authority,O=Tailscale Inc,ST=ON,C=CA" MachineCertificateSubject Key = "MachineCertificateSubject" + // Hostname is the hostname of the device that is running Tailscale. + // When this policy is set, it overrides the hostname that the client + // would otherwise obtain from the OS, e.g. by calling os.Hostname(). + Hostname Key = "Hostname" + // Keys with a string array value. // AllowedSuggestedExitNodes's string array value is a list of exit node IDs that restricts which exit nodes are considered when generating suggestions for exit nodes. AllowedSuggestedExitNodes Key = "AllowedSuggestedExitNodes" ) + +// implicitDefinitions is a list of [setting.Definition] that will be registered +// automatically when the policy setting definitions are first used by the syspolicy package hierarchy. +// This includes the first time a policy needs to be read from any source. +var implicitDefinitions = []*setting.Definition{ + // Device policy settings (can only be configured on a per-device basis): + setting.NewDefinition(AllowedSuggestedExitNodes, setting.DeviceSetting, setting.StringListValue), + setting.NewDefinition(AlwaysOn, setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition(AlwaysOnOverrideWithReason, setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition(ApplyUpdates, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(AuthKey, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(CheckUpdates, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(ControlURL, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(DeviceSerialNumber, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(EnableDNSRegistration, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(EnableIncomingConnections, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(EnableRunExitNode, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(EnableServerMode, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(EnableTailscaleDNS, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(EnableTailscaleSubnets, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(ExitNodeAllowLANAccess, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(ExitNodeID, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(ExitNodeIP, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(FlushDNSOnSessionUnlock, setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition(Hostname, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(LogSCMInteractions, setting.DeviceSetting, setting.BooleanValue), + setting.NewDefinition(LogTarget, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(MachineCertificateSubject, setting.DeviceSetting, setting.StringValue), + setting.NewDefinition(PostureChecking, setting.DeviceSetting, setting.PreferenceOptionValue), + setting.NewDefinition(ReconnectAfter, setting.DeviceSetting, setting.DurationValue), + setting.NewDefinition(Tailnet, setting.DeviceSetting, setting.StringValue), + + // User policy settings (can be configured on a user- or device-basis): + setting.NewDefinition(AdminConsoleVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(AutoUpdateVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(ExitNodeMenuVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(KeyExpirationNoticeTime, setting.UserSetting, setting.DurationValue), + setting.NewDefinition(ManagedByCaption, setting.UserSetting, setting.StringValue), + setting.NewDefinition(ManagedByOrganizationName, setting.UserSetting, setting.StringValue), + setting.NewDefinition(ManagedByURL, setting.UserSetting, setting.StringValue), + setting.NewDefinition(NetworkDevicesVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(PreferencesMenuVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(ResetToDefaultsVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(RunExitNodeVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(SuggestedExitNodeVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(TestMenuVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(UpdateMenuVisibility, setting.UserSetting, setting.VisibilityValue), + setting.NewDefinition(OnboardingFlowVisibility, setting.UserSetting, setting.VisibilityValue), +} + +func init() { + internal.Init.MustDefer(func() error { + // Avoid implicit [setting.Definition] registration during tests. + // Each test should control which policy settings to register. + // Use [setting.SetDefinitionsForTest] to specify necessary definitions, + // or [setWellKnownSettingsForTest] to set implicit definitions for the test duration. + if testenv.InTest() { + return nil + } + for _, d := range implicitDefinitions { + setting.RegisterDefinition(d) + } + return nil + }) +} + +var implicitDefinitionMap lazy.SyncValue[setting.DefinitionMap] + +// WellKnownSettingDefinition returns a well-known, implicit setting definition by its key, +// or an [ErrNoSuchKey] if a policy setting with the specified key does not exist +// among implicit policy definitions. +func WellKnownSettingDefinition(k Key) (*setting.Definition, error) { + m, err := implicitDefinitionMap.GetErr(func() (setting.DefinitionMap, error) { + return setting.DefinitionMapOf(implicitDefinitions) + }) + if err != nil { + return nil, err + } + if d, ok := m[k]; ok { + return d, nil + } + return nil, ErrNoSuchKey +} + +// RegisterWellKnownSettingsForTest registers all implicit setting definitions +// for the duration of the test. +func RegisterWellKnownSettingsForTest(tb testenv.TB) { + tb.Helper() + err := setting.SetDefinitionsForTest(tb, implicitDefinitions...) + if err != nil { + tb.Fatalf("Failed to register well-known settings: %v", err) + } +} diff --git a/util/syspolicy/policy_keys_test.go b/util/syspolicy/policy_keys_test.go new file mode 100644 index 0000000000000..4d3260f3e0e60 --- /dev/null +++ b/util/syspolicy/policy_keys_test.go @@ -0,0 +1,95 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syspolicy + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + "go/types" + "os" + "reflect" + "strconv" + "testing" + + "tailscale.com/util/syspolicy/setting" +) + +func TestKnownKeysRegistered(t *testing.T) { + keyConsts, err := listStringConsts[Key]("policy_keys.go") + if err != nil { + t.Fatalf("listStringConsts failed: %v", err) + } + + m, err := setting.DefinitionMapOf(implicitDefinitions) + if err != nil { + t.Fatalf("definitionMapOf failed: %v", err) + } + + for _, key := range keyConsts { + t.Run(string(key), func(t *testing.T) { + d := m[key] + if d == nil { + t.Fatalf("%q was not registered", key) + } + if d.Key() != key { + t.Fatalf("d.Key got: %s, want %s", d.Key(), key) + } + }) + } +} + +func TestNotAWellKnownSetting(t *testing.T) { + d, err := WellKnownSettingDefinition("TestSettingDoesNotExist") + if d != nil || err == nil { + t.Fatalf("got %v, %v; want nil, %v", d, err, ErrNoSuchKey) + } +} + +func listStringConsts[T ~string](filename string) (map[string]T, error) { + fset := token.NewFileSet() + src, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + + f, err := parser.ParseFile(fset, filename, src, 0) + if err != nil { + return nil, err + } + + consts := make(map[string]T) + typeName := reflect.TypeFor[T]().Name() + for _, d := range f.Decls { + g, ok := d.(*ast.GenDecl) + if !ok || g.Tok != token.CONST { + continue + } + + for _, s := range g.Specs { + vs, ok := s.(*ast.ValueSpec) + if !ok || len(vs.Names) != len(vs.Values) { + continue + } + if typ, ok := vs.Type.(*ast.Ident); !ok || typ.Name != typeName { + continue + } + + for i, n := range vs.Names { + lit, ok := vs.Values[i].(*ast.BasicLit) + if !ok { + return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, types.ExprString(vs.Values[i])) + } + val, err := strconv.Unquote(lit.Value) + if err != nil { + return nil, fmt.Errorf("unexpected string literal: %v = %v", n.Name, lit.Value) + } + consts[n.Name] = T(val) + } + } + } + + return consts, nil +} diff --git a/util/syspolicy/policy_keys_windows.go b/util/syspolicy/policy_keys_windows.go deleted file mode 100644 index 5e9a716957bdb..0000000000000 --- a/util/syspolicy/policy_keys_windows.go +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package syspolicy - -var stringKeys = []Key{ - ControlURL, - LogTarget, - Tailnet, - ExitNodeID, - ExitNodeIP, - EnableIncomingConnections, - EnableServerMode, - ExitNodeAllowLANAccess, - EnableTailscaleDNS, - EnableTailscaleSubnets, - AdminConsoleVisibility, - NetworkDevicesVisibility, - TestMenuVisibility, - UpdateMenuVisibility, - RunExitNodeVisibility, - PreferencesMenuVisibility, - ExitNodeMenuVisibility, - AutoUpdateVisibility, - ResetToDefaultsVisibility, - KeyExpirationNoticeTime, - PostureChecking, - ManagedByOrganizationName, - ManagedByCaption, - ManagedByURL, -} - -var boolKeys = []Key{ - LogSCMInteractions, - FlushDNSOnSessionUnlock, -} - -var uint64Keys = []Key{} diff --git a/util/syspolicy/rsop/resultant_policy.go b/util/syspolicy/rsop/resultant_policy.go index 019b8f602f86d..297d26f9f6fe5 100644 --- a/util/syspolicy/rsop/resultant_policy.go +++ b/util/syspolicy/rsop/resultant_policy.go @@ -13,6 +13,7 @@ import ( "tailscale.com/util/syspolicy/internal/loggerx" "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/testenv" "tailscale.com/util/syspolicy/source" ) @@ -447,3 +448,9 @@ func (p *Policy) Close() { go p.closeInternal() } } + +func setForTest[T any](tb testenv.TB, target *T, newValue T) { + oldValue := *target + tb.Cleanup(func() { *target = oldValue }) + *target = newValue +} diff --git a/util/syspolicy/rsop/resultant_policy_test.go b/util/syspolicy/rsop/resultant_policy_test.go index b2408c7f71519..e4bfb1a886878 100644 --- a/util/syspolicy/rsop/resultant_policy_test.go +++ b/util/syspolicy/rsop/resultant_policy_test.go @@ -574,9 +574,6 @@ func TestPolicyChangeHasChanged(t *testing.T) { } func TestChangePolicySetting(t *testing.T) { - setForTest(t, &policyReloadMinDelay, 100*time.Millisecond) - setForTest(t, &policyReloadMaxDelay, 500*time.Millisecond) - // Register policy settings used in this test. settingA := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue) settingB := setting.NewDefinition("TestSettingB", setting.DeviceSetting, setting.StringValue) @@ -589,6 +586,10 @@ func TestChangePolicySetting(t *testing.T) { if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil { t.Fatalf("Failed to register policy store: %v", err) } + + setForTest(t, &policyReloadMinDelay, 100*time.Millisecond) + setForTest(t, &policyReloadMaxDelay, 500*time.Millisecond) + policy, err := policyForTest(t, setting.DeviceScope) if err != nil { t.Fatalf("Failed to get effective policy: %v", err) @@ -978,9 +979,3 @@ func policyForTest(tb testing.TB, target setting.PolicyScope) (*Policy, error) { }) return policy, nil } - -func setForTest[T any](tb testing.TB, target *T, newValue T) { - oldValue := *target - tb.Cleanup(func() { *target = oldValue }) - *target = newValue -} diff --git a/util/syspolicy/rsop/store_registration.go b/util/syspolicy/rsop/store_registration.go index 09c83e98804ca..a7c354b6d5678 100644 --- a/util/syspolicy/rsop/store_registration.go +++ b/util/syspolicy/rsop/store_registration.go @@ -7,10 +7,11 @@ import ( "errors" "sync" "sync/atomic" + "time" - "tailscale.com/util/syspolicy/internal" "tailscale.com/util/syspolicy/setting" "tailscale.com/util/syspolicy/source" + "tailscale.com/util/testenv" ) // ErrAlreadyConsumed is the error returned when [StoreRegistration.ReplaceStore] @@ -32,7 +33,10 @@ func RegisterStore(name string, scope setting.PolicyScope, store source.Store) ( // RegisterStoreForTest is like [RegisterStore], but unregisters the store when // tb and all its subtests complete. -func RegisterStoreForTest(tb internal.TB, name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { +func RegisterStoreForTest(tb testenv.TB, name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { + setForTest(tb, &policyReloadMinDelay, 10*time.Millisecond) + setForTest(tb, &policyReloadMaxDelay, 500*time.Millisecond) + reg, err := RegisterStore(name, scope, store) if err == nil { tb.Cleanup(func() { diff --git a/util/syspolicy/setting/key.go b/util/syspolicy/setting/key.go index 406fde1321cc2..aa7606d36324a 100644 --- a/util/syspolicy/setting/key.go +++ b/util/syspolicy/setting/key.go @@ -10,4 +10,4 @@ package setting type Key string // KeyPathSeparator allows logical grouping of policy settings into categories. -const KeyPathSeparator = "/" +const KeyPathSeparator = '/' diff --git a/util/syspolicy/setting/origin.go b/util/syspolicy/setting/origin.go index 078ef758e9150..4c7cc7025cc48 100644 --- a/util/syspolicy/setting/origin.go +++ b/util/syspolicy/setting/origin.go @@ -50,22 +50,27 @@ func (s Origin) String() string { return s.Scope().String() } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (s Origin) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { - return jsonv2.MarshalEncode(out, &s.data, opts) +var ( + _ jsonv2.MarshalerTo = (*Origin)(nil) + _ jsonv2.UnmarshalerFrom = (*Origin)(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (s Origin) MarshalJSONTo(out *jsontext.Encoder) error { + return jsonv2.MarshalEncode(out, &s.data) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (s *Origin) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { - return jsonv2.UnmarshalDecode(in, &s.data, opts) +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (s *Origin) UnmarshalJSONFrom(in *jsontext.Decoder) error { + return jsonv2.UnmarshalDecode(in, &s.data) } // MarshalJSON implements [json.Marshaler]. func (s Origin) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(s) // uses MarshalJSONV2 + return jsonv2.Marshal(s) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (s *Origin) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONFrom } diff --git a/util/syspolicy/setting/raw_item.go b/util/syspolicy/setting/raw_item.go index 30480d8923f71..9a96073b01297 100644 --- a/util/syspolicy/setting/raw_item.go +++ b/util/syspolicy/setting/raw_item.go @@ -5,7 +5,11 @@ package setting import ( "fmt" + "reflect" + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" + "tailscale.com/types/opt" "tailscale.com/types/structs" ) @@ -17,10 +21,15 @@ import ( // or converted from strings, these setting types predate the typed policy // hierarchies, and must be supported at this layer. type RawItem struct { - _ structs.Incomparable - value any - err *ErrorText - origin *Origin // or nil + _ structs.Incomparable + data rawItemJSON +} + +// rawItemJSON holds JSON-marshallable data for [RawItem]. +type rawItemJSON struct { + Value RawValue `json:",omitzero"` + Error *ErrorText `json:",omitzero"` // or nil + Origin *Origin `json:",omitzero"` // or nil } // RawItemOf returns a [RawItem] with the specified value. @@ -30,20 +39,20 @@ func RawItemOf(value any) RawItem { // RawItemWith returns a [RawItem] with the specified value, error and origin. func RawItemWith(value any, err *ErrorText, origin *Origin) RawItem { - return RawItem{value: value, err: err, origin: origin} + return RawItem{data: rawItemJSON{Value: RawValue{opt.ValueOf(value)}, Error: err, Origin: origin}} } // Value returns the value of the policy setting, or nil if the policy setting // is not configured, or an error occurred while reading it. func (i RawItem) Value() any { - return i.value + return i.data.Value.Get() } // Error returns the error that occurred when reading the policy setting, // or nil if no error occurred. func (i RawItem) Error() error { - if i.err != nil { - return i.err + if i.data.Error != nil { + return i.data.Error } return nil } @@ -51,17 +60,113 @@ func (i RawItem) Error() error { // Origin returns an optional [Origin] indicating where the policy setting is // configured. func (i RawItem) Origin() *Origin { - return i.origin + return i.data.Origin } // String implements [fmt.Stringer]. func (i RawItem) String() string { var suffix string - if i.origin != nil { - suffix = fmt.Sprintf(" - {%v}", i.origin) + if i.data.Origin != nil { + suffix = fmt.Sprintf(" - {%v}", i.data.Origin) + } + if i.data.Error != nil { + return fmt.Sprintf("Error{%q}%s", i.data.Error.Error(), suffix) + } + return fmt.Sprintf("%v%s", i.data.Value.Value, suffix) +} + +var ( + _ jsonv2.MarshalerTo = (*RawItem)(nil) + _ jsonv2.UnmarshalerFrom = (*RawItem)(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (i RawItem) MarshalJSONTo(out *jsontext.Encoder) error { + return jsonv2.MarshalEncode(out, &i.data) +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (i *RawItem) UnmarshalJSONFrom(in *jsontext.Decoder) error { + return jsonv2.UnmarshalDecode(in, &i.data) +} + +// MarshalJSON implements [json.Marshaler]. +func (i RawItem) MarshalJSON() ([]byte, error) { + return jsonv2.Marshal(i) // uses MarshalJSONTo +} + +// UnmarshalJSON implements [json.Unmarshaler]. +func (i *RawItem) UnmarshalJSON(b []byte) error { + return jsonv2.Unmarshal(b, i) // uses UnmarshalJSONFrom +} + +// RawValue represents a raw policy setting value read from a policy store. +// It is JSON-marshallable and facilitates unmarshalling of JSON values +// into corresponding policy setting types, with special handling for JSON numbers +// (unmarshalled as float64) and JSON string arrays (unmarshalled as []string). +// See also [RawValue.UnmarshalJSONFrom]. +type RawValue struct { + opt.Value[any] +} + +// RawValueType is a constraint that permits raw setting value types. +type RawValueType interface { + bool | uint64 | string | []string +} + +// RawValueOf returns a new [RawValue] holding the specified value. +func RawValueOf[T RawValueType](v T) RawValue { + return RawValue{opt.ValueOf[any](v)} +} + +var ( + _ jsonv2.MarshalerTo = (*RawValue)(nil) + _ jsonv2.UnmarshalerFrom = (*RawValue)(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v RawValue) MarshalJSONTo(out *jsontext.Encoder) error { + return jsonv2.MarshalEncode(out, v.Value) +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom] by attempting to unmarshal +// a JSON value as one of the supported policy setting value types (bool, string, uint64, or []string), +// based on the JSON value type. It fails if the JSON value is an object, if it's a JSON number that +// cannot be represented as a uint64, or if a JSON array contains anything other than strings. +func (v *RawValue) UnmarshalJSONFrom(in *jsontext.Decoder) error { + var valPtr any + switch k := in.PeekKind(); k { + case 't', 'f': + valPtr = new(bool) + case '"': + valPtr = new(string) + case '0': + valPtr = new(uint64) // unmarshal JSON numbers as uint64 + case '[', 'n': + valPtr = new([]string) // unmarshal arrays as string slices + case '{': + return fmt.Errorf("unexpected token: %v", k) + default: + panic("unreachable") } - if i.err != nil { - return fmt.Sprintf("Error{%q}%s", i.err.Error(), suffix) + if err := jsonv2.UnmarshalDecode(in, valPtr); err != nil { + v.Value.Clear() + return err } - return fmt.Sprintf("%v%s", i.value, suffix) + value := reflect.ValueOf(valPtr).Elem().Interface() + v.Value = opt.ValueOf(value) + return nil +} + +// MarshalJSON implements [json.Marshaler]. +func (v RawValue) MarshalJSON() ([]byte, error) { + return jsonv2.Marshal(v) // uses MarshalJSONTo +} + +// UnmarshalJSON implements [json.Unmarshaler]. +func (v *RawValue) UnmarshalJSON(b []byte) error { + return jsonv2.Unmarshal(b, v) // uses UnmarshalJSONFrom } + +// RawValues is a map of keyed setting values that can be read from a JSON. +type RawValues map[Key]RawValue diff --git a/util/syspolicy/setting/raw_item_test.go b/util/syspolicy/setting/raw_item_test.go new file mode 100644 index 0000000000000..05562d78c41f3 --- /dev/null +++ b/util/syspolicy/setting/raw_item_test.go @@ -0,0 +1,101 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package setting + +import ( + "math" + "reflect" + "strconv" + "testing" + + jsonv2 "github.com/go-json-experiment/json" +) + +func TestMarshalUnmarshalRawValue(t *testing.T) { + tests := []struct { + name string + json string + want RawValue + wantErr bool + }{ + { + name: "Bool/True", + json: `true`, + want: RawValueOf(true), + }, + { + name: "Bool/False", + json: `false`, + want: RawValueOf(false), + }, + { + name: "String/Empty", + json: `""`, + want: RawValueOf(""), + }, + { + name: "String/NonEmpty", + json: `"Test"`, + want: RawValueOf("Test"), + }, + { + name: "StringSlice/Null", + json: `null`, + want: RawValueOf([]string(nil)), + }, + { + name: "StringSlice/Empty", + json: `[]`, + want: RawValueOf([]string{}), + }, + { + name: "StringSlice/NonEmpty", + json: `["A", "B", "C"]`, + want: RawValueOf([]string{"A", "B", "C"}), + }, + { + name: "StringSlice/NonStrings", + json: `[1, 2, 3]`, + wantErr: true, + }, + { + name: "Number/Integer/0", + json: `0`, + want: RawValueOf(uint64(0)), + }, + { + name: "Number/Integer/1", + json: `1`, + want: RawValueOf(uint64(1)), + }, + { + name: "Number/Integer/MaxUInt64", + json: strconv.FormatUint(math.MaxUint64, 10), + want: RawValueOf(uint64(math.MaxUint64)), + }, + { + name: "Number/Integer/Negative", + json: `-1`, + wantErr: true, + }, + { + name: "Object", + json: `{}`, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got RawValue + gotErr := jsonv2.Unmarshal([]byte(tt.json), &got) + if (gotErr != nil) != tt.wantErr { + t.Fatalf("Error: got %v; want %v", gotErr, tt.wantErr) + } + + if !tt.wantErr && !reflect.DeepEqual(got, tt.want) { + t.Fatalf("Value: got %v; want %v", got, tt.want) + } + }) + } +} diff --git a/util/syspolicy/setting/setting.go b/util/syspolicy/setting/setting.go index 70fb0a931e250..13c7a2a5fc1a9 100644 --- a/util/syspolicy/setting/setting.go +++ b/util/syspolicy/setting/setting.go @@ -16,6 +16,7 @@ import ( "tailscale.com/types/lazy" "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/testenv" ) // Scope indicates the broadest scope at which a policy setting may apply, @@ -277,7 +278,7 @@ func DefinitionMapOf(settings []*Definition) (DefinitionMap, error) { // for the test duration. It is not concurrency-safe, but unlike [Register], // it does not panic and can be called anytime. // It returns an error if ds contains two different settings with the same [Key]. -func SetDefinitionsForTest(tb lazy.TB, ds ...*Definition) error { +func SetDefinitionsForTest(tb testenv.TB, ds ...*Definition) error { m, err := DefinitionMapOf(ds) if err != nil { return err diff --git a/util/syspolicy/setting/snapshot.go b/util/syspolicy/setting/snapshot.go index 512bc487c5b98..087325a04c6f1 100644 --- a/util/syspolicy/setting/snapshot.go +++ b/util/syspolicy/setting/snapshot.go @@ -4,11 +4,14 @@ package setting import ( + "errors" "iter" "maps" "slices" "strings" + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" xmaps "golang.org/x/exp/maps" "tailscale.com/util/deephash" ) @@ -65,6 +68,9 @@ func (s *Snapshot) GetSetting(k Key) (setting RawItem, ok bool) { // Equal reports whether s and s2 are equal. func (s *Snapshot) Equal(s2 *Snapshot) bool { + if s == s2 { + return true + } if !s.EqualItems(s2) { return false } @@ -135,6 +141,50 @@ func (s *Snapshot) String() string { return sb.String() } +// snapshotJSON holds JSON-marshallable data for [Snapshot]. +type snapshotJSON struct { + Summary Summary `json:",omitzero"` + Settings map[Key]RawItem `json:",omitempty"` +} + +var ( + _ jsonv2.MarshalerTo = (*Snapshot)(nil) + _ jsonv2.UnmarshalerFrom = (*Snapshot)(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (s *Snapshot) MarshalJSONTo(out *jsontext.Encoder) error { + data := &snapshotJSON{} + if s != nil { + data.Summary = s.summary + data.Settings = s.m + } + return jsonv2.MarshalEncode(out, data) +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (s *Snapshot) UnmarshalJSONFrom(in *jsontext.Decoder) error { + if s == nil { + return errors.New("s must not be nil") + } + data := &snapshotJSON{} + if err := jsonv2.UnmarshalDecode(in, data); err != nil { + return err + } + *s = Snapshot{m: data.Settings, sig: deephash.Hash(&data.Settings), summary: data.Summary} + return nil +} + +// MarshalJSON implements [json.Marshaler]. +func (s *Snapshot) MarshalJSON() ([]byte, error) { + return jsonv2.Marshal(s) // uses MarshalJSONTo +} + +// UnmarshalJSON implements [json.Unmarshaler]. +func (s *Snapshot) UnmarshalJSON(b []byte) error { + return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONFrom +} + // MergeSnapshots returns a [Snapshot] that contains all [RawItem]s // from snapshot1 and snapshot2 and the [Summary] with the narrower [PolicyScope]. // If there's a conflict between policy settings in the two snapshots, diff --git a/util/syspolicy/setting/snapshot_test.go b/util/syspolicy/setting/snapshot_test.go index e198d4a58bfdb..d41b362f06976 100644 --- a/util/syspolicy/setting/snapshot_test.go +++ b/util/syspolicy/setting/snapshot_test.go @@ -4,8 +4,13 @@ package setting import ( + "cmp" + "encoding/json" "testing" "time" + + jsonv2 "github.com/go-json-experiment/json" + "tailscale.com/util/syspolicy/internal" ) func TestMergeSnapshots(t *testing.T) { @@ -30,134 +35,134 @@ func TestMergeSnapshots(t *testing.T) { name: "first-nil", s1: nil, s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), }, { name: "first-empty", s1: NewSnapshot(map[Key]RawItem{}), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), }, { name: "second-nil", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), s2: nil, want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), }, { name: "second-empty", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), s2: NewSnapshot(map[Key]RawItem{}), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), }, { name: "no-conflicts", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), s2: NewSnapshot(map[Key]RawItem{ - "Setting4": {value: 2 * time.Hour}, - "Setting5": {value: VisibleByPolicy}, - "Setting6": {value: ShowChoiceByPolicy}, + "Setting4": RawItemOf(2 * time.Hour), + "Setting5": RawItemOf(VisibleByPolicy), + "Setting6": RawItemOf(ShowChoiceByPolicy), }), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, - "Setting5": {value: VisibleByPolicy}, - "Setting6": {value: ShowChoiceByPolicy}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), + "Setting5": RawItemOf(VisibleByPolicy), + "Setting6": RawItemOf(ShowChoiceByPolicy), }), }, { name: "with-conflicts", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 456}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, + "Setting1": RawItemOf(456), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), }), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 456}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, + "Setting1": RawItemOf(456), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), }), }, { name: "with-scope-first-wins", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }, DeviceScope), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 456}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, + "Setting1": RawItemOf(456), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), }, CurrentUserScope), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, - "Setting4": {value: 2 * time.Hour}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), + "Setting4": RawItemOf(2 * time.Hour), }, CurrentUserScope), }, { name: "with-scope-second-wins", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }, CurrentUserScope), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 456}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, + "Setting1": RawItemOf(456), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), }, DeviceScope), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 456}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, - "Setting4": {value: 2 * time.Hour}, + "Setting1": RawItemOf(456), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), + "Setting4": RawItemOf(2 * time.Hour), }, CurrentUserScope), }, { @@ -170,28 +175,27 @@ func TestMergeSnapshots(t *testing.T) { name: "with-scope-first-empty", s1: NewSnapshot(map[Key]RawItem{}, CurrentUserScope), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}}, - DeviceScope, NewNamedOrigin("TestPolicy", DeviceScope)), + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true)}, DeviceScope, NewNamedOrigin("TestPolicy", DeviceScope)), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }, CurrentUserScope, NewNamedOrigin("TestPolicy", DeviceScope)), }, { name: "with-scope-second-empty", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }, CurrentUserScope), s2: NewSnapshot(map[Key]RawItem{}), want: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }, CurrentUserScope), }, } @@ -244,9 +248,9 @@ func TestSnapshotEqual(t *testing.T) { name: "first-nil", s1: nil, s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), wantEqual: false, wantEqualItems: false, @@ -255,9 +259,9 @@ func TestSnapshotEqual(t *testing.T) { name: "first-empty", s1: NewSnapshot(map[Key]RawItem{}), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), wantEqual: false, wantEqualItems: false, @@ -265,9 +269,9 @@ func TestSnapshotEqual(t *testing.T) { { name: "second-nil", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: true}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(true), }), s2: nil, wantEqual: false, @@ -276,9 +280,9 @@ func TestSnapshotEqual(t *testing.T) { { name: "second-empty", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), s2: NewSnapshot(map[Key]RawItem{}), wantEqual: false, @@ -287,14 +291,14 @@ func TestSnapshotEqual(t *testing.T) { { name: "same-items-same-order-no-scope", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }), wantEqual: true, wantEqualItems: true, @@ -302,14 +306,14 @@ func TestSnapshotEqual(t *testing.T) { { name: "same-items-same-order-same-scope", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, DeviceScope), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, DeviceScope), wantEqual: true, wantEqualItems: true, @@ -317,14 +321,14 @@ func TestSnapshotEqual(t *testing.T) { { name: "same-items-different-order-same-scope", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, DeviceScope), s2: NewSnapshot(map[Key]RawItem{ - "Setting3": {value: false}, - "Setting1": {value: 123}, - "Setting2": {value: "String"}, + "Setting3": RawItemOf(false), + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), }, DeviceScope), wantEqual: true, wantEqualItems: true, @@ -332,14 +336,14 @@ func TestSnapshotEqual(t *testing.T) { { name: "same-items-same-order-different-scope", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, DeviceScope), s2: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, CurrentUserScope), wantEqual: false, wantEqualItems: true, @@ -347,14 +351,14 @@ func TestSnapshotEqual(t *testing.T) { { name: "different-items-same-scope", s1: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 123}, - "Setting2": {value: "String"}, - "Setting3": {value: false}, + "Setting1": RawItemOf(123), + "Setting2": RawItemOf("String"), + "Setting3": RawItemOf(false), }, DeviceScope), s2: NewSnapshot(map[Key]RawItem{ - "Setting4": {value: 2 * time.Hour}, - "Setting5": {value: VisibleByPolicy}, - "Setting6": {value: ShowChoiceByPolicy}, + "Setting4": RawItemOf(2 * time.Hour), + "Setting5": RawItemOf(VisibleByPolicy), + "Setting6": RawItemOf(ShowChoiceByPolicy), }, DeviceScope), wantEqual: false, wantEqualItems: false, @@ -401,9 +405,9 @@ func TestSnapshotString(t *testing.T) { { name: "non-empty", snapshot: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 2 * time.Hour}, - "Setting2": {value: VisibleByPolicy}, - "Setting3": {value: ShowChoiceByPolicy}, + "Setting1": RawItemOf(2 * time.Hour), + "Setting2": RawItemOf(VisibleByPolicy), + "Setting3": RawItemOf(ShowChoiceByPolicy), }, NewNamedOrigin("Test Policy", DeviceScope)), wantString: `{Test Policy (Device)} Setting1 = 2h0m0s @@ -413,14 +417,14 @@ Setting3 = user-decides`, { name: "non-empty-with-item-origin", snapshot: NewSnapshot(map[Key]RawItem{ - "Setting1": {value: 42, origin: NewNamedOrigin("Test Policy", DeviceScope)}, + "Setting1": RawItemWith(42, nil, NewNamedOrigin("Test Policy", DeviceScope)), }), wantString: `Setting1 = 42 - {Test Policy (Device)}`, }, { name: "non-empty-with-item-error", snapshot: NewSnapshot(map[Key]RawItem{ - "Setting1": {err: NewErrorText("bang!")}, + "Setting1": RawItemWith(nil, NewErrorText("bang!"), nil), }), wantString: `Setting1 = Error{"bang!"}`, }, @@ -433,3 +437,133 @@ Setting3 = user-decides`, }) } } + +func TestMarshalUnmarshalSnapshot(t *testing.T) { + tests := []struct { + name string + snapshot *Snapshot + wantJSON string + wantBack *Snapshot + }{ + { + name: "Nil", + snapshot: (*Snapshot)(nil), + wantJSON: "null", + wantBack: NewSnapshot(nil), + }, + { + name: "Zero", + snapshot: &Snapshot{}, + wantJSON: "{}", + }, + { + name: "Bool/True", + snapshot: NewSnapshot(map[Key]RawItem{"BoolPolicy": RawItemOf(true)}), + wantJSON: `{"Settings": {"BoolPolicy": {"Value": true}}}`, + }, + { + name: "Bool/False", + snapshot: NewSnapshot(map[Key]RawItem{"BoolPolicy": RawItemOf(false)}), + wantJSON: `{"Settings": {"BoolPolicy": {"Value": false}}}`, + }, + { + name: "String/Non-Empty", + snapshot: NewSnapshot(map[Key]RawItem{"StringPolicy": RawItemOf("StringValue")}), + wantJSON: `{"Settings": {"StringPolicy": {"Value": "StringValue"}}}`, + }, + { + name: "String/Empty", + snapshot: NewSnapshot(map[Key]RawItem{"StringPolicy": RawItemOf("")}), + wantJSON: `{"Settings": {"StringPolicy": {"Value": ""}}}`, + }, + { + name: "Integer/NonZero", + snapshot: NewSnapshot(map[Key]RawItem{"IntPolicy": RawItemOf(uint64(42))}), + wantJSON: `{"Settings": {"IntPolicy": {"Value": 42}}}`, + }, + { + name: "Integer/Zero", + snapshot: NewSnapshot(map[Key]RawItem{"IntPolicy": RawItemOf(uint64(0))}), + wantJSON: `{"Settings": {"IntPolicy": {"Value": 0}}}`, + }, + { + name: "String-List", + snapshot: NewSnapshot(map[Key]RawItem{"ListPolicy": RawItemOf([]string{"Value1", "Value2"})}), + wantJSON: `{"Settings": {"ListPolicy": {"Value": ["Value1", "Value2"]}}}`, + }, + { + name: "Empty/With-Summary", + snapshot: NewSnapshot( + map[Key]RawItem{}, + SummaryWith(CurrentUserScope, NewNamedOrigin("TestSource", DeviceScope)), + ), + wantJSON: `{"Summary": {"Origin": {"Name": "TestSource", "Scope": "Device"}, "Scope": "User"}}`, + }, + { + name: "Setting/With-Summary", + snapshot: NewSnapshot( + map[Key]RawItem{"PolicySetting": RawItemOf(uint64(42))}, + SummaryWith(CurrentUserScope, NewNamedOrigin("TestSource", DeviceScope)), + ), + wantJSON: `{ + "Summary": {"Origin": {"Name": "TestSource", "Scope": "Device"}, "Scope": "User"}, + "Settings": {"PolicySetting": {"Value": 42}} + }`, + }, + { + name: "Settings/With-Origins", + snapshot: NewSnapshot( + map[Key]RawItem{ + "SettingA": RawItemWith(uint64(42), nil, NewNamedOrigin("SourceA", DeviceScope)), + "SettingB": RawItemWith("B", nil, NewNamedOrigin("SourceB", CurrentProfileScope)), + "SettingC": RawItemWith(true, nil, NewNamedOrigin("SourceC", CurrentUserScope)), + }, + ), + wantJSON: `{ + "Settings": { + "SettingA": {"Value": 42, "Origin": {"Name": "SourceA", "Scope": "Device"}}, + "SettingB": {"Value": "B", "Origin": {"Name": "SourceB", "Scope": "Profile"}}, + "SettingC": {"Value": true, "Origin": {"Name": "SourceC", "Scope": "User"}} + } + }`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + doTest := func(t *testing.T, useJSONv2 bool) { + var gotJSON []byte + var err error + if useJSONv2 { + gotJSON, err = jsonv2.Marshal(tt.snapshot) + } else { + gotJSON, err = json.Marshal(tt.snapshot) + } + if err != nil { + t.Fatal(err) + } + + if got, want, equal := internal.EqualJSONForTest(t, gotJSON, []byte(tt.wantJSON)); !equal { + t.Errorf("JSON: got %s; want %s", got, want) + } + + gotBack := &Snapshot{} + if useJSONv2 { + err = jsonv2.Unmarshal(gotJSON, &gotBack) + } else { + err = json.Unmarshal(gotJSON, &gotBack) + } + if err != nil { + t.Fatal(err) + } + + if wantBack := cmp.Or(tt.wantBack, tt.snapshot); !gotBack.Equal(wantBack) { + t.Errorf("Snapshot: got %+v; want %+v", gotBack, wantBack) + } + } + + t.Run("json", func(t *testing.T) { doTest(t, false) }) + t.Run("jsonv2", func(t *testing.T) { doTest(t, true) }) + }) + } +} diff --git a/util/syspolicy/setting/summary.go b/util/syspolicy/setting/summary.go index 5ff20e0aa2752..9864822f7a235 100644 --- a/util/syspolicy/setting/summary.go +++ b/util/syspolicy/setting/summary.go @@ -54,24 +54,29 @@ func (s Summary) String() string { return s.data.Scope.String() } -// MarshalJSONV2 implements [jsonv2.MarshalerV2]. -func (s Summary) MarshalJSONV2(out *jsontext.Encoder, opts jsonv2.Options) error { - return jsonv2.MarshalEncode(out, &s.data, opts) +var ( + _ jsonv2.MarshalerTo = (*Summary)(nil) + _ jsonv2.UnmarshalerFrom = (*Summary)(nil) +) + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (s Summary) MarshalJSONTo(out *jsontext.Encoder) error { + return jsonv2.MarshalEncode(out, &s.data) } -// UnmarshalJSONV2 implements [jsonv2.UnmarshalerV2]. -func (s *Summary) UnmarshalJSONV2(in *jsontext.Decoder, opts jsonv2.Options) error { - return jsonv2.UnmarshalDecode(in, &s.data, opts) +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (s *Summary) UnmarshalJSONFrom(in *jsontext.Decoder) error { + return jsonv2.UnmarshalDecode(in, &s.data) } // MarshalJSON implements [json.Marshaler]. func (s Summary) MarshalJSON() ([]byte, error) { - return jsonv2.Marshal(s) // uses MarshalJSONV2 + return jsonv2.Marshal(s) // uses MarshalJSONTo } // UnmarshalJSON implements [json.Unmarshaler]. func (s *Summary) UnmarshalJSON(b []byte) error { - return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONV2 + return jsonv2.Unmarshal(b, s) // uses UnmarshalJSONFrom } // SummaryOption is an option that configures [Summary] diff --git a/util/syspolicy/source/env_policy_store.go b/util/syspolicy/source/env_policy_store.go new file mode 100644 index 0000000000000..299132b4e11b3 --- /dev/null +++ b/util/syspolicy/source/env_policy_store.go @@ -0,0 +1,159 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "errors" + "fmt" + "os" + "strconv" + "strings" + "unicode/utf8" + + "tailscale.com/util/syspolicy/setting" +) + +var lookupEnv = os.LookupEnv // test hook + +var _ Store = (*EnvPolicyStore)(nil) + +// EnvPolicyStore is a [Store] that reads policy settings from environment variables. +type EnvPolicyStore struct{} + +// ReadString implements [Store]. +func (s *EnvPolicyStore) ReadString(key setting.Key) (string, error) { + _, str, err := s.lookupSettingVariable(key) + if err != nil { + return "", err + } + return str, nil +} + +// ReadUInt64 implements [Store]. +func (s *EnvPolicyStore) ReadUInt64(key setting.Key) (uint64, error) { + name, str, err := s.lookupSettingVariable(key) + if err != nil { + return 0, err + } + if str == "" { + return 0, setting.ErrNotConfigured + } + value, err := strconv.ParseUint(str, 0, 64) + if err != nil { + return 0, fmt.Errorf("%s: %w: %q is not a valid uint64", name, setting.ErrTypeMismatch, str) + } + return value, nil +} + +// ReadBoolean implements [Store]. +func (s *EnvPolicyStore) ReadBoolean(key setting.Key) (bool, error) { + name, str, err := s.lookupSettingVariable(key) + if err != nil { + return false, err + } + if str == "" { + return false, setting.ErrNotConfigured + } + value, err := strconv.ParseBool(str) + if err != nil { + return false, fmt.Errorf("%s: %w: %q is not a valid bool", name, setting.ErrTypeMismatch, str) + } + return value, nil +} + +// ReadStringArray implements [Store]. +func (s *EnvPolicyStore) ReadStringArray(key setting.Key) ([]string, error) { + _, str, err := s.lookupSettingVariable(key) + if err != nil || str == "" { + return nil, err + } + var dst int + res := strings.Split(str, ",") + for src := range res { + res[dst] = strings.TrimSpace(res[src]) + if res[dst] != "" { + dst++ + } + } + return res[0:dst], nil +} + +func (s *EnvPolicyStore) lookupSettingVariable(key setting.Key) (name, value string, err error) { + name, err = keyToEnvVarName(key) + if err != nil { + return "", "", err + } + value, ok := lookupEnv(name) + if !ok { + return name, "", setting.ErrNotConfigured + } + return name, value, nil +} + +var ( + errEmptyKey = errors.New("key must not be empty") + errInvalidKey = errors.New("key must consist of alphanumeric characters and slashes") +) + +// keyToEnvVarName returns the environment variable name for a given policy +// setting key, or an error if the key is invalid. It converts CamelCase keys into +// underscore-separated words and prepends the variable name with the TS prefix. +// For example: AuthKey => TS_AUTH_KEY, ExitNodeAllowLANAccess => TS_EXIT_NODE_ALLOW_LAN_ACCESS, etc. +// +// It's fine to use this in [EnvPolicyStore] without caching variable names since it's not a hot path. +// [EnvPolicyStore] is not a [Changeable] policy store, so the conversion will only happen once. +func keyToEnvVarName(key setting.Key) (string, error) { + if len(key) == 0 { + return "", errEmptyKey + } + + isLower := func(c byte) bool { return 'a' <= c && c <= 'z' } + isUpper := func(c byte) bool { return 'A' <= c && c <= 'Z' } + isLetter := func(c byte) bool { return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') } + isDigit := func(c byte) bool { return '0' <= c && c <= '9' } + + words := make([]string, 0, 8) + words = append(words, "TS_DEBUGSYSPOLICY") + var currentWord strings.Builder + for i := 0; i < len(key); i++ { + c := key[i] + if c >= utf8.RuneSelf { + return "", errInvalidKey + } + + var split bool + switch { + case isLower(c): + c -= 'a' - 'A' // make upper + split = currentWord.Len() > 0 && !isLetter(key[i-1]) + case isUpper(c): + if currentWord.Len() > 0 { + prevUpper := isUpper(key[i-1]) + nextLower := i < len(key)-1 && isLower(key[i+1]) + split = !prevUpper || nextLower // split on case transition + } + case isDigit(c): + split = currentWord.Len() > 0 && !isDigit(key[i-1]) + case c == setting.KeyPathSeparator: + words = append(words, currentWord.String()) + currentWord.Reset() + continue + default: + return "", errInvalidKey + } + + if split { + words = append(words, currentWord.String()) + currentWord.Reset() + } + + currentWord.WriteByte(c) + } + + if currentWord.Len() > 0 { + words = append(words, currentWord.String()) + } + + return strings.Join(words, "_"), nil +} diff --git a/util/syspolicy/source/env_policy_store_test.go b/util/syspolicy/source/env_policy_store_test.go new file mode 100644 index 0000000000000..9eacf6378b450 --- /dev/null +++ b/util/syspolicy/source/env_policy_store_test.go @@ -0,0 +1,359 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package source + +import ( + "cmp" + "errors" + "math" + "reflect" + "strconv" + "testing" + + "tailscale.com/util/syspolicy/setting" +) + +func TestKeyToEnvVarName(t *testing.T) { + tests := []struct { + name string + key setting.Key + want string // suffix after "TS_DEBUGSYSPOLICY_" + wantErr error + }{ + { + name: "empty", + key: "", + wantErr: errEmptyKey, + }, + { + name: "lowercase", + key: "tailnet", + want: "TAILNET", + }, + { + name: "CamelCase", + key: "AuthKey", + want: "AUTH_KEY", + }, + { + name: "LongerCamelCase", + key: "ManagedByOrganizationName", + want: "MANAGED_BY_ORGANIZATION_NAME", + }, + { + name: "UPPERCASE", + key: "UPPERCASE", + want: "UPPERCASE", + }, + { + name: "WithAbbrev/Front", + key: "DNSServer", + want: "DNS_SERVER", + }, + { + name: "WithAbbrev/Middle", + key: "ExitNodeAllowLANAccess", + want: "EXIT_NODE_ALLOW_LAN_ACCESS", + }, + { + name: "WithAbbrev/Back", + key: "ExitNodeID", + want: "EXIT_NODE_ID", + }, + { + name: "WithDigits/Single/Front", + key: "0TestKey", + want: "0_TEST_KEY", + }, + { + name: "WithDigits/Multi/Front", + key: "64TestKey", + want: "64_TEST_KEY", + }, + { + name: "WithDigits/Single/Middle", + key: "Test0Key", + want: "TEST_0_KEY", + }, + { + name: "WithDigits/Multi/Middle", + key: "Test64Key", + want: "TEST_64_KEY", + }, + { + name: "WithDigits/Single/Back", + key: "TestKey0", + want: "TEST_KEY_0", + }, + { + name: "WithDigits/Multi/Back", + key: "TestKey64", + want: "TEST_KEY_64", + }, + { + name: "WithDigits/Multi/Back", + key: "TestKey64", + want: "TEST_KEY_64", + }, + { + name: "WithPathSeparators/Single", + key: "Key/Subkey", + want: "KEY_SUBKEY", + }, + { + name: "WithPathSeparators/Multi", + key: "Root/Level1/Level2", + want: "ROOT_LEVEL_1_LEVEL_2", + }, + { + name: "Mixed", + key: "Network/DNSServer/IPAddress", + want: "NETWORK_DNS_SERVER_IP_ADDRESS", + }, + { + name: "Non-Alphanumeric/NonASCII/1", + key: "Đļ", + wantErr: errInvalidKey, + }, + { + name: "Non-Alphanumeric/NonASCII/2", + key: "KeyĐļName", + wantErr: errInvalidKey, + }, + { + name: "Non-Alphanumeric/Space", + key: "Key Name", + wantErr: errInvalidKey, + }, + { + name: "Non-Alphanumeric/Punct", + key: "Key!Name", + wantErr: errInvalidKey, + }, + { + name: "Non-Alphanumeric/Backslash", + key: `Key\Name`, + wantErr: errInvalidKey, + }, + } + for _, tt := range tests { + t.Run(cmp.Or(tt.name, string(tt.key)), func(t *testing.T) { + got, err := keyToEnvVarName(tt.key) + checkError(t, err, tt.wantErr, true) + + want := tt.want + if want != "" { + want = "TS_DEBUGSYSPOLICY_" + want + } + if got != want { + t.Fatalf("got %q; want %q", got, want) + } + }) + } +} + +func TestEnvPolicyStore(t *testing.T) { + blankEnv := func(string) (string, bool) { return "", false } + makeEnv := func(wantName, value string) func(string) (string, bool) { + wantName = "TS_DEBUGSYSPOLICY_" + wantName + return func(gotName string) (string, bool) { + if gotName != wantName { + return "", false + } + return value, true + } + } + tests := []struct { + name string + key setting.Key + lookup func(string) (string, bool) + want any + wantErr error + }{ + { + name: "NotConfigured/String", + key: "AuthKey", + lookup: blankEnv, + wantErr: setting.ErrNotConfigured, + want: "", + }, + { + name: "Configured/String/Empty", + key: "AuthKey", + lookup: makeEnv("AUTH_KEY", ""), + want: "", + }, + { + name: "Configured/String/NonEmpty", + key: "AuthKey", + lookup: makeEnv("AUTH_KEY", "ABC123"), + want: "ABC123", + }, + { + name: "NotConfigured/UInt64", + key: "IntegerSetting", + lookup: blankEnv, + wantErr: setting.ErrNotConfigured, + want: uint64(0), + }, + { + name: "Configured/UInt64/Empty", + key: "IntegerSetting", + lookup: makeEnv("INTEGER_SETTING", ""), + wantErr: setting.ErrNotConfigured, + want: uint64(0), + }, + { + name: "Configured/UInt64/Zero", + key: "IntegerSetting", + lookup: makeEnv("INTEGER_SETTING", "0"), + want: uint64(0), + }, + { + name: "Configured/UInt64/NonZero", + key: "IntegerSetting", + lookup: makeEnv("INTEGER_SETTING", "12345"), + want: uint64(12345), + }, + { + name: "Configured/UInt64/MaxUInt64", + key: "IntegerSetting", + lookup: makeEnv("INTEGER_SETTING", strconv.FormatUint(math.MaxUint64, 10)), + want: uint64(math.MaxUint64), + }, + { + name: "Configured/UInt64/Negative", + key: "IntegerSetting", + lookup: makeEnv("INTEGER_SETTING", "-1"), + wantErr: setting.ErrTypeMismatch, + want: uint64(0), + }, + { + name: "Configured/UInt64/Hex", + key: "IntegerSetting", + lookup: makeEnv("INTEGER_SETTING", "0xDEADBEEF"), + want: uint64(0xDEADBEEF), + }, + { + name: "NotConfigured/Bool", + key: "LogSCMInteractions", + lookup: blankEnv, + wantErr: setting.ErrNotConfigured, + want: false, + }, + { + name: "Configured/Bool/Empty", + key: "LogSCMInteractions", + lookup: makeEnv("LOG_SCM_INTERACTIONS", ""), + wantErr: setting.ErrNotConfigured, + want: false, + }, + { + name: "Configured/Bool/True", + key: "LogSCMInteractions", + lookup: makeEnv("LOG_SCM_INTERACTIONS", "true"), + want: true, + }, + { + name: "Configured/Bool/False", + key: "LogSCMInteractions", + lookup: makeEnv("LOG_SCM_INTERACTIONS", "False"), + want: false, + }, + { + name: "Configured/Bool/1", + key: "LogSCMInteractions", + lookup: makeEnv("LOG_SCM_INTERACTIONS", "1"), + want: true, + }, + { + name: "Configured/Bool/0", + key: "LogSCMInteractions", + lookup: makeEnv("LOG_SCM_INTERACTIONS", "0"), + want: false, + }, + { + name: "Configured/Bool/Invalid", + key: "IntegerSetting", + lookup: makeEnv("INTEGER_SETTING", "NotABool"), + wantErr: setting.ErrTypeMismatch, + want: false, + }, + { + name: "NotConfigured/StringArray", + key: "AllowedSuggestedExitNodes", + lookup: blankEnv, + wantErr: setting.ErrNotConfigured, + want: []string(nil), + }, + { + name: "Configured/StringArray/Empty", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", ""), + want: []string(nil), + }, + { + name: "Configured/StringArray/Spaces", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", " \t "), + want: []string{}, + }, + { + name: "Configured/StringArray/Single", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", "NodeA"), + want: []string{"NodeA"}, + }, + { + name: "Configured/StringArray/Multi", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", "NodeA,NodeB,NodeC"), + want: []string{"NodeA", "NodeB", "NodeC"}, + }, + { + name: "Configured/StringArray/WithBlank", + key: "AllowedSuggestedExitNodes", + lookup: makeEnv("ALLOWED_SUGGESTED_EXIT_NODES", "NodeA,\t,, ,NodeB"), + want: []string{"NodeA", "NodeB"}, + }, + } + for _, tt := range tests { + t.Run(cmp.Or(tt.name, string(tt.key)), func(t *testing.T) { + oldLookupEnv := lookupEnv + t.Cleanup(func() { lookupEnv = oldLookupEnv }) + lookupEnv = tt.lookup + + var got any + var err error + var store EnvPolicyStore + switch tt.want.(type) { + case string: + got, err = store.ReadString(tt.key) + case uint64: + got, err = store.ReadUInt64(tt.key) + case bool: + got, err = store.ReadBoolean(tt.key) + case []string: + got, err = store.ReadStringArray(tt.key) + } + checkError(t, err, tt.wantErr, false) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +} + +func checkError(tb testing.TB, got, want error, fatal bool) { + tb.Helper() + f := tb.Errorf + if fatal { + f = tb.Fatalf + } + if (want == nil && got != nil) || + (want != nil && got == nil) || + (want != nil && got != nil && !errors.Is(got, want) && want.Error() != got.Error()) { + f("gotErr: %v; wantErr: %v", got, want) + } +} diff --git a/util/syspolicy/source/policy_store_windows.go b/util/syspolicy/source/policy_store_windows.go index f526b4ce1c666..621701e84f23c 100644 --- a/util/syspolicy/source/policy_store_windows.go +++ b/util/syspolicy/source/policy_store_windows.go @@ -12,6 +12,7 @@ import ( "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" "tailscale.com/util/set" + "tailscale.com/util/syspolicy/internal/loggerx" "tailscale.com/util/syspolicy/setting" "tailscale.com/util/winutil/gp" ) @@ -29,6 +30,18 @@ var ( _ Expirable = (*PlatformPolicyStore)(nil) ) +// lockableCloser is a [Lockable] that can also be closed. +// It is implemented by [gp.PolicyLock] and [optionalPolicyLock]. +type lockableCloser interface { + Lockable + Close() error +} + +var ( + _ lockableCloser = (*gp.PolicyLock)(nil) + _ lockableCloser = (*optionalPolicyLock)(nil) +) + // PlatformPolicyStore implements [Store] by providing read access to // Registry-based Tailscale policies, such as those configured via Group Policy or MDM. // For better performance and consistency, it is recommended to lock it when @@ -55,7 +68,7 @@ type PlatformPolicyStore struct { // they are being read. // // When both policyLock and mu need to be taken, mu must be taken before policyLock. - policyLock *gp.PolicyLock + policyLock lockableCloser mu sync.Mutex tsKeys []registry.Key // or nil if the [PlatformPolicyStore] hasn't been locked. @@ -108,7 +121,7 @@ func newPlatformPolicyStore(scope gp.Scope, softwareKey registry.Key, policyLock scope: scope, softwareKey: softwareKey, done: make(chan struct{}), - policyLock: policyLock, + policyLock: &optionalPolicyLock{PolicyLock: policyLock}, } } @@ -319,9 +332,9 @@ func (ps *PlatformPolicyStore) ReadStringArray(key setting.Key) ([]string, error // If there are no [setting.KeyPathSeparator]s in the key, the policy setting value // is meant to be stored directly under {HKLM,HKCU}\Software\Policies\Tailscale. func splitSettingKey(key setting.Key) (path, valueName string) { - if idx := strings.LastIndex(string(key), setting.KeyPathSeparator); idx != -1 { - path = strings.ReplaceAll(string(key[:idx]), setting.KeyPathSeparator, `\`) - valueName = string(key[idx+len(setting.KeyPathSeparator):]) + if idx := strings.LastIndexByte(string(key), setting.KeyPathSeparator); idx != -1 { + path = strings.ReplaceAll(string(key[:idx]), string(setting.KeyPathSeparator), `\`) + valueName = string(key[idx+1:]) return path, valueName } return "", string(key) @@ -448,3 +461,68 @@ func tailscaleKeyNamesFor(scope gp.Scope) []string { panic("unreachable") } } + +type gpLockState int + +const ( + gpUnlocked = gpLockState(iota) + gpLocked + gpLockRestricted // the lock could not be acquired due to a restriction in place +) + +// optionalPolicyLock is a wrapper around [gp.PolicyLock] that locks +// and unlocks the underlying [gp.PolicyLock]. +// +// If the [gp.PolicyLock.Lock] returns [gp.ErrLockRestricted], the error is ignored, +// and calling [optionalPolicyLock.Unlock] is a no-op. +// +// The underlying GP lock is kinda optional: it is safe to read policy settings +// from the Registry without acquiring it, but it is recommended to lock it anyway +// when reading multiple policy settings to avoid potentially inconsistent results. +// +// It is not safe for concurrent use. +type optionalPolicyLock struct { + *gp.PolicyLock + state gpLockState +} + +// Lock acquires the underlying [gp.PolicyLock], returning an error on failure. +// If the lock cannot be acquired due to a restriction in place +// (e.g., attempting to acquire a lock while the service is starting), +// the lock is considered to be held, the method returns nil, and a subsequent +// call to [Unlock] is a no-op. +// It is a runtime error to call Lock when the lock is already held. +func (o *optionalPolicyLock) Lock() error { + if o.state != gpUnlocked { + panic("already locked") + } + switch err := o.PolicyLock.Lock(); err { + case nil: + o.state = gpLocked + return nil + case gp.ErrLockRestricted: + loggerx.Errorf("GP lock not acquired: %v", err) + o.state = gpLockRestricted + return nil + default: + return err + } +} + +// Unlock releases the underlying [gp.PolicyLock], if it was previously acquired. +// It is a runtime error to call Unlock when the lock is not held. +func (o *optionalPolicyLock) Unlock() { + switch o.state { + case gpLocked: + o.PolicyLock.Unlock() + case gpLockRestricted: + // The GP lock wasn't acquired due to a restriction in place + // when [optionalPolicyLock.Lock] was called. Unlock is a no-op. + case gpUnlocked: + panic("not locked") + default: + panic("unreachable") + } + + o.state = gpUnlocked +} diff --git a/util/syspolicy/source/test_store.go b/util/syspolicy/source/test_store.go index 1f19bbb4386b9..4b175611fef0d 100644 --- a/util/syspolicy/source/test_store.go +++ b/util/syspolicy/source/test_store.go @@ -11,8 +11,9 @@ import ( xmaps "golang.org/x/exp/maps" "tailscale.com/util/mak" "tailscale.com/util/set" - "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/slicesx" "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/testenv" ) var ( @@ -78,7 +79,7 @@ func (r TestExpectedReads) operation() testReadOperation { // TestStore is a [Store] that can be used in tests. type TestStore struct { - tb internal.TB + tb testenv.TB done chan struct{} @@ -97,7 +98,7 @@ type TestStore struct { // NewTestStore returns a new [TestStore]. // The tb will be used to report coding errors detected by the [TestStore]. -func NewTestStore(tb internal.TB) *TestStore { +func NewTestStore(tb testenv.TB) *TestStore { m := make(map[setting.Key]any) store := &TestStore{ tb: tb, @@ -111,7 +112,7 @@ func NewTestStore(tb internal.TB) *TestStore { // NewTestStoreOf is a shorthand for [NewTestStore] followed by [TestStore.SetBooleans], // [TestStore.SetUInt64s], [TestStore.SetStrings] or [TestStore.SetStringLists]. -func NewTestStoreOf[T TestValueType](tb internal.TB, settings ...TestSetting[T]) *TestStore { +func NewTestStoreOf[T TestValueType](tb testenv.TB, settings ...TestSetting[T]) *TestStore { store := NewTestStore(tb) switch settings := any(settings).(type) { case []TestSetting[bool]: @@ -418,7 +419,7 @@ func (s *TestStore) NotifyPolicyChanged() { s.mu.RUnlock() return } - cbs := xmaps.Values(s.cbs) + cbs := slicesx.MapValues(s.cbs) s.mu.RUnlock() var wg sync.WaitGroup diff --git a/util/syspolicy/syspolicy.go b/util/syspolicy/syspolicy.go index abe42ed90f8c7..afcc28ff1fd86 100644 --- a/util/syspolicy/syspolicy.go +++ b/util/syspolicy/syspolicy.go @@ -1,51 +1,83 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -// Package syspolicy provides functions to retrieve system settings of a device. +// Package syspolicy facilitates retrieval of the current policy settings +// applied to the device or user and receiving notifications when the policy +// changes. +// +// It provides functions that return specific policy settings by their unique +// [setting.Key]s, such as [GetBoolean], [GetUint64], [GetString], +// [GetStringArray], [GetPreferenceOption], [GetVisibility] and [GetDuration]. package syspolicy import ( "errors" + "fmt" + "reflect" "time" "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/rsop" "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" + "tailscale.com/util/testenv" ) -func GetString(key Key, defaultValue string) (string, error) { - markHandlerInUse() - v, err := handler.ReadString(string(key)) - if errors.Is(err, ErrNoSuchKey) { - return defaultValue, nil +var ( + // ErrNotConfigured is returned when the requested policy setting is not configured. + ErrNotConfigured = setting.ErrNotConfigured + // ErrTypeMismatch is returned when there's a type mismatch between the actual type + // of the setting value and the expected type. + ErrTypeMismatch = setting.ErrTypeMismatch + // ErrNoSuchKey is returned by [setting.DefinitionOf] when no policy setting + // has been registered with the specified key. + // + // This error is also returned by a (now deprecated) [Handler] when the specified + // key does not have a value set. While the package maintains compatibility with this + // usage of ErrNoSuchKey, it is recommended to return [ErrNotConfigured] from newer + // [source.Store] implementations. + ErrNoSuchKey = setting.ErrNoSuchKey +) + +// RegisterStore registers a new policy [source.Store] with the specified name and [setting.PolicyScope]. +// +// It is a shorthand for [rsop.RegisterStore]. +func RegisterStore(name string, scope setting.PolicyScope, store source.Store) (*rsop.StoreRegistration, error) { + return rsop.RegisterStore(name, scope, store) +} + +// MustRegisterStoreForTest is like [rsop.RegisterStoreForTest], but it fails the test if the store could not be registered. +func MustRegisterStoreForTest(tb testenv.TB, name string, scope setting.PolicyScope, store source.Store) *rsop.StoreRegistration { + tb.Helper() + reg, err := rsop.RegisterStoreForTest(tb, name, scope, store) + if err != nil { + tb.Fatalf("Failed to register policy store %q as a %v policy source: %v", name, scope, err) } - return v, err + return reg } +// GetString returns a string policy setting with the specified key, +// or defaultValue if it does not exist. +func GetString(key Key, defaultValue string) (string, error) { + return getCurrentPolicySettingValue(key, defaultValue) +} + +// GetUint64 returns a numeric policy setting with the specified key, +// or defaultValue if it does not exist. func GetUint64(key Key, defaultValue uint64) (uint64, error) { - markHandlerInUse() - v, err := handler.ReadUInt64(string(key)) - if errors.Is(err, ErrNoSuchKey) { - return defaultValue, nil - } - return v, err + return getCurrentPolicySettingValue(key, defaultValue) } +// GetBoolean returns a boolean policy setting with the specified key, +// or defaultValue if it does not exist. func GetBoolean(key Key, defaultValue bool) (bool, error) { - markHandlerInUse() - v, err := handler.ReadBoolean(string(key)) - if errors.Is(err, ErrNoSuchKey) { - return defaultValue, nil - } - return v, err + return getCurrentPolicySettingValue(key, defaultValue) } +// GetStringArray returns a multi-string policy setting with the specified key, +// or defaultValue if it does not exist. func GetStringArray(key Key, defaultValue []string) ([]string, error) { - markHandlerInUse() - v, err := handler.ReadStringArray(string(key)) - if errors.Is(err, ErrNoSuchKey) { - return defaultValue, nil - } - return v, err + return getCurrentPolicySettingValue(key, defaultValue) } // GetPreferenceOption loads a policy from the registry that can be @@ -55,13 +87,14 @@ func GetStringArray(key Key, defaultValue []string) ([]string, error) { // "always" and "never" remove the user's ability to make a selection. If not // present or set to a different value, "user-decides" is the default. func GetPreferenceOption(name Key) (setting.PreferenceOption, error) { - s, err := GetString(name, "user-decides") - if err != nil { - return setting.ShowChoiceByPolicy, err - } - var opt setting.PreferenceOption - err = opt.UnmarshalText([]byte(s)) - return opt, err + return getCurrentPolicySettingValue(name, setting.ShowChoiceByPolicy) +} + +// GetPreferenceOptionOrDefault is like [GetPreferenceOption], but allows +// specifying a default value to return if the policy setting is not configured. +// It can be used in situations where "user-decides" is not the default. +func GetPreferenceOptionOrDefault(name Key, defaultValue setting.PreferenceOption) (setting.PreferenceOption, error) { + return getCurrentPolicySettingValue(name, defaultValue) } // GetVisibility loads a policy from the registry that can be managed @@ -70,13 +103,7 @@ func GetPreferenceOption(name Key) (setting.PreferenceOption, error) { // true) or "hide" (return true). If not present or set to a different value, // "show" (return false) is the default. func GetVisibility(name Key) (setting.Visibility, error) { - s, err := GetString(name, "show") - if err != nil { - return setting.VisibleByPolicy, err - } - var visibility setting.Visibility - visibility.UnmarshalText([]byte(s)) - return visibility, nil + return getCurrentPolicySettingValue(name, setting.VisibleByPolicy) } // GetDuration loads a policy from the registry that can be managed @@ -85,15 +112,58 @@ func GetVisibility(name Key) (setting.Visibility, error) { // understands. If the registry value is "" or can not be processed, // defaultValue is returned instead. func GetDuration(name Key, defaultValue time.Duration) (time.Duration, error) { - opt, err := GetString(name, "") - if opt == "" || err != nil { - return defaultValue, err + d, err := getCurrentPolicySettingValue(name, defaultValue) + if err != nil { + return d, err } - v, err := time.ParseDuration(opt) - if err != nil || v < 0 { + if d < 0 { return defaultValue, nil } - return v, nil + return d, nil +} + +// RegisterChangeCallback adds a function that will be called whenever the effective policy +// for the default scope changes. The returned function can be used to unregister the callback. +func RegisterChangeCallback(cb rsop.PolicyChangeCallback) (unregister func(), err error) { + effective, err := rsop.PolicyFor(setting.DefaultScope()) + if err != nil { + return nil, err + } + return effective.RegisterChangeCallback(cb), nil +} + +// getCurrentPolicySettingValue returns the value of the policy setting +// specified by its key from the [rsop.Policy] of the [setting.DefaultScope]. It +// returns def if the policy setting is not configured, or an error if it has +// an error or could not be converted to the specified type T. +func getCurrentPolicySettingValue[T setting.ValueType](key Key, def T) (T, error) { + effective, err := rsop.PolicyFor(setting.DefaultScope()) + if err != nil { + return def, err + } + value, err := effective.Get().GetErr(key) + if err != nil { + if errors.Is(err, setting.ErrNotConfigured) || errors.Is(err, setting.ErrNoSuchKey) { + return def, nil + } + return def, err + } + if res, ok := value.(T); ok { + return res, nil + } + return convertPolicySettingValueTo(value, def) +} + +func convertPolicySettingValueTo[T setting.ValueType](value any, def T) (T, error) { + // Convert [PreferenceOption], [Visibility], or [time.Duration] back to a string + // if someone requests a string instead of the actual setting's value. + // TODO(nickkhyl): check if this behavior is relied upon anywhere besides the old tests. + if reflect.TypeFor[T]().Kind() == reflect.String { + if str, ok := value.(fmt.Stringer); ok { + return any(str.String()).(T), nil + } + } + return def, fmt.Errorf("%w: got %T, want %T", setting.ErrTypeMismatch, value, def) } // SelectControlURL returns the ControlURL to use based on a value in diff --git a/util/syspolicy/syspolicy_test.go b/util/syspolicy/syspolicy_test.go index 8280aa1dfbdac..fc01f364597c1 100644 --- a/util/syspolicy/syspolicy_test.go +++ b/util/syspolicy/syspolicy_test.go @@ -9,57 +9,16 @@ import ( "testing" "time" + "tailscale.com/types/logger" + "tailscale.com/util/syspolicy/internal/loggerx" + "tailscale.com/util/syspolicy/internal/metrics" "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" + "tailscale.com/util/testenv" ) -// testHandler encompasses all data types returned when testing any of the syspolicy -// methods that involve getting a policy value. -// For keys and the corresponding values, check policy_keys.go. -type testHandler struct { - t *testing.T - key Key - s string - u64 uint64 - b bool - sArr []string - err error - calls int // used for testing reads from cache vs. handler -} - var someOtherError = errors.New("error other than not found") -func (th *testHandler) ReadString(key string) (string, error) { - if key != string(th.key) { - th.t.Errorf("ReadString(%q) want %q", key, th.key) - } - th.calls++ - return th.s, th.err -} - -func (th *testHandler) ReadUInt64(key string) (uint64, error) { - if key != string(th.key) { - th.t.Errorf("ReadUint64(%q) want %q", key, th.key) - } - th.calls++ - return th.u64, th.err -} - -func (th *testHandler) ReadBoolean(key string) (bool, error) { - if key != string(th.key) { - th.t.Errorf("ReadBool(%q) want %q", key, th.key) - } - th.calls++ - return th.b, th.err -} - -func (th *testHandler) ReadStringArray(key string) ([]string, error) { - if key != string(th.key) { - th.t.Errorf("ReadStringArray(%q) want %q", key, th.key) - } - th.calls++ - return th.sArr, th.err -} - func TestGetString(t *testing.T) { tests := []struct { name string @@ -69,23 +28,28 @@ func TestGetString(t *testing.T) { defaultValue string wantValue string wantError error + wantMetrics []metrics.TestState }{ { name: "read existing value", key: AdminConsoleVisibility, handlerValue: "hide", wantValue: "hide", + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AdminConsole", Value: 1}, + }, }, { name: "read non-existing value", key: EnableServerMode, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantError: nil, }, { name: "read non-existing value, non-blank default", key: EnableServerMode, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, defaultValue: "test", wantValue: "test", wantError: nil, @@ -95,24 +59,43 @@ func TestGetString(t *testing.T) { key: NetworkDevicesVisibility, handlerError: someOtherError, wantError: someOtherError, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_NetworkDevices_error", Value: 1}, + }, }, } + RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - s: tt.handlerValue, - err: tt.handlerError, - }) + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[string]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + value, err := GetString(tt.key, tt.defaultValue) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if value != tt.wantValue { t.Errorf("value=%v, want %v", value, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -129,7 +112,7 @@ func TestGetUint64(t *testing.T) { }{ { name: "read existing value", - key: KeyExpirationNoticeTime, + key: LogSCMInteractions, handlerValue: 1, wantValue: 1, }, @@ -137,14 +120,14 @@ func TestGetUint64(t *testing.T) { name: "read non-existing value", key: LogSCMInteractions, handlerValue: 0, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: 0, }, { name: "read non-existing value, non-zero default", key: LogSCMInteractions, defaultValue: 2, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: 2, }, { @@ -157,14 +140,23 @@ func TestGetUint64(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - u64: tt.handlerValue, - err: tt.handlerError, - }) + // None of the policy settings tested here are integers. + // In fact, we don't have any integer policies as of 2024-10-08. + // However, we can register each of them as an integer policy setting + // for the duration of the test, providing us with something to test against. + if err := setting.SetDefinitionsForTest(t, setting.NewDefinition(tt.key, setting.DeviceSetting, setting.IntegerValue)); err != nil { + t.Fatalf("SetDefinitionsForTest failed: %v", err) + } + + s := source.TestSetting[uint64]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + value, err := GetUint64(tt.key, tt.defaultValue) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if value != tt.wantValue { @@ -183,45 +175,69 @@ func TestGetBoolean(t *testing.T) { defaultValue bool wantValue bool wantError error + wantMetrics []metrics.TestState }{ { name: "read existing value", key: FlushDNSOnSessionUnlock, handlerValue: true, wantValue: true, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_FlushDNSOnSessionUnlock", Value: 1}, + }, }, { name: "read non-existing value", key: LogSCMInteractions, handlerValue: false, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: false, }, { name: "reading value returns other error", key: FlushDNSOnSessionUnlock, handlerError: someOtherError, - wantError: someOtherError, + wantError: someOtherError, // expect error... defaultValue: true, - wantValue: false, + wantValue: true, // ...AND default value if the handler fails. + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_FlushDNSOnSessionUnlock_error", Value: 1}, + }, }, } + RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - b: tt.handlerValue, - err: tt.handlerError, - }) + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[bool]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + value, err := GetBoolean(tt.key, tt.defaultValue) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if value != tt.wantValue { t.Errorf("value=%v, want %v", value, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -234,29 +250,42 @@ func TestGetPreferenceOption(t *testing.T) { handlerError error wantValue setting.PreferenceOption wantError error + wantMetrics []metrics.TestState }{ { name: "always by policy", key: EnableIncomingConnections, handlerValue: "always", wantValue: setting.AlwaysByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AllowIncomingConnections", Value: 1}, + }, }, { name: "never by policy", key: EnableIncomingConnections, handlerValue: "never", wantValue: setting.NeverByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AllowIncomingConnections", Value: 1}, + }, }, { name: "use default", key: EnableIncomingConnections, handlerValue: "", wantValue: setting.ShowChoiceByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AllowIncomingConnections", Value: 1}, + }, }, { name: "read non-existing value", key: EnableIncomingConnections, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: setting.ShowChoiceByPolicy, }, { @@ -265,24 +294,43 @@ func TestGetPreferenceOption(t *testing.T) { handlerError: someOtherError, wantValue: setting.ShowChoiceByPolicy, wantError: someOtherError, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_AllowIncomingConnections_error", Value: 1}, + }, }, } + RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - s: tt.handlerValue, - err: tt.handlerError, - }) + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[string]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + option, err := GetPreferenceOption(tt.key) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if option != tt.wantValue { t.Errorf("option=%v, want %v", option, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -295,24 +343,33 @@ func TestGetVisibility(t *testing.T) { handlerError error wantValue setting.Visibility wantError error + wantMetrics []metrics.TestState }{ { name: "hidden by policy", key: AdminConsoleVisibility, handlerValue: "hide", wantValue: setting.HiddenByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AdminConsole", Value: 1}, + }, }, { name: "visibility default", key: AdminConsoleVisibility, handlerValue: "show", wantValue: setting.VisibleByPolicy, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AdminConsole", Value: 1}, + }, }, { name: "read non-existing value", key: AdminConsoleVisibility, handlerValue: "show", - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: setting.VisibleByPolicy, }, { @@ -322,24 +379,43 @@ func TestGetVisibility(t *testing.T) { handlerError: someOtherError, wantValue: setting.VisibleByPolicy, wantError: someOtherError, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_AdminConsole_error", Value: 1}, + }, }, } + RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - s: tt.handlerValue, - err: tt.handlerError, - }) + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[string]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + visibility, err := GetVisibility(tt.key) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if visibility != tt.wantValue { t.Errorf("visibility=%v, want %v", visibility, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -353,6 +429,7 @@ func TestGetDuration(t *testing.T) { defaultValue time.Duration wantValue time.Duration wantError error + wantMetrics []metrics.TestState }{ { name: "read existing value", @@ -360,25 +437,34 @@ func TestGetDuration(t *testing.T) { handlerValue: "2h", wantValue: 2 * time.Hour, defaultValue: 24 * time.Hour, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_KeyExpirationNotice", Value: 1}, + }, }, { name: "invalid duration value", key: KeyExpirationNoticeTime, handlerValue: "-20", wantValue: 24 * time.Hour, + wantError: errors.New(`time: missing unit in duration "-20"`), defaultValue: 24 * time.Hour, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1}, + }, }, { name: "read non-existing value", key: KeyExpirationNoticeTime, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: 24 * time.Hour, defaultValue: 24 * time.Hour, }, { name: "read non-existing value different default", key: KeyExpirationNoticeTime, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantValue: 0 * time.Second, defaultValue: 0 * time.Second, }, @@ -389,24 +475,43 @@ func TestGetDuration(t *testing.T) { wantValue: 24 * time.Hour, wantError: someOtherError, defaultValue: 24 * time.Hour, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_KeyExpirationNotice_error", Value: 1}, + }, }, } + RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - s: tt.handlerValue, - err: tt.handlerError, - }) + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[string]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + duration, err := GetDuration(tt.key, tt.defaultValue) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if duration != tt.wantValue { t.Errorf("duration=%v, want %v", duration, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } @@ -420,23 +525,28 @@ func TestGetStringArray(t *testing.T) { defaultValue []string wantValue []string wantError error + wantMetrics []metrics.TestState }{ { name: "read existing value", key: AllowedSuggestedExitNodes, handlerValue: []string{"foo", "bar"}, wantValue: []string{"foo", "bar"}, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_any", Value: 1}, + {Name: "$os_syspolicy_AllowedSuggestedExitNodes", Value: 1}, + }, }, { name: "read non-existing value", key: AllowedSuggestedExitNodes, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, wantError: nil, }, { name: "read non-existing value, non nil default", key: AllowedSuggestedExitNodes, - handlerError: ErrNoSuchKey, + handlerError: ErrNotConfigured, defaultValue: []string{"foo", "bar"}, wantValue: []string{"foo", "bar"}, wantError: nil, @@ -446,28 +556,68 @@ func TestGetStringArray(t *testing.T) { key: AllowedSuggestedExitNodes, handlerError: someOtherError, wantError: someOtherError, + wantMetrics: []metrics.TestState{ + {Name: "$os_syspolicy_errors", Value: 1}, + {Name: "$os_syspolicy_AllowedSuggestedExitNodes_error", Value: 1}, + }, }, } + RegisterWellKnownSettingsForTest(t) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - SetHandlerForTest(t, &testHandler{ - t: t, - key: tt.key, - sArr: tt.handlerValue, - err: tt.handlerError, - }) + h := metrics.NewTestHandler(t) + metrics.SetHooksForTest(t, h.AddMetric, h.SetMetric) + + s := source.TestSetting[[]string]{ + Key: tt.key, + Value: tt.handlerValue, + Error: tt.handlerError, + } + registerSingleSettingStoreForTest(t, s) + value, err := GetStringArray(tt.key, tt.defaultValue) - if err != tt.wantError { + if !errorsMatchForTest(err, tt.wantError) { t.Errorf("err=%q, want %q", err, tt.wantError) } if !slices.Equal(tt.wantValue, value) { t.Errorf("value=%v, want %v", value, tt.wantValue) } + wantMetrics := tt.wantMetrics + if !metrics.ShouldReport() { + // Check that metrics are not reported on platforms + // where they shouldn't be reported. + // As of 2024-09-04, syspolicy only reports metrics + // on Windows and Android. + wantMetrics = nil + } + h.MustEqual(wantMetrics...) }) } } +func registerSingleSettingStoreForTest[T source.TestValueType](tb testenv.TB, s source.TestSetting[T]) { + policyStore := source.NewTestStoreOf(tb, s) + MustRegisterStoreForTest(tb, "TestStore", setting.DeviceScope, policyStore) +} + +func BenchmarkGetString(b *testing.B) { + loggerx.SetForTest(b, logger.Discard, logger.Discard) + RegisterWellKnownSettingsForTest(b) + + wantControlURL := "https://login.tailscale.com" + registerSingleSettingStoreForTest(b, source.TestSettingOf(ControlURL, wantControlURL)) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + gotControlURL, _ := GetString(ControlURL, "https://controlplane.tailscale.com") + if gotControlURL != wantControlURL { + b.Fatalf("got %v; want %v", gotControlURL, wantControlURL) + } + } +} + func TestSelectControlURL(t *testing.T) { tests := []struct { reg, disk, want string @@ -499,3 +649,13 @@ func TestSelectControlURL(t *testing.T) { } } } + +func errorsMatchForTest(got, want error) bool { + if got == nil && want == nil { + return true + } + if got == nil || want == nil { + return false + } + return errors.Is(got, want) || got.Error() == want.Error() +} diff --git a/util/syspolicy/syspolicy_windows.go b/util/syspolicy/syspolicy_windows.go new file mode 100644 index 0000000000000..ca0fd329aca04 --- /dev/null +++ b/util/syspolicy/syspolicy_windows.go @@ -0,0 +1,92 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package syspolicy + +import ( + "errors" + "fmt" + "os/user" + + "tailscale.com/util/syspolicy/internal" + "tailscale.com/util/syspolicy/rsop" + "tailscale.com/util/syspolicy/setting" + "tailscale.com/util/syspolicy/source" + "tailscale.com/util/testenv" +) + +func init() { + // On Windows, we should automatically register the Registry-based policy + // store for the device. If we are running in a user's security context + // (e.g., we're the GUI), we should also register the Registry policy store for + // the user. In the future, we should register (and unregister) user policy + // stores whenever a user connects to (or disconnects from) the local backend. + // This ensures the backend is aware of the user's policy settings and can send + // them to the GUI/CLI/Web clients on demand or whenever they change. + // + // Other platforms, such as macOS, iOS and Android, should register their + // platform-specific policy stores via [RegisterStore] + // (or [RegisterHandler] until they implement the [source.Store] interface). + // + // External code, such as the ipnlocal package, may choose to register + // additional policy stores, such as config files and policies received from + // the control plane. + internal.Init.MustDefer(func() error { + // Do not register or use default policy stores during tests. + // Each test should set up its own necessary configurations. + if testenv.InTest() { + return nil + } + return configureSyspolicy(nil) + }) +} + +// configureSyspolicy configures syspolicy for use on Windows, +// either in test or regular builds depending on whether tb has a non-nil value. +func configureSyspolicy(tb testenv.TB) error { + const localSystemSID = "S-1-5-18" + // Always create and register a machine policy store that reads + // policy settings from the HKEY_LOCAL_MACHINE registry hive. + machineStore, err := source.NewMachinePlatformPolicyStore() + if err != nil { + return fmt.Errorf("failed to create the machine policy store: %v", err) + } + if tb == nil { + _, err = rsop.RegisterStore("Platform", setting.DeviceScope, machineStore) + } else { + _, err = rsop.RegisterStoreForTest(tb, "Platform", setting.DeviceScope, machineStore) + } + if err != nil { + return err + } + // Check whether the current process is running as Local System or not. + u, err := user.Current() + if err != nil { + return err + } + if u.Uid == localSystemSID { + return nil + } + // If it's not a Local System's process (e.g., it's the GUI rather than the tailscaled service), + // we should create and use a policy store for the current user that reads + // policy settings from that user's registry hive (HKEY_CURRENT_USER). + userStore, err := source.NewUserPlatformPolicyStore(0) + if err != nil { + return fmt.Errorf("failed to create the current user's policy store: %v", err) + } + if tb == nil { + _, err = rsop.RegisterStore("Platform", setting.CurrentUserScope, userStore) + } else { + _, err = rsop.RegisterStoreForTest(tb, "Platform", setting.CurrentUserScope, userStore) + } + if err != nil { + return err + } + // And also set [setting.CurrentUserScope] as the [setting.DefaultScope], so [GetString], + // [GetVisibility] and similar functions would be returning a merged result + // of the machine's and user's policies. + if !setting.SetDefaultScope(setting.CurrentUserScope) { + return errors.New("current scope already set") + } + return nil +} diff --git a/util/systemd/systemd_linux.go b/util/systemd/systemd_linux.go index 909cfcb20ac6e..fdfd1bba05451 100644 --- a/util/systemd/systemd_linux.go +++ b/util/systemd/systemd_linux.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build linux +//go:build linux && !android package systemd diff --git a/util/systemd/systemd_nonlinux.go b/util/systemd/systemd_nonlinux.go index 36214020ce566..5d7772bb3e61f 100644 --- a/util/systemd/systemd_nonlinux.go +++ b/util/systemd/systemd_nonlinux.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !linux +//go:build !linux || android package systemd diff --git a/util/testenv/testenv.go b/util/testenv/testenv.go index 12ada9003052b..aa6660411c91b 100644 --- a/util/testenv/testenv.go +++ b/util/testenv/testenv.go @@ -6,6 +6,7 @@ package testenv import ( + "context" "flag" "tailscale.com/types/lazy" @@ -19,3 +20,48 @@ func InTest() bool { return flag.Lookup("test.v") != nil }) } + +// TB is testing.TB, to avoid importing "testing" in non-test code. +type TB interface { + Cleanup(func()) + Error(args ...any) + Errorf(format string, args ...any) + Fail() + FailNow() + Failed() bool + Fatal(args ...any) + Fatalf(format string, args ...any) + Helper() + Log(args ...any) + Logf(format string, args ...any) + Name() string + Setenv(key, value string) + Chdir(dir string) + Skip(args ...any) + SkipNow() + Skipf(format string, args ...any) + Skipped() bool + TempDir() string + Context() context.Context +} + +// InParallelTest reports whether t is running as a parallel test. +// +// Use of this function taints t such that its Parallel method (assuming t is an +// actual *testing.T) will panic if called after this function. +func InParallelTest(t TB) (isParallel bool) { + defer func() { + if r := recover(); r != nil { + isParallel = true + } + }() + t.Chdir(".") // panics in a t.Parallel test + return false +} + +// AssertInTest panics if called outside of a test binary. +func AssertInTest() { + if !InTest() { + panic("func called outside of test binary") + } +} diff --git a/util/testenv/testenv_test.go b/util/testenv/testenv_test.go index 43c332b26a5a1..c647d9aec1ea4 100644 --- a/util/testenv/testenv_test.go +++ b/util/testenv/testenv_test.go @@ -16,3 +16,16 @@ func TestDeps(t *testing.T) { }, }.Check(t) } + +func TestInParallelTestTrue(t *testing.T) { + t.Parallel() + if !InParallelTest(t) { + t.Fatal("InParallelTest should return true once t.Parallel has been called") + } +} + +func TestInParallelTestFalse(t *testing.T) { + if InParallelTest(t) { + t.Fatal("InParallelTest should return false before t.Parallel has been called") + } +} diff --git a/util/uniq/slice.go b/util/uniq/slice.go deleted file mode 100644 index 4ab933a9d82d1..0000000000000 --- a/util/uniq/slice.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -// Package uniq provides removal of adjacent duplicate elements in slices. -// It is similar to the unix command uniq. -package uniq - -// ModifySlice removes adjacent duplicate elements from the given slice. It -// adjusts the length of the slice appropriately and zeros the tail. -// -// ModifySlice does O(len(*slice)) operations. -func ModifySlice[E comparable](slice *[]E) { - // Remove duplicates - dst := 0 - for i := 1; i < len(*slice); i++ { - if (*slice)[i] == (*slice)[dst] { - continue - } - dst++ - (*slice)[dst] = (*slice)[i] - } - - // Zero out the elements we removed at the end of the slice - end := dst + 1 - var zero E - for i := end; i < len(*slice); i++ { - (*slice)[i] = zero - } - - // Truncate the slice - if end < len(*slice) { - *slice = (*slice)[:end] - } -} - -// ModifySliceFunc is the same as ModifySlice except that it allows using a -// custom comparison function. -// -// eq should report whether the two provided elements are equal. -func ModifySliceFunc[E any](slice *[]E, eq func(i, j E) bool) { - // Remove duplicates - dst := 0 - for i := 1; i < len(*slice); i++ { - if eq((*slice)[dst], (*slice)[i]) { - continue - } - dst++ - (*slice)[dst] = (*slice)[i] - } - - // Zero out the elements we removed at the end of the slice - end := dst + 1 - var zero E - for i := end; i < len(*slice); i++ { - (*slice)[i] = zero - } - - // Truncate the slice - if end < len(*slice) { - *slice = (*slice)[:end] - } -} diff --git a/util/uniq/slice_test.go b/util/uniq/slice_test.go deleted file mode 100644 index 564fc08660332..0000000000000 --- a/util/uniq/slice_test.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package uniq_test - -import ( - "reflect" - "strconv" - "testing" - - "tailscale.com/util/uniq" -) - -func runTests(t *testing.T, cb func(*[]uint32)) { - tests := []struct { - // Use uint32 to be different from an int-typed slice index - in []uint32 - want []uint32 - }{ - {in: []uint32{0, 1, 2}, want: []uint32{0, 1, 2}}, - {in: []uint32{0, 1, 2, 2}, want: []uint32{0, 1, 2}}, - {in: []uint32{0, 0, 1, 2}, want: []uint32{0, 1, 2}}, - {in: []uint32{0, 1, 0, 2}, want: []uint32{0, 1, 0, 2}}, - {in: []uint32{0}, want: []uint32{0}}, - {in: []uint32{0, 0}, want: []uint32{0}}, - {in: []uint32{}, want: []uint32{}}, - } - - for _, test := range tests { - in := make([]uint32, len(test.in)) - copy(in, test.in) - cb(&test.in) - if !reflect.DeepEqual(test.in, test.want) { - t.Errorf("uniq.Slice(%v) = %v, want %v", in, test.in, test.want) - } - start := len(test.in) - test.in = test.in[:cap(test.in)] - for i := start; i < len(in); i++ { - if test.in[i] != 0 { - t.Errorf("uniq.Slice(%v): non-0 in tail of %v at index %v", in, test.in, i) - } - } - } -} - -func TestModifySlice(t *testing.T) { - runTests(t, func(slice *[]uint32) { - uniq.ModifySlice(slice) - }) -} - -func TestModifySliceFunc(t *testing.T) { - runTests(t, func(slice *[]uint32) { - uniq.ModifySliceFunc(slice, func(i, j uint32) bool { - return i == j - }) - }) -} - -func Benchmark(b *testing.B) { - benches := []struct { - name string - reset func(s []byte) - }{ - {name: "AllDups", - reset: func(s []byte) { - for i := range s { - s[i] = '*' - } - }, - }, - {name: "NoDups", - reset: func(s []byte) { - for i := range s { - s[i] = byte(i) - } - }, - }, - } - - for _, bb := range benches { - b.Run(bb.name, func(b *testing.B) { - for size := 1; size <= 4096; size *= 16 { - b.Run(strconv.Itoa(size), func(b *testing.B) { - benchmark(b, 64, bb.reset) - }) - } - }) - } -} - -func benchmark(b *testing.B, size int64, reset func(s []byte)) { - b.ReportAllocs() - b.SetBytes(size) - s := make([]byte, size) - b.ResetTimer() - for range b.N { - s = s[:size] - reset(s) - uniq.ModifySlice(&s) - } -} diff --git a/util/usermetric/metrics.go b/util/usermetric/metrics.go new file mode 100644 index 0000000000000..044b4d65f7120 --- /dev/null +++ b/util/usermetric/metrics.go @@ -0,0 +1,85 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// This file contains user-facing metrics that are used by multiple packages. +// Use it to define more common metrics. Any changes to the registry and +// metric types should be in usermetric.go. + +package usermetric + +import ( + "sync" + + "tailscale.com/metrics" +) + +// Metrics contains user-facing metrics that are used by multiple packages. +type Metrics struct { + initOnce sync.Once + + droppedPacketsInbound *metrics.MultiLabelMap[DropLabels] + droppedPacketsOutbound *metrics.MultiLabelMap[DropLabels] +} + +// DropReason is the reason why a packet was dropped. +type DropReason string + +const ( + // ReasonACL means that the packet was not permitted by ACL. + ReasonACL DropReason = "acl" + + // ReasonMulticast means that the packet was dropped because it was a multicast packet. + ReasonMulticast DropReason = "multicast" + + // ReasonLinkLocalUnicast means that the packet was dropped because it was a link-local unicast packet. + ReasonLinkLocalUnicast DropReason = "link_local_unicast" + + // ReasonTooShort means that the packet was dropped because it was a bad packet, + // this could be due to a short packet. + ReasonTooShort DropReason = "too_short" + + // ReasonFragment means that the packet was dropped because it was an IP fragment. + ReasonFragment DropReason = "fragment" + + // ReasonUnknownProtocol means that the packet was dropped because it was an unknown protocol. + ReasonUnknownProtocol DropReason = "unknown_protocol" + + // ReasonError means that the packet was dropped because of an error. + ReasonError DropReason = "error" +) + +// DropLabels contains common label(s) for dropped packet counters. +type DropLabels struct { + Reason DropReason +} + +// initOnce initializes the common metrics. +func (r *Registry) initOnce() { + r.m.initOnce.Do(func() { + r.m.droppedPacketsInbound = NewMultiLabelMapWithRegistry[DropLabels]( + r, + "tailscaled_inbound_dropped_packets_total", + "counter", + "Counts the number of dropped packets received by the node from other peers", + ) + r.m.droppedPacketsOutbound = NewMultiLabelMapWithRegistry[DropLabels]( + r, + "tailscaled_outbound_dropped_packets_total", + "counter", + "Counts the number of packets dropped while being sent to other peers", + ) + }) +} + +// DroppedPacketsOutbound returns the outbound dropped packet metric, creating it +// if necessary. +func (r *Registry) DroppedPacketsOutbound() *metrics.MultiLabelMap[DropLabels] { + r.initOnce() + return r.m.droppedPacketsOutbound +} + +// DroppedPacketsInbound returns the inbound dropped packet metric. +func (r *Registry) DroppedPacketsInbound() *metrics.MultiLabelMap[DropLabels] { + r.initOnce() + return r.m.droppedPacketsInbound +} diff --git a/util/usermetric/usermetric.go b/util/usermetric/usermetric.go index c964e08a76395..74e9447a64bbb 100644 --- a/util/usermetric/usermetric.go +++ b/util/usermetric/usermetric.go @@ -14,11 +14,15 @@ import ( "tailscale.com/metrics" "tailscale.com/tsweb/varz" + "tailscale.com/util/set" ) // Registry tracks user-facing metrics of various Tailscale subsystems. type Registry struct { vars expvar.Map + + // m contains common metrics owned by the registry. + m Metrics } // NewMultiLabelMapWithRegistry creates and register a new @@ -103,3 +107,13 @@ func (r *Registry) String() string { return sb.String() } + +// Metrics returns the name of all the metrics in the registry. +func (r *Registry) MetricNames() []string { + ret := make(set.Set[string]) + r.vars.Do(func(kv expvar.KeyValue) { + ret.Add(kv.Key) + }) + + return ret.Slice() +} diff --git a/util/winutil/gp/policylock_windows.go b/util/winutil/gp/policylock_windows.go index 95453aa16b110..69c5ff01697f4 100644 --- a/util/winutil/gp/policylock_windows.go +++ b/util/winutil/gp/policylock_windows.go @@ -48,10 +48,35 @@ type policyLockResult struct { } var ( - // ErrInvalidLockState is returned by (*PolicyLock).Lock if the lock has a zero value or has already been closed. + // ErrInvalidLockState is returned by [PolicyLock.Lock] if the lock has a zero value or has already been closed. ErrInvalidLockState = errors.New("the lock has not been created or has already been closed") + // ErrLockRestricted is returned by [PolicyLock.Lock] if the lock cannot be acquired due to a restriction in place, + // such as when [RestrictPolicyLocks] has been called. + ErrLockRestricted = errors.New("the lock cannot be acquired due to a restriction in place") ) +var policyLockRestricted atomic.Int32 + +// RestrictPolicyLocks forces all [PolicyLock.Lock] calls to return [ErrLockRestricted] +// until the returned function is called to remove the restriction. +// +// It is safe to call the returned function multiple times, but the restriction will only +// be removed once. If [RestrictPolicyLocks] is called multiple times, each call must be +// matched by a corresponding call to the returned function to fully remove the restrictions. +// +// It is primarily used to prevent certain deadlocks, such as when tailscaled attempts to acquire +// a policy lock during startup. If the service starts due to Tailscale being installed by GPSI, +// the write lock will be held by the Group Policy service throughout the installation, +// preventing tailscaled from acquiring the read lock. Since Group Policy waits for the installation +// to complete, and therefore for tailscaled to start, before releasing the write lock, this scenario +// would result in a deadlock. See tailscale/tailscale#14416 for more information. +func RestrictPolicyLocks() (removeRestriction func()) { + policyLockRestricted.Add(1) + return sync.OnceFunc(func() { + policyLockRestricted.Add(-1) + }) +} + // NewMachinePolicyLock creates a PolicyLock that facilitates pausing the // application of computer policy. To avoid deadlocks when acquiring both // machine and user locks, acquire the user lock before the machine lock. @@ -103,13 +128,18 @@ func NewUserPolicyLock(token windows.Token) (*PolicyLock, error) { } // Lock locks l. -// It returns ErrNotInitialized if l has a zero value or has already been closed, -// or an Errno if the underlying Group Policy lock cannot be acquired. +// It returns [ErrInvalidLockState] if l has a zero value or has already been closed, +// [ErrLockRestricted] if the lock cannot be acquired due to a restriction in place, +// or a [syscall.Errno] if the underlying Group Policy lock cannot be acquired. // -// As a special case, it fails with windows.ERROR_ACCESS_DENIED +// As a special case, it fails with [windows.ERROR_ACCESS_DENIED] // if l is a user policy lock, and the corresponding user is not logged in // interactively at the time of the call. func (l *PolicyLock) Lock() error { + if policyLockRestricted.Load() > 0 { + return ErrLockRestricted + } + l.mu.Lock() defer l.mu.Unlock() if l.lockCnt.Add(2)&1 == 0 { diff --git a/util/winutil/s4u/s4u_windows.go b/util/winutil/s4u/s4u_windows.go index a12b4786a0d06..8926aaedc5071 100644 --- a/util/winutil/s4u/s4u_windows.go +++ b/util/winutil/s4u/s4u_windows.go @@ -17,6 +17,7 @@ import ( "slices" "strconv" "strings" + "sync" "sync/atomic" "unsafe" @@ -128,9 +129,10 @@ func Login(logf logger.Logf, srcName string, u *user.User, capLevel CapabilityLe if err != nil { return nil, err } + tokenCloseOnce := sync.OnceFunc(func() { token.Close() }) defer func() { if err != nil { - token.Close() + tokenCloseOnce() } }() @@ -162,6 +164,7 @@ func Login(logf logger.Logf, srcName string, u *user.User, capLevel CapabilityLe sessToken.Close() } }() + tokenCloseOnce() } userProfile, err := winutil.LoadUserProfile(sessToken, u) diff --git a/version-embed.go b/version-embed.go index 2d517339d571c..17bf578dd33f1 100644 --- a/version-embed.go +++ b/version-embed.go @@ -26,6 +26,7 @@ var AlpineDockerTag string //go:embed go.toolchain.rev var GoToolchainRev string +//lint:ignore U1000 used by tests + assert_ts_toolchain_match.go w/ right build tags func tailscaleToolchainRev() (gitHash string, ok bool) { bi, ok := debug.ReadBuildInfo() if !ok { diff --git a/version/distro/distro.go b/version/distro/distro.go index 8865a834b97d3..f7997e1d9f81b 100644 --- a/version/distro/distro.go +++ b/version/distro/distro.go @@ -6,13 +6,12 @@ package distro import ( "bytes" - "io" "os" "runtime" "strconv" "tailscale.com/types/lazy" - "tailscale.com/util/lineread" + "tailscale.com/util/lineiter" ) type Distro string @@ -31,6 +30,7 @@ const ( WDMyCloud = Distro("wdmycloud") Unraid = Distro("unraid") Alpine = Distro("alpine") + UBNT = Distro("ubnt") // Ubiquiti Networks ) var distro lazy.SyncValue[Distro] @@ -76,6 +76,12 @@ func linuxDistro() Distro { case have("/usr/local/bin/freenas-debug"): // TrueNAS Scale runs on debian return TrueNAS + case have("/usr/bin/ubnt-device-info"): + // UBNT runs on Debian-based systems. This MUST be checked before Debian. + // + // Currently supported product families: + // - UDM (UniFi Dream Machine, UDM-Pro) + return UBNT case have("/etc/debian_version"): return Debian case have("/etc/arch-release"): @@ -132,18 +138,19 @@ func DSMVersion() int { return v } // But when run from the command line, we have to read it from the file: - lineread.File("/etc/VERSION", func(line []byte) error { + for lr := range lineiter.File("/etc/VERSION") { + line, err := lr.Value() + if err != nil { + break // but otherwise ignore + } line = bytes.TrimSpace(line) if string(line) == `majorversion="7"` { - v = 7 - return io.EOF + return 7 } if string(line) == `majorversion="6"` { - v = 6 - return io.EOF + return 6 } - return nil - }) - return v + } + return 0 }) } diff --git a/version/print.go b/version/print.go index 7d8554279f255..be90432cc85df 100644 --- a/version/print.go +++ b/version/print.go @@ -7,11 +7,10 @@ import ( "fmt" "runtime" "strings" - - "tailscale.com/types/lazy" + "sync" ) -var stringLazy = lazy.SyncFunc(func() string { +var stringLazy = sync.OnceValue(func() string { var ret strings.Builder ret.WriteString(Short()) ret.WriteByte('\n') diff --git a/version/prop.go b/version/prop.go index fee76c65fe0f2..9327e6fe6d0f4 100644 --- a/version/prop.go +++ b/version/prop.go @@ -9,6 +9,7 @@ import ( "runtime" "strconv" "strings" + "sync" "tailscale.com/tailcfg" "tailscale.com/types/lazy" @@ -61,26 +62,21 @@ func IsSandboxedMacOS() bool { // Tailscale for macOS, either the main GUI process (non-sandboxed) or the // system extension (sandboxed). func IsMacSys() bool { - return IsMacSysExt() || IsMacSysApp() + return IsMacSysExt() || IsMacSysGUI() } var isMacSysApp lazy.SyncValue[bool] -// IsMacSysApp reports whether this process is the main, non-sandboxed GUI process +// IsMacSysGUI reports whether this process is the main, non-sandboxed GUI process // that ships with the Standalone variant of Tailscale for macOS. -func IsMacSysApp() bool { +func IsMacSysGUI() bool { if runtime.GOOS != "darwin" { return false } return isMacSysApp.Get(func() bool { - exe, err := os.Executable() - if err != nil { - return false - } - // Check that this is the GUI binary, and it is not sandboxed. The GUI binary - // shipped in the App Store will always have the App Sandbox enabled. - return strings.HasSuffix(exe, "/Contents/MacOS/Tailscale") && !IsMacAppStore() + return strings.Contains(os.Getenv("HOME"), "/Containers/io.tailscale.ipn.macsys/") || + strings.Contains(os.Getenv("XPC_SERVICE_NAME"), "io.tailscale.ipn.macsys") }) } @@ -94,10 +90,6 @@ func IsMacSysExt() bool { return false } return isMacSysExt.Get(func() bool { - if strings.Contains(os.Getenv("HOME"), "/Containers/io.tailscale.ipn.macsys/") || - strings.Contains(os.Getenv("XPC_SERVICE_NAME"), "io.tailscale.ipn.macsys") { - return true - } exe, err := os.Executable() if err != nil { return false @@ -108,8 +100,8 @@ func IsMacSysExt() bool { var isMacAppStore lazy.SyncValue[bool] -// IsMacAppStore whether this binary is from the App Store version of Tailscale -// for macOS. +// IsMacAppStore returns whether this binary is from the App Store version of Tailscale +// for macOS. Returns true for both the network extension and the GUI app. func IsMacAppStore() bool { if runtime.GOOS != "darwin" { return false @@ -123,6 +115,25 @@ func IsMacAppStore() bool { }) } +var isMacAppStoreGUI lazy.SyncValue[bool] + +// IsMacAppStoreGUI reports whether this binary is the GUI app from the App Store +// version of Tailscale for macOS. +func IsMacAppStoreGUI() bool { + if runtime.GOOS != "darwin" { + return false + } + return isMacAppStoreGUI.Get(func() bool { + exe, err := os.Executable() + if err != nil { + return false + } + // Check that this is the GUI binary, and it is not sandboxed. The GUI binary + // shipped in the App Store will always have the App Sandbox enabled. + return strings.Contains(exe, "/Tailscale") && !IsMacSysGUI() + }) +} + var isAppleTV lazy.SyncValue[bool] // IsAppleTV reports whether this binary is part of the Tailscale network extension for tvOS. @@ -174,7 +185,7 @@ func IsUnstableBuild() bool { }) } -var isDev = lazy.SyncFunc(func() bool { +var isDev = sync.OnceValue(func() bool { return strings.Contains(Short(), "-dev") }) diff --git a/version/version.go b/version/version.go index 4b96d15eaa336..2add25689e1dd 100644 --- a/version/version.go +++ b/version/version.go @@ -7,7 +7,9 @@ package version import ( "fmt" "runtime/debug" + "strconv" "strings" + "sync" tailscaleroot "tailscale.com" "tailscale.com/types/lazy" @@ -116,7 +118,7 @@ func (i embeddedInfo) commitAbbrev() string { return i.commit } -var getEmbeddedInfo = lazy.SyncFunc(func() embeddedInfo { +var getEmbeddedInfo = sync.OnceValue(func() embeddedInfo { bi, ok := debug.ReadBuildInfo() if !ok { return embeddedInfo{} @@ -169,3 +171,42 @@ func majorMinorPatch() string { ret, _, _ := strings.Cut(Short(), "-") return ret } + +func isValidLongWithTwoRepos(v string) bool { + s := strings.Split(v, "-") + if len(s) != 3 { + return false + } + hexChunk := func(s string) bool { + if len(s) < 6 { + return false + } + for i := range len(s) { + b := s[i] + if (b < '0' || b > '9') && (b < 'a' || b > 'f') { + return false + } + } + return true + } + + v, t, g := s[0], s[1], s[2] + if !strings.HasPrefix(t, "t") || !strings.HasPrefix(g, "g") || + !hexChunk(t[1:]) || !hexChunk(g[1:]) { + return false + } + nums := strings.Split(v, ".") + if len(nums) != 3 { + return false + } + for i, n := range nums { + bits := 8 + if i == 2 { + bits = 16 + } + if _, err := strconv.ParseUint(n, 10, bits); err != nil { + return false + } + } + return true +} diff --git a/version/version_checkformat.go b/version/version_checkformat.go new file mode 100644 index 0000000000000..05a97d1912dbe --- /dev/null +++ b/version/version_checkformat.go @@ -0,0 +1,17 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build tailscale_go && android + +package version + +import "fmt" + +func init() { + // For official Android builds using the tailscale_go toolchain, + // panic if the builder is screwed up and we fail to stamp a valid + // version string. + if !isValidLongWithTwoRepos(Long()) { + panic(fmt.Sprintf("malformed version.Long value %q", Long())) + } +} diff --git a/version/version_internal_test.go b/version/version_internal_test.go new file mode 100644 index 0000000000000..19aeab44228bd --- /dev/null +++ b/version/version_internal_test.go @@ -0,0 +1,27 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package version + +import "testing" + +func TestIsValidLongWithTwoRepos(t *testing.T) { + tests := []struct { + long string + want bool + }{ + {"1.2.3-t01234abcde-g01234abcde", true}, + {"1.2.259-t01234abcde-g01234abcde", true}, // big patch version + {"1.2.3-t01234abcde", false}, // missing repo + {"1.2.3-g01234abcde", false}, // missing repo + {"-t01234abcde-g01234abcde", false}, + {"1.2.3", false}, + {"1.2.3-t01234abcde-g", false}, + {"1.2.3-t01234abcde-gERRBUILDINFO", false}, + } + for _, tt := range tests { + if got := isValidLongWithTwoRepos(tt.long); got != tt.want { + t.Errorf("IsValidLongWithTwoRepos(%q) = %v; want %v", tt.long, got, tt.want) + } + } +} diff --git a/wgengine/bench/wg.go b/wgengine/bench/wg.go index 45823dd56825f..9b195bdb78fde 100644 --- a/wgengine/bench/wg.go +++ b/wgengine/bench/wg.go @@ -46,7 +46,7 @@ func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netip. logf: logger.WithPrefix(logf, "tun1: "), traf: traf, } - s1 := new(tsd.System) + s1 := tsd.NewSystem() e1, err := wgengine.NewUserspaceEngine(l1, wgengine.Config{ Router: router.NewFake(l1), NetMon: nil, @@ -73,7 +73,7 @@ func setupWGTest(b *testing.B, logf logger.Logf, traf *TrafficGen, a1, a2 netip. logf: logger.WithPrefix(logf, "tun2: "), traf: traf, } - s2 := new(tsd.System) + s2 := tsd.NewSystem() e2, err := wgengine.NewUserspaceEngine(l2, wgengine.Config{ Router: router.NewFake(l2), NetMon: nil, diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index 56224ac5d3fbc..987fcee0153a6 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -24,6 +24,7 @@ import ( "tailscale.com/types/views" "tailscale.com/util/mak" "tailscale.com/util/slicesx" + "tailscale.com/util/usermetric" "tailscale.com/wgengine/filter/filtertype" ) @@ -202,16 +203,17 @@ func New(matches []Match, capTest CapTestFunc, localNets, logIPs *netipx.IPSet, } f := &Filter{ - logf: logf, - matches4: matchesFamily(matches, netip.Addr.Is4), - matches6: matchesFamily(matches, netip.Addr.Is6), - cap4: capMatchesFunc(matches, netip.Addr.Is4), - cap6: capMatchesFunc(matches, netip.Addr.Is6), - local4: ipset.FalseContainsIPFunc(), - local6: ipset.FalseContainsIPFunc(), - logIPs4: ipset.FalseContainsIPFunc(), - logIPs6: ipset.FalseContainsIPFunc(), - state: state, + logf: logf, + matches4: matchesFamily(matches, netip.Addr.Is4), + matches6: matchesFamily(matches, netip.Addr.Is6), + cap4: capMatchesFunc(matches, netip.Addr.Is4), + cap6: capMatchesFunc(matches, netip.Addr.Is6), + local4: ipset.FalseContainsIPFunc(), + local6: ipset.FalseContainsIPFunc(), + logIPs4: ipset.FalseContainsIPFunc(), + logIPs6: ipset.FalseContainsIPFunc(), + state: state, + srcIPHasCap: capTest, } if localNets != nil { p := localNets.Prefixes() @@ -409,7 +411,7 @@ func (f *Filter) ShieldsUp() bool { return f.shieldsUp } // Tailscale peer. func (f *Filter) RunIn(q *packet.Parsed, rf RunFlags) Response { dir := in - r := f.pre(q, rf, dir) + r, _ := f.pre(q, rf, dir) if r == Accept || r == Drop { // already logged return r @@ -430,16 +432,16 @@ func (f *Filter) RunIn(q *packet.Parsed, rf RunFlags) Response { // RunOut determines whether this node is allowed to send q to a // Tailscale peer. -func (f *Filter) RunOut(q *packet.Parsed, rf RunFlags) Response { +func (f *Filter) RunOut(q *packet.Parsed, rf RunFlags) (Response, usermetric.DropReason) { dir := out - r := f.pre(q, rf, dir) + r, reason := f.pre(q, rf, dir) if r == Accept || r == Drop { // already logged - return r + return r, reason } r, why := f.runOut(q) f.logRateLimit(rf, q, dir, r, why) - return r + return r, "" } var unknownProtoStringCache sync.Map // ipproto.Proto -> string @@ -609,33 +611,38 @@ var gcpDNSAddr = netaddr.IPv4(169, 254, 169, 254) // pre runs the direction-agnostic filter logic. dir is only used for // logging. -func (f *Filter) pre(q *packet.Parsed, rf RunFlags, dir direction) Response { +func (f *Filter) pre(q *packet.Parsed, rf RunFlags, dir direction) (Response, usermetric.DropReason) { if len(q.Buffer()) == 0 { // wireguard keepalive packet, always permit. - return Accept + return Accept, "" } if len(q.Buffer()) < 20 { f.logRateLimit(rf, q, dir, Drop, "too short") - return Drop + return Drop, usermetric.ReasonTooShort + } + + if q.IPProto == ipproto.Unknown { + f.logRateLimit(rf, q, dir, Drop, "unknown proto") + return Drop, usermetric.ReasonUnknownProtocol } if q.Dst.Addr().IsMulticast() { f.logRateLimit(rf, q, dir, Drop, "multicast") - return Drop + return Drop, usermetric.ReasonMulticast } if q.Dst.Addr().IsLinkLocalUnicast() && q.Dst.Addr() != gcpDNSAddr { f.logRateLimit(rf, q, dir, Drop, "link-local-unicast") - return Drop + return Drop, usermetric.ReasonLinkLocalUnicast } if q.IPProto == ipproto.Fragment { // Fragments after the first always need to be passed through. // Very small fragments are considered Junk by Parsed. f.logRateLimit(rf, q, dir, Accept, "fragment") - return Accept + return Accept, "" } - return noVerdict + return noVerdict, "" } // loggingAllowed reports whether p can appear in logs at all. diff --git a/wgengine/filter/filter_test.go b/wgengine/filter/filter_test.go index f2796d71f6da7..ae39eeb08692f 100644 --- a/wgengine/filter/filter_test.go +++ b/wgengine/filter/filter_test.go @@ -18,7 +18,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "go4.org/netipx" - xmaps "golang.org/x/exp/maps" "tailscale.com/net/flowtrack" "tailscale.com/net/ipset" "tailscale.com/net/packet" @@ -30,6 +29,8 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/views" "tailscale.com/util/must" + "tailscale.com/util/slicesx" + "tailscale.com/util/usermetric" "tailscale.com/wgengine/filter/filtertype" ) @@ -211,7 +212,7 @@ func TestUDPState(t *testing.T) { t.Fatalf("incoming initial packet not dropped, got=%v: %v", got, a4) } // We talk to that peer - if got := acl.RunOut(&b4, flags); got != Accept { + if got, _ := acl.RunOut(&b4, flags); got != Accept { t.Fatalf("outbound packet didn't egress, got=%v: %v", got, b4) } // Now, the same packet as before is allowed back. @@ -227,7 +228,7 @@ func TestUDPState(t *testing.T) { t.Fatalf("incoming initial packet not dropped: %v", a4) } // We talk to that peer - if got := acl.RunOut(&b6, flags); got != Accept { + if got, _ := acl.RunOut(&b6, flags); got != Accept { t.Fatalf("outbound packet didn't egress: %v", b4) } // Now, the same packet as before is allowed back. @@ -382,25 +383,27 @@ func BenchmarkFilter(b *testing.B) { func TestPreFilter(t *testing.T) { packets := []struct { - desc string - want Response - b []byte + desc string + want Response + wantReason usermetric.DropReason + b []byte }{ - {"empty", Accept, []byte{}}, - {"short", Drop, []byte("short")}, - {"junk", Drop, raw4default(ipproto.Unknown, 10)}, - {"fragment", Accept, raw4default(ipproto.Fragment, 40)}, - {"tcp", noVerdict, raw4default(ipproto.TCP, 0)}, - {"udp", noVerdict, raw4default(ipproto.UDP, 0)}, - {"icmp", noVerdict, raw4default(ipproto.ICMPv4, 0)}, + {"empty", Accept, "", []byte{}}, + {"short", Drop, usermetric.ReasonTooShort, []byte("short")}, + {"short-junk", Drop, usermetric.ReasonTooShort, raw4default(ipproto.Unknown, 10)}, + {"long-junk", Drop, usermetric.ReasonUnknownProtocol, raw4default(ipproto.Unknown, 21)}, + {"fragment", Accept, "", raw4default(ipproto.Fragment, 40)}, + {"tcp", noVerdict, "", raw4default(ipproto.TCP, 0)}, + {"udp", noVerdict, "", raw4default(ipproto.UDP, 0)}, + {"icmp", noVerdict, "", raw4default(ipproto.ICMPv4, 0)}, } f := NewAllowNone(t.Logf, &netipx.IPSet{}) for _, testPacket := range packets { p := &packet.Parsed{} p.Decode(testPacket.b) - got := f.pre(p, LogDrops|LogAccepts, in) - if got != testPacket.want { - t.Errorf("%q got=%v want=%v packet:\n%s", testPacket.desc, got, testPacket.want, packet.Hexdump(testPacket.b)) + got, gotReason := f.pre(p, LogDrops|LogAccepts, in) + if got != testPacket.want || gotReason != testPacket.wantReason { + t.Errorf("%q got=%v want=%v gotReason=%s wantReason=%s packet:\n%s", testPacket.desc, got, testPacket.want, gotReason, testPacket.wantReason, packet.Hexdump(testPacket.b)) } } } @@ -768,7 +771,7 @@ func ports(s string) PortRange { if err != nil { panic(fmt.Sprintf("invalid NetPortRange %q", s)) } - return PortRange{uint16(first), uint16(last)} + return PortRange{First: uint16(first), Last: uint16(last)} } func netports(netPorts ...string) (ret []NetPortRange) { @@ -814,11 +817,11 @@ func TestMatchesFromFilterRules(t *testing.T) { Dsts: []NetPortRange{ { Net: netip.MustParsePrefix("0.0.0.0/0"), - Ports: PortRange{22, 22}, + Ports: PortRange{First: 22, Last: 22}, }, { Net: netip.MustParsePrefix("::0/0"), - Ports: PortRange{22, 22}, + Ports: PortRange{First: 22, Last: 22}, }, }, Srcs: []netip.Prefix{ @@ -848,7 +851,7 @@ func TestMatchesFromFilterRules(t *testing.T) { Dsts: []NetPortRange{ { Net: netip.MustParsePrefix("1.2.0.0/16"), - Ports: PortRange{22, 22}, + Ports: PortRange{First: 22, Last: 22}, }, }, Srcs: []netip.Prefix{ @@ -997,7 +1000,7 @@ func TestPeerCaps(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := xmaps.Keys(filt.CapsWithValues(netip.MustParseAddr(tt.src), netip.MustParseAddr(tt.dst))) + got := slicesx.MapKeys(filt.CapsWithValues(netip.MustParseAddr(tt.src), netip.MustParseAddr(tt.dst))) slices.Sort(got) slices.Sort(tt.want) if !slices.Equal(got, tt.want) { diff --git a/wgengine/magicsock/debughttp.go b/wgengine/magicsock/debughttp.go index 6c07b0d5eaa83..aa109c242e27c 100644 --- a/wgengine/magicsock/debughttp.go +++ b/wgengine/magicsock/debughttp.go @@ -102,8 +102,7 @@ func (c *Conn) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { sort.Slice(ent, func(i, j int) bool { return ent[i].pub.Less(ent[j].pub) }) peers := map[key.NodePublic]tailcfg.NodeView{} - for i := range c.peers.Len() { - p := c.peers.At(i) + for _, p := range c.peers.All() { peers[p.Key()] = p } diff --git a/wgengine/magicsock/derp.go b/wgengine/magicsock/derp.go index bfee02f6e87da..ffdff14a11d47 100644 --- a/wgengine/magicsock/derp.go +++ b/wgengine/magicsock/derp.go @@ -64,10 +64,30 @@ func (c *Conn) removeDerpPeerRoute(peer key.NodePublic, regionID int, dc *derpht // addDerpPeerRoute adds a DERP route entry, noting that peer was seen // on DERP node derpID, at least on the connection identified by dc. // See issue 150 for details. -func (c *Conn) addDerpPeerRoute(peer key.NodePublic, derpID int, dc *derphttp.Client) { +func (c *Conn) addDerpPeerRoute(peer key.NodePublic, regionID int, dc *derphttp.Client) { c.mu.Lock() defer c.mu.Unlock() - mak.Set(&c.derpRoute, peer, derpRoute{derpID, dc}) + mak.Set(&c.derpRoute, peer, derpRoute{regionID, dc}) +} + +// fallbackDERPRegionForPeer returns the DERP region ID we might be able to use +// to contact peer, learned from observing recent DERP traffic from them. +// +// This is used as a fallback when a peer receives a packet from a peer +// over DERP but doesn't known that peer's home DERP or any UDP endpoints. +// This is particularly useful for large one-way nodes (such as hello.ts.net) +// that don't actively reach out to other nodes, so don't need to be told +// the DERP home of peers. They can instead learn the DERP home upon getting the +// first connection. +// +// This can also help nodes from a slow or misbehaving control plane. +func (c *Conn) fallbackDERPRegionForPeer(peer key.NodePublic) (regionID int) { + c.mu.Lock() + defer c.mu.Unlock() + if dr, ok := c.derpRoute[peer]; ok { + return dr.regionID + } + return 0 } // activeDerp contains fields for an active DERP connection. @@ -158,10 +178,10 @@ func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int) } else { connectedToControl = c.health.GetInPollNetMap() } + c.mu.Lock() + myDerp := c.myDerp + c.mu.Unlock() if !connectedToControl { - c.mu.Lock() - myDerp := c.myDerp - c.mu.Unlock() if myDerp != 0 { metricDERPHomeNoChangeNoControl.Add(1) return myDerp @@ -178,6 +198,11 @@ func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int) // one. preferredDERP = c.pickDERPFallback() } + if preferredDERP != myDerp { + c.logf( + "magicsock: home DERP changing from derp-%d [%dms] to derp-%d [%dms]", + c.myDerp, report.RegionLatency[myDerp].Milliseconds(), preferredDERP, report.RegionLatency[preferredDERP].Milliseconds()) + } if !c.setNearestDERP(preferredDERP) { preferredDERP = 0 } @@ -627,7 +652,7 @@ func (c *Conn) runDerpReader(ctx context.Context, regionID int, dc *derphttp.Cli // Do nothing. case derp.PeerGoneReasonNotHere: metricRecvDiscoDERPPeerNotHere.Add(1) - c.logf("[unexpected] magicsock: derp-%d does not know about peer %s, removing route", + c.logf("magicsock: derp-%d does not know about peer %s, removing route", regionID, key.NodePublic(m.Peer).ShortString()) default: metricRecvDiscoDERPPeerGoneUnknown.Add(1) @@ -644,9 +669,10 @@ func (c *Conn) runDerpReader(ctx context.Context, regionID int, dc *derphttp.Cli } type derpWriteRequest struct { - addr netip.AddrPort - pubKey key.NodePublic - b []byte // copied; ownership passed to receiver + addr netip.AddrPort + pubKey key.NodePublic + b []byte // copied; ownership passed to receiver + isDisco bool } // runDerpWriter runs in a goroutine for the life of a DERP @@ -668,7 +694,10 @@ func (c *Conn) runDerpWriter(ctx context.Context, dc *derphttp.Client, ch <-chan if err != nil { c.logf("magicsock: derp.Send(%v): %v", wr.addr, err) metricSendDERPError.Add(1) - } else { + if !wr.isDisco { + c.metrics.outboundPacketsDroppedErrors.Add(1) + } + } else if !wr.isDisco { c.metrics.outboundPacketsDERPTotal.Add(1) c.metrics.outboundBytesDERPTotal.Add(int64(len(wr.b))) } @@ -691,8 +720,6 @@ func (c *connBind) receiveDERP(buffs [][]byte, sizes []int, eps []conn.Endpoint) // No data read occurred. Wait for another packet. continue } - c.metrics.inboundPacketsDERPTotal.Add(1) - c.metrics.inboundBytesDERPTotal.Add(int64(n)) sizes[0] = n eps[0] = ep return 1, nil @@ -732,6 +759,9 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en if stats := c.stats.Load(); stats != nil { stats.UpdateRxPhysical(ep.nodeAddr, ipp, 1, dm.n) } + + c.metrics.inboundPacketsDERPTotal.Add(1) + c.metrics.inboundBytesDERPTotal.Add(int64(n)) return n, ep } diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index ab9f3d47dd033..3788708a8260f 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -9,6 +9,7 @@ import ( "encoding/binary" "errors" "fmt" + "iter" "math" "math/rand/v2" "net" @@ -20,7 +21,6 @@ import ( "sync/atomic" "time" - xmaps "golang.org/x/exp/maps" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" "tailscale.com/disco" @@ -33,6 +33,7 @@ import ( "tailscale.com/types/logger" "tailscale.com/util/mak" "tailscale.com/util/ringbuffer" + "tailscale.com/util/slicesx" ) var mtuProbePingSizesV4 []int @@ -94,6 +95,7 @@ type endpoint struct { expired bool // whether the node has expired isWireguardOnly bool // whether the endpoint is WireGuard only + relayCapable bool // whether the node is capable of speaking via a [tailscale.com/net/udprelay.Server] } func (de *endpoint) setBestAddrLocked(v addrQuality) { @@ -550,7 +552,7 @@ func (de *endpoint) addrForSendLocked(now mono.Time) (udpAddr, derpAddr netip.Ad // addrForWireGuardSendLocked returns the address that should be used for // sending the next packet. If a packet has never or not recently been sent to // the endpoint, then a randomly selected address for the endpoint is returned, -// as well as a bool indiciating that WireGuard discovery pings should be started. +// as well as a bool indicating that WireGuard discovery pings should be started. // If the addresses have latency information available, then the address with the // best latency is used. // @@ -586,7 +588,7 @@ func (de *endpoint) addrForWireGuardSendLocked(now mono.Time) (udpAddr netip.Add needPing := len(de.endpointState) > 1 && now.Sub(oldestPing) > wireguardPingInterval if !udpAddr.IsValid() { - candidates := xmaps.Keys(de.endpointState) + candidates := slicesx.MapKeys(de.endpointState) // Randomly select an address to use until we retrieve latency information // and give it a short trustBestAddrUntil time so we avoid flapping between @@ -818,7 +820,7 @@ func (de *endpoint) heartbeat() { udpAddr, _, _ := de.addrForSendLocked(now) if udpAddr.IsValid() { - // We have a preferred path. Ping that every 2 seconds. + // We have a preferred path. Ping that every 'heartbeatInterval'. de.startDiscoPingLocked(udpAddr, now, pingHeartbeat, 0, nil) } @@ -947,7 +949,15 @@ func (de *endpoint) send(buffs [][]byte) error { de.mu.Unlock() if !udpAddr.IsValid() && !derpAddr.IsValid() { - return errNoUDPOrDERP + // Make a last ditch effort to see if we have a DERP route for them. If + // they contacted us over DERP and we don't know their UDP endpoints or + // their DERP home, we can at least assume they're reachable over the + // DERP they used to contact us. + if rid := de.c.fallbackDERPRegionForPeer(de.publicKey); rid != 0 { + derpAddr = netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, uint16(rid)) + } else { + return errNoUDPOrDERP + } } var err error if udpAddr.IsValid() { @@ -983,7 +993,8 @@ func (de *endpoint) send(buffs [][]byte) error { allOk := true var txBytes int for _, buff := range buffs { - ok, _ := de.c.sendAddr(derpAddr, de.publicKey, buff) + const isDisco = false + ok, _ := de.c.sendAddr(derpAddr, de.publicKey, buff, isDisco) txBytes += len(buff) if !ok { allOk = false @@ -991,7 +1002,7 @@ func (de *endpoint) send(buffs [][]byte) error { } if stats := de.c.stats.Load(); stats != nil { - stats.UpdateTxPhysical(de.nodeAddr, derpAddr, 1, txBytes) + stats.UpdateTxPhysical(de.nodeAddr, derpAddr, len(buffs), txBytes) } if allOk { return nil @@ -1101,7 +1112,7 @@ func (de *endpoint) sendDiscoPing(ep netip.AddrPort, discoKey key.DiscoPublic, t size = min(size, MaxDiscoPingSize) padding := max(size-discoPingSize, 0) - sent, _ := de.c.sendDiscoMessage(ep, de.publicKey, discoKey, &disco.Ping{ + sent, _ := de.c.sendDiscoMessage(ep, virtualNetworkID{}, de.publicKey, discoKey, &disco.Ping{ TxID: [12]byte(txid), NodeKey: de.c.publicKeyAtomic.Load(), Padding: padding, @@ -1239,11 +1250,18 @@ func (de *endpoint) sendDiscoPingsLocked(now mono.Time, sendCallMeMaybe bool) { // sent so our firewall ports are probably open and now // would be a good time for them to connect. go de.c.enqueueCallMeMaybe(derpAddr, de) + + // Schedule allocation of relay endpoints. We make no considerations for + // current relay endpoints or best UDP path state for now, keep it + // simple. + if de.relayCapable { + go de.c.relayManager.allocateAndHandshakeAllServers(de) + } } } // sendWireGuardOnlyPingsLocked evaluates all available addresses for -// a WireGuard only endpoint and initates an ICMP ping for useable +// a WireGuard only endpoint and initiates an ICMP ping for useable // addresses. func (de *endpoint) sendWireGuardOnlyPingsLocked(now mono.Time) { if runtime.GOOS == "js" { @@ -1357,7 +1375,7 @@ func (de *endpoint) updateFromNode(n tailcfg.NodeView, heartbeatDisabled bool, p }) de.resetLocked() } - if n.DERP() == "" { + if n.HomeDERP() == 0 { if de.derpAddr.IsValid() { de.debugUpdates.Add(EndpointChange{ When: time.Now(), @@ -1367,7 +1385,7 @@ func (de *endpoint) updateFromNode(n tailcfg.NodeView, heartbeatDisabled bool, p } de.derpAddr = netip.AddrPort{} } else { - newDerp, _ := netip.ParseAddrPort(n.DERP()) + newDerp := netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, uint16(n.HomeDERP())) if de.derpAddr != newDerp { de.debugUpdates.Add(EndpointChange{ When: time.Now(), @@ -1383,20 +1401,18 @@ func (de *endpoint) updateFromNode(n tailcfg.NodeView, heartbeatDisabled bool, p } func (de *endpoint) setEndpointsLocked(eps interface { - Len() int - At(i int) netip.AddrPort + All() iter.Seq2[int, netip.AddrPort] }) { for _, st := range de.endpointState { st.index = indexSentinelDeleted // assume deleted until updated in next loop } var newIpps []netip.AddrPort - for i := range eps.Len() { + for i, ipp := range eps.All() { if i > math.MaxInt16 { // Seems unlikely. break } - ipp := eps.At(i) if !ipp.IsValid() { de.c.logf("magicsock: bogus netmap endpoint from %v", eps) continue @@ -1464,7 +1480,7 @@ func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort, forRxPingTxID stun.T } } size2 := len(de.endpointState) - de.c.dlogf("[v1] magicsock: disco: addCandidateEndpoint pruned %v candidate set from %v to %v entries", size, size2) + de.c.dlogf("[v1] magicsock: disco: addCandidateEndpoint pruned %v (%s) candidate set from %v to %v entries", de.discoShort(), de.publicKey.ShortString(), size, size2) } return false } @@ -1613,7 +1629,7 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip de.c.logf("magicsock: disco: node %v %v now using %v mtu=%v tx=%x", de.publicKey.ShortString(), de.discoShort(), sp.to, thisPong.wireMTU, m.TxID[:6]) de.debugUpdates.Add(EndpointChange{ When: time.Now(), - What: "handlePingLocked-bestAddr-update", + What: "handlePongConnLocked-bestAddr-update", From: de.bestAddr, To: thisPong, }) @@ -1622,7 +1638,7 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip if de.bestAddr.AddrPort == thisPong.AddrPort { de.debugUpdates.Add(EndpointChange{ When: time.Now(), - What: "handlePingLocked-bestAddr-latency", + What: "handlePongConnLocked-bestAddr-latency", From: de.bestAddr, To: thisPong, }) @@ -1855,6 +1871,7 @@ func (de *endpoint) resetLocked() { } } de.probeUDPLifetime.resetCycleEndpointLocked() + de.c.relayManager.stopWork(de) } func (de *endpoint) numStopAndReset() int64 { diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 72e59a2e72c62..61cdf49543493 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -9,6 +9,7 @@ import ( "bufio" "bytes" "context" + "encoding/binary" "errors" "expvar" "fmt" @@ -17,11 +18,11 @@ import ( "net/netip" "reflect" "runtime" + "slices" "strconv" "strings" "sync" "sync/atomic" - "syscall" "time" "github.com/tailscale/wireguard-go/conn" @@ -56,13 +57,12 @@ import ( "tailscale.com/types/nettype" "tailscale.com/types/views" "tailscale.com/util/clientmetric" + "tailscale.com/util/eventbus" "tailscale.com/util/mak" "tailscale.com/util/ringbuffer" "tailscale.com/util/set" "tailscale.com/util/testenv" - "tailscale.com/util/uniq" "tailscale.com/util/usermetric" - "tailscale.com/wgengine/capture" "tailscale.com/wgengine/wgint" ) @@ -127,6 +127,10 @@ type metrics struct { outboundBytesIPv4Total expvar.Int outboundBytesIPv6Total expvar.Int outboundBytesDERPTotal expvar.Int + + // outboundPacketsDroppedErrors is the total number of outbound packets + // dropped due to errors. + outboundPacketsDroppedErrors expvar.Int } // A Conn routes UDP packets and actively manages a list of its endpoints. @@ -134,6 +138,8 @@ type Conn struct { // This block mirrors the contents and field order of the Options // struct. Initialized once at construction, then constant. + eventBus *eventbus.Bus + eventClient *eventbus.Client logf logger.Logf epFunc func([]tailcfg.Endpoint) derpActiveFunc func() @@ -175,6 +181,10 @@ type Conn struct { // port mappings from NAT devices. portMapper *portmapper.Client + // portMapperLogfUnregister is the function to call to unregister + // the portmapper log limiter. + portMapperLogfUnregister func() + // derpRecvCh is used by receiveDERP to read DERP messages. // It must have buffer size > 0; see issue 3736. derpRecvCh chan derpReadResult @@ -235,7 +245,7 @@ type Conn struct { stats atomic.Pointer[connstats.Statistics] // captureHook, if non-nil, is the pcap logging callback when capturing. - captureHook syncs.AtomicValue[capture.Callback] + captureHook syncs.AtomicValue[packet.CaptureCallback] // discoPrivate is the private naclbox key used for active // discovery traffic. It is always present, and immutable. @@ -307,7 +317,11 @@ type Conn struct { // by node key, node ID, and discovery key. peerMap peerMap - // discoInfo is the state for an active DiscoKey. + // relayManager manages allocation and handshaking of + // [tailscale.com/net/udprelay.Server] endpoints. + relayManager relayManager + + // discoInfo is the state for an active peer DiscoKey. discoInfo map[key.DiscoPublic]*discoInfo // netInfoFunc is a callback that provides a tailcfg.NetInfo when @@ -361,9 +375,9 @@ type Conn struct { // wireguard state by its public key. If nil, it's not used. getPeerByKey func(key.NodePublic) (_ wgint.Peer, ok bool) - // lastEPERMRebind tracks the last time a rebind was performed - // after experiencing a syscall.EPERM. - lastEPERMRebind syncs.AtomicValue[time.Time] + // lastErrRebind tracks the last time a rebind was performed after + // experiencing a write error, and is used to throttle the rate of rebinds. + lastErrRebind syncs.AtomicValue[time.Time] // staticEndpoints are user set endpoints that this node should // advertise amongst its wireguard endpoints. It is user's @@ -395,8 +409,15 @@ func (c *Conn) dlogf(format string, a ...any) { // Options contains options for Listen. type Options struct { - // Logf optionally provides a log function to use. - // Must not be nil. + // EventBus, if non-nil, is used for event publication and subscription by + // each Conn created from these Options. + // + // TODO(creachadair): As of 2025-03-19 this is optional, but is intended to + // become required non-nil. + EventBus *eventbus.Bus + + // Logf provides a log function to use. It must not be nil. + // Use [logger.Discard] to disrcard logs. Logf logger.Logf // Port is the port to listen on. @@ -523,6 +544,7 @@ func NewConn(opts Options) (*Conn, error) { } c := newConn(opts.logf()) + c.eventBus = opts.EventBus c.port.Store(uint32(opts.Port)) c.controlKnobs = opts.ControlKnobs c.epFunc = opts.endpointsFunc() @@ -530,10 +552,47 @@ func NewConn(opts Options) (*Conn, error) { c.idleFunc = opts.IdleFunc c.testOnlyPacketListener = opts.TestOnlyPacketListener c.noteRecvActivity = opts.NoteRecvActivity + + // If an event bus is enabled, subscribe to portmapping changes; otherwise + // use the callback mechanism of portmapper.Client. + // + // TODO(creachadair): Remove the switch once the event bus is mandatory. + onPortMapChanged := c.onPortMapChanged + if c.eventBus != nil { + c.eventClient = c.eventBus.Client("magicsock.Conn") + + pmSub := eventbus.Subscribe[portmapper.Mapping](c.eventClient) + go func() { + defer pmSub.Close() + for { + select { + case <-pmSub.Events(): + c.onPortMapChanged() + case <-pmSub.Done(): + return + } + } + }() + + // Disable the explicit callback from the portmapper, the subscriber handles it. + onPortMapChanged = nil + } + + // Don't log the same log messages possibly every few seconds in our + // portmapper. + portmapperLogf := logger.WithPrefix(c.logf, "portmapper: ") + portmapperLogf, c.portMapperLogfUnregister = netmon.LinkChangeLogLimiter(portmapperLogf, opts.NetMon) portMapOpts := &portmapper.DebugKnobs{ DisableAll: func() bool { return opts.DisablePortMapper || c.onlyTCP443.Load() }, } - c.portMapper = portmapper.NewClient(logger.WithPrefix(c.logf, "portmapper: "), opts.NetMon, portMapOpts, opts.ControlKnobs, c.onPortMapChanged) + c.portMapper = portmapper.NewClient(portmapper.Config{ + EventBus: c.eventBus, + Logf: portmapperLogf, + NetMon: opts.NetMon, + DebugKnobs: portMapOpts, + ControlKnobs: opts.ControlKnobs, + OnChange: onPortMapChanged, + }) c.portMapper.SetGatewayLookupFunc(opts.NetMon.GatewayAndSelfIP) c.netMon = opts.NetMon c.health = opts.HealthTracker @@ -605,6 +664,8 @@ func registerMetrics(reg *usermetric.Registry) *metrics { "counter", "Counts the number of bytes sent to other peers", ) + outboundPacketsDroppedErrors := reg.DroppedPacketsOutbound() + m := new(metrics) // Map clientmetrics to the usermetric counters. @@ -631,6 +692,8 @@ func registerMetrics(reg *usermetric.Registry) *metrics { outboundBytesTotal.Set(pathDirectV6, &m.outboundBytesIPv6Total) outboundBytesTotal.Set(pathDERP, &m.outboundBytesDERPTotal) + outboundPacketsDroppedErrors.Set(usermetric.DropLabels{Reason: usermetric.ReasonError}, &m.outboundPacketsDroppedErrors) + return m } @@ -648,7 +711,7 @@ func deregisterMetrics(m *metrics) { // log debug information into the pcap stream. This function // can be called with a nil argument to uninstall the capture // hook. -func (c *Conn) InstallCaptureHook(cb capture.Callback) { +func (c *Conn) InstallCaptureHook(cb packet.CaptureCallback) { c.captureHook.Store(cb) } @@ -704,7 +767,7 @@ func (c *Conn) updateEndpoints(why string) { c.muCond.Broadcast() }() c.dlogf("[v1] magicsock: starting endpoint update (%s)", why) - if c.noV4Send.Load() && runtime.GOOS != "js" && !c.onlyTCP443.Load() { + if c.noV4Send.Load() && runtime.GOOS != "js" && !c.onlyTCP443.Load() && !hostinfo.IsInVM86() { c.mu.Lock() closed := c.closed c.mu.Unlock() @@ -1112,8 +1175,8 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro // re-run. eps = c.endpointTracker.update(time.Now(), eps) - for i := range c.staticEndpoints.Len() { - addAddr(c.staticEndpoints.At(i), tailcfg.EndpointExplicitConf) + for _, ep := range c.staticEndpoints.All() { + addAddr(ep, tailcfg.EndpointExplicitConf) } if localAddr := c.pconn4.LocalAddr(); localAddr.IP.IsUnspecified() { @@ -1202,8 +1265,13 @@ func (c *Conn) networkDown() bool { return !c.networkUp.Load() } // Send implements conn.Bind. // // See https://pkg.go.dev/golang.zx2c4.com/wireguard/conn#Bind.Send -func (c *Conn) Send(buffs [][]byte, ep conn.Endpoint) error { +func (c *Conn) Send(buffs [][]byte, ep conn.Endpoint) (err error) { n := int64(len(buffs)) + defer func() { + if err != nil { + c.metrics.outboundPacketsDroppedErrors.Add(n) + } + }() metricSendData.Add(n) if c.networkDown() { metricSendDataNetworkDown.Add(n) @@ -1246,7 +1314,7 @@ func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err c.logf("magicsock: %s", errGSO.Error()) err = errGSO.RetryErr } else { - _ = c.maybeRebindOnError(runtime.GOOS, err) + c.maybeRebindOnError(err) } } return err == nil, err @@ -1254,16 +1322,16 @@ func (c *Conn) sendUDPBatch(addr netip.AddrPort, buffs [][]byte) (sent bool, err // sendUDP sends UDP packet b to ipp. // See sendAddr's docs on the return value meanings. -func (c *Conn) sendUDP(ipp netip.AddrPort, b []byte) (sent bool, err error) { +func (c *Conn) sendUDP(ipp netip.AddrPort, b []byte, isDisco bool) (sent bool, err error) { if runtime.GOOS == "js" { return false, errNoUDP } sent, err = c.sendUDPStd(ipp, b) if err != nil { metricSendUDPError.Add(1) - _ = c.maybeRebindOnError(runtime.GOOS, err) + c.maybeRebindOnError(err) } else { - if sent { + if sent && !isDisco { switch { case ipp.Addr().Is4(): c.metrics.outboundPacketsIPv4Total.Add(1) @@ -1277,32 +1345,21 @@ func (c *Conn) sendUDP(ipp netip.AddrPort, b []byte) (sent bool, err error) { return } -// maybeRebindOnError performs a rebind and restun if the error is defined and -// any conditionals are met. -func (c *Conn) maybeRebindOnError(os string, err error) bool { - switch { - case errors.Is(err, syscall.EPERM): - why := "operation-not-permitted-rebind" - switch os { - // We currently will only rebind and restun on a syscall.EPERM if it is experienced - // on a client running darwin. - // TODO(charlotte, raggi): expand os options if required. - case "darwin": - // TODO(charlotte): implement a backoff, so we don't end up in a rebind loop for persistent - // EPERMs. - if c.lastEPERMRebind.Load().Before(time.Now().Add(-5 * time.Second)) { - c.logf("magicsock: performing %q", why) - c.lastEPERMRebind.Store(time.Now()) - c.Rebind() - go c.ReSTUN(why) - return true - } - default: - c.logf("magicsock: not performing %q", why) - return false - } +// maybeRebindOnError performs a rebind and restun if the error is one that is +// known to be healed by a rebind, and the rebind is not throttled. +func (c *Conn) maybeRebindOnError(err error) { + ok, reason := shouldRebind(err) + if !ok { + return + } + + if c.lastErrRebind.Load().Before(time.Now().Add(-5 * time.Second)) { + c.logf("magicsock: performing rebind due to %q", reason) + c.Rebind() + go c.ReSTUN(reason) + } else { + c.logf("magicsock: not performing %q rebind due to throttle", reason) } - return false } // sendUDPNetcheck sends b via UDP to addr. It is used exclusively by netcheck. @@ -1356,9 +1413,9 @@ func (c *Conn) sendUDPStd(addr netip.AddrPort, b []byte) (sent bool, err error) // An example of when they might be different: sending to an // IPv6 address when the local machine doesn't have IPv6 support // returns (false, nil); it's not an error, but nothing was sent. -func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte) (sent bool, err error) { +func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte, isDisco bool) (sent bool, err error) { if addr.Addr() != tailcfg.DerpMagicIPAddr { - return c.sendUDP(addr, b) + return c.sendUDP(addr, b, isDisco) } regionID := int(addr.Port()) @@ -1379,7 +1436,7 @@ func (c *Conn) sendAddr(addr netip.AddrPort, pubKey key.NodePublic, b []byte) (s case <-c.donec: metricSendDERPErrorClosed.Add(1) return false, errConnClosed - case ch <- derpWriteRequest{addr, pubKey, pkt}: + case ch <- derpWriteRequest{addr, pubKey, pkt, isDisco}: metricSendDERPQueued.Add(1) return true, nil default: @@ -1546,28 +1603,96 @@ const ( // speeds. var debugIPv4DiscoPingPenalty = envknob.RegisterDuration("TS_DISCO_PONG_IPV4_DELAY") +// virtualNetworkID is a Geneve header (RFC8926) 3-byte virtual network +// identifier. Its field must only ever be accessed via its methods. +type virtualNetworkID struct { + _vni uint32 +} + +const ( + vniSetMask uint32 = 0xFF000000 + vniGetMask uint32 = ^vniSetMask +) + +// isSet returns true if set() had been called previously, otherwise false. +func (v *virtualNetworkID) isSet() bool { + return v._vni&vniSetMask != 0 +} + +// set sets the provided VNI. If VNI exceeds the 3-byte storage it will be +// clamped. +func (v *virtualNetworkID) set(vni uint32) { + v._vni = vni | vniSetMask +} + +// get returns the VNI value. +func (v *virtualNetworkID) get() uint32 { + return v._vni & vniGetMask +} + // sendDiscoMessage sends discovery message m to dstDisco at dst. // // If dst is a DERP IP:port, then dstKey must be non-zero. // +// If vni.isSet(), the [disco.Message] will be preceded by a Geneve header with +// the VNI field set to the value returned by vni.get(). +// // The dstKey should only be non-zero if the dstDisco key // unambiguously maps to exactly one peer. -func (c *Conn) sendDiscoMessage(dst netip.AddrPort, dstKey key.NodePublic, dstDisco key.DiscoPublic, m disco.Message, logLevel discoLogLevel) (sent bool, err error) { +func (c *Conn) sendDiscoMessage(dst netip.AddrPort, vni virtualNetworkID, dstKey key.NodePublic, dstDisco key.DiscoPublic, m disco.Message, logLevel discoLogLevel) (sent bool, err error) { isDERP := dst.Addr() == tailcfg.DerpMagicIPAddr if _, isPong := m.(*disco.Pong); isPong && !isDERP && dst.Addr().Is4() { time.Sleep(debugIPv4DiscoPingPenalty()) } + isRelayHandshakeMsg := false + switch m.(type) { + case *disco.BindUDPRelayEndpoint, *disco.BindUDPRelayEndpointAnswer: + isRelayHandshakeMsg = true + } + c.mu.Lock() if c.closed { c.mu.Unlock() return false, errConnClosed } + var di *discoInfo + switch { + case isRelayHandshakeMsg: + var ok bool + di, ok = c.relayManager.discoInfo(dstDisco) + if !ok { + c.mu.Unlock() + return false, errors.New("unknown relay server") + } + case c.peerMap.knownPeerDiscoKey(dstDisco): + di = c.discoInfoForKnownPeerLocked(dstDisco) + default: + // This is an attempt to send to an unknown peer that is not a relay + // server. This can happen when a call to the current function, which is + // often via a new goroutine, races with applying a change in the + // netmap, e.g. the associated peer(s) for dstDisco goes away. + c.mu.Unlock() + return false, errors.New("unknown peer") + } + c.mu.Unlock() + pkt := make([]byte, 0, 512) // TODO: size it correctly? pool? if it matters. + if vni.isSet() { + gh := packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolDisco, + VNI: vni.get(), + Control: isRelayHandshakeMsg, + } + pkt = append(pkt, make([]byte, packet.GeneveFixedHeaderLength)...) + err := gh.Encode(pkt) + if err != nil { + return false, err + } + } pkt = append(pkt, disco.Magic...) pkt = c.discoPublic.AppendTo(pkt) - di := c.discoInfoLocked(dstDisco) - c.mu.Unlock() if isDERP { metricSendDiscoDERP.Add(1) @@ -1577,7 +1702,8 @@ func (c *Conn) sendDiscoMessage(dst netip.AddrPort, dstKey key.NodePublic, dstDi box := di.sharedKey.Seal(m.AppendMarshal(nil)) pkt = append(pkt, box...) - sent, err = c.sendAddr(dst, dstKey, pkt) + const isDisco = true + sent, err = c.sendAddr(dst, dstKey, pkt, isDisco) if sent { if logLevel == discoLog || (logLevel == discoVerboseLog && debugDisco()) { node := "?" @@ -1617,6 +1743,45 @@ const ( discoRXPathRawSocket discoRXPath = "raw socket" ) +const discoHeaderLen = len(disco.Magic) + key.DiscoPublicRawLen + +// isDiscoMaybeGeneve reports whether msg is a Tailscale Disco protocol +// message, and if true, whether it is encapsulated by a Geneve header. +// +// isGeneveEncap is only relevant when isDiscoMsg is true. +// +// Naked Disco, Geneve followed by Disco, and naked WireGuard can be confidently +// distinguished based on the following: +// 1. [disco.Magic] is sufficiently non-overlapping with a Geneve protocol +// field value of [packet.GeneveProtocolDisco]. +// 2. [disco.Magic] is sufficiently non-overlapping with the first 4 bytes of +// a WireGuard packet. +// 3. [packet.GeneveHeader] with a Geneve protocol field value of +// [packet.GeneveProtocolDisco] is sufficiently non-overlapping with the +// first 4 bytes of a WireGuard packet. +func isDiscoMaybeGeneve(msg []byte) (isDiscoMsg bool, isGeneveEncap bool) { + if len(msg) < discoHeaderLen { + return false, false + } + if string(msg[:len(disco.Magic)]) == disco.Magic { + return true, false + } + if len(msg) < packet.GeneveFixedHeaderLength+discoHeaderLen { + return false, false + } + if msg[0]&0xC0 != 0 || // version bits that we always transmit as 0s + msg[1]&0x3F != 0 || // reserved bits that we always transmit as 0s + binary.BigEndian.Uint16(msg[2:4]) != packet.GeneveProtocolDisco || + msg[7] != 0 { // reserved byte that we always transmit as 0 + return false, false + } + msg = msg[packet.GeneveFixedHeaderLength:] + if string(msg[:len(disco.Magic)]) == disco.Magic { + return true, true + } + return false, false +} + // handleDiscoMessage handles a discovery message and reports whether // msg was a Tailscale inter-node discovery message. // @@ -1632,18 +1797,28 @@ const ( // it was received from at the DERP layer. derpNodeSrc is zero when received // over UDP. func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc key.NodePublic, via discoRXPath) (isDiscoMsg bool) { - const headerLen = len(disco.Magic) + key.DiscoPublicRawLen - if len(msg) < headerLen || string(msg[:len(disco.Magic)]) != disco.Magic { - return false + isDiscoMsg, isGeneveEncap := isDiscoMaybeGeneve(msg) + if !isDiscoMsg { + return } + var geneve packet.GeneveHeader + if isGeneveEncap { + err := geneve.Decode(msg) + if err != nil { + // Decode only returns an error when 'msg' is too short, and + // 'isGeneveEncap' indicates it's a sufficient length. + c.logf("[unexpected] geneve header decoding error: %v", err) + return + } + msg = msg[packet.GeneveFixedHeaderLength:] + } + // The control bit should only be set for relay handshake messages + // terminating on or originating from a UDP relay server. We have yet to + // open the encrypted payload to determine the [disco.MessageType], but + // we assert it should be handshake-related. + shouldBeRelayHandshakeMsg := isGeneveEncap && geneve.Control - // If the first four parts are the prefix of disco.Magic - // (0x5453f09f) then it's definitely not a valid WireGuard - // packet (which starts with little-endian uint32 1, 2, 3, 4). - // Use naked returns for all following paths. - isDiscoMsg = true - - sender := key.DiscoPublicFromRaw32(mem.B(msg[len(disco.Magic):headerLen])) + sender := key.DiscoPublicFromRaw32(mem.B(msg[len(disco.Magic):discoHeaderLen])) c.mu.Lock() defer c.mu.Unlock() @@ -1660,7 +1835,20 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke return } - if !c.peerMap.knownPeerDiscoKey(sender) { + var di *discoInfo + switch { + case shouldBeRelayHandshakeMsg: + var ok bool + di, ok = c.relayManager.discoInfo(sender) + if !ok { + if debugDisco() { + c.logf("magicsock: disco: ignoring disco-looking relay handshake frame, no active handshakes with key %v over VNI %d", sender.ShortString(), geneve.VNI) + } + return + } + case c.peerMap.knownPeerDiscoKey(sender): + di = c.discoInfoForKnownPeerLocked(sender) + default: metricRecvDiscoBadPeer.Add(1) if debugDisco() { c.logf("magicsock: disco: ignoring disco-looking frame, don't know of key %v", sender.ShortString()) @@ -1669,7 +1857,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke } isDERP := src.Addr() == tailcfg.DerpMagicIPAddr - if !isDERP { + if !isDERP && !shouldBeRelayHandshakeMsg { // Record receive time for UDP transport packets. pi, ok := c.peerMap.byIPPort[src] if ok { @@ -1677,17 +1865,13 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke } } - // We're now reasonably sure we're expecting communication from - // this peer, do the heavy crypto lifting to see what they want. - // - // From here on, peerNode and de are non-nil. - - di := c.discoInfoLocked(sender) + // We're now reasonably sure we're expecting communication from 'sender', + // do the heavy crypto lifting to see what they want. - sealedBox := msg[headerLen:] + sealedBox := msg[discoHeaderLen:] payload, ok := di.sharedKey.Open(sealedBox) if !ok { - // This might be have been intended for a previous + // This might have been intended for a previous // disco key. When we restart we get a new disco key // and old packets might've still been in flight (or // scheduled). This is particularly the case for LANs @@ -1707,7 +1891,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke // Emit information about the disco frame into the pcap stream // if a capture hook is installed. if cb := c.captureHook.Load(); cb != nil { - cb(capture.PathDisco, time.Now(), disco.ToPCAPFrame(src, derpNodeSrc, payload), packet.CaptureMeta{}) + cb(packet.PathDisco, time.Now(), disco.ToPCAPFrame(src, derpNodeSrc, payload), packet.CaptureMeta{}) } dm, err := disco.Parse(payload) @@ -1730,12 +1914,33 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke metricRecvDiscoUDP.Add(1) } + if shouldBeRelayHandshakeMsg { + challenge, ok := dm.(*disco.BindUDPRelayEndpointChallenge) + if !ok { + // We successfully parsed the disco message, but it wasn't a + // challenge. We should never receive other message types + // from a relay server with the Geneve header control bit set. + c.logf("[unexpected] %T packets should not come from a relay server with Geneve control bit set", dm) + return + } + c.relayManager.handleBindUDPRelayEndpointChallenge(challenge, di, src, geneve.VNI) + return + } + switch dm := dm.(type) { case *disco.Ping: metricRecvDiscoPing.Add(1) + if isGeneveEncap { + // TODO(jwhited): handle Geneve-encapsulated disco ping. + return + } c.handlePingLocked(dm, src, di, derpNodeSrc) case *disco.Pong: metricRecvDiscoPong.Add(1) + if isGeneveEncap { + // TODO(jwhited): handle Geneve-encapsulated disco pong. + return + } // There might be multiple nodes for the sender's DiscoKey. // Ask each to handle it, stopping once one reports that // the Pong's TxID was theirs. @@ -1745,18 +1950,35 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke } return true }) - case *disco.CallMeMaybe: + case *disco.CallMeMaybe, *disco.CallMeMaybeVia: + var via *disco.CallMeMaybeVia + isVia := false + msgType := "CallMeMaybe" + cmm, ok := dm.(*disco.CallMeMaybe) + if !ok { + via = dm.(*disco.CallMeMaybeVia) + msgType = "CallMeMaybeVia" + isVia = true + } + metricRecvDiscoCallMeMaybe.Add(1) if !isDERP || derpNodeSrc.IsZero() { - // CallMeMaybe messages should only come via DERP. - c.logf("[unexpected] CallMeMaybe packets should only come via DERP") + // CallMeMaybe{Via} messages should only come via DERP. + c.logf("[unexpected] %s packets should only come via DERP", msgType) return } nodeKey := derpNodeSrc ep, ok := c.peerMap.endpointForNodeKey(nodeKey) if !ok { metricRecvDiscoCallMeMaybeBadNode.Add(1) - c.logf("magicsock: disco: ignoring CallMeMaybe from %v; %v is unknown", sender.ShortString(), derpNodeSrc.ShortString()) + c.logf("magicsock: disco: ignoring %s from %v; %v is unknown", msgType, sender.ShortString(), derpNodeSrc.ShortString()) + return + } + ep.mu.Lock() + relayCapable := ep.relayCapable + ep.mu.Unlock() + if isVia && !relayCapable { + c.logf("magicsock: disco: ignoring %s from %v; %v is not known to be relay capable", msgType, sender.ShortString(), sender.ShortString()) return } epDisco := ep.disco.Load() @@ -1765,14 +1987,23 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke } if epDisco.key != di.discoKey { metricRecvDiscoCallMeMaybeBadDisco.Add(1) - c.logf("[unexpected] CallMeMaybe from peer via DERP whose netmap discokey != disco source") + c.logf("[unexpected] %s from peer via DERP whose netmap discokey != disco source", msgType) return } - c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", - c.discoShort, epDisco.short, - ep.publicKey.ShortString(), derpStr(src.String()), - len(dm.MyNumber)) - go ep.handleCallMeMaybe(dm) + if isVia { + c.dlogf("[v1] magicsock: disco: %v<-%v via %v (%v, %v) got call-me-maybe-via, %d endpoints", + c.discoShort, epDisco.short, via.ServerDisco.ShortString(), + ep.publicKey.ShortString(), derpStr(src.String()), + len(via.AddrPorts)) + c.relayManager.handleCallMeMaybeVia(ep, via) + } else { + c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", + c.discoShort, epDisco.short, + ep.publicKey.ShortString(), derpStr(src.String()), + len(cmm.MyNumber)) + go ep.handleCallMeMaybe(cmm) + } + } return } @@ -1824,12 +2055,12 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src netip.AddrPort, di *discoInf isDerp := src.Addr() == tailcfg.DerpMagicIPAddr // If we can figure out with certainty which node key this disco - // message is for, eagerly update our IP<>node and disco<>node + // message is for, eagerly update our IP:port<>node and disco<>node // mappings to make p2p path discovery faster in simple // cases. Without this, disco would still work, but would be // reliant on DERP call-me-maybe to establish the disco<>node // mapping, and on subsequent disco handlePongConnLocked to establish - // the IP<>disco mapping. + // the IP:port<>disco mapping. if nk, ok := c.unambiguousNodeKeyOfPingLocked(dm, di.discoKey, derpNodeSrc); ok { if !isDerp { c.peerMap.setNodeKeyForIPPort(src, nk) @@ -1890,7 +2121,7 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src netip.AddrPort, di *discoInf ipDst := src discoDest := di.discoKey - go c.sendDiscoMessage(ipDst, dstKey, discoDest, &disco.Pong{ + go c.sendDiscoMessage(ipDst, virtualNetworkID{}, dstKey, discoDest, &disco.Pong{ TxID: dm.TxID, Src: src, }, discoVerboseLog) @@ -1935,19 +2166,24 @@ func (c *Conn) enqueueCallMeMaybe(derpAddr netip.AddrPort, de *endpoint) { for _, ep := range c.lastEndpoints { eps = append(eps, ep.Addr) } - go de.c.sendDiscoMessage(derpAddr, de.publicKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog) + go de.c.sendDiscoMessage(derpAddr, virtualNetworkID{}, de.publicKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog) if debugSendCallMeUnknownPeer() { // Send a callMeMaybe packet to a non-existent peer unknownKey := key.NewNode().Public() c.logf("magicsock: sending CallMeMaybe to unknown peer per TS_DEBUG_SEND_CALLME_UNKNOWN_PEER") - go de.c.sendDiscoMessage(derpAddr, unknownKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog) + go de.c.sendDiscoMessage(derpAddr, virtualNetworkID{}, unknownKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog) } } -// discoInfoLocked returns the previous or new discoInfo for k. +// discoInfoForKnownPeerLocked returns the previous or new discoInfo for k. +// +// Callers must only pass key.DiscoPublic's that are present in and +// lifetime-managed via [Conn].peerMap. UDP relay server disco keys are discovered +// at relay endpoint allocation time or [disco.CallMeMaybeVia] reception time +// and therefore must never pass through this method. // // c.mu must be held. -func (c *Conn) discoInfoLocked(k key.DiscoPublic) *discoInfo { +func (c *Conn) discoInfoForKnownPeerLocked(k key.DiscoPublic) *discoInfo { di, ok := c.discoInfo[k] if !ok { di = &discoInfo{ @@ -2335,10 +2571,7 @@ func devPanicf(format string, a ...any) { func (c *Conn) logEndpointCreated(n tailcfg.NodeView) { c.logf("magicsock: created endpoint key=%s: disco=%s; %v", n.Key().ShortString(), n.DiscoKey().ShortString(), logger.ArgWriter(func(w *bufio.Writer) { - const derpPrefix = "127.3.3.40:" - if strings.HasPrefix(n.DERP(), derpPrefix) { - ipp, _ := netip.ParseAddrPort(n.DERP()) - regionID := int(ipp.Port()) + if regionID := n.HomeDERP(); regionID != 0 { code := c.derpRegionCodeLocked(regionID) if code != "" { code = "(" + code + ")" @@ -2346,16 +2579,14 @@ func (c *Conn) logEndpointCreated(n tailcfg.NodeView) { fmt.Fprintf(w, "derp=%v%s ", regionID, code) } - for i := range n.AllowedIPs().Len() { - a := n.AllowedIPs().At(i) + for _, a := range n.AllowedIPs().All() { if a.IsSingleIP() { fmt.Fprintf(w, "aip=%v ", a.Addr()) } else { fmt.Fprintf(w, "aip=%v ", a) } } - for i := range n.Endpoints().Len() { - ep := n.Endpoints().At(i) + for _, ep := range n.Endpoints().All() { fmt.Fprintf(w, "ep=%v ", ep) } })) @@ -2456,6 +2687,9 @@ func (c *connBind) Close() error { if c.closeDisco6 != nil { c.closeDisco6.Close() } + if c.eventClient != nil { + c.eventClient.Close() + } // Send an empty read result to unblock receiveDERP, // which will then check connBind.Closed. // connBind.Closed takes c.mu, but c.derpRecvCh is buffered. @@ -2485,6 +2719,7 @@ func (c *Conn) Close() error { } c.stopPeriodicReSTUNTimerLocked() c.portMapper.Close() + c.portMapperLogfUnregister() c.peerMap.forEachEndpoint(func(ep *endpoint) { ep.stopAndReset() @@ -2666,7 +2901,7 @@ func (c *Conn) bindSocket(ruc *RebindingUDPConn, network string, curPortFate cur } ports = append(ports, 0) // Remove duplicates. (All duplicates are consecutive.) - uniq.ModifySlice(&ports) + ports = slices.Compact(ports) if debugBindSocket() { c.logf("magicsock: bindSocket: candidate ports: %+v", ports) @@ -2761,7 +2996,9 @@ func (c *Conn) Rebind() { c.logf("Rebind; defIf=%q, ips=%v", defIf, ifIPs) } - c.maybeCloseDERPsOnRebind(ifIPs) + if len(ifIPs) > 0 { + c.maybeCloseDERPsOnRebind(ifIPs) + } c.resetEndpointStates() } @@ -3001,9 +3238,21 @@ func (c *Conn) DebugPickNewDERP() error { return errors.New("too few regions") } +func (c *Conn) DebugForcePreferDERP(n int) { + c.mu.Lock() + defer c.mu.Unlock() + + c.logf("magicsock: [debug] force preferred DERP set to: %d", n) + c.netChecker.SetForcePreferredDERP(n) +} + // portableTrySetSocketBuffer sets SO_SNDBUF and SO_RECVBUF on pconn to socketBufferSize, // logging an error if it occurs. func portableTrySetSocketBuffer(pconn nettype.PacketConn, logf logger.Logf) { + if runtime.GOOS == "plan9" { + // Not supported. Don't try. Avoid logspam. + return + } if c, ok := pconn.(*net.UDPConn); ok { // Attempt to increase the buffer size, and allow failures. if err := c.SetReadBuffer(socketBufferSize); err != nil { diff --git a/wgengine/magicsock/magicsock_notplan9.go b/wgengine/magicsock/magicsock_notplan9.go new file mode 100644 index 0000000000000..86d099ee7f48c --- /dev/null +++ b/wgengine/magicsock/magicsock_notplan9.go @@ -0,0 +1,31 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package magicsock + +import ( + "errors" + "syscall" +) + +// shouldRebind returns if the error is one that is known to be healed by a +// rebind, and if so also returns a resason string for the rebind. +func shouldRebind(err error) (ok bool, reason string) { + switch { + // EPIPE/ENOTCONN are common errors when a send fails due to a closed + // socket. There is some platform and version inconsistency in which + // error is returned, but the meaning is the same. + case errors.Is(err, syscall.EPIPE), errors.Is(err, syscall.ENOTCONN): + return true, "broken-pipe" + + // EPERM is typically caused by EDR software, and has been observed to be + // transient, it seems that some versions of some EDR lose track of sockets + // at times, and return EPERM, but reconnects will establish appropriate + // rights associated with a new socket. + case errors.Is(err, syscall.EPERM): + return true, "operation-not-permitted" + } + return false, "" +} diff --git a/wgengine/magicsock/magicsock_plan9.go b/wgengine/magicsock/magicsock_plan9.go new file mode 100644 index 0000000000000..65714c3e13c33 --- /dev/null +++ b/wgengine/magicsock/magicsock_plan9.go @@ -0,0 +1,12 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build plan9 + +package magicsock + +// shouldRebind returns if the error is one that is known to be healed by a +// rebind, and if so also returns a resason string for the rebind. +func shouldRebind(err error) (ok bool, reason string) { + return false, "" +} diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index c1b8eef223257..ddbf3e3940efe 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -12,6 +12,7 @@ import ( "errors" "fmt" "io" + "math" "math/rand" "net" "net/http" @@ -33,7 +34,6 @@ import ( "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun/tuntest" "go4.org/mem" - xmaps "golang.org/x/exp/maps" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" "tailscale.com/cmd/testwrapper/flakytest" @@ -63,8 +63,11 @@ import ( "tailscale.com/types/nettype" "tailscale.com/types/ptr" "tailscale.com/util/cibuild" + "tailscale.com/util/eventbus" + "tailscale.com/util/must" "tailscale.com/util/racebuild" "tailscale.com/util/set" + "tailscale.com/util/slicesx" "tailscale.com/util/usermetric" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/wgcfg" @@ -172,10 +175,14 @@ func newMagicStack(t testing.TB, logf logger.Logf, l nettype.PacketListener, der func newMagicStackWithKey(t testing.TB, logf logger.Logf, l nettype.PacketListener, derpMap *tailcfg.DERPMap, privateKey key.NodePrivate) *magicStack { t.Helper() - netMon, err := netmon.New(logf) + bus := eventbus.New() + defer bus.Close() + + netMon, err := netmon.New(bus, logf) if err != nil { t.Fatalf("netmon.New: %v", err) } + ht := new(health.Tracker) var reg usermetric.Registry epCh := make(chan []tailcfg.Endpoint, 100) // arbitrary @@ -183,6 +190,7 @@ func newMagicStackWithKey(t testing.TB, logf logger.Logf, l nettype.PacketListen NetMon: netMon, Metrics: ®, Logf: logf, + HealthTracker: ht, DisablePortMapper: true, TestOnlyPacketListener: l, EndpointsFunc: func(eps []tailcfg.Endpoint) { @@ -311,7 +319,7 @@ func meshStacks(logf logger.Logf, mutateNetmap func(idx int, nm *netmap.NetworkM Addresses: addrs, AllowedIPs: addrs, Endpoints: epFromTyped(eps[i]), - DERP: "127.3.3.40:1", + HomeDERP: 1, } nm.Peers = append(nm.Peers, peer.View()) } @@ -387,7 +395,10 @@ func TestNewConn(t *testing.T) { } } - netMon, err := netmon.New(logger.WithPrefix(t.Logf, "... netmon: ")) + bus := eventbus.New() + defer bus.Close() + + netMon, err := netmon.New(bus, logger.WithPrefix(t.Logf, "... netmon: ")) if err != nil { t.Fatalf("netmon.New: %v", err) } @@ -520,7 +531,10 @@ func TestDeviceStartStop(t *testing.T) { tstest.PanicOnLog() tstest.ResourceCheck(t) - netMon, err := netmon.New(logger.WithPrefix(t.Logf, "... netmon: ")) + bus := eventbus.New() + defer bus.Close() + + netMon, err := netmon.New(bus, logger.WithPrefix(t.Logf, "... netmon: ")) if err != nil { t.Fatalf("netmon.New: %v", err) } @@ -1130,7 +1144,7 @@ func testTwoDevicePing(t *testing.T, d *devices) { } } t.Helper() - t.Errorf("missing any connection to %s from %s", wantConns, xmaps.Keys(stats)) + t.Errorf("missing any connection to %s from %s", wantConns, slicesx.MapKeys(stats)) } addrPort := netip.MustParseAddrPort @@ -1359,7 +1373,10 @@ func newTestConn(t testing.TB) *Conn { t.Helper() port := pickPort(t) - netMon, err := netmon.New(logger.WithPrefix(t.Logf, "... netmon: ")) + bus := eventbus.New() + defer bus.Close() + + netMon, err := netmon.New(bus, logger.WithPrefix(t.Logf, "... netmon: ")) if err != nil { t.Fatalf("netmon.New: %v", err) } @@ -3047,37 +3064,306 @@ func TestMaybeSetNearestDERP(t *testing.T) { } } +func TestShouldRebind(t *testing.T) { + tests := []struct { + err error + ok bool + reason string + }{ + {nil, false, ""}, + {io.EOF, false, ""}, + {io.ErrUnexpectedEOF, false, ""}, + {io.ErrShortBuffer, false, ""}, + {&net.OpError{Err: syscall.EPERM}, true, "operation-not-permitted"}, + {&net.OpError{Err: syscall.EPIPE}, true, "broken-pipe"}, + {&net.OpError{Err: syscall.ENOTCONN}, true, "broken-pipe"}, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("%s-%v", tt.err, tt.ok), func(t *testing.T) { + if got, reason := shouldRebind(tt.err); got != tt.ok || reason != tt.reason { + t.Errorf("errShouldRebind(%v) = %v, %q; want %v, %q", tt.err, got, reason, tt.ok, tt.reason) + } + }) + } +} + func TestMaybeRebindOnError(t *testing.T) { tstest.PanicOnLog() tstest.ResourceCheck(t) - err := fmt.Errorf("outer err: %w", syscall.EPERM) + var rebindErrs []error + if runtime.GOOS != "plan9" { + rebindErrs = append(rebindErrs, + &net.OpError{Err: syscall.EPERM}, + &net.OpError{Err: syscall.EPIPE}, + &net.OpError{Err: syscall.ENOTCONN}, + ) + } + + for _, rebindErr := range rebindErrs { + t.Run(fmt.Sprintf("rebind-%s", rebindErr), func(t *testing.T) { + conn := newTestConn(t) + defer conn.Close() + + before := metricRebindCalls.Value() + conn.maybeRebindOnError(rebindErr) + after := metricRebindCalls.Value() + if before+1 != after { + t.Errorf("should rebind on %#v", rebindErr) + } + }) + } - t.Run("darwin-rebind", func(t *testing.T) { - conn := newTestConn(t) - defer conn.Close() - rebound := conn.maybeRebindOnError("darwin", err) - if !rebound { - t.Errorf("darwin should rebind on syscall.EPERM") + t.Run("no-frequent-rebind", func(t *testing.T) { + if runtime.GOOS != "plan9" { + err := fmt.Errorf("outer err: %w", syscall.EPERM) + conn := newTestConn(t) + defer conn.Close() + conn.lastErrRebind.Store(time.Now().Add(-1 * time.Second)) + before := metricRebindCalls.Value() + conn.maybeRebindOnError(err) + after := metricRebindCalls.Value() + if before != after { + t.Errorf("should not rebind within 5 seconds of last") + } } }) +} - t.Run("linux-not-rebind", func(t *testing.T) { - conn := newTestConn(t) - defer conn.Close() - rebound := conn.maybeRebindOnError("linux", err) - if rebound { - t.Errorf("linux should not rebind on syscall.EPERM") - } - }) +func TestNetworkDownSendErrors(t *testing.T) { + bus := eventbus.New() + defer bus.Close() - t.Run("no-frequent-rebind", func(t *testing.T) { - conn := newTestConn(t) - defer conn.Close() - conn.lastEPERMRebind.Store(time.Now().Add(-1 * time.Second)) - rebound := conn.maybeRebindOnError("darwin", err) - if rebound { - t.Errorf("darwin should not rebind on syscall.EPERM within 5 seconds of last") - } - }) + netMon := must.Get(netmon.New(bus, t.Logf)) + defer netMon.Close() + + reg := new(usermetric.Registry) + conn := must.Get(NewConn(Options{ + DisablePortMapper: true, + Logf: t.Logf, + NetMon: netMon, + Metrics: reg, + })) + defer conn.Close() + + conn.SetNetworkUp(false) + if err := conn.Send([][]byte{{00}}, &lazyEndpoint{}); err == nil { + t.Error("expected error, got nil") + } + resp := httptest.NewRecorder() + reg.Handler(resp, new(http.Request)) + if !strings.Contains(resp.Body.String(), `tailscaled_outbound_dropped_packets_total{reason="error"} 1`) { + t.Errorf("expected NetworkDown to increment packet dropped metric; got %q", resp.Body.String()) + } +} + +func Test_isDiscoMaybeGeneve(t *testing.T) { + discoPub := key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 30: 30, 31: 31})) + nakedDisco := make([]byte, 0, 512) + nakedDisco = append(nakedDisco, disco.Magic...) + nakedDisco = discoPub.AppendTo(nakedDisco) + + geneveEncapDisco := make([]byte, packet.GeneveFixedHeaderLength+len(nakedDisco)) + gh := packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolDisco, + VNI: 1, + Control: true, + } + err := gh.Encode(geneveEncapDisco) + if err != nil { + t.Fatal(err) + } + copy(geneveEncapDisco[packet.GeneveFixedHeaderLength:], nakedDisco) + + nakedWireGuardInitiation := make([]byte, len(geneveEncapDisco)) + binary.LittleEndian.PutUint32(nakedWireGuardInitiation, device.MessageInitiationType) + nakedWireGuardResponse := make([]byte, len(geneveEncapDisco)) + binary.LittleEndian.PutUint32(nakedWireGuardResponse, device.MessageResponseType) + nakedWireGuardCookieReply := make([]byte, len(geneveEncapDisco)) + binary.LittleEndian.PutUint32(nakedWireGuardCookieReply, device.MessageCookieReplyType) + nakedWireGuardTransport := make([]byte, len(geneveEncapDisco)) + binary.LittleEndian.PutUint32(nakedWireGuardTransport, device.MessageTransportType) + + geneveEncapWireGuard := make([]byte, packet.GeneveFixedHeaderLength+len(nakedWireGuardInitiation)) + gh = packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolWireGuard, + VNI: 1, + Control: true, + } + err = gh.Encode(geneveEncapWireGuard) + if err != nil { + t.Fatal(err) + } + copy(geneveEncapWireGuard[packet.GeneveFixedHeaderLength:], nakedWireGuardInitiation) + + geneveEncapDiscoNonZeroGeneveVersion := make([]byte, packet.GeneveFixedHeaderLength+len(nakedDisco)) + gh = packet.GeneveHeader{ + Version: 1, + Protocol: packet.GeneveProtocolDisco, + VNI: 1, + Control: true, + } + err = gh.Encode(geneveEncapDiscoNonZeroGeneveVersion) + if err != nil { + t.Fatal(err) + } + copy(geneveEncapDiscoNonZeroGeneveVersion[packet.GeneveFixedHeaderLength:], nakedDisco) + + geneveEncapDiscoNonZeroGeneveReservedBits := make([]byte, packet.GeneveFixedHeaderLength+len(nakedDisco)) + gh = packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolDisco, + VNI: 1, + Control: true, + } + err = gh.Encode(geneveEncapDiscoNonZeroGeneveReservedBits) + if err != nil { + t.Fatal(err) + } + geneveEncapDiscoNonZeroGeneveReservedBits[1] |= 0x3F + copy(geneveEncapDiscoNonZeroGeneveReservedBits[packet.GeneveFixedHeaderLength:], nakedDisco) + + geneveEncapDiscoNonZeroGeneveVNILSB := make([]byte, packet.GeneveFixedHeaderLength+len(nakedDisco)) + gh = packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolDisco, + VNI: 1, + Control: true, + } + err = gh.Encode(geneveEncapDiscoNonZeroGeneveVNILSB) + if err != nil { + t.Fatal(err) + } + geneveEncapDiscoNonZeroGeneveVNILSB[7] |= 0xFF + copy(geneveEncapDiscoNonZeroGeneveVNILSB[packet.GeneveFixedHeaderLength:], nakedDisco) + + tests := []struct { + name string + msg []byte + wantIsDiscoMsg bool + wantIsGeneveEncap bool + }{ + { + name: "naked disco", + msg: nakedDisco, + wantIsDiscoMsg: true, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap disco", + msg: geneveEncapDisco, + wantIsDiscoMsg: true, + wantIsGeneveEncap: true, + }, + { + name: "geneve encap disco nonzero geneve version", + msg: geneveEncapDiscoNonZeroGeneveVersion, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap disco nonzero geneve reserved bits", + msg: geneveEncapDiscoNonZeroGeneveReservedBits, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap disco nonzero geneve vni lsb", + msg: geneveEncapDiscoNonZeroGeneveVNILSB, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "geneve encap wireguard", + msg: geneveEncapWireGuard, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "naked WireGuard Initiation type", + msg: nakedWireGuardInitiation, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "naked WireGuard Response type", + msg: nakedWireGuardResponse, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "naked WireGuard Cookie Reply type", + msg: nakedWireGuardCookieReply, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + { + name: "naked WireGuard Transport type", + msg: nakedWireGuardTransport, + wantIsDiscoMsg: false, + wantIsGeneveEncap: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotIsDiscoMsg, gotIsGeneveEncap := isDiscoMaybeGeneve(tt.msg) + if gotIsDiscoMsg != tt.wantIsDiscoMsg { + t.Errorf("isDiscoMaybeGeneve() gotIsDiscoMsg = %v, want %v", gotIsDiscoMsg, tt.wantIsDiscoMsg) + } + if gotIsGeneveEncap != tt.wantIsGeneveEncap { + t.Errorf("isDiscoMaybeGeneve() gotIsGeneveEncap = %v, want %v", gotIsGeneveEncap, tt.wantIsGeneveEncap) + } + }) + } +} + +func Test_virtualNetworkID(t *testing.T) { + tests := []struct { + name string + set *uint32 + want uint32 + }{ + { + "don't set", + nil, + 0, + }, + { + "set 0", + ptr.To(uint32(0)), + 0, + }, + { + "set 1", + ptr.To(uint32(1)), + 1, + }, + { + "set math.MaxUint32", + ptr.To(uint32(math.MaxUint32)), + 1<<24 - 1, + }, + { + "set max 3-byte value", + ptr.To(uint32(1<<24 - 1)), + 1<<24 - 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v := virtualNetworkID{} + if tt.set != nil { + v.set(*tt.set) + } + if v.isSet() != (tt.set != nil) { + t.Fatalf("isSet: %v != wantIsSet: %v", v.isSet(), tt.set != nil) + } + if v.get() != tt.want { + t.Fatalf("get(): %v != want: %v", v.get(), tt.want) + } + }) + } } diff --git a/wgengine/magicsock/relaymanager.go b/wgengine/magicsock/relaymanager.go new file mode 100644 index 0000000000000..0b19bb83fcc1a --- /dev/null +++ b/wgengine/magicsock/relaymanager.go @@ -0,0 +1,624 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/netip" + "sync" + "time" + + "tailscale.com/disco" + udprelay "tailscale.com/net/udprelay/endpoint" + "tailscale.com/types/key" + "tailscale.com/util/httpm" + "tailscale.com/util/set" +) + +// relayManager manages allocation and handshaking of +// [tailscale.com/net/udprelay.Server] endpoints. The zero value is ready for +// use. +type relayManager struct { + initOnce sync.Once + + // =================================================================== + // The following fields are owned by a single goroutine, runLoop(). + serversByAddrPort map[netip.AddrPort]key.DiscoPublic + serversByDisco map[key.DiscoPublic]netip.AddrPort + allocWorkByEndpoint map[*endpoint]*relayEndpointAllocWork + handshakeWorkByEndpointByServerDisco map[*endpoint]map[key.DiscoPublic]*relayHandshakeWork + handshakeWorkByServerDiscoVNI map[serverDiscoVNI]*relayHandshakeWork + + // =================================================================== + // The following chan fields serve event inputs to a single goroutine, + // runLoop(). + allocateHandshakeCh chan *endpoint + allocateWorkDoneCh chan relayEndpointAllocWorkDoneEvent + handshakeWorkDoneCh chan relayEndpointHandshakeWorkDoneEvent + cancelWorkCh chan *endpoint + newServerEndpointCh chan newRelayServerEndpointEvent + rxChallengeCh chan relayHandshakeChallengeEvent + + discoInfoMu sync.Mutex // guards the following field + discoInfoByServerDisco map[key.DiscoPublic]*relayHandshakeDiscoInfo + + // runLoopStoppedCh is written to by runLoop() upon return, enabling event + // writers to restart it when they are blocked (see + // relayManagerInputEvent()). + runLoopStoppedCh chan struct{} +} + +// serverDiscoVNI represents a [tailscale.com/net/udprelay.Server] disco key +// and Geneve header VNI value for a given [udprelay.ServerEndpoint]. +type serverDiscoVNI struct { + serverDisco key.DiscoPublic + vni uint32 +} + +// relayHandshakeWork serves to track in-progress relay handshake work for a +// [udprelay.ServerEndpoint]. This structure is immutable once initialized. +type relayHandshakeWork struct { + ep *endpoint + se udprelay.ServerEndpoint + + // In order to not deadlock, runLoop() must select{} read doneCh when + // attempting to write into rxChallengeCh, and the handshake work goroutine + // must close(doneCh) before attempting to write to + // relayManager.handshakeWorkDoneCh. + rxChallengeCh chan relayHandshakeChallengeEvent + doneCh chan struct{} + + ctx context.Context + cancel context.CancelFunc + wg *sync.WaitGroup +} + +// newRelayServerEndpointEvent indicates a new [udprelay.ServerEndpoint] has +// become known either via allocation with a relay server, or via +// [disco.CallMeMaybeVia] reception. This structure is immutable once +// initialized. +type newRelayServerEndpointEvent struct { + ep *endpoint + se udprelay.ServerEndpoint + server netip.AddrPort // zero value if learned via [disco.CallMeMaybeVia] +} + +// relayEndpointAllocWorkDoneEvent indicates relay server endpoint allocation +// work for an [*endpoint] has completed. This structure is immutable once +// initialized. +type relayEndpointAllocWorkDoneEvent struct { + work *relayEndpointAllocWork +} + +// relayEndpointHandshakeWorkDoneEvent indicates relay server endpoint handshake +// work for an [*endpoint] has completed. This structure is immutable once +// initialized. +type relayEndpointHandshakeWorkDoneEvent struct { + work *relayHandshakeWork + answerSentTo netip.AddrPort // zero value if answer was not transmitted +} + +// activeWorkRunLoop returns true if there is outstanding allocation or +// handshaking work, otherwise it returns false. +func (r *relayManager) activeWorkRunLoop() bool { + return len(r.allocWorkByEndpoint) > 0 || len(r.handshakeWorkByEndpointByServerDisco) > 0 +} + +// runLoop is a form of event loop. It ensures exclusive access to most of +// [relayManager] state. +func (r *relayManager) runLoop() { + defer func() { + r.runLoopStoppedCh <- struct{}{} + }() + + for { + select { + case ep := <-r.allocateHandshakeCh: + r.stopWorkRunLoop(ep, stopHandshakeWorkOnlyKnownServers) + r.allocateAllServersRunLoop(ep) + if !r.activeWorkRunLoop() { + return + } + case done := <-r.allocateWorkDoneCh: + work, ok := r.allocWorkByEndpoint[done.work.ep] + if ok && work == done.work { + // Verify the work in the map is the same as the one that we're + // cleaning up. New events on r.allocateHandshakeCh can + // overwrite pre-existing keys. + delete(r.allocWorkByEndpoint, done.work.ep) + } + if !r.activeWorkRunLoop() { + return + } + case ep := <-r.cancelWorkCh: + r.stopWorkRunLoop(ep, stopHandshakeWorkAllServers) + if !r.activeWorkRunLoop() { + return + } + case newServerEndpoint := <-r.newServerEndpointCh: + r.handleNewServerEndpointRunLoop(newServerEndpoint) + if !r.activeWorkRunLoop() { + return + } + case done := <-r.handshakeWorkDoneCh: + r.handleHandshakeWorkDoneRunLoop(done) + if !r.activeWorkRunLoop() { + return + } + case challenge := <-r.rxChallengeCh: + r.handleRxChallengeRunLoop(challenge) + if !r.activeWorkRunLoop() { + return + } + } + } +} + +type relayHandshakeChallengeEvent struct { + challenge [32]byte + disco key.DiscoPublic + from netip.AddrPort + vni uint32 + at time.Time +} + +// relayEndpointAllocWork serves to track in-progress relay endpoint allocation +// for an [*endpoint]. This structure is immutable once initialized. +type relayEndpointAllocWork struct { + // ep is the [*endpoint] associated with the work + ep *endpoint + // cancel() will signal all associated goroutines to return + cancel context.CancelFunc + // wg.Wait() will return once all associated goroutines have returned + wg *sync.WaitGroup +} + +// init initializes [relayManager] if it is not already initialized. +func (r *relayManager) init() { + r.initOnce.Do(func() { + r.discoInfoByServerDisco = make(map[key.DiscoPublic]*relayHandshakeDiscoInfo) + r.serversByDisco = make(map[key.DiscoPublic]netip.AddrPort) + r.serversByAddrPort = make(map[netip.AddrPort]key.DiscoPublic) + r.allocWorkByEndpoint = make(map[*endpoint]*relayEndpointAllocWork) + r.handshakeWorkByEndpointByServerDisco = make(map[*endpoint]map[key.DiscoPublic]*relayHandshakeWork) + r.handshakeWorkByServerDiscoVNI = make(map[serverDiscoVNI]*relayHandshakeWork) + r.allocateHandshakeCh = make(chan *endpoint) + r.allocateWorkDoneCh = make(chan relayEndpointAllocWorkDoneEvent) + r.handshakeWorkDoneCh = make(chan relayEndpointHandshakeWorkDoneEvent) + r.cancelWorkCh = make(chan *endpoint) + r.newServerEndpointCh = make(chan newRelayServerEndpointEvent) + r.rxChallengeCh = make(chan relayHandshakeChallengeEvent) + r.runLoopStoppedCh = make(chan struct{}, 1) + go r.runLoop() + }) +} + +// relayHandshakeDiscoInfo serves to cache a [*discoInfo] for outstanding +// [*relayHandshakeWork] against a given relay server. +type relayHandshakeDiscoInfo struct { + work set.Set[*relayHandshakeWork] // guarded by relayManager.discoInfoMu + di *discoInfo // immutable once initialized +} + +// ensureDiscoInfoFor ensures a [*discoInfo] will be returned by discoInfo() for +// the server disco key associated with 'work'. Callers must also call +// derefDiscoInfoFor() when 'work' is complete. +func (r *relayManager) ensureDiscoInfoFor(work *relayHandshakeWork) { + r.discoInfoMu.Lock() + defer r.discoInfoMu.Unlock() + di, ok := r.discoInfoByServerDisco[work.se.ServerDisco] + if !ok { + di = &relayHandshakeDiscoInfo{} + di.work.Make() + r.discoInfoByServerDisco[work.se.ServerDisco] = di + } + di.work.Add(work) + if di.di == nil { + di.di = &discoInfo{ + discoKey: work.se.ServerDisco, + discoShort: work.se.ServerDisco.ShortString(), + sharedKey: work.ep.c.discoPrivate.Shared(work.se.ServerDisco), + } + } +} + +// derefDiscoInfoFor decrements the reference count of the [*discoInfo] +// associated with 'work'. +func (r *relayManager) derefDiscoInfoFor(work *relayHandshakeWork) { + r.discoInfoMu.Lock() + defer r.discoInfoMu.Unlock() + di, ok := r.discoInfoByServerDisco[work.se.ServerDisco] + if !ok { + // TODO(jwhited): unexpected + return + } + di.work.Delete(work) + if di.work.Len() == 0 { + delete(r.discoInfoByServerDisco, work.se.ServerDisco) + } +} + +// discoInfo returns a [*discoInfo] for 'serverDisco' if there is an +// active/ongoing handshake with it, otherwise it returns nil, false. +func (r *relayManager) discoInfo(serverDisco key.DiscoPublic) (_ *discoInfo, ok bool) { + r.discoInfoMu.Lock() + defer r.discoInfoMu.Unlock() + di, ok := r.discoInfoByServerDisco[serverDisco] + if ok { + return di.di, ok + } + return nil, false +} + +func (r *relayManager) handleCallMeMaybeVia(ep *endpoint, dm *disco.CallMeMaybeVia) { + se := udprelay.ServerEndpoint{ + ServerDisco: dm.ServerDisco, + LamportID: dm.LamportID, + AddrPorts: dm.AddrPorts, + VNI: dm.VNI, + } + se.BindLifetime.Duration = dm.BindLifetime + se.SteadyStateLifetime.Duration = dm.SteadyStateLifetime + relayManagerInputEvent(r, nil, &r.newServerEndpointCh, newRelayServerEndpointEvent{ + ep: ep, + se: se, + }) +} + +func (r *relayManager) handleBindUDPRelayEndpointChallenge(dm *disco.BindUDPRelayEndpointChallenge, di *discoInfo, src netip.AddrPort, vni uint32) { + relayManagerInputEvent(r, nil, &r.rxChallengeCh, relayHandshakeChallengeEvent{challenge: dm.Challenge, disco: di.discoKey, from: src, vni: vni, at: time.Now()}) +} + +// relayManagerInputEvent initializes [relayManager] if necessary, starts +// relayManager.runLoop() if it is not running, and writes 'event' on 'eventCh'. +// +// [relayManager] initialization will make `*eventCh`, so it must be passed as +// a pointer to a channel. +// +// 'ctx' can be used for returning when runLoop is waiting for the calling +// goroutine to return, i.e. the calling goroutine was birthed by runLoop and is +// cancelable via 'ctx'. 'ctx' may be nil. +func relayManagerInputEvent[T any](r *relayManager, ctx context.Context, eventCh *chan T, event T) { + r.init() + var ctxDoneCh <-chan struct{} + if ctx != nil { + ctxDoneCh = ctx.Done() + } + for { + select { + case <-ctxDoneCh: + return + case *eventCh <- event: + return + case <-r.runLoopStoppedCh: + go r.runLoop() + } + } +} + +// allocateAndHandshakeAllServers kicks off allocation and handshaking of relay +// endpoints for 'ep' on all known relay servers, canceling any existing +// in-progress work. +func (r *relayManager) allocateAndHandshakeAllServers(ep *endpoint) { + relayManagerInputEvent(r, nil, &r.allocateHandshakeCh, ep) +} + +// stopWork stops all outstanding allocation & handshaking work for 'ep'. +func (r *relayManager) stopWork(ep *endpoint) { + relayManagerInputEvent(r, nil, &r.cancelWorkCh, ep) +} + +// stopHandshakeWorkFilter represents filters for handshake work cancellation +type stopHandshakeWorkFilter bool + +const ( + stopHandshakeWorkAllServers stopHandshakeWorkFilter = false + stopHandshakeWorkOnlyKnownServers = true +) + +// stopWorkRunLoop cancels & clears outstanding allocation and handshaking +// work for 'ep'. Handshake work cancellation is subject to the filter supplied +// in 'f'. +func (r *relayManager) stopWorkRunLoop(ep *endpoint, f stopHandshakeWorkFilter) { + allocWork, ok := r.allocWorkByEndpoint[ep] + if ok { + allocWork.cancel() + allocWork.wg.Wait() + delete(r.allocWorkByEndpoint, ep) + } + byServerDisco, ok := r.handshakeWorkByEndpointByServerDisco[ep] + if ok { + for disco, handshakeWork := range byServerDisco { + _, knownServer := r.serversByDisco[disco] + if knownServer || f == stopHandshakeWorkAllServers { + handshakeWork.cancel() + handshakeWork.wg.Wait() + delete(byServerDisco, disco) + delete(r.handshakeWorkByServerDiscoVNI, serverDiscoVNI{handshakeWork.se.ServerDisco, handshakeWork.se.VNI}) + } + } + if len(byServerDisco) == 0 { + delete(r.handshakeWorkByEndpointByServerDisco, ep) + } + } +} + +func (r *relayManager) handleRxChallengeRunLoop(challenge relayHandshakeChallengeEvent) { + work, ok := r.handshakeWorkByServerDiscoVNI[serverDiscoVNI{challenge.disco, challenge.vni}] + if !ok { + return + } + select { + case <-work.doneCh: + return + case work.rxChallengeCh <- challenge: + return + } +} + +func (r *relayManager) handleHandshakeWorkDoneRunLoop(done relayEndpointHandshakeWorkDoneEvent) { + byServerDisco, ok := r.handshakeWorkByEndpointByServerDisco[done.work.ep] + if !ok { + return + } + work, ok := byServerDisco[done.work.se.ServerDisco] + if !ok || work != done.work { + return + } + delete(byServerDisco, done.work.se.ServerDisco) + if len(byServerDisco) == 0 { + delete(r.handshakeWorkByEndpointByServerDisco, done.work.ep) + } + delete(r.handshakeWorkByServerDiscoVNI, serverDiscoVNI{done.work.se.ServerDisco, done.work.se.VNI}) + if !done.answerSentTo.IsValid() { + // The handshake timed out. + return + } + // We received a challenge from and transmitted an answer towards the relay + // server. + // TODO(jwhited): Make the associated [*endpoint] aware of this + // [tailscale.com/net/udprelay.ServerEndpoint]. +} + +func (r *relayManager) handleNewServerEndpointRunLoop(newServerEndpoint newRelayServerEndpointEvent) { + // Check for duplicate work by server disco + VNI. + sdv := serverDiscoVNI{newServerEndpoint.se.ServerDisco, newServerEndpoint.se.VNI} + existingWork, ok := r.handshakeWorkByServerDiscoVNI[sdv] + if ok { + // There's in-progress handshake work for the server disco + VNI, which + // uniquely identify a [udprelay.ServerEndpoint]. Compare Lamport + // IDs to determine which is newer. + if existingWork.se.LamportID >= newServerEndpoint.se.LamportID { + // The existing work is a duplicate or newer. Return early. + return + } + + // The existing work is no longer valid, clean it up. Be sure to lookup + // by the existing work's [*endpoint], not the incoming "new" work as + // they are not necessarily matching. + existingWork.cancel() + existingWork.wg.Wait() + delete(r.handshakeWorkByServerDiscoVNI, sdv) + byServerDisco, ok := r.handshakeWorkByEndpointByServerDisco[existingWork.ep] + if ok { + delete(byServerDisco, sdv.serverDisco) + if len(byServerDisco) == 0 { + delete(r.handshakeWorkByEndpointByServerDisco, existingWork.ep) + } + } + } + + // Check for duplicate work by [*endpoint] + server disco. + byServerDisco, ok := r.handshakeWorkByEndpointByServerDisco[newServerEndpoint.ep] + if ok { + existingWork, ok := byServerDisco[newServerEndpoint.se.ServerDisco] + if ok { + if newServerEndpoint.se.LamportID <= existingWork.se.LamportID { + // The "new" server endpoint is outdated or duplicate in + // consideration against existing handshake work. Return early. + return + } + // Cancel existing handshake that has a lower lamport ID. + existingWork.cancel() + existingWork.wg.Wait() + delete(r.handshakeWorkByServerDiscoVNI, sdv) + delete(byServerDisco, sdv.serverDisco) + if len(byServerDisco) == 0 { + delete(r.handshakeWorkByEndpointByServerDisco, existingWork.ep) + } + } + } + + // We're now reasonably sure we're dealing with the latest + // [udprelay.ServerEndpoint] from a server event order perspective + // (LamportID). Update server disco key tracking if appropriate. + if newServerEndpoint.server.IsValid() { + serverDisco, ok := r.serversByAddrPort[newServerEndpoint.server] + if !ok { + // Allocation raced with an update to our known servers set. This + // server is no longer known. Return early. + return + } + if serverDisco.Compare(newServerEndpoint.se.ServerDisco) != 0 { + // The server's disco key has either changed, or simply become + // known for the first time. In the former case we end up detaching + // any in-progress handshake work from a "known" relay server. + // Practically speaking we expect the detached work to fail + // if the server key did in fact change (server restart) while we + // were attempting to handshake with it. It is possible, though + // unlikely, for a server addr:port to effectively move between + // nodes. Either way, there is no harm in detaching existing work, + // and we explicitly let that happen for the rare case the detached + // handshake would complete and remain functional. + delete(r.serversByDisco, serverDisco) + delete(r.serversByAddrPort, newServerEndpoint.server) + r.serversByDisco[serverDisco] = newServerEndpoint.server + r.serversByAddrPort[newServerEndpoint.server] = serverDisco + } + } + + // We're ready to start a new handshake. + ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} + work := &relayHandshakeWork{ + ep: newServerEndpoint.ep, + se: newServerEndpoint.se, + doneCh: make(chan struct{}), + ctx: ctx, + cancel: cancel, + wg: wg, + } + if byServerDisco == nil { + byServerDisco = make(map[key.DiscoPublic]*relayHandshakeWork) + r.handshakeWorkByEndpointByServerDisco[newServerEndpoint.ep] = byServerDisco + } + byServerDisco[newServerEndpoint.se.ServerDisco] = work + r.handshakeWorkByServerDiscoVNI[sdv] = work + + wg.Add(1) + go r.handshakeServerEndpoint(work) +} + +func (r *relayManager) handshakeServerEndpoint(work *relayHandshakeWork) { + defer work.wg.Done() + + done := relayEndpointHandshakeWorkDoneEvent{work: work} + r.ensureDiscoInfoFor(work) + + defer func() { + r.derefDiscoInfoFor(work) + close(work.doneCh) + relayManagerInputEvent(r, work.ctx, &r.handshakeWorkDoneCh, done) + work.cancel() + }() + + sentBindAny := false + bind := &disco.BindUDPRelayEndpoint{} + vni := virtualNetworkID{} + vni.set(work.se.VNI) + for _, addrPort := range work.se.AddrPorts { + if addrPort.IsValid() { + sentBindAny = true + go work.ep.c.sendDiscoMessage(addrPort, vni, key.NodePublic{}, work.se.ServerDisco, bind, discoLog) + } + } + if !sentBindAny { + return + } + + // Limit goroutine lifetime to a reasonable duration. This is intentionally + // detached and independent of 'BindLifetime' to prevent relay server + // (mis)configuration from negatively impacting client resource usage. + const maxHandshakeLifetime = time.Second * 30 + timer := time.NewTimer(min(work.se.BindLifetime.Duration, maxHandshakeLifetime)) + defer timer.Stop() + + // Wait for cancellation, a challenge to be rx'd, or handshake lifetime to + // expire. Our initial implementation values simplicity over other aspects, + // e.g. it is not resilient to any packet loss. + // + // We may want to eventually consider [disc.BindUDPRelayEndpoint] + // retransmission lacking challenge rx, and + // [disco.BindUDPRelayEndpointAnswer] duplication in front of + // [disco.Ping] until [disco.Ping] or [disco.Pong] is received. + select { + case <-work.ctx.Done(): + return + case challenge := <-work.rxChallengeCh: + answer := &disco.BindUDPRelayEndpointAnswer{Answer: challenge.challenge} + done.answerSentTo = challenge.from + // Send answer back to relay server. Typically sendDiscoMessage() calls + // are invoked via a new goroutine in attempt to limit crypto+syscall + // time contributing to system backpressure, and to fire roundtrip + // latency-relevant messages as closely together as possible. We + // intentionally don't do that here, because: + // 1. The primary backpressure concern is around the work.rxChallengeCh + // writer on the [Conn] packet rx path, who is already unblocked + // since we read from the channel. Relay servers only ever tx one + // challenge per rx'd bind message for a given (the first seen) src. + // 2. runLoop() may be waiting for this 'work' to complete if + // explicitly canceled for some reason elsewhere, but this is + // typically only around [*endpoint] and/or [Conn] shutdown. + // 3. It complicates the defer()'d [*discoInfo] deref and 'work' + // completion event order. sendDiscoMessage() assumes the related + // [*discoInfo] is still available. We also don't want the + // [*endpoint] to send a [disco.Ping] before the + // [disco.BindUDPRelayEndpointAnswer] has gone out, otherwise the + // remote side will never see the ping, delaying/preventing the + // [udprelay.ServerEndpoint] from becoming fully operational. + // 4. This is a singular tx with no roundtrip latency measurements + // involved. + work.ep.c.sendDiscoMessage(challenge.from, vni, key.NodePublic{}, work.se.ServerDisco, answer, discoLog) + return + case <-timer.C: + // The handshake timed out. + return + } +} + +func (r *relayManager) allocateAllServersRunLoop(ep *endpoint) { + if len(r.serversByAddrPort) == 0 { + return + } + ctx, cancel := context.WithCancel(context.Background()) + started := &relayEndpointAllocWork{ep: ep, cancel: cancel, wg: &sync.WaitGroup{}} + for k := range r.serversByAddrPort { + started.wg.Add(1) + go r.allocateSingleServer(ctx, started.wg, k, ep) + } + r.allocWorkByEndpoint[ep] = started + go func() { + started.wg.Wait() + started.cancel() + relayManagerInputEvent(r, ctx, &r.allocateWorkDoneCh, relayEndpointAllocWorkDoneEvent{work: started}) + }() +} + +func (r *relayManager) allocateSingleServer(ctx context.Context, wg *sync.WaitGroup, server netip.AddrPort, ep *endpoint) { + // TODO(jwhited): introduce client metrics counters for notable failures + defer wg.Done() + var b bytes.Buffer + remoteDisco := ep.disco.Load() + if remoteDisco == nil { + return + } + type allocateRelayEndpointReq struct { + DiscoKeys []key.DiscoPublic + } + a := &allocateRelayEndpointReq{ + DiscoKeys: []key.DiscoPublic{ep.c.discoPublic, remoteDisco.key}, + } + err := json.NewEncoder(&b).Encode(a) + if err != nil { + return + } + const reqTimeout = time.Second * 10 + reqCtx, cancel := context.WithTimeout(ctx, reqTimeout) + defer cancel() + req, err := http.NewRequestWithContext(reqCtx, httpm.POST, "http://"+server.String()+"/relay/endpoint", &b) + if err != nil { + return + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return + } + var se udprelay.ServerEndpoint + err = json.NewDecoder(io.LimitReader(resp.Body, 4096)).Decode(&se) + if err != nil { + return + } + relayManagerInputEvent(r, ctx, &r.newServerEndpointCh, newRelayServerEndpointEvent{ + ep: ep, + se: se, + }) +} diff --git a/wgengine/magicsock/relaymanager_test.go b/wgengine/magicsock/relaymanager_test.go new file mode 100644 index 0000000000000..3b75db9f6e1f9 --- /dev/null +++ b/wgengine/magicsock/relaymanager_test.go @@ -0,0 +1,30 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "net/netip" + "testing" + + "tailscale.com/disco" + "tailscale.com/types/key" +) + +func TestRelayManagerInitAndIdle(t *testing.T) { + rm := relayManager{} + rm.allocateAndHandshakeAllServers(&endpoint{}) + <-rm.runLoopStoppedCh + + rm = relayManager{} + rm.stopWork(&endpoint{}) + <-rm.runLoopStoppedCh + + rm = relayManager{} + rm.handleCallMeMaybeVia(&endpoint{c: &Conn{discoPrivate: key.NewDisco()}}, &disco.CallMeMaybeVia{ServerDisco: key.NewDisco().Public()}) + <-rm.runLoopStoppedCh + + rm = relayManager{} + rm.handleBindUDPRelayEndpointChallenge(&disco.BindUDPRelayEndpointChallenge{}, &discoInfo{}, netip.AddrPort{}, 0) + <-rm.runLoopStoppedCh +} diff --git a/wgengine/netstack/gro/gro.go b/wgengine/netstack/gro/gro.go index b268534eb46c8..654d170566f0d 100644 --- a/wgengine/netstack/gro/gro.go +++ b/wgengine/netstack/gro/gro.go @@ -6,6 +6,7 @@ package gro import ( "bytes" + "github.com/tailscale/wireguard-go/tun" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" diff --git a/wgengine/netstack/link_endpoint.go b/wgengine/netstack/link_endpoint.go index 485d829a3b8e5..39da64b5503cc 100644 --- a/wgengine/netstack/link_endpoint.go +++ b/wgengine/netstack/link_endpoint.go @@ -16,19 +16,27 @@ import ( ) type queue struct { - // TODO(jwhited): evaluate performance with mu as Mutex and/or alternative - // non-channel buffer. - c chan *stack.PacketBuffer - mu sync.RWMutex // mu guards closed + // TODO(jwhited): evaluate performance with a non-channel buffer. + c chan *stack.PacketBuffer + + closeOnce sync.Once + closedCh chan struct{} + + mu sync.RWMutex closed bool } func (q *queue) Close() { + q.closeOnce.Do(func() { + close(q.closedCh) + }) + q.mu.Lock() defer q.mu.Unlock() - if !q.closed { - close(q.c) + if q.closed { + return } + close(q.c) q.closed = true } @@ -51,26 +59,27 @@ func (q *queue) ReadContext(ctx context.Context) *stack.PacketBuffer { } func (q *queue) Write(pkt *stack.PacketBuffer) tcpip.Error { - // q holds the PacketBuffer. q.mu.RLock() defer q.mu.RUnlock() if q.closed { return &tcpip.ErrClosedForSend{} } - - wrote := false select { case q.c <- pkt.IncRef(): - wrote = true - default: - // TODO(jwhited): reconsider/count + return nil + case <-q.closedCh: pkt.DecRef() + return &tcpip.ErrClosedForSend{} } +} - if wrote { - return nil +func (q *queue) Drain() int { + c := 0 + for pkt := range q.c { + pkt.DecRef() + c++ } - return &tcpip.ErrNoBufferSpace{} + return c } func (q *queue) Num() int { @@ -107,7 +116,8 @@ func newLinkEndpoint(size int, mtu uint32, linkAddr tcpip.LinkAddress, supported le := &linkEndpoint{ supportedGRO: supportedGRO, q: &queue{ - c: make(chan *stack.PacketBuffer, size), + c: make(chan *stack.PacketBuffer, size), + closedCh: make(chan struct{}), }, mtu: mtu, linkAddr: linkAddr, @@ -164,12 +174,7 @@ func (l *linkEndpoint) ReadContext(ctx context.Context) *stack.PacketBuffer { // Drain removes all outbound packets from the channel and counts them. func (l *linkEndpoint) Drain() int { - c := 0 - for pkt := l.Read(); pkt != nil; pkt = l.Read() { - pkt.DecRef() - c++ - } - return c + return l.q.Drain() } // NumQueued returns the number of packets queued for outbound. diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 3185c5d556aa9..dab692ead4aa7 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -32,13 +32,13 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" - "tailscale.com/drive" "tailscale.com/envknob" "tailscale.com/ipn/ipnlocal" "tailscale.com/metrics" "tailscale.com/net/dns" "tailscale.com/net/ipset" "tailscale.com/net/netaddr" + "tailscale.com/net/netx" "tailscale.com/net/packet" "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" @@ -51,6 +51,7 @@ import ( "tailscale.com/types/netmap" "tailscale.com/types/nettype" "tailscale.com/util/clientmetric" + "tailscale.com/util/set" "tailscale.com/version" "tailscale.com/wgengine" "tailscale.com/wgengine/filter" @@ -174,19 +175,18 @@ type Impl struct { // It can only be set before calling Start. ProcessSubnets bool - ipstack *stack.Stack - linkEP *linkEndpoint - tundev *tstun.Wrapper - e wgengine.Engine - pm *proxymap.Mapper - mc *magicsock.Conn - logf logger.Logf - dialer *tsdial.Dialer - ctx context.Context // alive until Close - ctxCancel context.CancelFunc // called on Close - lb *ipnlocal.LocalBackend // or nil - dns *dns.Manager - driveForLocal drive.FileSystemForLocal // or nil + ipstack *stack.Stack + linkEP *linkEndpoint + tundev *tstun.Wrapper + e wgengine.Engine + pm *proxymap.Mapper + mc *magicsock.Conn + logf logger.Logf + dialer *tsdial.Dialer + ctx context.Context // alive until Close + ctxCancel context.CancelFunc // called on Close + lb *ipnlocal.LocalBackend // or nil + dns *dns.Manager // loopbackPort, if non-nil, will enable Impl to loop back (dnat to // :loopbackPort) TCP & UDP flows originally @@ -202,12 +202,14 @@ type Impl struct { // updates. atomicIsLocalIPFunc syncs.AtomicValue[func(netip.Addr) bool] + atomicIsVIPServiceIPFunc syncs.AtomicValue[func(netip.Addr) bool] + // forwardDialFunc, if non-nil, is the net.Dialer.DialContext-style // function that is used to make outgoing connections when forwarding a // TCP connection to another host (e.g. in subnet router mode). // // This is currently only used in tests. - forwardDialFunc func(context.Context, string, string) (net.Conn, error) + forwardDialFunc netx.DialFunc // forwardInFlightPerClientDropped is a metric that tracks how many // in-flight TCP forward requests were dropped due to the per-client @@ -288,7 +290,7 @@ func setTCPBufSizes(ipstack *stack.Stack) error { } // Create creates and populates a new Impl. -func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn, dialer *tsdial.Dialer, dns *dns.Manager, pm *proxymap.Mapper, driveForLocal drive.FileSystemForLocal) (*Impl, error) { +func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magicsock.Conn, dialer *tsdial.Dialer, dns *dns.Manager, pm *proxymap.Mapper) (*Impl, error) { if mc == nil { return nil, errors.New("nil magicsock.Conn") } @@ -316,16 +318,24 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi if tcpipErr != nil { return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) } - if runtime.GOOS == "windows" { - // See https://github.com/tailscale/tailscale/issues/9707 - // Windows w/RACK performs poorly. ACKs do not appear to be handled in a - // timely manner, leading to spurious retransmissions and a reduced - // congestion window. - tcpRecoveryOpt := tcpip.TCPRecovery(0) - tcpipErr = ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRecoveryOpt) - if tcpipErr != nil { - return nil, fmt.Errorf("could not disable TCP RACK: %v", tcpipErr) - } + // See https://github.com/tailscale/tailscale/issues/9707 + // gVisor's RACK performs poorly. ACKs do not appear to be handled in a + // timely manner, leading to spurious retransmissions and a reduced + // congestion window. + tcpRecoveryOpt := tcpip.TCPRecovery(0) + tcpipErr = ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &tcpRecoveryOpt) + if tcpipErr != nil { + return nil, fmt.Errorf("could not disable TCP RACK: %v", tcpipErr) + } + // gVisor defaults to reno at the time of writing. We explicitly set reno + // congestion control in order to prevent unexpected changes. Netstack + // has an int overflow in sender congestion window arithmetic that is more + // prone to trigger with cubic congestion control. + // See https://github.com/google/gvisor/issues/11632 + renoOpt := tcpip.CongestionControlOption("reno") + tcpipErr = ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &renoOpt) + if tcpipErr != nil { + return nil, fmt.Errorf("could not set reno congestion control: %v", tcpipErr) } err := setTCPBufSizes(ipstack) if err != nil { @@ -382,7 +392,6 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi connsInFlightByClient: make(map[netip.Addr]int), packetsInFlight: make(map[stack.TransportEndpointID]struct{}), dns: dns, - driveForLocal: driveForLocal, } loopbackPort, ok := envknob.LookupInt("TS_DEBUG_NETSTACK_LOOPBACK_PORT") if ok && loopbackPort >= 0 && loopbackPort <= math.MaxUint16 { @@ -390,6 +399,7 @@ func Create(logf logger.Logf, tundev *tstun.Wrapper, e wgengine.Engine, mc *magi } ns.ctx, ns.ctxCancel = context.WithCancel(context.Background()) ns.atomicIsLocalIPFunc.Store(ipset.FalseContainsIPFunc()) + ns.atomicIsVIPServiceIPFunc.Store(ipset.FalseContainsIPFunc()) ns.tundev.PostFilterPacketInboundFromWireGuard = ns.injectInbound ns.tundev.PreFilterPacketOutboundToWireGuardNetstackIntercept = ns.handleLocalPackets stacksForMetrics.Store(ns, struct{}{}) @@ -404,6 +414,14 @@ func (ns *Impl) Close() error { return nil } +// SetTransportProtocolOption forwards to the underlying +// [stack.Stack.SetTransportProtocolOption]. Callers are responsible for +// ensuring that the options are valid, compatible and appropriate for their use +// case. Compatibility may change at any version. +func (ns *Impl) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option tcpip.SettableTransportProtocolOption) tcpip.Error { + return ns.ipstack.SetTransportProtocolOption(transport, option) +} + // A single process might have several netstacks running at the same time. // Exported clientmetric counters will have a sum of counters of all of them. var stacksForMetrics syncs.Map[*Impl, struct{}] @@ -535,7 +553,7 @@ func (ns *Impl) wrapTCPProtocolHandler(h protocolHandlerFunc) protocolHandlerFun // Dynamically reconfigure ns's subnet addresses as needed for // outbound traffic. - if !ns.isLocalIP(localIP) { + if !ns.isLocalIP(localIP) && !ns.isVIPServiceIP(localIP) { ns.addSubnetAddress(localIP) } @@ -623,11 +641,19 @@ var v4broadcast = netaddr.IPv4(255, 255, 255, 255) // address slice views. func (ns *Impl) UpdateNetstackIPs(nm *netmap.NetworkMap) { var selfNode tailcfg.NodeView + var serviceAddrSet set.Set[netip.Addr] if nm != nil { + vipServiceIPMap := nm.GetVIPServiceIPMap() + serviceAddrSet = make(set.Set[netip.Addr], len(vipServiceIPMap)*2) + for _, addrs := range vipServiceIPMap { + serviceAddrSet.AddSlice(addrs) + } ns.atomicIsLocalIPFunc.Store(ipset.NewContainsIPFunc(nm.GetAddresses())) + ns.atomicIsVIPServiceIPFunc.Store(serviceAddrSet.Contains) selfNode = nm.SelfNode } else { ns.atomicIsLocalIPFunc.Store(ipset.FalseContainsIPFunc()) + ns.atomicIsVIPServiceIPFunc.Store(ipset.FalseContainsIPFunc()) } oldPfx := make(map[netip.Prefix]bool) @@ -646,18 +672,21 @@ func (ns *Impl) UpdateNetstackIPs(nm *netmap.NetworkMap) { newPfx := make(map[netip.Prefix]bool) if selfNode.Valid() { - for i := range selfNode.Addresses().Len() { - p := selfNode.Addresses().At(i) + for _, p := range selfNode.Addresses().All() { newPfx[p] = true } if ns.ProcessSubnets { - for i := range selfNode.AllowedIPs().Len() { - p := selfNode.AllowedIPs().At(i) + for _, p := range selfNode.AllowedIPs().All() { newPfx[p] = true } } } + for addr := range serviceAddrSet { + p := netip.PrefixFrom(addr, addr.BitLen()) + newPfx[p] = true + } + pfxToAdd := make(map[netip.Prefix]bool) for p := range newPfx { if !oldPfx[p] { @@ -820,6 +849,27 @@ func (ns *Impl) DialContextTCP(ctx context.Context, ipp netip.AddrPort) (*gonet. return gonet.DialContextTCP(ctx, ns.ipstack, remoteAddress, ipType) } +// DialContextTCPWithBind creates a new gonet.TCPConn connected to the specified +// remoteAddress with its local address bound to localAddr on an available port. +func (ns *Impl) DialContextTCPWithBind(ctx context.Context, localAddr netip.Addr, remoteAddr netip.AddrPort) (*gonet.TCPConn, error) { + remoteAddress := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(remoteAddr.Addr().AsSlice()), + Port: remoteAddr.Port(), + } + localAddress := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(localAddr.AsSlice()), + } + var ipType tcpip.NetworkProtocolNumber + if remoteAddr.Addr().Is4() { + ipType = ipv4.ProtocolNumber + } else { + ipType = ipv6.ProtocolNumber + } + return gonet.DialTCPWithBind(ctx, ns.ipstack, localAddress, remoteAddress, ipType) +} + func (ns *Impl) DialContextUDP(ctx context.Context, ipp netip.AddrPort) (*gonet.UDPConn, error) { remoteAddress := &tcpip.FullAddress{ NIC: nicID, @@ -836,6 +886,28 @@ func (ns *Impl) DialContextUDP(ctx context.Context, ipp netip.AddrPort) (*gonet. return gonet.DialUDP(ns.ipstack, nil, remoteAddress, ipType) } +// DialContextUDPWithBind creates a new gonet.UDPConn. Connected to remoteAddr. +// With its local address bound to localAddr on an available port. +func (ns *Impl) DialContextUDPWithBind(ctx context.Context, localAddr netip.Addr, remoteAddr netip.AddrPort) (*gonet.UDPConn, error) { + remoteAddress := &tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(remoteAddr.Addr().AsSlice()), + Port: remoteAddr.Port(), + } + localAddress := &tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(localAddr.AsSlice()), + } + var ipType tcpip.NetworkProtocolNumber + if remoteAddr.Addr().Is4() { + ipType = ipv4.ProtocolNumber + } else { + ipType = ipv6.ProtocolNumber + } + + return gonet.DialUDP(ns.ipstack, localAddress, remoteAddress, ipType) +} + // getInjectInboundBuffsSizes returns packet memory and a sizes slice for usage // when calling tstun.Wrapper.InjectInboundPacketBuffer(). These are sized with // consideration for MTU and GSO support on ns.linkEP. They should be recycled @@ -957,6 +1029,12 @@ func (ns *Impl) isLocalIP(ip netip.Addr) bool { return ns.atomicIsLocalIPFunc.Load()(ip) } +// isVIPServiceIP reports whether ip is an IP address that's +// assigned to a VIP service. +func (ns *Impl) isVIPServiceIP(ip netip.Addr) bool { + return ns.atomicIsVIPServiceIPFunc.Load()(ip) +} + func (ns *Impl) peerAPIPortAtomic(ip netip.Addr) *atomic.Uint32 { if ip.Is4() { return &ns.peerapiPort4Atomic @@ -973,6 +1051,7 @@ func (ns *Impl) shouldProcessInbound(p *packet.Parsed, t *tstun.Wrapper) bool { // Handle incoming peerapi connections in netstack. dstIP := p.Dst.Addr() isLocal := ns.isLocalIP(dstIP) + isService := ns.isVIPServiceIP(dstIP) // Handle TCP connection to the Tailscale IP(s) in some cases: if ns.lb != nil && p.IPProto == ipproto.TCP && isLocal { @@ -995,6 +1074,19 @@ func (ns *Impl) shouldProcessInbound(p *packet.Parsed, t *tstun.Wrapper) bool { return true } } + if isService { + if p.IsEchoRequest() { + return true + } + if ns.lb != nil && p.IPProto == ipproto.TCP { + // An assumption holds for this to work: when tun mode is on for a service, + // its tcp and web are not set. This is enforced in b.setServeConfigLocked. + if ns.lb.ShouldInterceptVIPServiceTCPPort(p.Dst) { + return true + } + } + return false + } if p.IPVersion == 6 && !isLocal && viaRange.Contains(dstIP) { return ns.lb != nil && ns.lb.ShouldHandleViaIP(dstIP) } @@ -1371,7 +1463,7 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet. }() // Attempt to dial the outbound connection before we accept the inbound one. - var dialFunc func(context.Context, string, string) (net.Conn, error) + var dialFunc netx.DialFunc if ns.forwardDialFunc != nil { dialFunc = ns.forwardDialFunc } else { diff --git a/wgengine/netstack/netstack_test.go b/wgengine/netstack/netstack_test.go index 1bfc76fef097f..584b3babc6004 100644 --- a/wgengine/netstack/netstack_test.go +++ b/wgengine/netstack/netstack_test.go @@ -22,6 +22,7 @@ import ( "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/store/mem" "tailscale.com/metrics" + "tailscale.com/net/netx" "tailscale.com/net/packet" "tailscale.com/net/tsaddr" "tailscale.com/net/tsdial" @@ -44,13 +45,14 @@ func TestInjectInboundLeak(t *testing.T) { t.Logf(format, args...) } } - sys := new(tsd.System) + sys := tsd.NewSystem() eng, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{ Tun: tunDev, Dialer: dialer, SetSubsystem: sys.Set, HealthTracker: sys.HealthTracker(), Metrics: sys.UserMetricsRegistry(), + EventBus: sys.Bus.Get(), }) if err != nil { t.Fatal(err) @@ -64,8 +66,9 @@ func TestInjectInboundLeak(t *testing.T) { if err != nil { t.Fatal(err) } + t.Cleanup(lb.Shutdown) - ns, err := Create(logf, tunWrap, eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil) + ns, err := Create(logf, tunWrap, eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper()) if err != nil { t.Fatal(err) } @@ -99,7 +102,7 @@ func getMemStats() (ms runtime.MemStats) { func makeNetstack(tb testing.TB, config func(*Impl)) *Impl { tunDev := tstun.NewFake() - sys := &tsd.System{} + sys := tsd.NewSystem() sys.Set(new(mem.Store)) dialer := new(tsdial.Dialer) logf := tstest.WhileTestRunningLogger(tb) @@ -109,6 +112,7 @@ func makeNetstack(tb testing.TB, config func(*Impl)) *Impl { SetSubsystem: sys.Set, HealthTracker: sys.HealthTracker(), Metrics: sys.UserMetricsRegistry(), + EventBus: sys.Bus.Get(), }) if err != nil { tb.Fatal(err) @@ -116,7 +120,7 @@ func makeNetstack(tb testing.TB, config func(*Impl)) *Impl { tb.Cleanup(func() { eng.Close() }) sys.Set(eng) - ns, err := Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper(), nil) + ns, err := Create(logf, sys.Tun.Get(), eng, sys.MagicSock.Get(), dialer, sys.DNSManager.Get(), sys.ProxyMapper()) if err != nil { tb.Fatal(err) } @@ -126,6 +130,7 @@ func makeNetstack(tb testing.TB, config func(*Impl)) *Impl { if err != nil { tb.Fatalf("NewLocalBackend: %v", err) } + tb.Cleanup(lb.Shutdown) ns.atomicIsLocalIPFunc.Store(func(netip.Addr) bool { return true }) if config != nil { @@ -510,7 +515,7 @@ func tcp4syn(tb testing.TB, src, dst netip.Addr, sport, dport uint16) []byte { // makeHangDialer returns a dialer that notifies the returned channel when a // connection is dialed and then hangs until the test finishes. -func makeHangDialer(tb testing.TB) (func(context.Context, string, string) (net.Conn, error), chan struct{}) { +func makeHangDialer(tb testing.TB) (netx.DialFunc, chan struct{}) { done := make(chan struct{}) tb.Cleanup(func() { close(done) diff --git a/wgengine/pendopen.go b/wgengine/pendopen.go index 340c7e0f3f7be..28d1f4f9d59e4 100644 --- a/wgengine/pendopen.go +++ b/wgengine/pendopen.go @@ -8,6 +8,7 @@ import ( "net/netip" "runtime" "strings" + "sync" "time" "github.com/gaissmai/bart" @@ -15,7 +16,6 @@ import ( "tailscale.com/net/packet" "tailscale.com/net/tstun" "tailscale.com/types/ipproto" - "tailscale.com/types/lazy" "tailscale.com/util/mak" "tailscale.com/wgengine/filter" ) @@ -91,7 +91,7 @@ func (e *userspaceEngine) trackOpenPreFilterIn(pp *packet.Parsed, t *tstun.Wrapp var ( appleIPRange = netip.MustParsePrefix("17.0.0.0/8") - canonicalIPs = lazy.SyncFunc(func() (checkIPFunc func(netip.Addr) bool) { + canonicalIPs = sync.OnceValue(func() (checkIPFunc func(netip.Addr) bool) { // https://bgp.he.net/AS41231#_prefixes t := &bart.Table[bool]{} for _, s := range strings.Fields(` @@ -198,7 +198,7 @@ func (e *userspaceEngine) onOpenTimeout(flow flowtrack.Tuple) { e.logf("open-conn-track: timeout opening %v; peer node %v running pre-0.100", flow, n.Key().ShortString()) return } - if n.DERP() == "" { + if n.HomeDERP() == 0 { e.logf("open-conn-track: timeout opening %v; peer node %v not connected to any DERP relay", flow, n.Key().ShortString()) return } @@ -207,8 +207,7 @@ func (e *userspaceEngine) onOpenTimeout(flow flowtrack.Tuple) { ps, found := e.getPeerStatusLite(n.Key()) if !found { onlyZeroRoute := true // whether peerForIP returned n only because its /0 route matched - for i := range n.AllowedIPs().Len() { - r := n.AllowedIPs().At(i) + for _, r := range n.AllowedIPs().All() { if r.Bits() != 0 && r.Contains(flow.DstAddr()) { onlyZeroRoute = false break @@ -240,15 +239,15 @@ func (e *userspaceEngine) onOpenTimeout(flow flowtrack.Tuple) { if n.IsWireGuardOnly() { online = "wg" } else { - if v := n.Online(); v != nil { - if *v { + if v, ok := n.Online().GetOk(); ok { + if v { online = "yes" } else { online = "no" } } - if n.LastSeen() != nil && online != "yes" { - online += fmt.Sprintf(", lastseen=%v", durFmt(*n.LastSeen())) + if lastSeen, ok := n.LastSeen().GetOk(); ok && online != "yes" { + online += fmt.Sprintf(", lastseen=%v", durFmt(lastSeen)) } } e.logf("open-conn-track: timeout opening %v to node %v; online=%v, lastRecv=%v", diff --git a/wgengine/router/router_android.go b/wgengine/router/router_android.go new file mode 100644 index 0000000000000..deeccda4a7028 --- /dev/null +++ b/wgengine/router/router_android.go @@ -0,0 +1,29 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build android + +package router + +import ( + "github.com/tailscale/wireguard-go/tun" + "tailscale.com/health" + "tailscale.com/net/netmon" + "tailscale.com/types/logger" +) + +func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Monitor, health *health.Tracker) (Router, error) { + // Note, this codepath is _not_ used when building the android app + // from github.com/tailscale/tailscale-android. The android app + // constructs its own wgengine with a custom router implementation + // that plugs into Android networking APIs. + // + // In practice, the only place this fake router gets used is when + // you build a tsnet app for android, in which case we don't want + // to touch the OS network stack and a no-op router is correct. + return NewFake(logf), nil +} + +func cleanUp(logf logger.Logf, interfaceName string) { + // Nothing to do here. +} diff --git a/wgengine/router/router_default.go b/wgengine/router/router_default.go index 1e675d1fc4d42..8dcbd36d0a7a2 100644 --- a/wgengine/router/router_default.go +++ b/wgengine/router/router_default.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause -//go:build !windows && !linux && !darwin && !openbsd && !freebsd +//go:build !windows && !linux && !darwin && !openbsd && !freebsd && !plan9 package router diff --git a/wgengine/router/router_linux.go b/wgengine/router/router_linux.go index 2af73e26d2f28..adc54c88dad1c 100644 --- a/wgengine/router/router_linux.go +++ b/wgengine/router/router_linux.go @@ -1,6 +1,8 @@ // Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause +//go:build !android + package router import ( @@ -32,6 +34,8 @@ import ( "tailscale.com/version/distro" ) +var getDistroFunc = distro.Get + const ( netfilterOff = preftype.NetfilterOff netfilterNoDivert = preftype.NetfilterNoDivert @@ -222,7 +226,7 @@ func busyboxParseVersion(output string) (major, minor, patch int, err error) { } func useAmbientCaps() bool { - if distro.Get() != distro.Synology { + if getDistroFunc() != distro.Synology { return false } return distro.DSMVersion() >= 7 @@ -438,7 +442,7 @@ func (r *linuxRouter) Set(cfg *Config) error { // Issue 11405: enable IP forwarding on gokrazy. advertisingRoutes := len(cfg.SubnetRoutes) > 0 - if distro.Get() == distro.Gokrazy && advertisingRoutes { + if getDistroFunc() == distro.Gokrazy && advertisingRoutes { r.enableIPForwarding() } @@ -1181,7 +1185,9 @@ var ( tailscaleRouteTable = newRouteTable("tailscale", 52) ) -// ipRules are the policy routing rules that Tailscale uses. +// baseIPRules are the policy routing rules that Tailscale uses, when not +// running on a UBNT device. +// // The priority is the value represented here added to r.ipPolicyPrefBase, // which is usually 5200. // @@ -1196,7 +1202,7 @@ var ( // and 'ip rule' implementations (including busybox), don't support // checking for the lack of a fwmark, only the presence. The technique // below works even on very old kernels. -var ipRules = []netlink.Rule{ +var baseIPRules = []netlink.Rule{ // Packets from us, tagged with our fwmark, first try the kernel's // main routing table. { @@ -1232,6 +1238,34 @@ var ipRules = []netlink.Rule{ // usual rules (pref 32766 and 32767, ie. main and default). } +// ubntIPRules are the policy routing rules that Tailscale uses, when running +// on a UBNT device. +// +// The priority is the value represented here added to +// r.ipPolicyPrefBase, which is usually 5200. +// +// This represents an experiment that will be used to gather more information. +// If this goes well, Tailscale may opt to use this for all of Linux. +var ubntIPRules = []netlink.Rule{ + // non-fwmark packets fall through to the usual rules (pref 32766 and 32767, + // ie. main and default). + { + Priority: 70, + Invert: true, + Mark: linuxfw.TailscaleBypassMarkNum, + Table: tailscaleRouteTable.Num, + }, +} + +// ipRules returns the appropriate list of ip rules to be used by Tailscale. See +// comments on baseIPRules and ubntIPRules for more details. +func ipRules() []netlink.Rule { + if getDistroFunc() == distro.UBNT { + return ubntIPRules + } + return baseIPRules +} + // justAddIPRules adds policy routing rule without deleting any first. func (r *linuxRouter) justAddIPRules() error { if !r.ipRuleAvailable { @@ -1243,7 +1277,7 @@ func (r *linuxRouter) justAddIPRules() error { var errAcc error for _, family := range r.addrFamilies() { - for _, ru := range ipRules { + for _, ru := range ipRules() { // Note: r is a value type here; safe to mutate it. ru.Family = family.netlinkInt() if ru.Mark != 0 { @@ -1272,7 +1306,7 @@ func (r *linuxRouter) addIPRulesWithIPCommand() error { rg := newRunGroup(nil, r.cmd) for _, family := range r.addrFamilies() { - for _, rule := range ipRules { + for _, rule := range ipRules() { args := []string{ "ip", family.dashArg(), "rule", "add", @@ -1320,7 +1354,7 @@ func (r *linuxRouter) delIPRules() error { } var errAcc error for _, family := range r.addrFamilies() { - for _, ru := range ipRules { + for _, ru := range ipRules() { // Note: r is a value type here; safe to mutate it. // When deleting rules, we want to be a bit specific (mention which // table we were routing to) but not *too* specific (fwmarks, etc). @@ -1363,7 +1397,7 @@ func (r *linuxRouter) delIPRulesWithIPCommand() error { // That leaves us some flexibility to change these values in later // versions without having ongoing hacks for every possible // combination. - for _, rule := range ipRules { + for _, rule := range ipRules() { args := []string{ "ip", family.dashArg(), "rule", "del", @@ -1500,7 +1534,7 @@ func normalizeCIDR(cidr netip.Prefix) string { // platformCanNetfilter reports whether the current distro/environment supports // running iptables/nftables commands. func platformCanNetfilter() bool { - switch distro.Get() { + switch getDistroFunc() { case distro.Synology: // Synology doesn't support iptables or nftables. Attempting to run it // just blocks for a long time while it logs about failures. @@ -1526,7 +1560,7 @@ func cleanUp(logf logger.Logf, interfaceName string) { // of the config file being present as well as a policy rule with a specific // priority (2000 + 1 - first interface mwan3 manages) and non-zero mark. func checkOpenWRTUsingMWAN3() (bool, error) { - if distro.Get() != distro.OpenWrt { + if getDistroFunc() != distro.OpenWrt { return false, nil } diff --git a/wgengine/router/router_linux_test.go b/wgengine/router/router_linux_test.go index dce69550d909a..a289fb0ac4aae 100644 --- a/wgengine/router/router_linux_test.go +++ b/wgengine/router/router_linux_test.go @@ -27,7 +27,9 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tstest" "tailscale.com/types/logger" + "tailscale.com/util/eventbus" "tailscale.com/util/linuxfw" + "tailscale.com/version/distro" ) func TestRouterStates(t *testing.T) { @@ -362,7 +364,9 @@ ip route add throw 192.168.0.0/24 table 52` + basic, }, } - mon, err := netmon.New(logger.Discard) + bus := eventbus.New() + defer bus.Close() + mon, err := netmon.New(bus, logger.Discard) if err != nil { t.Fatal(err) } @@ -553,6 +557,14 @@ func (n *fakeIPTablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error { return errors.New("not implemented") } +func (n *fakeIPTablesRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + return errors.New("not implemented") +} + +func (n *fakeIPTablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { + return errors.New("not implemented") +} + func (n *fakeIPTablesRunner) addBase4(tunname string) error { curIPT := n.ipt4 newRules := []struct{ chain, rule string }{ @@ -972,7 +984,10 @@ func newLinuxRootTest(t *testing.T) *linuxTest { logf := lt.logOutput.Logf - mon, err := netmon.New(logger.Discard) + bus := eventbus.New() + defer bus.Close() + + mon, err := netmon.New(bus, logger.Discard) if err != nil { lt.Close() t.Fatal(err) @@ -1231,3 +1246,24 @@ func adjustFwmask(t *testing.T, s string) string { return fwmaskAdjustRe.ReplaceAllString(s, "$1") } + +func TestIPRulesForUBNT(t *testing.T) { + // Override the global getDistroFunc + getDistroFunc = func() distro.Distro { + return distro.UBNT + } + defer func() { getDistroFunc = distro.Get }() // Restore original after the test + + expected := ubntIPRules + actual := ipRules() + + if len(expected) != len(actual) { + t.Fatalf("Expected %d rules, got %d", len(expected), len(actual)) + } + + for i, rule := range expected { + if rule != actual[i] { + t.Errorf("Rule mismatch at index %d: expected %+v, got %+v", i, rule, actual[i]) + } + } +} diff --git a/wgengine/router/router_plan9.go b/wgengine/router/router_plan9.go new file mode 100644 index 0000000000000..7ed7686d9e33f --- /dev/null +++ b/wgengine/router/router_plan9.go @@ -0,0 +1,156 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package router + +import ( + "bufio" + "bytes" + "fmt" + "net/netip" + "os" + "strings" + + "github.com/tailscale/wireguard-go/tun" + "tailscale.com/health" + "tailscale.com/net/netmon" + "tailscale.com/types/logger" +) + +func newUserspaceRouter(logf logger.Logf, tundev tun.Device, netMon *netmon.Monitor, health *health.Tracker) (Router, error) { + r := &plan9Router{ + logf: logf, + tundev: tundev, + netMon: netMon, + } + cleanAllTailscaleRoutes(logf) + return r, nil +} + +type plan9Router struct { + logf logger.Logf + tundev tun.Device + netMon *netmon.Monitor + health *health.Tracker +} + +func (r *plan9Router) Up() error { + return nil +} + +func (r *plan9Router) Set(cfg *Config) error { + if cfg == nil { + cleanAllTailscaleRoutes(r.logf) + return nil + } + + var self4, self6 netip.Addr + for _, addr := range cfg.LocalAddrs { + ctl := r.tundev.File() + maskBits := addr.Bits() + if addr.Addr().Is4() { + // The mask sizes in Plan9 are in IPv6 bits, even for IPv4. + maskBits += (128 - 32) + self4 = addr.Addr() + } + if addr.Addr().Is6() { + self6 = addr.Addr() + } + _, err := fmt.Fprintf(ctl, "add %s /%d\n", addr.Addr().String(), maskBits) + r.logf("route/plan9: add %s /%d = %v", addr.Addr().String(), maskBits, err) + } + + ipr, err := os.OpenFile("/net/iproute", os.O_RDWR, 0) + if err != nil { + return fmt.Errorf("open /net/iproute: %w", err) + } + defer ipr.Close() + + // TODO(bradfitz): read existing routes, delete ones tagged "tail" + // that aren't in cfg.LocalRoutes. + + if _, err := fmt.Fprintf(ipr, "tag tail\n"); err != nil { + return fmt.Errorf("tag tail: %w", err) + } + + for _, route := range cfg.Routes { + maskBits := route.Bits() + if route.Addr().Is4() { + // The mask sizes in Plan9 are in IPv6 bits, even for IPv4. + maskBits += (128 - 32) + } + var nextHop netip.Addr + if route.Addr().Is4() { + nextHop = self4 + } else if route.Addr().Is6() { + nextHop = self6 + } + if !nextHop.IsValid() { + r.logf("route/plan9: skipping route %s: no next hop (no self addr)", route.String()) + continue + } + r.logf("route/plan9: plan9.router: add %s /%d %s", route.Addr(), maskBits, nextHop) + if _, err := fmt.Fprintf(ipr, "add %s /%d %s\n", route.Addr(), maskBits, nextHop); err != nil { + return fmt.Errorf("add %s: %w", route.String(), err) + } + } + + if len(cfg.LocalRoutes) > 0 { + r.logf("route/plan9: TODO: Set LocalRoutes %v", cfg.LocalRoutes) + } + if len(cfg.SubnetRoutes) > 0 { + r.logf("route/plan9: TODO: Set SubnetRoutes %v", cfg.SubnetRoutes) + } + + return nil +} + +// UpdateMagicsockPort implements the Router interface. This implementation +// does nothing and returns nil because this router does not currently need +// to know what the magicsock UDP port is. +func (r *plan9Router) UpdateMagicsockPort(_ uint16, _ string) error { + return nil +} + +func (r *plan9Router) Close() error { + // TODO(bradfitz): unbind + return nil +} + +func cleanUp(logf logger.Logf, _ string) { + cleanAllTailscaleRoutes(logf) +} + +func cleanAllTailscaleRoutes(logf logger.Logf) { + routes, err := os.OpenFile("/net/iproute", os.O_RDWR, 0) + if err != nil { + logf("cleaning routes: %v", err) + return + } + defer routes.Close() + + // Using io.ReadAll or os.ReadFile on /net/iproute fails; it results in a + // 511 byte result when the actual /net/iproute contents are over 1k. + // So do it in one big read instead. Who knows. + routeBuf := make([]byte, 1<<20) + n, err := routes.Read(routeBuf) + if err != nil { + logf("cleaning routes: %v", err) + return + } + routeBuf = routeBuf[:n] + + bs := bufio.NewScanner(bytes.NewReader(routeBuf)) + for bs.Scan() { + f := strings.Fields(bs.Text()) + if len(f) < 6 { + continue + } + tag := f[4] + if tag != "tail" { + continue + } + _, err := fmt.Fprintf(routes, "remove %s %s\n", f[0], f[1]) + logf("router: cleaning route %s %s: %v", f[0], f[1], err) + } +} diff --git a/wgengine/userspace.go b/wgengine/userspace.go index fc204736a1da2..b1b82032b2a6c 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -26,6 +26,7 @@ import ( "tailscale.com/health" "tailscale.com/ipn/ipnstate" "tailscale.com/net/dns" + "tailscale.com/net/dns/resolver" "tailscale.com/net/flowtrack" "tailscale.com/net/ipset" "tailscale.com/net/netmon" @@ -46,12 +47,12 @@ import ( "tailscale.com/types/views" "tailscale.com/util/clientmetric" "tailscale.com/util/deephash" + "tailscale.com/util/eventbus" "tailscale.com/util/mak" "tailscale.com/util/set" "tailscale.com/util/testenv" "tailscale.com/util/usermetric" "tailscale.com/version" - "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/magicsock" "tailscale.com/wgengine/netlog" @@ -90,14 +91,19 @@ const statusPollInterval = 1 * time.Minute const networkLoggerUploadTimeout = 5 * time.Second type userspaceEngine struct { + // eventBus will eventually become required, but for now may be nil. + // TODO(creachadair): Enforce that this is non-nil at construction. + eventBus *eventbus.Bus + logf logger.Logf - wgLogger *wglog.Logger //a wireguard-go logging wrapper + wgLogger *wglog.Logger // a wireguard-go logging wrapper reqCh chan struct{} waitCh chan struct{} // chan is closed when first Close call completes; contrast with closing bool timeNow func() mono.Time tundev *tstun.Wrapper wgdev *device.Device router router.Router + dialer *tsdial.Dialer confListenPort uint16 // original conf.ListenPort dns *dns.Manager magicConn *magicsock.Conn @@ -228,6 +234,13 @@ type Config struct { // DriveForLocal, if populated, will cause the engine to expose a Taildrive // listener at 100.100.100.100:8080. DriveForLocal drive.FileSystemForLocal + + // EventBus, if non-nil, is used for event publication and subscription by + // the Engine and its subsystems. + // + // TODO(creachadair): As of 2025-03-19 this is optional, but is intended to + // become required non-nil. + EventBus *eventbus.Bus } // NewFakeUserspaceEngine returns a new userspace engine for testing. @@ -256,6 +269,8 @@ func NewFakeUserspaceEngine(logf logger.Logf, opts ...any) (Engine, error) { conf.HealthTracker = v case *usermetric.Registry: conf.Metrics = v + case *eventbus.Bus: + conf.EventBus = v default: return nil, fmt.Errorf("unknown option type %T", v) } @@ -324,12 +339,14 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) } e := &userspaceEngine{ + eventBus: conf.EventBus, timeNow: mono.Now, logf: logf, reqCh: make(chan struct{}, 1), waitCh: make(chan struct{}), tundev: tsTUNDev, router: rtr, + dialer: conf.Dialer, confListenPort: conf.ListenPort, birdClient: conf.BIRDClient, controlKnobs: conf.ControlKnobs, @@ -349,7 +366,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) if conf.NetMon != nil { e.netMon = conf.NetMon } else { - mon, err := netmon.New(logf) + mon, err := netmon.New(conf.EventBus, logf) if err != nil { return nil, err } @@ -390,6 +407,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) } } magicsockOpts := magicsock.Options{ + EventBus: e.eventBus, Logf: logf, Port: conf.ListenPort, EndpointsFunc: endpointsFn, @@ -570,6 +588,17 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper) return filter.Drop } } + if runtime.GOOS == "plan9" { + isLocalAddr, ok := e.isLocalAddr.LoadOk() + if ok { + if isLocalAddr(p.Dst.Addr()) { + // On Plan9's "tun" equivalent, everything goes back in and out + // the tun, even when the kernel's replying to itself. + t.InjectInboundCopy(p.Buffer()) + return filter.Drop + } + } + } return filter.Accept } @@ -852,8 +881,7 @@ func (e *userspaceEngine) updateActivityMapsLocked(trackNodes []key.NodePublic, // hasOverlap checks if there is a IPPrefix which is common amongst the two // provided slices. func hasOverlap(aips, rips views.Slice[netip.Prefix]) bool { - for i := range aips.Len() { - aip := aips.At(i) + for _, aip := range aips.All() { if views.SliceContains(rips, aip) { return true } @@ -1003,6 +1031,14 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, if err != nil { return err } + + if resolver.ShouldUseRoutes(e.controlKnobs) { + e.logf("wgengine: Reconfig: user dialer") + e.dialer.SetRoutes(routerCfg.Routes, routerCfg.LocalRoutes) + } else { + e.dialer.SetRoutes(nil, nil) + } + // Keep DNS configuration after router configuration, as some // DNS managers refuse to apply settings if the device has no // assigned address. @@ -1236,7 +1272,7 @@ func (e *userspaceEngine) linkChange(delta *netmon.ChangeDelta) { // and Apple platforms. if changed { switch runtime.GOOS { - case "linux", "android", "ios", "darwin": + case "linux", "android", "ios", "darwin", "openbsd": e.wgLock.Lock() dnsCfg := e.lastDNSConfig e.wgLock.Unlock() @@ -1329,9 +1365,9 @@ func (e *userspaceEngine) mySelfIPMatchingFamily(dst netip.Addr) (src netip.Addr if addrs.Len() == 0 { return zero, errors.New("no self address in netmap") } - for i := range addrs.Len() { - if a := addrs.At(i); a.IsSingleIP() && a.Addr().BitLen() == dst.BitLen() { - return a.Addr(), nil + for _, p := range addrs.All() { + if p.IsSingleIP() && p.Addr().BitLen() == dst.BitLen() { + return p.Addr(), nil } } return zero, errors.New("no self address in netmap matching address family") @@ -1582,6 +1618,12 @@ type fwdDNSLinkSelector struct { } func (ls fwdDNSLinkSelector) PickLink(ip netip.Addr) (linkName string) { + // sandboxed macOS does not automatically bind to the loopback interface so + // we must be explicit about it. + if runtime.GOOS == "darwin" && ip.IsLoopback() { + return "lo0" + } + if ls.ue.isDNSIPOverTailscale.Load()(ip) { return ls.tunName } @@ -1595,7 +1637,7 @@ var ( metricNumMinorChanges = clientmetric.NewCounter("wgengine_minor_changes") ) -func (e *userspaceEngine) InstallCaptureHook(cb capture.Callback) { +func (e *userspaceEngine) InstallCaptureHook(cb packet.CaptureCallback) { e.tundev.InstallCaptureHook(cb) e.magicConn.InstallCaptureHook(cb) } diff --git a/wgengine/userspace_ext_test.go b/wgengine/userspace_ext_test.go index cc29be234d4ea..5e7d1ce6a517d 100644 --- a/wgengine/userspace_ext_test.go +++ b/wgengine/userspace_ext_test.go @@ -16,13 +16,14 @@ import ( ) func TestIsNetstack(t *testing.T) { - sys := new(tsd.System) + sys := tsd.NewSystem() e, err := wgengine.NewUserspaceEngine( tstest.WhileTestRunningLogger(t), wgengine.Config{ SetSubsystem: sys.Set, HealthTracker: sys.HealthTracker(), Metrics: sys.UserMetricsRegistry(), + EventBus: sys.Bus.Get(), }, ) if err != nil { @@ -66,7 +67,7 @@ func TestIsNetstackRouter(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - sys := &tsd.System{} + sys := tsd.NewSystem() if tt.setNetstackRouter { sys.NetstackRouter.Set(true) } @@ -74,6 +75,7 @@ func TestIsNetstackRouter(t *testing.T) { conf.SetSubsystem = sys.Set conf.HealthTracker = sys.HealthTracker() conf.Metrics = sys.UserMetricsRegistry() + conf.EventBus = sys.Bus.Get() e, err := wgengine.NewUserspaceEngine(logger.Discard, conf) if err != nil { t.Fatal(err) diff --git a/wgengine/userspace_test.go b/wgengine/userspace_test.go index 0514218625a60..87a36c6734f08 100644 --- a/wgengine/userspace_test.go +++ b/wgengine/userspace_test.go @@ -25,6 +25,7 @@ import ( "tailscale.com/types/key" "tailscale.com/types/netmap" "tailscale.com/types/opt" + "tailscale.com/util/eventbus" "tailscale.com/util/usermetric" "tailscale.com/wgengine/router" "tailscale.com/wgengine/wgcfg" @@ -100,9 +101,12 @@ func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { } func TestUserspaceEngineReconfig(t *testing.T) { + bus := eventbus.New() + defer bus.Close() + ht := new(health.Tracker) reg := new(usermetric.Registry) - e, err := NewFakeUserspaceEngine(t.Logf, 0, ht, reg) + e, err := NewFakeUserspaceEngine(t.Logf, 0, ht, reg, bus) if err != nil { t.Fatal(err) } @@ -166,13 +170,16 @@ func TestUserspaceEnginePortReconfig(t *testing.T) { var knobs controlknobs.Knobs + bus := eventbus.New() + defer bus.Close() + // Keep making a wgengine until we find an unused port var ue *userspaceEngine ht := new(health.Tracker) reg := new(usermetric.Registry) for i := range 100 { attempt := uint16(defaultPort + i) - e, err := NewFakeUserspaceEngine(t.Logf, attempt, &knobs, ht, reg) + e, err := NewFakeUserspaceEngine(t.Logf, attempt, &knobs, ht, reg, bus) if err != nil { t.Fatal(err) } @@ -251,9 +258,11 @@ func TestUserspaceEnginePeerMTUReconfig(t *testing.T) { var knobs controlknobs.Knobs + bus := eventbus.New() + defer bus.Close() ht := new(health.Tracker) reg := new(usermetric.Registry) - e, err := NewFakeUserspaceEngine(t.Logf, 0, &knobs, ht, reg) + e, err := NewFakeUserspaceEngine(t.Logf, 0, &knobs, ht, reg, bus) if err != nil { t.Fatal(err) } diff --git a/wgengine/watchdog.go b/wgengine/watchdog.go index 232591f5eca60..74a1917488dd8 100644 --- a/wgengine/watchdog.go +++ b/wgengine/watchdog.go @@ -17,10 +17,10 @@ import ( "tailscale.com/envknob" "tailscale.com/ipn/ipnstate" "tailscale.com/net/dns" + "tailscale.com/net/packet" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/netmap" - "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/router" "tailscale.com/wgengine/wgcfg" @@ -162,7 +162,7 @@ func (e *watchdogEngine) Done() <-chan struct{} { return e.wrap.Done() } -func (e *watchdogEngine) InstallCaptureHook(cb capture.Callback) { +func (e *watchdogEngine) InstallCaptureHook(cb packet.CaptureCallback) { e.wrap.InstallCaptureHook(cb) } diff --git a/wgengine/watchdog_test.go b/wgengine/watchdog_test.go index b05cd421fe309..a54a0d3fa1e13 100644 --- a/wgengine/watchdog_test.go +++ b/wgengine/watchdog_test.go @@ -9,6 +9,7 @@ import ( "time" "tailscale.com/health" + "tailscale.com/util/eventbus" "tailscale.com/util/usermetric" ) @@ -24,9 +25,11 @@ func TestWatchdog(t *testing.T) { t.Run("default watchdog does not fire", func(t *testing.T) { t.Parallel() + bus := eventbus.New() + defer bus.Close() ht := new(health.Tracker) reg := new(usermetric.Registry) - e, err := NewFakeUserspaceEngine(t.Logf, 0, ht, reg) + e, err := NewFakeUserspaceEngine(t.Logf, 0, ht, reg, bus) if err != nil { t.Fatal(err) } diff --git a/wgengine/wgcfg/nmcfg/nmcfg.go b/wgengine/wgcfg/nmcfg/nmcfg.go index d156f7fcb0ef2..1add608e4496c 100644 --- a/wgengine/wgcfg/nmcfg/nmcfg.go +++ b/wgengine/wgcfg/nmcfg/nmcfg.go @@ -40,8 +40,7 @@ func cidrIsSubnet(node tailcfg.NodeView, cidr netip.Prefix) bool { if !cidr.IsSingleIP() { return true } - for i := range node.Addresses().Len() { - selfCIDR := node.Addresses().At(i) + for _, selfCIDR := range node.Addresses().All() { if cidr == selfCIDR { return false } @@ -82,11 +81,11 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, // Logging buffers skippedUnselected := new(bytes.Buffer) - skippedIPs := new(bytes.Buffer) skippedSubnets := new(bytes.Buffer) + skippedExpired := new(bytes.Buffer) for _, peer := range nm.Peers { - if peer.DiscoKey().IsZero() && peer.DERP() == "" && !peer.IsWireGuardOnly() { + if peer.DiscoKey().IsZero() && peer.HomeDERP() == 0 && !peer.IsWireGuardOnly() { // Peer predates both DERP and active discovery, we cannot // communicate with it. logf("[v1] wgcfg: skipped peer %s, doesn't offer DERP or disco", peer.Key().ShortString()) @@ -96,7 +95,16 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, // anyway, since control intentionally breaks node keys for // expired peers so that we can't discover endpoints via DERP. if peer.Expired() { - logf("[v1] wgcfg: skipped expired peer %s", peer.Key().ShortString()) + if skippedExpired.Len() >= 1<<10 { + if !bytes.HasSuffix(skippedExpired.Bytes(), []byte("...")) { + skippedExpired.WriteString("...") + } + } else { + if skippedExpired.Len() > 0 { + skippedExpired.WriteString(", ") + } + fmt.Fprintf(skippedExpired, "%s/%v", peer.StableID(), peer.Key().ShortString()) + } continue } @@ -107,11 +115,10 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, cpeer := &cfg.Peers[len(cfg.Peers)-1] didExitNodeWarn := false - cpeer.V4MasqAddr = peer.SelfNodeV4MasqAddrForThisPeer() - cpeer.V6MasqAddr = peer.SelfNodeV6MasqAddrForThisPeer() + cpeer.V4MasqAddr = peer.SelfNodeV4MasqAddrForThisPeer().Clone() + cpeer.V6MasqAddr = peer.SelfNodeV6MasqAddrForThisPeer().Clone() cpeer.IsJailed = peer.IsJailed() - for i := range peer.AllowedIPs().Len() { - allowedIP := peer.AllowedIPs().At(i) + for _, allowedIP := range peer.AllowedIPs().All() { if allowedIP.Bits() == 0 && peer.StableID() != exitNode { if didExitNodeWarn { // Don't log about both the IPv4 /0 and IPv6 /0. @@ -139,12 +146,11 @@ func WGCfg(nm *netmap.NetworkMap, logf logger.Logf, flags netmap.WGConfigFlags, if skippedUnselected.Len() > 0 { logf("[v1] wgcfg: skipped unselected default routes from: %s", skippedUnselected.Bytes()) } - if skippedIPs.Len() > 0 { - logf("[v1] wgcfg: skipped node IPs: %s", skippedIPs) - } if skippedSubnets.Len() > 0 { logf("[v1] wgcfg: did not accept subnet routes: %s", skippedSubnets) } - + if skippedExpired.Len() > 0 { + logf("[v1] wgcfg: skipped expired peer: %s", skippedExpired) + } return cfg, nil } diff --git a/wgengine/wgengine.go b/wgengine/wgengine.go index c165ccdf3c3aa..6aaf567ad01ee 100644 --- a/wgengine/wgengine.go +++ b/wgengine/wgengine.go @@ -11,10 +11,10 @@ import ( "tailscale.com/ipn/ipnstate" "tailscale.com/net/dns" + "tailscale.com/net/packet" "tailscale.com/tailcfg" "tailscale.com/types/key" "tailscale.com/types/netmap" - "tailscale.com/wgengine/capture" "tailscale.com/wgengine/filter" "tailscale.com/wgengine/router" "tailscale.com/wgengine/wgcfg" @@ -129,5 +129,5 @@ type Engine interface { // InstallCaptureHook registers a function to be called to capture // packets traversing the data path. The hook can be uninstalled by // calling this function with a nil value. - InstallCaptureHook(capture.Callback) + InstallCaptureHook(packet.CaptureCallback) } diff --git a/words/scales.txt b/words/scales.txt index f27dfc5c4aa36..532734f6dcf8a 100644 --- a/words/scales.txt +++ b/words/scales.txt @@ -391,3 +391,54 @@ godzilla sirius vector cherimoya +shilling +kettle +kitchen +fahrenheit +rankine +piano +ruler +scoville +oratrice +teeth +cliff +degree +company +economy +court +justitia +themis +carat +carob +karat +barley +corn +penny +pound +mark +pence +mine +stairs +escalator +elevator +skilift +gondola +firefighter +newton +smoot +city +truck +everest +wall +fence +fort +trench +matrix +census +likert +sidemirror +wage +salary +fujita +caiman +cichlid diff --git a/words/tails.txt b/words/tails.txt index 4975332419855..7e35c69702d5d 100644 --- a/words/tails.txt +++ b/words/tails.txt @@ -694,3 +694,31 @@ ussuri kitty tanuki neko +wind +airplane +time +gumiho +eel +moray +twin +hair +braid +gate +end +queue +miku +at +fin +solarflare +asymptote +reverse +bone +stern +quaver +note +mining +coat +follow +stalk +caudal +chronicle