Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 60 additions & 170 deletions index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,62 @@ import type { ExtensionAPI } from "@mariozechner/pi-coding-agent";
import { DynamicBorder } from "@mariozechner/pi-coding-agent";
import { Container, Key, matchesKey, Text, truncateToWidth, type SelectItem } from "@mariozechner/pi-tui";
import { spawn, type ChildProcessWithoutNullStreams } from "node:child_process";
import fs from "node:fs";
import os from "node:os";
import path from "node:path";
import net from "node:net";

const SETTINGS_PATH = path.join(os.homedir(), ".pi", "agent", "settings.json");
const WIDGET_ID = "port-forwards";
const DEFAULT_LOCAL_HOST = "127.0.0.1";
const DEFAULT_REMOTE_HOST = "127.0.0.1";

type AllowedPort = number | { from: number; to: number };

type ProviderConfig = {
id: string;
remote: string;
label: string;
allowedPorts: AllowedPort[];
localHost: string;
remoteHost: string;
localPortOffset: number;
sshOptions: string[];
import {
DEFAULT_LOCAL_HOST,
DEFAULT_REMOTE_HOST,
type AllowedPort,
type ProviderConfig,
type PortForwardConfig,
type RemotePort,
type Forward,
defaultAllowedPorts,
defaultConfig,
shellQuote,
isValidPort,
normalizeAllowedPorts,
portLabel,
isAllowedPort,
normalizeString,
normalizeSshOptions,
normalizeProvider,
normalizeCommand,
loadConfig as utilsLoadConfig,
localPortFor,
forwardKey,
localEndpoint as utilsLocalEndpoint,
remoteEndpoint as utilsRemoteEndpoint,
parsePorts as utilsParsePorts,
isLocalPortFree as utilsIsLocalPortFree,
} from "./utils";

export type { AllowedPort, ProviderConfig, PortForwardConfig, RemotePort, Forward };
export {
shellQuote,
isValidPort,
normalizeAllowedPorts,
portLabel,
isAllowedPort,
normalizeString,
normalizeSshOptions,
normalizeProvider,
normalizeCommand,
loadConfig,
localPortFor,
forwardKey,
localEndpoint,
remoteEndpoint,
parsePorts,
isLocalPortFree,
};

type PortForwardConfig = {
command: string;
providers: ProviderConfig[];
maxVisible: number;
};

type RemotePort = {
provider: ProviderConfig;
key: string;
port: number;
localPort: number;
address: string;
processName: string;
pid?: number;
raw: string;
};
const SETTINGS_PATH = path.join(os.homedir(), ".pi", "agent", "settings.json");
const WIDGET_ID = "port-forwards";

type Forward = {
type InternalForward = {
key: string;
provider: ProviderConfig;
remotePort: number;
Expand All @@ -57,143 +72,28 @@ type Forward = {
exited: boolean;
};

const defaultAllowedPorts: AllowedPort[] = [3000, { from: 8080, to: 9000 }];
const defaultConfig: PortForwardConfig = {
command: "port",
providers: [],
maxVisible: 15,
};

const forwards = new Map<string, Forward>();
const forwards = new Map<string, InternalForward>();
let latestCtx: any;
let processHooksInstalled = false;

function shellQuote(value: string): string {
return `'${value.replace(/'/g, `'"'"'`)}'`;
}

function isValidPort(port: number): boolean {
return Number.isInteger(port) && port > 0 && port <= 65535;
}

function normalizeAllowedPorts(value: unknown): AllowedPort[] {
if (!Array.isArray(value)) return defaultAllowedPorts;
const allowed = value.flatMap((entry): AllowedPort[] => {
if (typeof entry === "number" && isValidPort(entry)) return [entry];
if (
entry && typeof entry === "object" &&
isValidPort((entry as any).from) && isValidPort((entry as any).to) &&
(entry as any).to >= (entry as any).from
) return [{ from: (entry as any).from, to: (entry as any).to }];
return [];
});
return allowed.length ? allowed : defaultAllowedPorts;
}

function portLabel(allowedPorts: AllowedPort[]): string {
return allowedPorts.map((entry) => typeof entry === "number" ? String(entry) : `${entry.from}-${entry.to}`).join(", ");
}

function isAllowedPort(port: number, allowedPorts: AllowedPort[]): boolean {
return allowedPorts.some((entry) => typeof entry === "number" ? port === entry : port >= entry.from && port <= entry.to);
}

function normalizeString(value: unknown, fallback: string): string {
return typeof value === "string" && value.trim() ? value.trim() : fallback;
}

function normalizeSshOptions(value: unknown): string[] {
return Array.isArray(value) ? value.filter((option): option is string => typeof option === "string" && option.trim().length > 0).map((option) => option.trim()) : [];
}

function normalizeProvider(raw: any, index: number): ProviderConfig | undefined {
const remote = typeof raw?.remote === "string" ? raw.remote.trim() : typeof raw?.ssh === "string" ? raw.ssh.trim() : "";
if (!remote) return undefined;
const label = typeof raw?.label === "string" && raw.label.trim()
? raw.label.trim()
: typeof raw?.hostLabel === "string" && raw.hostLabel.trim()
? raw.hostLabel.trim()
: remote.replace(/^[^@]+@/, "");
const id = typeof raw?.id === "string" && raw.id.trim() ? raw.id.trim() : label || `remote-${index + 1}`;
const localPortOffset = Number.isInteger(raw?.localPortOffset) ? raw.localPortOffset : 0;
return {
id,
remote,
label,
allowedPorts: normalizeAllowedPorts(raw?.allowedPorts),
localHost: normalizeString(raw?.localHost, DEFAULT_LOCAL_HOST),
remoteHost: normalizeString(raw?.remoteHost, DEFAULT_REMOTE_HOST),
localPortOffset,
sshOptions: normalizeSshOptions(raw?.sshOptions),
};
}

function normalizeCommand(value: unknown): string {
const command = typeof value === "string" ? value.trim().replace(/^\//, "") : "";
return command && !/\s/.test(command) ? command : defaultConfig.command;
}

function loadConfig(): PortForwardConfig {
try {
const settings = JSON.parse(fs.readFileSync(SETTINGS_PATH, "utf8"));
const raw = settings?.portForward ?? settings?.piPortForward;
const providersRaw = Array.isArray(raw?.providers) ? raw.providers : Array.isArray(raw?.remotes) ? raw.remotes : [];
const seenIds = new Set<string>();
const providers = (providersRaw.map(normalizeProvider).filter(Boolean) as ProviderConfig[]).map((provider, index) => {
let id = provider.id;
if (seenIds.has(id)) id = `${id}-${index + 1}`;
seenIds.add(id);
return { ...provider, id };
});
return {
command: normalizeCommand(raw?.command),
providers,
maxVisible: Number.isInteger(raw?.maxVisible) && raw.maxVisible > 0 ? raw.maxVisible : defaultConfig.maxVisible,
};
} catch {
return defaultConfig;
}
}

function localPortFor(provider: ProviderConfig, remotePort: number): number {
return remotePort + provider.localPortOffset;
return utilsLoadConfig(SETTINGS_PATH);
}

function forwardKey(provider: ProviderConfig, port: number): string {
return `${provider.id}:${port}`;
}

function localEndpoint(forward: Pick<Forward, "localHost" | "localPort">): string {
function localEndpoint(forward: Pick<InternalForward, "localHost" | "localPort">): string {
return `${forward.localHost}:${forward.localPort}`;
}

function remoteEndpoint(provider: ProviderConfig, port: number): string {
return `${provider.label}:${provider.remoteHost}:${port}`;
return utilsRemoteEndpoint(provider, port);
}

function parsePorts(provider: ProviderConfig, stdout: string): RemotePort[] {
const byPort = new Map<number, RemotePort>();
for (const raw of stdout.split(/\r?\n/).map((line) => line.trim()).filter(Boolean)) {
const tokens = raw.split(/\s+/);
const local = tokens.find((token) => /:\d+$/.test(token.replace(/^\[|\]$/g, "")));
if (!local) continue;

const portMatch = local.match(/:(\d+)$/);
if (!portMatch) continue;
const port = Number(portMatch[1]);
if (!isValidPort(port) || !isAllowedPort(port, provider.allowedPorts)) continue;
const localPort = localPortFor(provider, port);
if (!isValidPort(localPort)) continue;

const address = local.replace(/:(\d+)$/, "").replace(/^\[|\]$/g, "");
const procMatch = raw.match(/users:\(\(\"([^\"]+)\",pid=(\d+)/);
const processName = procMatch?.[1] ?? "unknown";
const pid = procMatch?.[2] ? Number(procMatch[2]) : undefined;
const key = forwardKey(provider, port);

if (!byPort.has(port)) byPort.set(port, { provider, key, port, localPort, address, processName, pid, raw });
}
return [...byPort.values()].sort((a, b) => a.port - b.port);
return utilsParsePorts(provider, stdout);
}

async function isLocalPortFree(host: string, port: number): Promise<boolean> {
return utilsIsLocalPortFree(host, port);
}

async function getRemotePorts(pi: ExtensionAPI, provider: ProviderConfig): Promise<RemotePort[]> {
Expand All @@ -209,16 +109,6 @@ async function getRemotePorts(pi: ExtensionAPI, provider: ProviderConfig): Promi
return parsePorts(provider, result.stdout);
}

function isLocalPortFree(host: string, port: number): Promise<boolean> {
return new Promise((resolve) => {
const server = net.createServer();
server.once("error", () => resolve(false));
server.listen({ host, port }, () => {
server.close(() => resolve(true));
});
});
}

function stopForward(key: string, updateUi = true): boolean {
const forward = forwards.get(key);
if (!forward) return false;
Expand Down Expand Up @@ -274,8 +164,8 @@ async function startForward(remotePort: RemotePort): Promise<Forward> {
startedAt: Date.now(),
notifyOnExit: false,
exited: false,
};
forwards.set(remotePort.key, forward);
} as Forward;
forwards.set(remotePort.key, forward as InternalForward);

let stderr = "";
let exitCode: number | null = null;
Expand Down
Loading