diff --git a/app/appearance/langs/en_US.json b/app/appearance/langs/en_US.json index a99ef37dbd7..809491b6db6 100644 --- a/app/appearance/langs/en_US.json +++ b/app/appearance/langs/en_US.json @@ -87,7 +87,7 @@ "removeButKeepRelationField": "Remove only this field, keep bidirectional relation field", "exportPDFLowMemory": "Insufficient available memory to export this PDF, please reduce the content or increase available memory and try exporting again", "exportConf": "Export settings", - "exportConfTip": "Account, access authorization code, synchronization, API token and data repo key will not be exported", + "exportConfTip": "Account, auth settings, synchronization, API token and data repo key will not be exported", "importConf": "Import settings", "importConfTip": "After importing, the current settings will be overwritten and the application will be automatically closed, please restart manually", "jumpToPage": "Jump to the specified page: 1 ~ ${x}", @@ -1281,7 +1281,55 @@ "about3": "Please use the Chrome browser and keep it in the same network as the computer, port ${port}(In addition to the random port, the first started workspace will also automatically listen to 6806 as a fixed port, so that it is convenient for the browser to clip extensions or other external programs to call the kernel interface), the addresses that may be connected are as follows: ", "about4": "Open browser", "about5": "Access authorization code", - "about6": "After configuration, it will be used as the access authentication password, leave it blank to close the authentication", + "about6": "After configuration, it enables access-code authentication; leave empty to disable this method", + "accessAuthCodeDisableWarning": "If you want to disable access-code authentication and only keep OIDC, make sure OIDC is configured first, otherwise you may lose a fallback access method", + "accessAuthBypass": "Bypass access authentication", + "accessAuthBypassTip": "Skip all access code, OIDC checks and mandatory safety checks (not recommended)", + "accessAuthBypassConfirm": "Are you sure you want to bypass all access authentication and security checks? This will allow anyone to access without credentials.", + "accessAuthBypassAuthCodeDisabledTip": "Access authorization code is disabled because bypass is enabled.", + "accessAuthBypassOIDCDisabledTip": "OIDC settings are disabled because bypass is enabled.", + "oidc": "OIDC Login", + "oidcTip": "Configure OIDC login authentication", + "oidcEnabled": "Enable OIDC", + "oidcEnabledTip": "Keep configuration even when disabled", + "oidcProvider": "Provider", + "oidcProviderTip": "Select a provider or choose Disable to turn off OIDC", + "oidcProviderCustom": "Custom OIDC", + "oidcProviderConfig": "Provider configuration", + "oidcProviderLabel": "Button label", + "oidcProviderLabelTip": "Text shown on the login button", + "oidcClientID": "Client ID", + "oidcClientIDTip": "From your provider application settings", + "oidcClientSecret": "Client Secret", + "oidcClientSecretTip": "From your provider application settings", + "oidcPKCETip": "Use PKCE instead of client secret", + "oidcIssuerURL": "Issuer URL", + "oidcIssuerURLTip": "Issuer/Discovery URL from the provider", + "oidcRedirectURL": "Redirect URL", + "oidcRedirectURLTip": "Optional; leave empty to use the local callback", + "oidcScopes": "Scopes", + "oidcScopesTip": "Separate with commas or spaces", + "oidcTenant": "Tenant", + "oidcTenantTip": "Tenant ID for Microsoft (optional)", + "oidcClaimMap": "Claim map", + "oidcClaimMapTip": "Optional: map provider claim fields to SiYuan standard claim fields", + "oidcClaimMapInvalid": "Invalid claim map line: ${line}", + "oidcClaimMapRowInvalid": "Claim map rule is incomplete", + "oidcClaimMapAdd": "Add mapping", + "oidcClaimMapValuePlaceholder": "Provider claim key", + "oidcFilters": "Access filters", + "oidcFiltersTip": "Add rules to restrict which accounts can sign in.", + "oidcFiltersTipLine1": "1. Any rule under the same item can match", + "oidcFiltersTipLine2": "2. Different items must all match", + "oidcFiltersInvalid": "Invalid filter line: ${line}", + "oidcFiltersRowInvalid": "Filter rule is incomplete", + "oidcFilterAdd": "Add rule", + "oidcFilterClaimPlaceholder": "Claim", + "oidcFilterPatternPlaceholder": "Pattern", + "oidcFilterOpRegex": "Regex", + "oidcFilterOpRegexI": "Regex (ignore case)", + "oidcFilterOpString": "Equals (ignore case)", + "oidcFilterOpExact": "Equals", "about7": "Follow system lock screen", "about8": "When enabled, the application will be automatically locked when locking the system screen", "about11": "Network serving", @@ -1610,7 +1658,7 @@ "169": "Uploading data repo file %v/%v", "170": "Uploading data repo chunk %v/%v", "171": "Uploading data repo reference %s", - "172": "If you forget the authorization code, please find help here", + "172": "If you forget the authorization, please find help here", "173": "Please enter the access auth code", "174": "Unlock access", "175": "Please enter the verification code", @@ -1713,6 +1761,22 @@ "272": "Unnamed field", "273": "Do not create the workspace in the partition root path, please create a new folder as the workspace", "274": "This folder contains other files, please create a new folder as the workspace", - "275": "Cannot open documents created by a newer version. Please upgrade to the latest version and try again" + "275": "Cannot open documents created by a newer version. Please upgrade to the latest version and try again", + "276": "OIDC authentication failed, please start login again", + "277": "OIDC session is invalid, please start login again", + "278": "OIDC authentication was rejected by policy", + "279": "In Docker mode, you must enable at least one auth method, or enable access auth bypass", + "280": "OIDC is not enabled. Please enable it in Settings and try again", + "281": "OIDC provider [%s] failed to initialize. Please check the configuration", + "282": "Authentication Successful, you can now close this window and return to the application", + "283": "Authentication Failed, please return to SiYuan", + "284": "OIDC session has expired, please start login again", + "285": "OIDC session already handled, no further action is required", + "286": "OIDC callback parameters are missing, please start login again", + "287": "OIDC provider mismatch, please check the configuration", + "288": "Failed to save OIDC session, please retry or check storage permissions", + "289": "OIDC login failed, please retry", + "290": "OIDC login timed out, please retry", + "291": "Failed to fetch OIDC login status, please make sure your network connection is normal" } } diff --git a/app/appearance/langs/zh_CN.json b/app/appearance/langs/zh_CN.json index 61bcb8713ed..5db1c266024 100644 --- a/app/appearance/langs/zh_CN.json +++ b/app/appearance/langs/zh_CN.json @@ -87,7 +87,7 @@ "removeButKeepRelationField": "仅删除本字段,保留双向关联字段", "exportPDFLowMemory": "系统可用内存不足,无法导出该 PDF,请减少内容或者增加可用内存后再尝试导出", "exportConf": "导出设置", - "exportConfTip": "账号、访问授权码、同步、API token 和数据仓库密钥不会被导出", + "exportConfTip": "账号、认证设置、同步、API token 和数据仓库密钥不会被导出", "importConf": "导入设置", "importConfTip": "导入后会覆盖当前设置并自动关闭应用,请手动重启", "jumpToPage": "跳转到指定页:1 ~ ${x}", @@ -1281,7 +1281,55 @@ "about3": "请使用 Chrome 浏览器并保持和电脑在同一个网络内,端口 ${port}(第一个启动的工作空间除了随机端口外也会自动监听 6806 作为固定端口,以方便浏览器剪藏扩展或者其他外部程序调用内核接口),可能连通的网络地址:", "about4": "打开浏览器", "about5": "访问授权码", - "about6": "配置后作为访问鉴权密码,留空则关闭鉴权", + "about6": "配置后开启授权码访问认证方式,留空则关闭该认证方式", + "accessAuthCodeDisableWarning": "如需关闭访问授权码方式,仅保留 OIDC 认证,请先确保 OIDC 已配置完成,否则可能失去后备访问方式", + "accessAuthBypass": "绕过访问认证", + "accessAuthBypassTip": "跳过访问授权码、OIDC 等所有认证和强制安全检查(不推荐)", + "accessAuthBypassConfirm": "确定要绕过所有访问认证和安全检查吗?这将允许任何人无需凭据即可访问。", + "accessAuthBypassAuthCodeDisabledTip": "已开启绕过访问认证,访问授权码设置已禁用。", + "accessAuthBypassOIDCDisabledTip": "已开启绕过访问认证,OIDC 设置已禁用。", + "oidc": "OIDC 登录", + "oidcTip": "配置 OIDC 登录认证方式", + "oidcEnabled": "启用 OIDC", + "oidcEnabledTip": "关闭后仍保留配置", + "oidcProvider": "提供方", + "oidcProviderTip": "选择提供方,选择禁用可关闭 OIDC 登录", + "oidcProviderCustom": "自定义 OIDC", + "oidcProviderConfig": "提供方配置", + "oidcProviderLabel": "按钮名称", + "oidcProviderLabelTip": "显示在登录页按钮上的文字", + "oidcClientID": "Client ID", + "oidcClientIDTip": "提供方应用配置中的 Client ID", + "oidcClientSecret": "Client Secret", + "oidcClientSecretTip": "提供方应用配置中的 Client Secret", + "oidcPKCETip": "启用 PKCE 代替 Client Secret", + "oidcIssuerURL": "Issuer URL", + "oidcIssuerURLTip": "提供方的 Issuer/Discovery 地址", + "oidcRedirectURL": "Redirect URL", + "oidcRedirectURLTip": "可选,留空默认回调本地", + "oidcScopes": "Scopes", + "oidcScopesTip": "可用逗号或空格分隔", + "oidcTenant": "Tenant", + "oidcTenantTip": "Microsoft 的租户 ID(可选)", + "oidcClaimMap": "Claim 映射", + "oidcClaimMapTip": "可选,将提供方的 Claim 字段映射到思源内部标准 Claim 字段", + "oidcClaimMapInvalid": "Claim 映射格式错误:${line}", + "oidcClaimMapRowInvalid": "Claim 映射未填写完整", + "oidcClaimMapAdd": "添加映射", + "oidcClaimMapValuePlaceholder": "提供方 Claim 字段", + "oidcFilters": "访问过滤规则", + "oidcFiltersTip": "按需添加规则,限制可登录账户。", + "oidcFiltersTipLine1": "1. 同一项多条规则,匹配任意一条即可", + "oidcFiltersTipLine2": "2. 不同项需要同时匹配", + "oidcFiltersInvalid": "过滤规则格式错误:${line}", + "oidcFiltersRowInvalid": "过滤规则未填写完整", + "oidcFilterAdd": "添加规则", + "oidcFilterClaimPlaceholder": "Claim", + "oidcFilterPatternPlaceholder": "规则", + "oidcFilterOpRegex": "正则", + "oidcFilterOpRegexI": "正则(忽略大小写)", + "oidcFilterOpString": "相等(忽略大小写)", + "oidcFilterOpExact": "相等", "about7": "跟随系统锁屏", "about8": "启用后将会在系统锁屏时自动锁定应用", "about11": "网络伺服", @@ -1610,7 +1658,7 @@ "169": "正在上传数据仓库文件 %v/%v", "170": "正在上传数据仓库分块 %v/%v", "171": "正在上传数据仓库引用 %s", - "172": "如果你忘记了授权码,请在这里寻求帮助", + "172": "如果你忘记了认证方式,请在这里寻求帮助", "173": "请输入访问授权码", "174": "解锁访问", "175": "请输入验证码", @@ -1713,6 +1761,22 @@ "272": "未命名字段", "273": "请勿在分区根路径上创建工作空间,请新建一个文件夹作为工作空间", "274": "该文件夹包含了其他文件,请新建一个文件夹作为工作空间", - "275": "无法打开新版本创建的文档,请升级到最新版本后再试" + "275": "无法打开新版本创建的文档,请升级到最新版本后再试", + "276": "OIDC 认证失败,请重新发起登录", + "277": "OIDC 会话无效,请重新发起登录", + "278": "OIDC 认证被策略拒绝", + "279": "Docker 模式下必须至少启用一种认证方式,或开启绕过访问认证", + "280": "OIDC 未启用,请在设置中开启后重试", + "281": "OIDC 提供方 [%s] 初始化失败,请检查配置", + "282": "认证成功,您现在可以关闭此窗口并返回应用", + "283": "认证失败,请返回思源笔记", + "284": "OIDC 会话已过期,请重新发起登录", + "285": "OIDC 会话已处理,无需重复操作", + "286": "OIDC 回调参数缺失,请重新发起登录", + "287": "OIDC 提供方不匹配,请检查配置", + "288": "OIDC 会话保存失败,请重试或检查存储权限", + "289": "OIDC 登录失败,请重试", + "290": "OIDC 登录超时,请重试", + "291": "OIDC 获取登录状态失败,请确保网络连接正常" } } diff --git a/app/src/config/about.ts b/app/src/config/about.ts index f876207c489..c9ecb2d598c 100644 --- a/app/src/config/about.ts +++ b/app/src/config/about.ts @@ -5,6 +5,7 @@ import {ipcRenderer, shell} from "electron"; import {isBrowser} from "../util/functions"; import {fetchPost} from "../util/fetch"; import {setAccessAuthCode} from "./util/about"; +import {setOIDCConfig} from "./util/oidc"; import {exportLayout} from "../layout/util"; import {exitSiYuan, processSync} from "../dialog/processSystem"; import {isInAndroid, isInHarmony, isInIOS, isIPad, isMac, openByMobile, writeText} from "../protyle/util/compatibility"; @@ -64,26 +65,44 @@ export const about = {
-
-
-
- ${window.siyuan.languages.about5} -
${window.siyuan.languages.about6}
-
-
- + +
+
+ ${window.siyuan.languages.about5} +
${window.siyuan.languages.about6}
+
${window.siyuan.languages.accessAuthBypassAuthCodeDisabledTip}
- +
+
+ +
${window.siyuan.languages.about2} @@ -242,6 +261,35 @@ ${checkUpdateHTML} if (window.siyuan.config.system.isInsider) { about.element.querySelector("#isInsider").innerHTML = "Insider Preview"; } + const authCodeElement = about.element.querySelector("#authCode") as HTMLButtonElement; + const oidcSettingElement = about.element.querySelector("#oidcSetting") as HTMLButtonElement; + const lockScreenModeElement = about.element.querySelector("#lockScreenMode") as HTMLInputElement; + const setAuthControlsDisabled = (disabled: boolean) => { + authCodeElement?.toggleAttribute("disabled", disabled); + oidcSettingElement?.toggleAttribute("disabled", disabled); + lockScreenModeElement?.toggleAttribute("disabled", disabled); + }; + const applyAccessAuthBypassChange = (enabled: boolean) => { + fetchPost("/api/system/setAccessAuthBypass", {accessAuthBypass: enabled}, () => { + window.siyuan.config.accessAuthBypass = enabled; + setAuthControlsDisabled(enabled); + }); + }; + setAuthControlsDisabled(window.siyuan.config.accessAuthBypass); + const accessAuthBypassElement = about.element.querySelector("#accessAuthBypass") as HTMLInputElement; + if (accessAuthBypassElement) { + accessAuthBypassElement.addEventListener("change", () => { + if (accessAuthBypassElement.checked) { + accessAuthBypassElement.checked = false; + confirmDialog("⚠️ " + window.siyuan.languages.accessAuthBypass, window.siyuan.languages.accessAuthBypassConfirm, () => { + accessAuthBypassElement.checked = true; + applyAccessAuthBypassChange(true); + }); + } else { + applyAccessAuthBypassChange(false); + } + }); + } const indexRetentionDaysElement = about.element.querySelector("#indexRetentionDays") as HTMLInputElement; indexRetentionDaysElement.addEventListener("change", () => { fetchPost("/api/repo/setRepoIndexRetentionDays", {days: parseInt(indexRetentionDaysElement.value)}, () => { @@ -303,6 +351,9 @@ ${checkUpdateHTML} about.element.querySelector("#authCode").addEventListener("click", () => { setAccessAuthCode(); }); + about.element.querySelector("#oidcSetting").addEventListener("click", () => { + setOIDCConfig(); + }); const importKeyElement = about.element.querySelector("#importKey"); importKeyElement.addEventListener("click", () => { const passwordDialog = new Dialog({ @@ -376,8 +427,7 @@ ${checkUpdateHTML} }); }); }); - const lockScreenModeElement = about.element.querySelector("#lockScreenMode") as HTMLInputElement; - lockScreenModeElement.addEventListener("change", () => { + lockScreenModeElement?.addEventListener("change", () => { fetchPost("/api/system/setFollowSystemLockScreen", {lockScreenMode: lockScreenModeElement.checked ? 1 : 0}, () => { window.siyuan.config.system.lockScreenMode = lockScreenModeElement.checked ? 1 : 0; }); diff --git a/app/src/config/util/about.ts b/app/src/config/util/about.ts index 65ecc8a628c..cff40e513f3 100644 --- a/app/src/config/util/about.ts +++ b/app/src/config/util/about.ts @@ -9,6 +9,7 @@ export const setAccessAuthCode = () => { content: `
${window.siyuan.languages.about6}
+
${window.siyuan.languages.accessAuthCodeDisableWarning}
diff --git a/app/src/config/util/oidc.ts b/app/src/config/util/oidc.ts new file mode 100644 index 00000000000..b93f87e1909 --- /dev/null +++ b/app/src/config/util/oidc.ts @@ -0,0 +1,639 @@ +import {Dialog} from "../../dialog"; +import {fetchPost} from "../../util/fetch"; +import {isMobile} from "../../util/functions"; +import {showMessage} from "../../dialog/message"; + +const defaultProviders = ["custom", "google", "microsoft", "github"]; +const claimOptions = [ + "provider", + "subject", + "email", + "email_verified", + "preferred_username", + "name", + "issuer", + "audience", + "hosted_domain", + "tenant_id", + "groups", +]; + +const cloneProviders = (providers: Record | undefined) => { + const cloned: Record = {}; + if (!providers) { + return cloned; + } + Object.keys(providers).forEach((id) => { + const provider = providers[id]; + if (!provider) { + return; + } + cloned[id] = { + clientID: provider.clientID || "", + clientSecret: provider.clientSecret || "", + pkce: !!provider.pkce, + redirectURL: provider.redirectURL || "", + issuerURL: provider.issuerURL || "", + scopes: provider.scopes ? [...provider.scopes] : [], + tenant: provider.tenant || "", + providerLabel: provider.providerLabel || "", + claimMap: provider.claimMap ? Object.assign({}, provider.claimMap) : {}, + }; + }); + return cloned; +}; + +const ensureProvider = (providers: Record, id: string) => { + if (!providers[id]) { + providers[id] = { + clientID: "", + clientSecret: "", + pkce: false, + redirectURL: "", + issuerURL: "", + scopes: [], + tenant: "", + providerLabel: "", + claimMap: {}, + }; + } +}; + +const parseScopes = (raw: string) => { + return raw + .split(/[, \t\r\n]+/) + .map((item) => item.trim()) + .filter((item) => item); +}; + +type OIDCClaimMapRow = { + claim: string; + field: string; +}; + +const claimMapToRows = (claimMap: Record | undefined) => { + const rows: OIDCClaimMapRow[] = []; + if (!claimMap) { + return rows; + } + Object.keys(claimMap).sort().forEach((claim) => { + rows.push({ + claim, + field: claimMap[claim] || "", + }); + }); + return rows; +}; + +const rowsToClaimMap = (rows: OIDCClaimMapRow[]) => { + const claimMap: Record = {}; + if (!rows.length) { + return {claimMap}; + } + for (const row of rows) { + const claim = row.claim.trim(); + const field = row.field.trim(); + if (!claim || !field) { + return {claimMap: null}; + } + claimMap[claim] = field; + } + return {claimMap}; +}; + +type OIDCFilterRow = { + claim: string; + op: string; + pattern: string; +}; + +const parseFilterPattern = (pattern: string) => { + const trimmed = pattern.trim(); + if (!trimmed) { + return null; + } + const sepIndex = trimmed.indexOf(":"); + if (sepIndex > 0) { + const prefix = trimmed.slice(0, sepIndex).trim().toLowerCase(); + const rest = trimmed.slice(sepIndex + 1).trim(); + if (prefix === "regex" || prefix === "re") { + return {op: "regex", pattern: rest}; + } + if (prefix === "regexi") { + return {op: "regexi", pattern: rest}; + } + if (prefix === "str" || prefix === "string") { + return {op: "str", pattern: rest}; + } + if (prefix === "exact") { + return {op: "exact", pattern: rest}; + } + } + return {op: "regexi", pattern: trimmed}; +}; + +const filterPatternFromRow = (row: OIDCFilterRow) => { + const pattern = row.pattern.trim(); + if (!pattern) { + return ""; + } + switch (row.op) { + case "regex": + return `regex:${pattern}`; + case "regexi": + return `regexi:${pattern}`; + case "str": + return `str:${pattern}`; + case "exact": + return `exact:${pattern}`; + default: + return pattern; + } +}; + +const filtersToRows = (filters: Record | undefined) => { + const rows: OIDCFilterRow[] = []; + if (!filters) { + return rows; + } + Object.keys(filters).sort().forEach((claim) => { + const patterns = filters[claim] || []; + patterns.forEach((pattern) => { + const parsed = parseFilterPattern(pattern || ""); + if (!parsed) { + return; + } + rows.push({ + claim, + op: parsed.op, + pattern: parsed.pattern, + }); + }); + }); + return rows; +}; + +const rowsToFilters = (rows: OIDCFilterRow[]) => { + const filters: Record = {}; + for (const row of rows) { + const claim = row.claim.trim(); + const pattern = row.pattern.trim(); + if (!claim || !pattern) { + return {filters: null}; + } + const encoded = filterPatternFromRow(row); + if (!encoded) { + return {filters: null}; + } + if (!filters[claim]) { + filters[claim] = []; + } + filters[claim].push(encoded); + } + return {filters}; +}; + +export const setOIDCConfig = () => { + const oidc = window.siyuan.config.oidc || {provider: "", providers: {}, filters: {}}; + const providers = cloneProviders(oidc.providers); + const providerIds = Array.from(new Set([...defaultProviders, ...Object.keys(providers)])).filter((id) => id); + if (oidc.provider && !providerIds.includes(oidc.provider)) { + providerIds.unshift(oidc.provider); + } + let enabledProvider = oidc.provider || ""; + let currentProvider = oidc.provider || providerIds[0]; + ensureProvider(providers, currentProvider); + const providerDisplayNames: Record = { + custom: window.siyuan.languages.oidcProviderCustom, + google: "Google", + microsoft: "Microsoft", + github: "GitHub", + }; + + const dialog = new Dialog({ + title: "\uD83D\uDD10 " + window.siyuan.languages.oidc, + width: isMobile() ? "92vw" : "640px", + height: isMobile() ? "80vh" : "70vh", + content: `
+
+
+ ${window.siyuan.languages.oidcProvider} +
${window.siyuan.languages.oidcProviderTip}
+
+ + +
+
+
+
+ ${window.siyuan.languages.oidcProviderLabel} +
${window.siyuan.languages.oidcProviderLabelTip}
+
+ + +
+
+
+ ${window.siyuan.languages.oidcClientID} +
${window.siyuan.languages.oidcClientIDTip}
+
+ + +
+
+
+ ${window.siyuan.languages.oidcClientSecret} +
${window.siyuan.languages.oidcClientSecretTip}
+
+ + +
+
+
+ PKCE +
${window.siyuan.languages.oidcPKCETip}
+
+ + +
+
+
+ ${window.siyuan.languages.oidcIssuerURL} +
${window.siyuan.languages.oidcIssuerURLTip}
+
+ + +
+
+
+ ${window.siyuan.languages.oidcRedirectURL} +
${window.siyuan.languages.oidcRedirectURLTip}
+
+ + +
+
+
+ ${window.siyuan.languages.oidcScopes} +
${window.siyuan.languages.oidcScopesTip}
+
+ + +
+
+
+ ${window.siyuan.languages.oidcTenant} +
${window.siyuan.languages.oidcTenantTip}
+
+ + +
+
+
+
+ ${window.siyuan.languages.oidcClaimMap} +
${window.siyuan.languages.oidcClaimMapTip}
+
+ + +
+
+
+
+
+
+
+ ${window.siyuan.languages.oidcFilters} +
${window.siyuan.languages.oidcFiltersTip}
+
${window.siyuan.languages.oidcFiltersTipLine1}
+
${window.siyuan.languages.oidcFiltersTipLine2}
+
+ + +
+
+
+
+
+
+ +
`, + }); + + const providerSelect = dialog.element.querySelector("#oidcProvider") as HTMLSelectElement; + const providerLabelInput = dialog.element.querySelector("#oidcProviderLabel") as HTMLInputElement; + const clientIDInput = dialog.element.querySelector("#oidcClientID") as HTMLInputElement; + const clientSecretInput = dialog.element.querySelector("#oidcClientSecret") as HTMLInputElement; + const pkceInput = dialog.element.querySelector("#oidcPKCE") as HTMLInputElement; + const issuerInput = dialog.element.querySelector("#oidcIssuerURL") as HTMLInputElement; + const redirectInput = dialog.element.querySelector("#oidcRedirectURL") as HTMLInputElement; + const scopesInput = dialog.element.querySelector("#oidcScopes") as HTMLInputElement; + const tenantInput = dialog.element.querySelector("#oidcTenant") as HTMLInputElement; + const claimMapList = dialog.element.querySelector("#oidcClaimMapList") as HTMLDivElement; + const claimMapAddButton = dialog.element.querySelector("#oidcClaimMapAdd") as HTMLButtonElement; + const filtersList = dialog.element.querySelector("#oidcFiltersList") as HTMLDivElement; + const filterAddButton = dialog.element.querySelector("#oidcFilterAdd") as HTMLButtonElement; + const filtersBlock = dialog.element.querySelector("#oidcFiltersBlock") as HTMLElement; + const providerConfigBlock = dialog.element.querySelector("#oidcProviderConfig") as HTMLElement; + const issuerRow = dialog.element.querySelector("#oidcIssuerRow") as HTMLElement; + const providerLabelRow = dialog.element.querySelector("#oidcProviderLabelRow") as HTMLElement; + const scopesRow = dialog.element.querySelector("#oidcScopesRow") as HTMLElement; + const tenantRow = dialog.element.querySelector("#oidcTenantRow") as HTMLElement; + const pkceRow = dialog.element.querySelector("#oidcPKCERow") as HTMLElement; + const claimMapRow = dialog.element.querySelector("#oidcClaimMapRow") as HTMLElement; + const buttons = dialog.element.querySelectorAll(".b3-dialog__action .b3-button"); + + const setProviderVisibility = (id: string) => { + const isKnownProvider = Object.prototype.hasOwnProperty.call(providerDisplayNames, id); + const showAll = !isKnownProvider; + const showIssuer = showAll || id === "custom"; + const showTenant = showAll || id === "microsoft"; + const showPKCE = id === "microsoft"; + const showClaimMap = showAll || id === "custom"; + const showScopes = showAll || id === "custom"; + const showProviderLabel = showAll || id === "custom"; + issuerRow.classList.toggle("fn__none", !showIssuer); + scopesRow.classList.toggle("fn__none", !showScopes); + providerLabelRow.classList.toggle("fn__none", !showProviderLabel); + tenantRow.classList.toggle("fn__none", !showTenant); + pkceRow.classList.toggle("fn__none", !showPKCE); + claimMapRow.classList.toggle("fn__none", !showClaimMap); + }; + + const setProviderConfigVisible = (visible: boolean) => { + providerConfigBlock.classList.toggle("fn__none", !visible); + }; + + const setFiltersVisible = (visible: boolean) => { + filtersBlock.classList.toggle("fn__none", !visible); + }; + + let claimMapRows: OIDCClaimMapRow[] = []; + + const syncPKCEState = (id: string) => { + const enablePKCE = id === "microsoft" && pkceInput.checked; + if (enablePKCE) { + clientSecretInput.value = ""; + } + clientSecretInput.disabled = enablePKCE; + }; + + const setProviderForm = (id: string) => { + const provider = providers[id]; + providerLabelInput.value = provider.providerLabel || ""; + clientIDInput.value = provider.clientID || ""; + clientSecretInput.value = provider.clientSecret || ""; + pkceInput.checked = id === "microsoft" && !!provider.pkce; + issuerInput.value = provider.issuerURL || ""; + redirectInput.value = provider.redirectURL || ""; + scopesInput.value = provider.scopes && provider.scopes.length ? provider.scopes.join(", ") : ""; + tenantInput.value = provider.tenant || ""; + claimMapRows = claimMapToRows(provider.claimMap); + renderClaimMapRows(); + setProviderVisibility(id); + syncPKCEState(id); + }; + + const readProviderForm = () => { + const provider = providers[currentProvider]; + let claimMap = provider.claimMap || {}; + if (currentProvider === "custom") { + const claimMapResult = rowsToClaimMap(claimMapRows); + if (!claimMapResult.claimMap) { + showMessage(window.siyuan.languages.oidcClaimMapRowInvalid); + return null; + } + claimMap = claimMapResult.claimMap || {}; + } + return { + clientID: clientIDInput.value.trim(), + clientSecret: currentProvider === "microsoft" && pkceInput.checked ? "" : clientSecretInput.value, + pkce: currentProvider === "microsoft" && pkceInput.checked, + redirectURL: redirectInput.value.trim(), + issuerURL: issuerInput.value.trim(), + scopes: parseScopes(scopesInput.value), + tenant: tenantInput.value.trim(), + providerLabel: currentProvider === "custom" ? providerLabelInput.value.trim() : "", + claimMap, + } as Config.IOIDCProviderConf; + }; + + const filterRows = filtersToRows(oidc.filters); + const operatorOptions = [ + {value: "regexi", label: window.siyuan.languages.oidcFilterOpRegexI}, + {value: "regex", label: window.siyuan.languages.oidcFilterOpRegex}, + {value: "str", label: window.siyuan.languages.oidcFilterOpString}, + {value: "exact", label: window.siyuan.languages.oidcFilterOpExact}, + ]; + + const renderClaimMapRows = () => { + const mobile = isMobile(); + if (!claimMapRows.length) { + claimMapList.innerHTML = ""; + return; + } + + claimMapList.innerHTML = `
    ${ + claimMapRows.map((row, index) => ` +
  • + + + + + ${mobile ? ` + ` : ` + + + `} +
  • `).join("") + }
`; + + claimMapList.querySelectorAll("input, select").forEach((input) => { + input.addEventListener("change", () => { + const li = input.closest("li"); + if (!li) { + return; + } + const index = parseInt(li.getAttribute("data-index") || "0", 10); + const field = (input as HTMLInputElement).dataset.field; + if (!claimMapRows[index] || !field) { + return; + } + if (field === "claim") { + claimMapRows[index].claim = (input as HTMLInputElement).value; + } else if (field === "field") { + claimMapRows[index].field = (input as HTMLInputElement).value; + } + }); + }); + + claimMapList.querySelectorAll('[data-action="remove"]').forEach((remove) => { + remove.addEventListener("click", () => { + const li = remove.closest("li"); + if (!li) { + return; + } + const index = parseInt(li.getAttribute("data-index") || "0", 10); + claimMapRows.splice(index, 1); + renderClaimMapRows(); + }); + }); + }; + + const renderFilterRows = () => { + filtersList.innerHTML = `
    ${ + filterRows + .map((row, index) => ` +
  • + + + + + + + ${isMobile() ? ` + ` : ` + + + `} +
  • `).join("") + }
`; + + filtersList.querySelectorAll("input, select").forEach((input) => { + input.addEventListener("change", () => { + const li = input.closest("li"); + if (!li) { + return; + } + const index = parseInt(li.getAttribute("data-index") || "0", 10); + const field = (input as HTMLInputElement).dataset.field; + if (!filterRows[index] || !field) { + return; + } + if (field === "op") { + filterRows[index].op = (input as HTMLSelectElement).value; + } else if (field === "claim") { + filterRows[index].claim = (input as HTMLInputElement).value; + } else if (field === "pattern") { + filterRows[index].pattern = (input as HTMLInputElement).value; + } + }); + }); + + filtersList.querySelectorAll('[data-action="remove"]').forEach((remove) => { + remove.addEventListener("click", () => { + const li = remove.closest("li"); + if (!li) { + return; + } + const index = parseInt(li.getAttribute("data-index") || "0", 10); + filterRows.splice(index, 1); + renderFilterRows(); + }); + }); + }; + + setProviderForm(currentProvider); + setProviderConfigVisible(!!enabledProvider); + setFiltersVisible(!!enabledProvider); + renderFilterRows(); + + pkceInput.addEventListener("change", () => { + syncPKCEState(currentProvider); + }); + + claimMapAddButton.addEventListener("click", () => { + const defaultClaim = claimOptions[0] || ""; + claimMapRows.push({ + claim: defaultClaim, + field: "", + }); + renderClaimMapRows(); + }); + + filterAddButton.addEventListener("click", () => { + const defaultClaim = claimOptions[0] || ""; + filterRows.push({ + claim: defaultClaim, + op: "regexi", + pattern: "", + }); + renderFilterRows(); + }); + + providerSelect.addEventListener("change", () => { + const nextProvider = providerSelect.value; + const updated = readProviderForm(); + if (!updated) { + providerSelect.value = enabledProvider; + return; + } + providers[currentProvider] = updated; + if (nextProvider) { + currentProvider = nextProvider; + ensureProvider(providers, currentProvider); + setProviderForm(currentProvider); + } + enabledProvider = nextProvider; + setProviderConfigVisible(!!enabledProvider); + setFiltersVisible(!!enabledProvider); + }); + + buttons[0].addEventListener("click", () => { + dialog.destroy(); + }); + + buttons[1].addEventListener("click", () => { + const updated = readProviderForm(); + if (!updated) { + return; + } + providers[currentProvider] = updated; + const filtersResult = rowsToFilters(filterRows); + if (!filtersResult.filters) { + showMessage(window.siyuan.languages.oidcFiltersRowInvalid); + return; + } + const payload = { + provider: enabledProvider, + providers, + filters: filtersResult.filters, + }; + fetchPost("/api/system/setOIDCConfig", {oidc: payload}, () => { + window.siyuan.config.oidc = payload; + dialog.destroy(); + }); + }); +}; diff --git a/app/src/mobile/settings/about.ts b/app/src/mobile/settings/about.ts index 727a56b98bd..aba760d24cd 100644 --- a/app/src/mobile/settings/about.ts +++ b/app/src/mobile/settings/about.ts @@ -1,5 +1,6 @@ import {Constants} from "../../constants"; import {setAccessAuthCode} from "../../config/util/about"; +import {setOIDCConfig} from "../../config/util/oidc"; import {Dialog} from "../../dialog"; import {fetchPost} from "../../util/fetch"; import {confirmDialog} from "../../dialog/confirmDialog"; @@ -55,12 +56,30 @@ export const initAbout = () => {
${window.siyuan.languages.about18}
+
+
+ ${window.siyuan.languages.accessAuthBypass} +
${window.siyuan.languages.accessAuthBypassTip}
+
+
+ +
+
${window.siyuan.languages.about5}
- -
${window.siyuan.languages.about6}
+
${window.siyuan.languages.about6}
+
${window.siyuan.languages.accessAuthBypassAuthCodeDisabledTip}
+
+ ${window.siyuan.languages.oidc} +
+ +
${window.siyuan.languages.oidcTip}
+
${window.siyuan.languages.accessAuthBypassOIDCDisabledTip}
${window.siyuan.languages.dataRepoKey} @@ -206,6 +225,34 @@ export const initAbout = () => { const workspaceDirElement = modelMainElement.querySelector("#workspaceDir"); genWorkspace(workspaceDirElement); const importKeyElement = modelMainElement.querySelector("#importKey"); + const authCodeElement = modelMainElement.querySelector("#authCode") as HTMLButtonElement; + const accessAuthBypassElement = modelMainElement.querySelector("#accessAuthBypass") as HTMLInputElement; + const oidcSettingElement = modelMainElement.querySelector("#oidcSetting") as HTMLButtonElement; + const setAuthControlsDisabled = (disabled: boolean) => { + authCodeElement?.toggleAttribute("disabled", disabled); + oidcSettingElement?.toggleAttribute("disabled", disabled); + }; + const applyAccessAuthBypassChange = (enabled: boolean) => { + fetchPost("/api/system/setAccessAuthBypass", {accessAuthBypass: enabled}, () => { + window.siyuan.config.accessAuthBypass = enabled; + setAuthControlsDisabled(enabled); + }); + }; + setAuthControlsDisabled(window.siyuan.config.accessAuthBypass); + if (accessAuthBypassElement) { + accessAuthBypassElement.addEventListener("change", () => { + const nextChecked = accessAuthBypassElement.checked; + if (nextChecked) { + accessAuthBypassElement.checked = false; + confirmDialog("⚠️ " + window.siyuan.languages.accessAuthBypass, window.siyuan.languages.accessAuthBypassConfirm, () => { + accessAuthBypassElement.checked = true; + applyAccessAuthBypassChange(true); + }); + } else { + applyAccessAuthBypassChange(false); + } + }); + } modelMainElement.firstElementChild.addEventListener("click", (event) => { let target = event.target as HTMLElement; while (target && (target !== modelMainElement)) { @@ -214,6 +261,11 @@ export const initAbout = () => { event.preventDefault(); event.stopPropagation(); break; + } else if (target.id === "oidcSetting") { + setOIDCConfig(); + event.preventDefault(); + event.stopPropagation(); + break; } else if (target.id === "importKey") { const passwordDialog = new Dialog({ title: "🔑 " + window.siyuan.languages.key, diff --git a/app/src/types/config.d.ts b/app/src/types/config.d.ts index 962f046821a..c2f6462ad8d 100644 --- a/app/src/types/config.d.ts +++ b/app/src/types/config.d.ts @@ -25,6 +25,14 @@ declare namespace Config { * Access authorization code */ accessAuthCode: TAccessAuthCode; + /** + * Whether to bypass all access authentication checks + */ + accessAuthBypass: boolean; + /** + * OIDC login configuration + */ + oidc: IOIDC; account: IAccount; ai: IAI; api: IAPI; @@ -96,6 +104,35 @@ declare namespace Config { */ export type TAccessAuthCode = "" | "*******"; + /** + * OIDC login configuration + */ + export interface IOIDC { + provider: string; + providers: Record; + filters: IOIDCFilter; + /** + * kernel used fields, frontend ignored + * + * providerHash?: string; + * filterHash?: string; + */ + } + + export interface IOIDCProviderConf { + clientID: string; + clientSecret: string; + pkce: boolean; + redirectURL: string; + issuerURL: string; + scopes: string[]; + tenant: string; + providerLabel: string; + claimMap: Record; + } + + export type IOIDCFilter = Record; + /** * Account configuration */ diff --git a/app/stage/auth.html b/app/stage/auth.html index 8d5d5094142..4f912777363 100644 --- a/app/stage/auth.html +++ b/app/stage/auth.html @@ -73,6 +73,30 @@ box-shadow: 0px 5px 5px -3px rgba(0, 0, 0, 0.2), 0px 8px 10px 1px rgba(0, 0, 0, 0.14), 0px 3px 14px 2px rgba(0, 0, 0, .12); } + .b3-button:disabled, + .b3-button.fn__disabled { + opacity: 0.6; + cursor: not-allowed; + box-shadow: none; + } + + .b3-button[data-loading="true"]::after { + content: ""; + width: 12px; + height: 12px; + margin-left: 8px; + border: 2px solid rgba(255, 255, 255, 0.6); + border-top-color: #fff; + border-radius: 50%; + animation: b3-loading-spin 0.8s linear infinite; + } + + @keyframes b3-loading-spin { + to { + transform: rotate(360deg); + } + } + .b3-button--white { color: rgb(13, 60, 97); background-color: #d6eaf9; @@ -173,28 +197,64 @@ font-size: 12px; margin: 16px 0; } + + .action-group, + .auth-buttons { + display: flex; + flex-direction: column; + align-items: center; + } + + .action-group { + gap: 10px; + margin-top: 8px; + } + + .auth-buttons { + gap: 6px; + width: 100%; + } + + .auth-buttons .b3-button { + margin: 6px 0; + }

{{.workspace}}

+ {{if .hasAccessAuthCode}}
+ {{end}}
- -
- {{.l2}} -
- -
- {{.l7}} +
+
+ {{if .hasAccessAuthCode}} + + {{end}} + {{if .oidcEnabled}} + + {{end}} +
+ {{.authError}} +
+
+ {{.l2}} +
+
+
+ +
+ {{.l7}} +
@@ -417,6 +477,26 @@

{{.workspace + + \ No newline at end of file diff --git a/kernel/api/router.go b/kernel/api/router.go index d36fe83a6f2..a4962dc15c4 100644 --- a/kernel/api/router.go +++ b/kernel/api/router.go @@ -40,6 +40,8 @@ func ServeAPI(ginServer *gin.Engine) { ginServer.Handle("POST", "/api/system/getEmojiConf", model.CheckAuth, getEmojiConf) ginServer.Handle("POST", "/api/system/setAPIToken", model.CheckAuth, model.CheckAdminRole, model.CheckReadonly, setAPIToken) ginServer.Handle("POST", "/api/system/setAccessAuthCode", model.CheckAuth, model.CheckAdminRole, model.CheckReadonly, setAccessAuthCode) + ginServer.Handle("POST", "/api/system/setAccessAuthBypass", model.CheckAuth, model.CheckAdminRole, model.CheckReadonly, setAccessAuthBypass) + ginServer.Handle("POST", "/api/system/setOIDCConfig", model.CheckAuth, model.CheckAdminRole, model.CheckReadonly, setOIDCConfig) ginServer.Handle("POST", "/api/system/setFollowSystemLockScreen", model.CheckAuth, model.CheckAdminRole, model.CheckReadonly, setFollowSystemLockScreen) ginServer.Handle("POST", "/api/system/setNetworkServe", model.CheckAuth, model.CheckAdminRole, model.CheckReadonly, setNetworkServe) ginServer.Handle("POST", "/api/system/setAutoLaunch", model.CheckAuth, model.CheckAdminRole, model.CheckReadonly, setAutoLaunch) diff --git a/kernel/api/system.go b/kernel/api/system.go index 83b9367b222..0258876f474 100644 --- a/kernel/api/system.go +++ b/kernel/api/system.go @@ -323,6 +323,8 @@ func exportConf(c *gin.Context) { clonedConf.UserData = "" clonedConf.Account = nil clonedConf.AccessAuthCode = "" + clonedConf.AccessAuthBypass = false + clonedConf.OIDC = nil if nil != clonedConf.System { clonedConf.System.ID = "" clonedConf.System.Name = "" @@ -609,6 +611,15 @@ func setAccessAuthCode(c *gin.Context) { aac = strings.TrimSpace(aac) aac = util.RemoveInvalid(aac) + if util.ContainerDocker == util.Container && !model.Conf.AccessAuthBypass { + if "" == aac && !model.OIDCIsEnabled(model.Conf.OIDC) { + ret.Code = -1 + // At least one auth method required in Docker mode + ret.Msg = model.Conf.Language(279) + return + } + } + model.Conf.AccessAuthCode = aac model.Conf.Save() @@ -623,6 +634,97 @@ func setAccessAuthCode(c *gin.Context) { return } +func setAccessAuthBypass(c *gin.Context) { + ret := gulu.Ret.NewResult() + defer c.JSON(http.StatusOK, ret) + + arg, ok := util.JsonArg(c, ret) + if !ok { + return + } + + accessAuthBypass, ok := arg["accessAuthBypass"].(bool) + if !ok { + ret.Code = -1 + ret.Msg = "accessAuthBypass is required" + return + } + + model.Conf.AccessAuthBypass = accessAuthBypass + model.Conf.Save() + + go func() { + time.Sleep(200 * time.Millisecond) + util.ReloadUI() + }() +} + +func setOIDCConfig(c *gin.Context) { + ret := gulu.Ret.NewResult() + defer c.JSON(http.StatusOK, ret) + + arg, ok := util.JsonArg(c, ret) + if !ok { + return + } + + oidcArg, ok := arg["oidc"] + if !ok { + ret.Code = -1 + ret.Msg = "oidc is required" + return + } + + param, err := gulu.JSON.MarshalJSON(oidcArg) + if err != nil { + ret.Code = -1 + ret.Msg = err.Error() + return + } + + oidcConf := &conf.OIDC{} + if err = gulu.JSON.UnmarshalJSON(param, oidcConf); err != nil { + ret.Code = -1 + ret.Msg = err.Error() + return + } + + if nil != oidcConf.Providers { + for providerID, provider := range oidcConf.Providers { + if nil == provider { + continue + } + // convert masked client secret to the existing one + if strings.TrimSpace(provider.ClientSecret) == model.MaskedSecret { + if nil != model.Conf.OIDC && nil != model.Conf.OIDC.Providers { + if existing := model.Conf.OIDC.Providers[providerID]; nil != existing { + provider.ClientSecret = existing.ClientSecret + continue + } + } + provider.ClientSecret = "" + } + } + } + + if util.ContainerDocker == util.Container && !model.Conf.AccessAuthBypass { + if "" == strings.TrimSpace(model.Conf.AccessAuthCode) && !model.OIDCIsEnabled(oidcConf) { + ret.Code = -1 + // At least one auth method required in Docker mode + ret.Msg = model.Conf.Language(279) + return + } + } + + model.Conf.UpdateOIDCConfig(oidcConf) + model.Conf.Save() + + go func() { + time.Sleep(200 * time.Millisecond) + util.ReloadUI() + }() +} + func setFollowSystemLockScreen(c *gin.Context) { ret := gulu.Ret.NewResult() defer c.JSON(http.StatusOK, ret) diff --git a/kernel/conf/oidc.go b/kernel/conf/oidc.go new file mode 100644 index 00000000000..fd785c0ae45 --- /dev/null +++ b/kernel/conf/oidc.go @@ -0,0 +1,50 @@ +// SiYuan - Refactor your thinking +// Copyright (c) 2020-present, b3log.org +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package conf + +type OIDC struct { + Provider string `json:"provider"` + Providers map[string]*OIDCProviderConf `json:"providers"` + Filters map[string][]string `json:"filters"` + ProviderHash string `json:"providerHash"` + FilterHash string `json:"filterHash"` +} + +type OIDCProviderConf struct { + ClientID string `json:"clientID"` + ClientSecret string `json:"clientSecret"` + PKCE bool `json:"pkce"` + RedirectURL string `json:"redirectURL"` + IssuerURL string `json:"issuerURL"` + Scopes []string `json:"scopes"` + Tenant string `json:"tenant"` + ProviderLabel string `json:"providerLabel"` + ClaimMap map[string]string `json:"claimMap"` +} + +func NewOIDC() *OIDC { + return &OIDC{ + Provider: "", + Providers: map[string]*OIDCProviderConf{ + "custom": {}, + "google": {}, + "microsoft": {}, + "github": {}, + }, + Filters: map[string][]string{}, + } +} diff --git a/kernel/entrypoint.sh b/kernel/entrypoint.sh index 6c8d65782a6..19600fd59b6 100644 --- a/kernel/entrypoint.sh +++ b/kernel/entrypoint.sh @@ -29,16 +29,32 @@ else fi # Parse command line arguments for --workspace option or SIYUAN_WORKSPACE_PATH env variable -# Store other arguments in ARGS for later use +# Delete --workspace argument for no duplication and keep other arguments for later exec if [[ -n "${SIYUAN_WORKSPACE_PATH}" ]]; then WORKSPACE_DIR="${SIYUAN_WORKSPACE_PATH}" fi -ARGS="" -while [[ "$#" -gt 0 ]]; do +# in POSIX sh, we don't have arrays, so we use a Argument Rotation trick +arg_count=$# +while [ "$arg_count" -gt 0 ]; do case $1 in - --workspace=*) WORKSPACE_DIR="${1#*=}"; shift ;; - *) ARGS="$ARGS $1"; shift ;; + --workspace=*) + WORKSPACE_DIR="${1#*=}" + shift + ;; + --workspace) + WORKSPACE_DIR="$2" + shift 2 + # there are 2 arguments, we need to decrease one more here + arg_count=$((arg_count - 1)) + ;; + *) + # Core idea of Argument Rotation: move argument to the end of the list + set -- "$@" "$1" + shift + ;; esac + + arg_count=$((arg_count - 1)) done # Change ownership of relevant directories, including the workspace directory @@ -49,4 +65,4 @@ chown -R "${PUID}:${PGID}" "${WORKSPACE_DIR}" # Switch to the newly created user and start the main process with all arguments echo "Starting Siyuan with UID:${PUID} and GID:${PGID} in workspace ${WORKSPACE_DIR}" -exec su-exec "${PUID}:${PGID}" /opt/siyuan/kernel --workspace="${WORKSPACE_DIR}" ${ARGS} +exec su-exec "${PUID}:${PGID}" /opt/siyuan/kernel --workspace="${WORKSPACE_DIR}" "$@" diff --git a/kernel/go.mod b/kernel/go.mod index 75084f988a5..551e179757c 100644 --- a/kernel/go.mod +++ b/kernel/go.mod @@ -18,6 +18,7 @@ require ( github.com/araddon/dateparse v0.0.0-20210429162001-6b43995a97de github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be + github.com/coreos/go-oidc/v3 v3.16.0 github.com/denisbrodbeck/machineid v1.0.1 github.com/dgraph-io/ristretto v0.2.0 github.com/disintegration/imaging v1.6.2 @@ -76,6 +77,7 @@ require ( golang.org/x/mobile v0.0.0-20251209145715-2553ed8ce294 golang.org/x/mod v0.31.0 golang.org/x/net v0.48.0 + golang.org/x/oauth2 v0.30.0 golang.org/x/sys v0.39.0 golang.org/x/text v0.32.0 golang.org/x/time v0.14.0 @@ -126,6 +128,7 @@ require ( github.com/fatih/set v0.2.1 // indirect github.com/gammazero/toposort v0.1.1 // indirect github.com/gigawattio/window v0.0.0-20180317192513-0f5467e35573 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.28.0 // indirect diff --git a/kernel/go.sum b/kernel/go.sum index a504ebc0cd4..2a3b1ed1f36 100644 --- a/kernel/go.sum +++ b/kernel/go.sum @@ -108,6 +108,8 @@ github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be h1:J5BL2kskAlV9ckgEsNQXscjIaLiOYiZ75d4e94E6dcQ= github.com/common-nighthawk/go-figure v0.0.0-20210622060536-734e95fb86be/go.mod h1:mk5IQ+Y0ZeO87b858TlA645sVcEcbiX6YqP98kt+7+w= +github.com/coreos/go-oidc/v3 v3.16.0 h1:qRQUCFstKpXwmEjDQTIbyY/5jF00+asXzSkmkoa/mow= +github.com/coreos/go-oidc/v3 v3.16.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/dave/jennifer v1.6.1/go.mod h1:nXbxhEmQfOZhWml3D1cDK5M1FLnMSozpbFN/m3RmGZc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -169,6 +171,8 @@ github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk= github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= @@ -492,6 +496,8 @@ golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/kernel/model/conf.go b/kernel/model/conf.go index fd3b3219619..2e9fa293007 100644 --- a/kernel/model/conf.go +++ b/kernel/model/conf.go @@ -19,6 +19,7 @@ package model import ( "bytes" "crypto/sha1" + "encoding/json" "fmt" "os" "path/filepath" @@ -50,39 +51,41 @@ var Conf *AppConf // AppConf 维护应用元数据,保存在 ~/.siyuan/conf.json。 type AppConf struct { - LogLevel string `json:"logLevel"` // 日志级别:off, trace, debug, info, warn, error, fatal - Appearance *conf.Appearance `json:"appearance"` // 外观 - Langs []*conf.Lang `json:"langs"` // 界面语言列表 - Lang string `json:"lang"` // 选择的界面语言,同 Appearance.Lang - FileTree *conf.FileTree `json:"fileTree"` // 文档面板 - Tag *conf.Tag `json:"tag"` // 标签面板 - Editor *conf.Editor `json:"editor"` // 编辑器配置 - Export *conf.Export `json:"export"` // 导出配置 - Graph *conf.Graph `json:"graph"` // 关系图配置 - UILayout *conf.UILayout `json:"uiLayout"` // 界面布局。不要直接使用,使用 GetUILayout() 和 SetUILayout() 方法 - UserData string `json:"userData"` // 社区用户信息,对 User 加密存储 - User *conf.User `json:"-"` // 社区用户内存结构,不持久化。不要直接使用,使用 GetUser() 和 SetUser() 方法 - Account *conf.Account `json:"account"` // 帐号配置 - ReadOnly bool `json:"readonly"` // 是否是以只读模式运行 - LocalIPs []string `json:"localIPs"` // 本地 IP 列表 - AccessAuthCode string `json:"accessAuthCode"` // 访问授权码 - System *conf.System `json:"system"` // 系统配置 - Keymap *conf.Keymap `json:"keymap"` // 快捷键配置 - Sync *conf.Sync `json:"sync"` // 同步配置 - Search *conf.Search `json:"search"` // 搜索配置 - Flashcard *conf.Flashcard `json:"flashcard"` // 闪卡配置 - AI *conf.AI `json:"ai"` // 人工智能配置 - Bazaar *conf.Bazaar `json:"bazaar"` // 集市配置 - Stat *conf.Stat `json:"stat"` // 统计 - Api *conf.API `json:"api"` // API - Repo *conf.Repo `json:"repo"` // 数据仓库 - Publish *conf.Publish `json:"publish"` // 发布服务 - OpenHelp bool `json:"openHelp"` // 启动后是否需要打开用户指南 - ShowChangelog bool `json:"showChangelog"` // 是否显示版本更新日志 - CloudRegion int `json:"cloudRegion"` // 云端区域,0:中国大陆,1:北美 - Snippet *conf.Snpt `json:"snippet"` // 代码片段 - DataIndexState int `json:"dataIndexState"` // 数据索引状态,0:已索引,1:未索引 - CookieKey string `json:"cookieKey"` // 用于加密 Cookie 的密钥 + LogLevel string `json:"logLevel"` // 日志级别:off, trace, debug, info, warn, error, fatal + Appearance *conf.Appearance `json:"appearance"` // 外观 + Langs []*conf.Lang `json:"langs"` // 界面语言列表 + Lang string `json:"lang"` // 选择的界面语言,同 Appearance.Lang + FileTree *conf.FileTree `json:"fileTree"` // 文档面板 + Tag *conf.Tag `json:"tag"` // 标签面板 + Editor *conf.Editor `json:"editor"` // 编辑器配置 + Export *conf.Export `json:"export"` // 导出配置 + Graph *conf.Graph `json:"graph"` // 关系图配置 + UILayout *conf.UILayout `json:"uiLayout"` // 界面布局。不要直接使用,使用 GetUILayout() 和 SetUILayout() 方法 + UserData string `json:"userData"` // 社区用户信息,对 User 加密存储 + User *conf.User `json:"-"` // 社区用户内存结构,不持久化。不要直接使用,使用 GetUser() 和 SetUser() 方法 + Account *conf.Account `json:"account"` // 帐号配置 + ReadOnly bool `json:"readonly"` // 是否是以只读模式运行 + LocalIPs []string `json:"localIPs"` // 本地 IP 列表 + AccessAuthBypass bool `json:"accessAuthBypass"` // 跳过一切访问认证和安全检查 + AccessAuthCode string `json:"accessAuthCode"` // 访问授权码 + OIDC *conf.OIDC `json:"oidc"` // OIDC 登录配置 + System *conf.System `json:"system"` // 系统配置 + Keymap *conf.Keymap `json:"keymap"` // 快捷键配置 + Sync *conf.Sync `json:"sync"` // 同步配置 + Search *conf.Search `json:"search"` // 搜索配置 + Flashcard *conf.Flashcard `json:"flashcard"` // 闪卡配置 + AI *conf.AI `json:"ai"` // 人工智能配置 + Bazaar *conf.Bazaar `json:"bazaar"` // 集市配置 + Stat *conf.Stat `json:"stat"` // 统计 + Api *conf.API `json:"api"` // API + Repo *conf.Repo `json:"repo"` // 数据仓库 + Publish *conf.Publish `json:"publish"` // 发布服务 + OpenHelp bool `json:"openHelp"` // 启动后是否需要打开用户指南 + ShowChangelog bool `json:"showChangelog"` // 是否显示版本更新日志 + CloudRegion int `json:"cloudRegion"` // 云端区域,0:中国大陆,1:北美 + Snippet *conf.Snpt `json:"snippet"` // 代码片段 + DataIndexState int `json:"dataIndexState"` // 数据索引状态,0:已索引,1:未索引 + CookieKey string `json:"cookieKey"` // 用于加密 Cookie 的密钥 m *sync.RWMutex // 配置数据锁 userLock *sync.RWMutex // 用户数据独立锁,避免与配置保存操作竞争 @@ -137,6 +140,61 @@ func InitConf() { } } + // 合并命令行和环境变量提供的认证配置项。 CLI 优先级高于 配置文件。 + // CLI 只存在于桌面端/容器化构建版本,移动平台不会设置这些全局变量。 + if util.AuthCLI.AccessCodeSet { + Conf.AccessAuthCode = util.RemoveInvalid(strings.TrimSpace(util.AuthCLI.AccessCode)) + } else { + Conf.AccessAuthCode = util.RemoveInvalid(strings.TrimSpace(Conf.AccessAuthCode)) + } + + if util.AuthCLI.AccessAuthBypassSet { + Conf.AccessAuthBypass = util.AuthCLI.AccessAuthBypass + } + + if nil == Conf.OIDC { + Conf.OIDC = conf.NewOIDC() + } + + if util.AuthCLI.OIDCProviderSet { + Conf.OIDC.Provider = util.AuthCLI.OIDCProvider + } + + if util.AuthCLI.OIDCProvidersSet { + if "" == util.AuthCLI.OIDCProviders { + Conf.OIDC.Providers = map[string]*conf.OIDCProviderConf{} + } else { + providers := map[string]*conf.OIDCProviderConf{} + if err := json.Unmarshal([]byte(util.AuthCLI.OIDCProviders), &providers); err != nil { + logging.LogErrorf("parse oidc providers from cli failed: %s", err) + } else { + Conf.OIDC.Providers = providers + } + } + } + + if util.AuthCLI.OIDCFiltersSet { + if "" == util.AuthCLI.OIDCFilters { + Conf.OIDC.Filters = map[string][]string{} + } else { + filters := map[string][]string{} + if err := json.Unmarshal([]byte(util.AuthCLI.OIDCFilters), &filters); err != nil { + logging.LogErrorf("parse oidc filters from cli failed: %s", err) + } else { + Conf.OIDC.Filters = filters + } + } + } + + Conf.UpdateOIDCConfig(Conf.OIDC) + + if util.ContainerDocker == util.Container && !Conf.AccessAuthBypass { + if "" == Conf.AccessAuthCode && !OIDCIsEnabled(Conf.OIDC) { + fmt.Println("in Docker mode, you must set (or set --accessAuthBypass [not recommended]) at least one auth method: accessAuthCode or OIDC") + os.Exit(logging.ExitCodeSecurityRisk) + } + } + if "" != util.Lang { initialized := false if util.ContainerAndroid == util.Container || util.ContainerIOS == util.Container || util.ContainerHarmony == util.Container { @@ -565,12 +623,6 @@ func InitConf() { Conf.ReadOnly = util.ReadOnly - if "" != util.AccessAuthCode { - Conf.AccessAuthCode = util.AccessAuthCode - } - Conf.AccessAuthCode = strings.TrimSpace(Conf.AccessAuthCode) - Conf.AccessAuthCode = util.RemoveInvalid(Conf.AccessAuthCode) - Conf.LocalIPs = util.GetLocalIPs() if 1 == Conf.DataIndexState { @@ -652,6 +704,77 @@ func initLang() { } } +func (appConf *AppConf) UpdateOIDCConfig(new *conf.OIDC) { + if nil == new { + new = conf.NewOIDC() + } + + new.Provider = strings.TrimSpace(new.Provider) + if nil == new.Providers { + new.Providers = map[string]*conf.OIDCProviderConf{} + } + if nil == new.Filters { + new.Filters = map[string][]string{} + } + + new.ProviderHash = oidcProviderHash(new) + new.FilterHash = oidcFilterHash(new.Filters) + appConf.OIDC = new +} + +func oidcProviderHash(oidcConf *conf.OIDC) string { + if nil == oidcConf { + return "" + } + providerID := strings.TrimSpace(oidcConf.Provider) + if "" == providerID { + return "" + } + providerConf := oidcConf.Providers[providerID] + if nil == providerConf { + return "" + } + + // json.Marshal 会自动按字典顺序序列化map,因此可以直接使用它来生成哈希值 + data, err := json.Marshal(providerConf) + if err != nil { + return "" + } + + input := append([]byte(providerID+":"), data...) + sum := sha1.Sum(input) + return fmt.Sprintf("%x", sum) +} + +func oidcFilterHash(filters map[string][]string) string { + if 0 == len(filters) { + return "" + } + normalized := map[string][]string{} + for key, patterns := range filters { + key = strings.TrimSpace(key) + if "" == key { + continue + } + normalizedPatterns := make([]string, 0, len(patterns)) + for _, pattern := range patterns { + pattern = strings.TrimSpace(pattern) + if "" == pattern { + continue + } + normalizedPatterns = append(normalizedPatterns, pattern) + } + sort.Strings(normalizedPatterns) + normalized[key] = normalizedPatterns + } + data, err := json.Marshal(normalized) + if err != nil { + return "" + } + sum := sha1.Sum(data) + return fmt.Sprintf("%x", sum) +} + func loadLangs() (ret []*conf.Lang) { for name, langMap := range util.Langs { lang := &conf.Lang{Label: langMap[-1], Name: name} @@ -979,6 +1102,7 @@ func IsPaidUser() bool { const ( MaskedUserData = "" MaskedAccessAuthCode = "*******" + MaskedSecret = "*******" ) func GetMaskedConf() (ret *AppConf, err error) { @@ -998,6 +1122,16 @@ func GetMaskedConf() (ret *AppConf, err error) { if "" != ret.AccessAuthCode { ret.AccessAuthCode = MaskedAccessAuthCode } + if nil != ret.OIDC && nil != ret.OIDC.Providers { + for _, provider := range ret.OIDC.Providers { + if nil == provider { + continue + } + if "" != provider.ClientSecret { + provider.ClientSecret = MaskedSecret + } + } + } return } @@ -1008,6 +1142,7 @@ func HideConfSecret(c *AppConf) { c.Api = &conf.API{} c.Flashcard = &conf.Flashcard{} c.LocalIPs = []string{} + c.OIDC = &conf.OIDC{} c.Publish = &conf.Publish{} c.Repo = &conf.Repo{} c.Sync = &conf.Sync{} diff --git a/kernel/model/oidc.go b/kernel/model/oidc.go new file mode 100644 index 00000000000..6be2f6561ec --- /dev/null +++ b/kernel/model/oidc.go @@ -0,0 +1,842 @@ +// SiYuan - Refactor your thinking +// Copyright (c) 2020-present, b3log.org +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package model + +import ( + "bytes" + "container/heap" + "context" + "errors" + "fmt" + "html/template" + "net/http" + "net/url" + "os" + "path/filepath" + "regexp" + "slices" + "strings" + "sync" + "time" + + "github.com/88250/gulu" + ginSessions "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/siyuan-note/logging" + "github.com/siyuan-note/siyuan/kernel/conf" + oidcprovider "github.com/siyuan-note/siyuan/kernel/model/oidc_provider" + "github.com/siyuan-note/siyuan/kernel/util" +) + +const ( + oidcFlowWeb = "web" + oidcFlowDesktop = "desktop" + oidcFlowMobile = "mobile" +) + +const ( + oidcStatusPending = "pending" + oidcStatusOK = "ok" + oidcStatusError = "error" +) + +type oidcStateReason int + +const ( + oidcStateReasonOK oidcStateReason = iota + oidcStateReasonSessionInvalid + oidcStateReasonSessionExpired + oidcStateReasonSessionHandled + oidcStateReasonCallbackParamsMissing + oidcStateReasonProviderMismatch + oidcStateReasonFilterRejected + oidcStateReasonAuthFailed + oidcStateReasonSessionSaveFailed + oidcStateReasonNotEnabled + oidcStateReasonProviderInitFailed +) + +func oidcReasonMessage(reason oidcStateReason, args ...any) string { + switch reason { + case oidcStateReasonSessionInvalid: + return Conf.Language(277) + case oidcStateReasonSessionExpired: + return Conf.Language(284) + case oidcStateReasonSessionHandled: + return Conf.Language(285) + case oidcStateReasonCallbackParamsMissing: + return Conf.Language(286) + case oidcStateReasonProviderMismatch: + return Conf.Language(287) + case oidcStateReasonFilterRejected: + return Conf.Language(278) + case oidcStateReasonSessionSaveFailed: + return Conf.Language(288) + case oidcStateReasonNotEnabled: + return Conf.Language(280) + case oidcStateReasonProviderInitFailed: + return fmt.Sprintf(Conf.Language(281), args...) + case oidcStateReasonAuthFailed: + return Conf.Language(276) + default: + return Conf.Language(276) + } +} + +func OIDCLogin(c *gin.Context) { + if !OIDCIsEnabled(Conf.OIDC) { + oidcLoginError(c, oidcStateReasonNotEnabled, "oidc not enabled", nil) + return + } + + p, err := oidcProvider(Conf.OIDC) + if err != nil { + oidcLoginError(c, oidcStateReasonProviderInitFailed, "init oidc provider failed", err, Conf.OIDC.Provider) + return + } + + entry := oidcChallenge(p.ID(), oidcFlowFromQuery(c), util.SanitizeRedirectPath(c.Query("to")), util.ParseBoolQuery(c.Query("rememberMe"))) + authURL, extra, err := p.AuthURL(entry.state, entry.nonce) + entry.extra = extra + if err != nil { + oidcLoginError(c, oidcStateReasonAuthFailed, "oidc auth url failed", err) + return + } else if "" == authURL { + oidcLoginError(c, oidcStateReasonAuthFailed, "oidc auth url is empty", nil) + return + } + stateStore().put(entry) + + oidcLoginSuccess(c, authURL, entry.state) +} + +func oidcFlowFromQuery(c *gin.Context) string { + flow := strings.ToLower(strings.TrimSpace(c.Query("flow"))) + + // Unknown flow, default to web flow for maximum compatibility + if oidcFlowDesktop != flow && oidcFlowMobile != flow && oidcFlowWeb != flow { + logging.LogWarnf("unknown oidc flow [%s], default to web flow", flow) + flow = oidcFlowWeb + } + + return flow +} + +func oidcLoginError(c *gin.Context, reason oidcStateReason, logMsg string, err error, args ...any) { + ret := util.NewResult() + ret.Code = -1 + ret.Msg = oidcReasonMessage(reason, args...) + c.JSON(http.StatusOK, ret) + logOIDCFailure(logMsg, err, c.Request) +} + +func oidcChallenge(providerID, flow, to string, rememberMe bool) *stateEntry { + if "" == flow { + flow = oidcFlowWeb + } + entry := &stateEntry{ + state: gulu.Rand.String(32), + nonce: gulu.Rand.String(32), + providerID: providerID, + to: to, + remember: rememberMe, + flow: flow, + status: oidcStatusPending, + expiresAt: time.Now().Add(oidcStateTTL), + } + return entry +} + +func oidcLoginSuccess(c *gin.Context, authURL, state string) { + ret := util.NewResult() + ret.Data = map[string]any{ + "authUrl": authURL, + "state": state, + } + c.JSON(http.StatusOK, ret) +} + +const oidcLoginTimeout = 10 * time.Second + +func OIDCCallback(c *gin.Context) { + if !OIDCIsEnabled(Conf.OIDC) { + c.Status(http.StatusNotFound) + return + } + + state := strings.TrimSpace(c.Query("state")) + if "" == state { + oidcCallbackError(c, nil, oidcStateReasonCallbackParamsMissing, "missing oidc state", nil) + return + } + + code := strings.TrimSpace(c.Query("code")) + if "" == code { + oidcCallbackError(c, nil, oidcStateReasonCallbackParamsMissing, "missing oidc code", nil) + return + } + + entry, reason := stateStore().do(state, func(entry *stateEntry) bool { + return oidcFlowDesktop != entry.flow + }) + if nil == entry { + oidcCallbackError(c, entry, reason, "get entry failed", nil) + return + } + + p, err := oidcProvider(Conf.OIDC) + if err != nil { + oidcCallbackError(c, entry, oidcStateReasonProviderInitFailed, "init oidc provider failed", err, Conf.OIDC.Provider) + return + } + + if entry.providerID != p.ID() { + oidcCallbackError(c, entry, oidcStateReasonProviderMismatch, "oidc provider mismatch", nil) + return + } + + ctx, cancel := context.WithTimeout(c.Request.Context(), oidcLoginTimeout) + defer cancel() + + claims, err := p.HandleCallback(ctx, code, entry.nonce, entry.extra) + if err != nil { + oidcCallbackError(c, entry, oidcStateReasonAuthFailed, "oidc callback failed", err) + return + } + + if !IsAllowed(Conf.OIDC.Filters, claims) { + oidcCallbackError(c, entry, oidcStateReasonFilterRejected, "oidc filter rejected", nil) + return + } + + oidcCallbackSuccess(c, entry) +} + +func oidcCallbackError(c *gin.Context, entry *stateEntry, reason oidcStateReason, logMsg string, err error, args ...any) { + userMsg := oidcReasonMessage(reason, args...) + flow := "" + to := "" + state := "" + if nil != entry { + flow = entry.flow + to = entry.to + state = entry.state + } + + if oidcFlowDesktop == flow { + oidcRenderCallbackResult(c, false, state, userMsg, logMsg, err) + return + } + + oidcRedirectToCheckAuthError(c, to, userMsg, logMsg, err) +} + +func oidcCallbackSuccess(c *gin.Context, entry *stateEntry) { + if oidcFlowDesktop == entry.flow { + oidcRenderCallbackResult(c, true, entry.state, "", "oidc auth success", nil) + return + } + + if err := applyOIDCSession(c, entry); nil != err { + oidcRedirectToCheckAuthError(c, entry.to, oidcReasonMessage(oidcStateReasonSessionSaveFailed), "save session failed", err) + return + } + + logging.LogInfof("oidc auth success [ip=%s, maxAge=%d]", util.GetRemoteAddr(c.Request), oidcRememberMaxAge(entry.remember)) + c.Redirect(http.StatusFound, entry.to) +} + +func OIDCCheck(c *gin.Context) { + ret := util.NewResult() + + state := strings.TrimSpace(c.Query("state")) + if "" == state { + oidcCheckInvalid(c, ret, oidcStateReasonSessionInvalid) + return + } + + entry, reason := stateStore().do(state, func(entry *stateEntry) bool { + return entry.status != oidcStatusPending + }) + if entry == nil { + oidcCheckInvalid(c, ret, reason) + return + } + + if entry.status == oidcStatusPending { + oidcCheckPending(c, ret) + return + } + + if oidcStatusError == entry.status { + oidcCheckError(c, ret, oidcStateReasonAuthFailed, entry.msg) + return + } + + if err := applyOIDCSession(c, entry); nil != err { + oidcCheckError(c, ret, oidcStateReasonSessionSaveFailed, "") + return + } + + logging.LogInfof("oidc auth success [ip=%s, maxAge=%d]", util.GetRemoteAddr(c.Request), oidcRememberMaxAge(entry.remember)) + oidcCheckOK(c, ret, entry.to) +} + +func oidcCheckInvalid(c *gin.Context, ret *util.Result, reason oidcStateReason) { + ret.Code = -1 + ret.Msg = oidcReasonMessage(reason) + ret.Data = map[string]any{ + "status": oidcStatusError, + } + c.JSON(http.StatusOK, ret) +} + +func oidcCheckPending(c *gin.Context, ret *util.Result) { + ret.Code = 1 + ret.Msg = "Pending" + ret.Data = map[string]any{ + "status": oidcStatusPending, + } + c.JSON(http.StatusOK, ret) +} + +func oidcCheckError(c *gin.Context, ret *util.Result, reason oidcStateReason, msg string) { + if msg == "" { + msg = oidcReasonMessage(reason) + } + ret.Code = -1 + ret.Msg = oidcDefaultErrorMsg(msg) + ret.Data = map[string]any{ + "status": oidcStatusError, + } + c.JSON(http.StatusOK, ret) +} + +func oidcCheckOK(c *gin.Context, ret *util.Result, to string) { + ret.Code = 0 + ret.Msg = "OK" + ret.Data = map[string]any{ + "status": oidcStatusOK, + "to": to, + } + c.JSON(http.StatusOK, ret) +} + +func applyOIDCSession(c *gin.Context, entry *stateEntry) error { + session := util.GetSession(c) + workspaceSession := util.GetWorkspaceSession(session) + + workspaceSession.OIDC.ProviderID = entry.providerID + workspaceSession.OIDC.ProviderHash = Conf.OIDC.ProviderHash + workspaceSession.OIDC.FilterHash = Conf.OIDC.FilterHash + + ginSessions.Default(c).Options(ginSessions.Options{ + Path: "/", + Secure: util.SSL, + MaxAge: oidcRememberMaxAge(entry.remember), + HttpOnly: true, + }) + + return session.Save(c) +} + +func oidcRememberMaxAge(remember bool) int { + if remember { + return 60 * 60 * 24 * 30 + } + return 0 +} + +func oidcRenderCallbackResult(c *gin.Context, ok bool, state, userMsg, logMsg string, err error) { + status := "" + + if ok { + status = oidcStatusOK + } else { + status = oidcStatusError + userMsg = oidcDefaultErrorMsg(userMsg) + logOIDCFailure(logMsg, err, c.Request) + } + + if "" != state { + stateStore().setResult(state, status, userMsg) + } + oidcRenderAppCallbackPage(c, ok, userMsg) +} + +func oidcDefaultErrorMsg(msg string) string { + if "" == msg { + return oidcReasonMessage(oidcStateReasonAuthFailed) + } + return msg +} + +func oidcRenderAppCallbackPage(c *gin.Context, ok bool, msg string) { + title := Conf.Language(283) + if ok { + title = Conf.Language(282) + } + + detail := strings.TrimSpace(msg) + + data, err := os.ReadFile(filepath.Join(util.WorkingDir, "stage/oidc-callback.html")) + if err != nil { + logging.LogErrorf("load oidc callback page failed: %s", err) + c.String(http.StatusOK, detail) + return + } + + tpl, err := template.New("oidc-callback").Parse(string(data)) + if err != nil { + logging.LogErrorf("parse oidc callback page failed: %s", err) + c.String(http.StatusOK, detail) + return + } + + safeDetail := template.HTMLEscapeString(detail) + safeDetail = strings.ReplaceAll(safeDetail, "\n", "
") + model := map[string]any{ + "title": title, + "ok": ok, + "detail": template.HTML(safeDetail), + "appearanceMode": Conf.Appearance.Mode, + "appearanceModeOS": Conf.Appearance.ModeOS, + } + + buf := &bytes.Buffer{} + if err = tpl.Execute(buf, model); err != nil { + logging.LogErrorf("execute oidc callback page failed: %s", err) + c.String(http.StatusOK, detail) + return + } + c.Data(http.StatusOK, "text/html; charset=utf-8", buf.Bytes()) +} + +func OIDCIsEnabled(oidcConf *conf.OIDC) bool { + if nil == oidcConf || "" == strings.TrimSpace(oidcConf.Provider) { + return false + } + + pc := providerConf(oidcConf) + if nil == pc { + return false + } + + return true +} + +const defaultProviderLabel = "Login with SSO" + +func OIDCProviderLabel(oidcConf *conf.OIDC) string { + if !OIDCIsEnabled(oidcConf) { + return defaultProviderLabel + } + p, err := oidcProvider(oidcConf) + if err != nil { + return defaultProviderLabel + } + return p.Label() +} + +func OIDCIsValid(oidcConf *conf.OIDC, workspaceSession *util.WorkspaceSession) bool { + if nil == workspaceSession || nil == workspaceSession.OIDC { + return false + } + if !OIDCIsEnabled(oidcConf) { + return false + } + if "" == workspaceSession.OIDC.ProviderID { + return false + } + if workspaceSession.OIDC.ProviderID != oidcConf.Provider { + return false + } + if workspaceSession.OIDC.ProviderHash != oidcConf.ProviderHash { + return false + } + if workspaceSession.OIDC.FilterHash != oidcConf.FilterHash { + return false + } + return true +} + +func providerConf(oidcConf *conf.OIDC) *conf.OIDCProviderConf { + providerID := strings.TrimSpace(oidcConf.Provider) + return oidcConf.Providers[providerID] +} + +func oidcProvider(oidcConf *conf.OIDC) (oidcprovider.Provider, error) { + pc := providerConf(oidcConf) + if nil == pc { + return nil, errors.New("OIDC provider config not found") + } + return oidcprovider.New(oidcConf.Provider, pc) +} + +func logOIDCFailure(logMsg string, err error, req *http.Request) { + if err != nil { + logging.LogWarnf("oidc auth failed: %s [err=%s, ip=%s]", logMsg, err, util.GetRemoteAddr(req)) + } else { + logging.LogWarnf("oidc auth failed: %s [ip=%s]", logMsg, util.GetRemoteAddr(req)) + } +} + +func oidcRedirectToCheckAuthError(c *gin.Context, redirectTo string, userMsg string, logMsg string, err error) { + userMsg = oidcDefaultErrorMsg(userMsg) + if 200 < len(userMsg) { + userMsg = userMsg[:200] + "..." + } + + logOIDCFailure(logMsg, err, c.Request) + + location := url.URL{Path: "/check-auth"} + queryParams := url.Values{} + if "" != redirectTo { + queryParams.Set("to", redirectTo) + } + queryParams.Set("error", userMsg) + location.RawQuery = queryParams.Encode() + c.Redirect(http.StatusFound, location.String()) +} + +func IsAllowed(filters map[string][]string, claims *oidcprovider.OIDCClaims) bool { + if 0 == len(filters) { + return true + } + + values := claims.FilterValues() + for key, patterns := range filters { + if 0 == len(patterns) { + continue + } + vals, ok := values[key] + if !ok || 0 == len(vals) { + return false + } + + if !matchAnyPattern(vals, patterns) { + return false + } + } + return true +} + +func matchAnyPattern(values []string, patterns []string) bool { + for _, pattern := range patterns { + pattern = strings.TrimSpace(pattern) + if "" == pattern { + continue + } + + matcher, err := buildMatcher(pattern) + if err != nil { + logging.LogErrorf("invalid oidc filter pattern [%s]: %s", pattern, err) + continue + } + if slices.ContainsFunc(values, matcher.Match) { + return true + } + } + return false +} + +type patternMatcher interface { + Match(value string) bool +} + +type matcherFactory func(pattern string) (patternMatcher, error) + +var matcherFactories = map[string]matcherFactory{ + "regex": func(pattern string) (patternMatcher, error) { return newRegexMatcher(pattern, false) }, + "re": func(pattern string) (patternMatcher, error) { return newRegexMatcher(pattern, false) }, + + "regexi": func(pattern string) (patternMatcher, error) { return newRegexMatcher(pattern, true) }, + + "str": func(pattern string) (patternMatcher, error) { return newStringMatcher(pattern, true) }, + "string": func(pattern string) (patternMatcher, error) { return newStringMatcher(pattern, true) }, + + "exact": func(pattern string) (patternMatcher, error) { return newStringMatcher(pattern, false) }, +} + +func buildMatcher(pattern string) (patternMatcher, error) { + if prefix, after, ok := strings.Cut(pattern, ":"); ok { + prefix = strings.ToLower(strings.TrimSpace(prefix)) + if factory, ok := matcherFactories[prefix]; ok { + return factory(strings.TrimSpace(after)) + } + } + + return newRegexMatcher(pattern, true) +} + +type regexMatcher struct { + re *regexp.Regexp +} + +func newRegexMatcher(pattern string, forceCaseInsensitive bool) (patternMatcher, error) { + if forceCaseInsensitive { + pattern = ensureCaseInsensitiveRegex(pattern) + } + re, err := regexp.Compile(pattern) + if err != nil { + return nil, err + } + return ®exMatcher{re: re}, nil +} + +func (m *regexMatcher) Match(value string) bool { + return m.re.MatchString(value) +} + +func ensureCaseInsensitiveRegex(pattern string) string { + if strings.HasPrefix(pattern, "(?i)") || strings.HasPrefix(pattern, "(?-i)") { + return pattern + } + return "(?i)" + pattern +} + +type stringMatcher struct { + pattern string + caseInsensitive bool +} + +func newStringMatcher(pattern string, caseInsensitive bool) (patternMatcher, error) { + return &stringMatcher{pattern: pattern, caseInsensitive: caseInsensitive}, nil +} + +func (m *stringMatcher) Match(value string) bool { + if m.caseInsensitive { + return strings.EqualFold(value, m.pattern) + } + return value == m.pattern +} + +// long enough for user to complete OIDC login, maybe? +const oidcStateTTL = 15 * time.Minute + +// stateEntry tracks the transient OIDC login state. +type stateEntry struct { + state string + nonce string + extra any + providerID string + to string + remember bool + flow string + status string + msg string + expiresAt time.Time + + // current position in heap, internally used by heap.Fix/Remove + // to reduce find complexity. + index int +} + +// oidcStateStore keeps short-lived OIDC login states in memory. +// It maps state -> entry and uses a min-heap to track the next expiration. +// The cleanup loop sleeps until the earliest expiry and stops when the store is empty. +// TTL expiration is enforced on both take() and in the cleanup loop. +type oidcStateStore struct { + mu sync.Mutex + + entries map[string]*stateEntry + heap stateHeap + + timer *time.Timer + wake chan struct{} +} + +var ( + oidcStateStoreOnce sync.Once + oidcStateStoreInst *oidcStateStore +) + +// stateStore returns the singleton store (lazy init). +func stateStore() *oidcStateStore { + oidcStateStoreOnce.Do(func() { + oidcStateStoreInst = &oidcStateStore{ + entries: map[string]*stateEntry{}, + wake: make(chan struct{}, 1), + } + go oidcStateStoreInst.loop() + }) + return oidcStateStoreInst +} + +// signal wakes the cleanup loop to recompute the next deadline. +func (s *oidcStateStore) signal() { + select { + case s.wake <- struct{}{}: + default: + } +} + +// put inserts or updates a state entry. +func (s *oidcStateStore) put(entry *stateEntry) { + defer s.signal() + + s.mu.Lock() + defer s.mu.Unlock() + + if existing, ok := s.entries[entry.state]; ok { + idx := existing.index + *existing = *entry + existing.index = idx + heap.Fix(&s.heap, existing.index) + return + } + + heap.Push(&s.heap, entry) + s.entries[entry.state] = entry +} + +// do returns a copy of the entry and removes it when decide returns true. +func (s *oidcStateStore) do(state string, decide func(entry *stateEntry) bool) (*stateEntry, oidcStateReason) { + remove := false + defer func() { + if remove { + s.signal() + } + }() + + s.mu.Lock() + defer s.mu.Unlock() + + entry, ok := s.entries[state] + if !ok { + return nil, oidcStateReasonSessionInvalid + } + if time.Now().After(entry.expiresAt) { + s.mu.Unlock() + return nil, oidcStateReasonSessionExpired + } + + copied := *entry + remove = decide(&copied) + if remove { + heap.Remove(&s.heap, entry.index) + delete(s.entries, state) + } + + return &copied, oidcStateReasonOK +} + +// setResult marks a state entry as finished and extends its TTL for polling. +func (s *oidcStateStore) setResult(state, status, msg string) { + s.mu.Lock() + defer s.mu.Unlock() + + entry, ok := s.entries[state] + if !ok { + logging.LogErrorf("oidc set result [state: %s, status: %s, msg: %s] failed: cannot find entry", state, status, msg) + return + } + + entry.status = status + entry.msg = msg +} + +// resetTimerLocked restarts the timer with the next delay. +func (s *oidcStateStore) resetTimerLocked(d time.Duration) { + if nil == s.timer { + s.timer = time.NewTimer(d) + return + } + s.timer.Stop() + s.timer.Reset(d) +} + +// stopTimerLocked stops the timer if it exists. +func (s *oidcStateStore) stopTimerLocked() { + if nil == s.timer { + return + } + s.timer.Stop() +} + +// purgeExpiredLocked pops expired entries from the heap. +func (s *oidcStateStore) purgeExpiredLocked(now time.Time) { + for len(s.heap) > 0 { + entry := s.heap[0] + if entry.expiresAt.After(now) { + return + } + heap.Pop(&s.heap) + delete(s.entries, entry.state) + } +} + +// loop sleeps until the next expiry, or wakes on updates. +func (s *oidcStateStore) loop() { + for { + s.mu.Lock() + if len(s.heap) == 0 { + s.stopTimerLocked() + s.mu.Unlock() + <-s.wake + continue + } + + now := time.Now() + s.purgeExpiredLocked(now) + if len(s.heap) == 0 { + s.stopTimerLocked() + s.mu.Unlock() + continue + } + + next := s.heap[0].expiresAt + delay := max(next.Sub(now), 0) + s.resetTimerLocked(delay) + s.mu.Unlock() + + select { + case <-s.timer.C: + case <-s.wake: + } + } +} + +// stateHeap is a min-heap ordered by expiration time. +type stateHeap []*stateEntry + +func (h stateHeap) Len() int { return len(h) } + +func (h stateHeap) Less(i, j int) bool { + return h[i].expiresAt.Before(h[j].expiresAt) +} + +func (h stateHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] + h[i].index = i + h[j].index = j +} + +func (h *stateHeap) Push(x any) { + entry := x.(*stateEntry) + entry.index = len(*h) + *h = append(*h, entry) +} + +func (h *stateHeap) Pop() any { + old := *h + n := len(old) + entry := old[n-1] + entry.index = -1 + *h = old[:n-1] + return entry +} diff --git a/kernel/model/oidc_provider/base_oidc.go b/kernel/model/oidc_provider/base_oidc.go new file mode 100644 index 00000000000..19685ff3e98 --- /dev/null +++ b/kernel/model/oidc_provider/base_oidc.go @@ -0,0 +1,174 @@ +// SiYuan - Refactor your thinking +// Copyright (c) 2020-present, b3log.org +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package oidcprovider + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/base64" + "errors" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "golang.org/x/oauth2" +) + +type ClaimNormalizer func(raw map[string]any, idToken *oidc.IDToken) (*OIDCClaims, error) + +type BaseOIDC struct { + IDStr string + ProviderLabel string + IssuerURLStr string + ClientIDStr string + ClientSecretStr string + RedirectURLStr string + ScopesList []string + PKCE bool + Normalizer ClaimNormalizer +} + +func (p *BaseOIDC) ID() string { + return p.IDStr +} + +func (p *BaseOIDC) Label() string { + if "" != p.ProviderLabel { + return p.ProviderLabel + } + return "Login with OIDC" +} + +func (p *BaseOIDC) AuthURL(state, nonce string) (string, any, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + oauthConfig, _, err := p.discover(ctx) + if err != nil { + return "", nil, err + } + + if p.PKCE { + pkceState, err := newPKCEState() + if err != nil { + return "", nil, err + } + + return oauthConfig.AuthCodeURL( + state, + oidc.Nonce(nonce), + oauth2.SetAuthURLParam("code_challenge", pkceState.challenge), + oauth2.SetAuthURLParam("code_challenge_method", "S256"), + ), pkceState, nil + } + + return oauthConfig.AuthCodeURL(state, oidc.Nonce(nonce)), nil, nil +} + +func (p *BaseOIDC) HandleCallback(ctx context.Context, code, nonce string, extra any) (*OIDCClaims, error) { + oauthConfig, oidcProvider, err := p.discover(ctx) + if err != nil { + return nil, err + } + + var token *oauth2.Token + if p.PKCE { + pkceState, ok := extra.(*pkceState) + if !ok || pkceState == nil || pkceState.Verifier == "" { + return nil, errors.New("oidc pkce verifier missing") + } + token, err = oauthConfig.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", pkceState.Verifier)) + } else { + token, err = oauthConfig.Exchange(ctx, code) + } + if err != nil { + return nil, err + } + + rawIDToken, ok := token.Extra("id_token").(string) + if !ok || "" == rawIDToken { + return nil, errors.New("oidc id_token missing") + } + + verifier := oidcProvider.Verifier(&oidc.Config{ClientID: p.ClientIDStr}) + idToken, err := verifier.Verify(ctx, rawIDToken) + if err != nil { + return nil, err + } + + if "" != nonce && idToken.Nonce != nonce { + return nil, errors.New("oidc nonce mismatch") + } + + rawClaims := map[string]any{} + if err = idToken.Claims(&rawClaims); err != nil { + return nil, err + } + + if p.Normalizer != nil { + return p.Normalizer(rawClaims, idToken) + } + + return DefaultNormalizeClaims(rawClaims, idToken) +} + +func (p *BaseOIDC) discover(ctx context.Context) (*oauth2.Config, *oidc.Provider, error) { + oidcProvider, err := oidc.NewProvider(ctx, p.IssuerURLStr) + if err != nil { + return nil, nil, err + } + + oauthConfig := &oauth2.Config{ + ClientID: p.ClientIDStr, + ClientSecret: p.ClientSecretStr, + Endpoint: oidcProvider.Endpoint(), + RedirectURL: p.RedirectURLStr, + Scopes: p.ScopesList, + } + return oauthConfig, oidcProvider, nil +} + +func DefaultNormalizeClaims(raw map[string]any, idToken *oidc.IDToken) (*OIDCClaims, error) { + claims := &OIDCClaims{ + Subject: idToken.Subject, + Issuer: idToken.Issuer, + Audience: idToken.Audience, + Email: claimString(raw, OIDCClaimEmail), + EmailVerified: claimBool(raw, OIDCClaimEmailVerified), + PreferredUsername: claimString(raw, OIDCClaimPreferredUsername), + Name: claimString(raw, OIDCClaimName), + } + return claims, nil +} + +type pkceState struct { + Verifier string + challenge string +} + +func newPKCEState() (*pkceState, error) { + entropy := make([]byte, 32) + if _, err := rand.Read(entropy); err != nil { + return nil, err + } + + verifier := base64.RawURLEncoding.EncodeToString(entropy) + sum := sha256.Sum256([]byte(verifier)) + challenge := base64.RawURLEncoding.EncodeToString(sum[:]) + + return &pkceState{Verifier: verifier, challenge: challenge}, nil +} diff --git a/kernel/model/oidc_provider/casdoor.go b/kernel/model/oidc_provider/casdoor.go new file mode 100644 index 00000000000..7847da6c7f0 --- /dev/null +++ b/kernel/model/oidc_provider/casdoor.go @@ -0,0 +1,78 @@ +// SiYuan - Refactor your thinking +// Copyright (c) 2020-present, b3log.org +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package oidcprovider + +import ( + "errors" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/siyuan-note/siyuan/kernel/conf" +) + +func NewCasdoor(cfg *conf.OIDCProviderConf) (Provider, error) { + issuerURL := strings.TrimSpace(cfg.IssuerURL) + if issuerURL == "" { + return nil, errors.New("Casdoor issuerURL is required") + } + + clientID := strings.TrimSpace(cfg.ClientID) + if clientID == "" { + return nil, errors.New("Casdoor clientID is required") + } + + clientSecret := strings.TrimSpace(cfg.ClientSecret) + if clientSecret == "" { + return nil, errors.New("Casdoor clientSecret is required") + } + + redirectURL := formatRedirectURL(cfg.RedirectURL) + if redirectURL == "" { + return nil, errors.New("Casdoor redirectURL is required") + } + + scopes := cfg.Scopes + if len(scopes) == 0 { + scopes = []string{"openid", "email", "profile"} + } + + p := &BaseOIDC{ + IDStr: "casdoor", + ProviderLabel: "Login with Casdoor", + IssuerURLStr: issuerURL, + ClientIDStr: clientID, + ClientSecretStr: clientSecret, + RedirectURLStr: redirectURL, + ScopesList: scopes, + } + + p.Normalizer = func(raw map[string]any, idToken *oidc.IDToken) (*OIDCClaims, error) { + claims := &OIDCClaims{ + Provider: "casdoor", + Subject: idToken.Subject, + Issuer: idToken.Issuer, + Audience: idToken.Audience, + Email: claimString(raw, OIDCClaimEmail), + EmailVerified: claimBool(raw, OIDCClaimEmailVerified), + PreferredUsername: claimString(raw, OIDCClaimPreferredUsername), + Name: claimString(raw, OIDCClaimName), + } + return claims, nil + } + + return p, nil +} diff --git a/kernel/model/oidc_provider/custom.go b/kernel/model/oidc_provider/custom.go new file mode 100644 index 00000000000..e6f4eb18265 --- /dev/null +++ b/kernel/model/oidc_provider/custom.go @@ -0,0 +1,96 @@ +// SiYuan - Refactor your thinking +// Copyright (c) 2020-present, b3log.org +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package oidcprovider + +import ( + "errors" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/siyuan-note/siyuan/kernel/conf" +) + +func NewCustom(cfg *conf.OIDCProviderConf) (Provider, error) { + clientID := strings.TrimSpace(cfg.ClientID) + if "" == clientID { + return nil, errors.New("custom OIDC clientID is required") + } + + clientSecret := strings.TrimSpace(cfg.ClientSecret) + if "" == clientSecret { + return nil, errors.New("custom OIDC clientSecret is required") + } + + issuerURL := strings.TrimSpace(cfg.IssuerURL) + if "" == issuerURL { + return nil, errors.New("custom OIDC issuerURL is required") + } + + redirectURL := formatRedirectURL(cfg.RedirectURL) + if redirectURL == "" { + return nil, errors.New("custom OIDC redirectURL is required") + } + + scopes := cfg.Scopes + if len(scopes) == 0 { + scopes = []string{"openid", "email", "profile"} + } + + label := strings.TrimSpace(cfg.ProviderLabel) + if "" == label { + label = "Login with SSO" + } + + p := &BaseOIDC{ + IDStr: "custom", + ProviderLabel: label, + IssuerURLStr: issuerURL, + ClientIDStr: clientID, + ClientSecretStr: clientSecret, + RedirectURLStr: redirectURL, + ScopesList: scopes, + } + + p.Normalizer = func(raw map[string]any, idToken *oidc.IDToken) (*OIDCClaims, error) { + claims := &OIDCClaims{ + Provider: "custom", + Subject: idToken.Subject, + Issuer: idToken.Issuer, + Audience: idToken.Audience, + Email: claimString(raw, claimKeyFromMap(cfg.ClaimMap, OIDCClaimEmail)), + EmailVerified: claimBool(raw, claimKeyFromMap(cfg.ClaimMap, OIDCClaimEmailVerified)), + PreferredUsername: claimString(raw, claimKeyFromMap(cfg.ClaimMap, OIDCClaimPreferredUsername)), + Name: claimString(raw, claimKeyFromMap(cfg.ClaimMap, OIDCClaimName)), + HostedDomain: claimString(raw, claimKeyFromMap(cfg.ClaimMap, OIDCClaimHostedDomain)), + TenantID: claimString(raw, claimKeyFromMap(cfg.ClaimMap, OIDCClaimTenantID)), + Groups: claimStringArray(raw, claimKeyFromMap(cfg.ClaimMap, OIDCClaimGroups)), + } + return claims, nil + } + + return p, nil +} + +func claimKeyFromMap(m map[string]string, key string) string { + if nil == m { + return key + } + if v, ok := m[key]; ok && "" != v { + return v + } + return key +} diff --git a/kernel/model/oidc_provider/github.go b/kernel/model/oidc_provider/github.go new file mode 100644 index 00000000000..06c7dee9285 --- /dev/null +++ b/kernel/model/oidc_provider/github.go @@ -0,0 +1,195 @@ +// SiYuan - Refactor your thinking +// Copyright (c) 2020-present, b3log.org +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package oidcprovider + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "strconv" + "strings" + + "github.com/siyuan-note/logging" + "github.com/siyuan-note/siyuan/kernel/conf" + "golang.org/x/oauth2" + "golang.org/x/oauth2/github" +) + +type oidcProviderGitHub struct { + providerLabel string + scopes []string + clientID string + clientSecret string + redirectURL string +} + +func NewGitHub(cfg *conf.OIDCProviderConf) (Provider, error) { + clientID := strings.TrimSpace(cfg.ClientID) + if "" == clientID { + return nil, errors.New("GitHub clientID is required") + } + + clientSecret := strings.TrimSpace(cfg.ClientSecret) + if "" == clientSecret { + return nil, errors.New("GitHub clientSecret is required") + } + + redirectURL := formatRedirectURL(cfg.RedirectURL) + if "" == redirectURL { + return nil, errors.New("GitHub redirectURL is required") + } + + label := strings.TrimSpace(cfg.ProviderLabel) + if label == "" { + label = "Login with GitHub" + } + + return &oidcProviderGitHub{ + providerLabel: label, + scopes: cfg.Scopes, + clientID: clientID, + clientSecret: clientSecret, + redirectURL: redirectURL, + }, nil +} + +func (p *oidcProviderGitHub) ID() string { + return "github" +} + +func (p *oidcProviderGitHub) Label() string { + return p.providerLabel +} + +func (p *oidcProviderGitHub) AuthURL(state, nonce string) (string, any, error) { + conf := p.oauthConfig() + // GitHub does not support nonce in the standard way OIDC does, so we just use state + return conf.AuthCodeURL(state), nil, nil +} + +func (p *oidcProviderGitHub) HandleCallback(ctx context.Context, code, nonce string, extra any) (*OIDCClaims, error) { + conf := p.oauthConfig() + token, err := conf.Exchange(ctx, code) + if err != nil { + return nil, err + } + + client := conf.Client(ctx, token) + rawUser, err := fetchUser(ctx, client) + if err != nil { + return nil, err + } + + // GitHub user ID is unique per account, we use it as the subject + // user ID is an integer + subject := "" + if id, ok := rawUser["id"].(float64); ok { + subject = strconv.FormatFloat(id, 'f', 0, 64) + } + + email := claimString(rawUser, "email") + var emailVerified *bool + // fetchUser only exposes the public email, same as in your github profile. + // fetchPrimaryEmail includes the primary email even when it's private. + if primaryEmail, verified, err := fetchPrimaryEmail(ctx, client); err != nil { + logging.LogWarnf("failed to fetch GitHub primary email: %s", err) + } else if primaryEmail != "" { + email = primaryEmail + emailVerified = verified + } + + claims := &OIDCClaims{ + Provider: "github", + Subject: subject, + Issuer: "https://github.com", + Email: email, + EmailVerified: emailVerified, + PreferredUsername: claimString(rawUser, "login"), + Name: claimString(rawUser, "name"), + } + return claims, nil +} + +func (p *oidcProviderGitHub) oauthConfig() *oauth2.Config { + scopes := p.scopes + if len(scopes) == 0 { + scopes = []string{"read:user", "user:email"} + } + return &oauth2.Config{ + ClientID: p.clientID, + ClientSecret: p.clientSecret, + RedirectURL: p.redirectURL, + Scopes: scopes, + Endpoint: github.Endpoint, + } +} + +type githubEmailEntry struct { + Email string `json:"email"` + Primary bool `json:"primary"` + Verified bool `json:"verified"` +} + +func fetchUser(ctx context.Context, client *http.Client) (map[string]any, error) { + var user map[string]any + const userURL = "https://api.github.com/user" + if err := fetchGithubJSON(ctx, client, userURL, &user); err != nil { + return nil, err + } + return user, nil +} + +func fetchPrimaryEmail(ctx context.Context, client *http.Client) (string, *bool, error) { + var entries []githubEmailEntry + const emailURL = "https://api.github.com/user/emails" + if err := fetchGithubJSON(ctx, client, emailURL, &entries); err != nil { + return "", nil, err + } + + for _, entry := range entries { + if entry.Primary && entry.Email != "" { + verified := entry.Verified + return entry.Email, &verified, nil + } + } + for _, entry := range entries { + if entry.Email != "" { + verified := entry.Verified + return entry.Email, &verified, nil + } + } + return "", nil, nil +} + +func fetchGithubJSON(ctx context.Context, client *http.Client, url string, out any) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + req.Header.Set("Accept", "application/vnd.github+json") + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return errors.New("GitHub API returned " + resp.Status) + } + return json.NewDecoder(resp.Body).Decode(out) +} diff --git a/kernel/model/oidc_provider/google.go b/kernel/model/oidc_provider/google.go new file mode 100644 index 00000000000..0e639499e55 --- /dev/null +++ b/kernel/model/oidc_provider/google.go @@ -0,0 +1,83 @@ +// SiYuan - Refactor your thinking +// Copyright (c) 2020-present, b3log.org +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package oidcprovider + +import ( + "errors" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/siyuan-note/siyuan/kernel/conf" +) + +const oidcGoogleIssuer = "https://accounts.google.com" + +func NewGoogle(cfg *conf.OIDCProviderConf) (Provider, error) { + clientID := strings.TrimSpace(cfg.ClientID) + if clientID == "" { + return nil, errors.New("Google clientID is required") + } + + clientSecret := strings.TrimSpace(cfg.ClientSecret) + if clientSecret == "" { + return nil, errors.New("Google clientSecret is required") + } + + redirectURL := formatRedirectURL(cfg.RedirectURL) + if redirectURL == "" { + return nil, errors.New("Google redirectURL is required") + } else if strings.HasPrefix(redirectURL, "siyuan://") { + return nil, errors.New("Google does not support custom uri scheme (siyuan://) now") + } + + scopes := cfg.Scopes + if len(scopes) == 0 { + scopes = []string{"openid", "email", "profile"} + } + + label := strings.TrimSpace(cfg.ProviderLabel) + if label == "" { + label = "Login with Google" + } + + p := &BaseOIDC{ + IDStr: "google", + ProviderLabel: label, + IssuerURLStr: oidcGoogleIssuer, + ClientIDStr: clientID, + ClientSecretStr: clientSecret, + RedirectURLStr: redirectURL, + ScopesList: scopes, + } + + p.Normalizer = func(raw map[string]any, idToken *oidc.IDToken) (*OIDCClaims, error) { + claims := &OIDCClaims{ + Provider: "google", + Subject: idToken.Subject, + Issuer: idToken.Issuer, + Audience: idToken.Audience, + Email: claimString(raw, OIDCClaimEmail), + EmailVerified: claimBool(raw, OIDCClaimEmailVerified), + PreferredUsername: claimString(raw, OIDCClaimPreferredUsername), + Name: claimString(raw, OIDCClaimName), + HostedDomain: claimString(raw, OIDCClaimHostedDomain), + } + return claims, nil + } + + return p, nil +} diff --git a/kernel/model/oidc_provider/microsoft.go b/kernel/model/oidc_provider/microsoft.go new file mode 100644 index 00000000000..3b5c8454694 --- /dev/null +++ b/kernel/model/oidc_provider/microsoft.go @@ -0,0 +1,132 @@ +// SiYuan - Refactor your thinking +// Copyright (c) 2020-present, b3log.org +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package oidcprovider + +import ( + "errors" + "fmt" + "strings" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/siyuan-note/siyuan/kernel/conf" +) + +const oidcMicrosoftIssuerFormat = "https://login.microsoftonline.com/%s/v2.0" + +func NewMicrosoft(cfg *conf.OIDCProviderConf) (Provider, error) { + clientID := strings.TrimSpace(cfg.ClientID) + if "" == clientID { + return nil, errors.New("Microsoft clientID is required") + } + + usePKCE := cfg.PKCE + clientSecret := strings.TrimSpace(cfg.ClientSecret) + if !usePKCE && "" == clientSecret { + return nil, errors.New("Microsoft clientSecret is required") + } + if usePKCE { + clientSecret = "" + } + + redirectURL := formatRedirectURL(cfg.RedirectURL) + if "" != redirectURL { + // OIDC best practice prefer 127.0.0.1 over localhost which has potencial dns hijack problem. + // but in Microsoft Azure portal, the UI only accepts "http://localhost:port/..." for http scheme. + // though you can passby UI restriction by directly modify manifest, see: https://learn.microsoft.com/en-us/entra/identity-platform/reply-url#prefer-127001-over-localhost + // but this is tedious to normal users, overall we sacrifice a little security for easy configuration. + if after, ok := strings.CutPrefix(redirectURL, "http://127.0.0.1"); ok { + redirectURL = "http://localhost" + after + } + } else { + return nil, errors.New("Microsoft redirectURL is required") + } + + /* + Microsoft Entra ID 的 OIDC 账号类型说明: + + 微软在 Portal 创建应用时支持 4 种 signInAudience / Supported account types: + + 1. Accounts in this organizational directory only (Default Directory - Single tenant) + 仅允许当前 Entra tenant 内的账号登录,issuer 与 discovery URL 稳定, + 是最符合 OpenID Connect 规范的模式。 + + 2. Accounts in any organizational directory (Any Microsoft Entra ID tenant - Multitenant) + 允许任意组织租户账号登录。 + + 3. Multitenant and personal Microsoft accounts (e.g. Skype, Xbox) + 同时支持组织账号与个人微软账号(MSA)。 + + 4. Personal Microsoft accounts only + 仅允许 MSA 个人账号登录。 + + 设计取舍: + + - 对个人用户和小团队来说,第 1 种已经足够: + 大多数用户的部署场景只需要在用户指定的Tenant内的账号登录即可, + 无需开放给任意组织或个人账号登录。 + + - 微软的多租户做法和 go-oidc 不一样,实现有复杂度: + 具体来说,微软在多租户 endpoint 里返回的issuer URL是 https://login.microsoftonline.com/{tenantid}/v2.0 即是动态的, + 然而 go-oidc 要求 issuer URL 必须和 discovery document 里的一致,会引发校验失败。 + + 因此当前仅实现第 1 种 Single tenant 方案; + 如日后确有真实的多租户需求,再实现。 + */ + tenant := strings.TrimSpace(cfg.Tenant) + if "" == tenant { + return nil, errors.New("Microsoft tenant is required") + } + issuerURL := fmt.Sprintf(oidcMicrosoftIssuerFormat, tenant) + + scopes := cfg.Scopes + if len(scopes) == 0 { + scopes = []string{"openid", "email", "profile"} + } + + label := strings.TrimSpace(cfg.ProviderLabel) + if label == "" { + label = "Login with Microsoft" + } + + p := &BaseOIDC{ + IDStr: "microsoft", + ProviderLabel: label, + IssuerURLStr: issuerURL, + ClientIDStr: clientID, + ClientSecretStr: clientSecret, + RedirectURLStr: redirectURL, + ScopesList: scopes, + PKCE: usePKCE, + } + + p.Normalizer = func(raw map[string]any, idToken *oidc.IDToken) (*OIDCClaims, error) { + claims := &OIDCClaims{ + Provider: "microsoft", + Subject: idToken.Subject, + Issuer: idToken.Issuer, + Audience: idToken.Audience, + Email: claimString(raw, OIDCClaimEmail), + EmailVerified: claimBool(raw, OIDCClaimEmailVerified), + PreferredUsername: claimString(raw, OIDCClaimPreferredUsername), + Name: claimString(raw, OIDCClaimName), + TenantID: claimString(raw, OIDCClaimTenantID), + } + return claims, nil + } + + return p, nil +} diff --git a/kernel/model/oidc_provider/provider.go b/kernel/model/oidc_provider/provider.go new file mode 100644 index 00000000000..b1faca05b38 --- /dev/null +++ b/kernel/model/oidc_provider/provider.go @@ -0,0 +1,201 @@ +// SiYuan - Refactor your thinking +// Copyright (c) 2020-present, b3log.org +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package oidcprovider + +import ( + "context" + "errors" + "net/url" + "strconv" + "strings" + + "github.com/siyuan-note/siyuan/kernel/conf" + "github.com/siyuan-note/siyuan/kernel/util" +) + +type Provider interface { + ID() string + Label() string + + // AuthURL generates the login URL. + // state: used for CSRF protection. + // nonce: used for OIDC replay protection (optional for pure OAuth2). + // extra: optional provider-specific data to be stored with the state. + AuthURL(state, nonce string) (authURL string, extra any, err error) + + // HandleCallback processes the code returned by the provider. + // It exchanges the code for a token and retrieves user claims. + // nonce: passed to verify OIDC ID Token (if applicable). + // extra: provider-specific data stored during AuthURL (optional). + HandleCallback(ctx context.Context, code, nonce string, extra any) (*OIDCClaims, error) +} + +func New(name string, cfg *conf.OIDCProviderConf) (Provider, error) { + switch name { + case "custom": + return NewCustom(cfg) + case "google": + return NewGoogle(cfg) + case "microsoft": + return NewMicrosoft(cfg) + case "github": + return NewGitHub(cfg) + default: + return nil, errors.New("oidc provider is not supported") + } +} + +func formatRedirectURL(rawURL string) string { + rawURL = strings.TrimSpace(rawURL) + if "" == rawURL { + return defaultRedirectURL() + } + + u, err := url.Parse(rawURL) + if nil != err { + // If parsing fails, try prepending http:// if it looks like a host:port + if !strings.HasPrefix(rawURL, "http") { + u, err = url.Parse("http://" + rawURL) + } + } + + if nil == err { + // For http/https scehme, If no path is specified (or just /), append the default callback path + if "http" == u.Scheme || "https" == u.Scheme { + if "" == u.Path || "/" == u.Path { + u.Path = "/auth/oidc/callback" + return u.String() + } + } + } + + return rawURL +} + +func defaultRedirectURL() string { + switch util.Container { + case util.ContainerAndroid, util.ContainerIOS, util.ContainerHarmony: + return "siyuan://oidc-callback" + case util.ContainerStd: + return "http://127.0.0.1:6806/auth/oidc/callback" + default: + return "" + } +} + +const ( + OIDCClaimProvider = "provider" + OIDCClaimSubject = "subject" + OIDCClaimEmail = "email" + OIDCClaimEmailVerified = "email_verified" + OIDCClaimPreferredUsername = "preferred_username" + OIDCClaimName = "name" + OIDCClaimIssuer = "issuer" + OIDCClaimAudience = "audience" + OIDCClaimHostedDomain = "hosted_domain" + OIDCClaimTenantID = "tenant_id" + OIDCClaimGroups = "groups" +) + +type OIDCClaims struct { + Provider string + Subject string + Email string + EmailVerified *bool + PreferredUsername string + Name string + Issuer string + Audience []string + HostedDomain string + TenantID string + Groups []string +} + +func (claims *OIDCClaims) FilterValues() map[string][]string { + values := map[string][]string{} + addValue(values, OIDCClaimProvider, claims.Provider) + addValue(values, OIDCClaimSubject, claims.Subject) + addValue(values, OIDCClaimEmail, claims.Email) + addValue(values, OIDCClaimPreferredUsername, claims.PreferredUsername) + addValue(values, OIDCClaimName, claims.Name) + addValue(values, OIDCClaimIssuer, claims.Issuer) + if nil != claims.EmailVerified { + addValue(values, OIDCClaimEmailVerified, strconv.FormatBool(*claims.EmailVerified)) + } + for _, aud := range claims.Audience { + addValue(values, OIDCClaimAudience, aud) + } + addValue(values, OIDCClaimHostedDomain, claims.HostedDomain) + addValue(values, OIDCClaimTenantID, claims.TenantID) + for _, group := range claims.Groups { + addValue(values, OIDCClaimGroups, group) + } + return values +} + +func addValue(values map[string][]string, key, value string) { + if "" == value { + return + } + values[key] = append(values[key], value) +} + +func claimString(raw map[string]any, key string) string { + val, ok := raw[key] + if !ok || nil == val { + return "" + } + switch typed := val.(type) { + case string: + return typed + } + return "" +} + +func claimStringArray(raw map[string]any, key string) []string { + val, ok := raw[key] + if !ok || nil == val { + return nil + } + switch typed := val.(type) { + case []string: + return typed + case []interface{}: + var out []string + for _, item := range typed { + if str, ok := item.(string); ok { + out = append(out, str) + } + } + return out + case string: + return []string{typed} + } + return nil +} + +func claimBool(raw map[string]any, key string) *bool { + val, ok := raw[key] + if !ok || nil == val { + return nil + } + switch typed := val.(type) { + case bool: + return &typed + } + return nil +} diff --git a/kernel/model/session.go b/kernel/model/session.go index c82c67580da..953372f85f2 100644 --- a/kernel/model/session.go +++ b/kernel/model/session.go @@ -194,183 +194,430 @@ func CheckReadonly(c *gin.Context) { } } -func CheckAuth(c *gin.Context) { - // 已通过 JWT 认证 - if role := GetGinContextRole(c); IsValidRole(role, []Role{ +type authAction int + +const ( + authActionContinue authAction = iota + authActionGrant + authActionPass + authActionDeny + authActionRedirect + authActionHeaderStatus +) + +type authResult struct { + action authAction + role Role + status int + payload map[string]interface{} + redirectTo string + headerKey string + headerValue string +} + +type authStep struct { + name string + handler func(*authContext) authResult +} + +type authContext struct { + ginCtx *gin.Context + session *util.SessionData + workspace *util.WorkspaceSession + isLocalhostConn bool + hasAccessCode bool + oidcEnabled bool + hasAnyAuth bool +} + +func newAuthContext(c *gin.Context) *authContext { + session := util.GetSession(c) + oidcEnabled := OIDCIsEnabled(Conf.OIDC) + hasAccessCode := "" != Conf.AccessAuthCode + return &authContext{ + ginCtx: c, + session: session, + workspace: util.GetWorkspaceSession(session), + isLocalhostConn: util.IsLocalHost(c.Request.RemoteAddr), + hasAccessCode: hasAccessCode, + oidcEnabled: oidcEnabled, + hasAnyAuth: hasAccessCode || oidcEnabled, + } +} + +// authContinue 继续下一个鉴权步骤 +func authContinue() authResult { + return authResult{action: authActionContinue} +} + +// authPass 放行请求,但不修改角色。 +// 也就是说保留当前已有的角色:如果前面 JWT 中间件写了管理员/访客,就保持那个; +// 如果没人写过,则是默认 RoleVisitor +func authPass() authResult { + return authResult{action: authActionPass} +} + +// authGrant 放行请求,并赋予指定角色 +func authGrant(role Role) authResult { + return authResult{action: authActionGrant, role: role} +} + +// authUnauthorized 拒绝请求,返回 401 状态码和指定消息 +func authUnauthorized(msg string) authResult { + return authResult{ + action: authActionDeny, + status: http.StatusUnauthorized, + payload: map[string]any{"code": -1, "msg": msg}, + } +} + +// authRedirect 重定向到指定路径 +func authRedirect(to string) authResult { + return authResult{action: authActionRedirect, redirectTo: to} +} + +// authRedirectToCheckAuth 重定向到 /check-auth 并携带当前请求路径作为参数 +func (ctx *authContext) authRedirectToCheckAuth() authResult { + location := url.URL{} + queryParams := url.Values{} + queryParams.Set("to", ctx.ginCtx.Request.URL.String()) + location.RawQuery = queryParams.Encode() + location.Path = "/check-auth" + return authRedirect(location.String()) +} + +// authHeaderStatus 返回指定 header 和状态码 +func authHeaderStatus(key, val string, status int) authResult { + return authResult{action: authActionHeaderStatus, headerKey: key, headerValue: val, status: status} +} + +// stepExistingRole 放行前面中间件已放行的请求,如已通过 JWT 的请求 +func (ctx *authContext) stepExistingRole() authResult { + if IsValidRole(GetGinContextRole(ctx.ginCtx), []Role{ RoleAdministrator, RoleEditor, RoleReader, }) { - c.Next() - return + return authPass() } + return authContinue() +} - // 通过 API token (header: Authorization) - if authHeader := c.GetHeader("Authorization"); "" != authHeader { - var token string - if strings.HasPrefix(authHeader, "Token ") { - token = strings.TrimPrefix(authHeader, "Token ") - } else if strings.HasPrefix(authHeader, "token ") { - token = strings.TrimPrefix(authHeader, "token ") - } else if strings.HasPrefix(authHeader, "Bearer ") { - token = strings.TrimPrefix(authHeader, "Bearer ") - } else if strings.HasPrefix(authHeader, "bearer ") { - token = strings.TrimPrefix(authHeader, "bearer ") - } +// stepSkipAuth 绕过所有认证步骤 +func (ctx *authContext) stepSkipAuth() authResult { + if Conf.AccessAuthBypass { + return authGrant(RoleAdministrator) + } + return authContinue() +} - if "" != token { - if Conf.Api.Token == token { - c.Set(RoleContextKey, RoleAdministrator) - c.Next() - return - } +// stepAuthorizationHeaderToken 通过 API Token (header: Authorization) 认证 +func (ctx *authContext) stepAuthorizationHeaderToken() authResult { + authHeader := ctx.ginCtx.GetHeader("Authorization") + if "" == authHeader { + return authContinue() + } - c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": "Auth failed [header: Authorization]"}) - c.Abort() - return - } + token := "" + switch { + case strings.HasPrefix(authHeader, "Token "): + token = strings.TrimPrefix(authHeader, "Token ") + case strings.HasPrefix(authHeader, "token "): + token = strings.TrimPrefix(authHeader, "token ") + case strings.HasPrefix(authHeader, "Bearer "): + token = strings.TrimPrefix(authHeader, "Bearer ") + case strings.HasPrefix(authHeader, "bearer "): + token = strings.TrimPrefix(authHeader, "bearer ") } - // 通过 API token (query-params: token) - if token := c.Query("token"); "" != token { - if Conf.Api.Token == token { - c.Set(RoleContextKey, RoleAdministrator) - c.Next() - return - } + if "" == token { + return authContinue() + } + if Conf.Api.Token == token { + return authGrant(RoleAdministrator) + } + return authUnauthorized("Auth failed [header: Authorization]") +} - c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": "Auth failed [query: token]"}) - c.Abort() - return +// stepQueryToken 通过 API Token (query: token) 认证 +func (ctx *authContext) stepQueryToken() authResult { + token := ctx.ginCtx.Query("token") + if "" == token { + return authContinue() + } + if Conf.Api.Token == token { + return authGrant(RoleAdministrator) } + return authUnauthorized("Auth failed [query: token]") +} - //logging.LogInfof("check auth for [%s]", c.Request.RequestURI) - localhost := util.IsLocalHost(c.Request.RemoteAddr) +// stepAuthPageWhitelist 放行鉴权相关页面(登录页/ OIDC 回调 Websocket连接) +func (ctx *authContext) stepAuthPageWhitelist() authResult { + reqURI := ctx.ginCtx.Request.RequestURI - // 未设置访问授权码 - if "" == Conf.AccessAuthCode { - // Skip the empty access authorization code check https://github.com/siyuan-note/siyuan/issues/9709 - if util.SiyuanAccessAuthCodeBypass { - c.Set(RoleContextKey, RoleAdministrator) - c.Next() - return - } + switch { + case "/check-auth" == reqURI: + return authPass() + case strings.HasPrefix(reqURI, "/auth/oidc/"): + return authPass() + // 用于授权页保持连接,避免非常驻内存内核自动退出 https://github.com/siyuan-note/insider/issues/1099 + case strings.Contains(reqURI, "/ws?app=siyuan&id=auth"): + return authPass() + } - // Authenticate requests with the Origin header other than 127.0.0.1 https://github.com/siyuan-note/siyuan/issues/9180 - clientIP := c.ClientIP() - host := c.GetHeader("Host") - origin := c.GetHeader("Origin") - forwardedHost := c.GetHeader("X-Forwarded-Host") - if !localhost || - ("" != clientIP && !util.IsLocalHostname(clientIP)) || - ("" != host && !util.IsLocalHost(host)) || - ("" != origin && !util.IsLocalOrigin(origin) && !strings.HasPrefix(origin, "chrome-extension://")) || - ("" != forwardedHost && !util.IsLocalHost(forwardedHost)) { - c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": "Auth failed: for security reasons, please set [Access authorization code] when using non-127.0.0.1 access\n\n为安全起见,使用非 127.0.0.1 访问时请设置 [访问授权码]"}) - c.Abort() - return + return authContinue() +} + +// stepAuthLocalGuard 远程访问需开启至少一种认证 +func (ctx *authContext) stepAuthLocalGuard() authResult { + // Authenticate requests with the Origin header other than 127.0.0.1 https://github.com/siyuan-note/siyuan/issues/9180 + clientIP := ctx.ginCtx.ClientIP() + host := ctx.ginCtx.GetHeader("Host") + origin := ctx.ginCtx.GetHeader("Origin") + forwardedHost := ctx.ginCtx.GetHeader("X-Forwarded-Host") + + remote := !ctx.isLocalhostConn || + ("" != clientIP && !util.IsLocalHostname(clientIP)) || + ("" != host && !util.IsLocalHost(host)) || + ("" != origin && !util.IsLocalOrigin(origin) && !strings.HasPrefix(origin, "chrome-extension://")) || + ("" != forwardedHost && !util.IsLocalHost(forwardedHost)) + + if remote && !ctx.hasAnyAuth { + return authUnauthorized("Auth failed: for security reasons, please set at least one authentication method when using non-127.0.0.1 access\n\n为安全起见,使用非 127.0.0.1 访问时请至少设置一种认证方式") + } + + return authContinue() +} + +// stepLocalhostNoAuth 本地请求且无认证配置时直接放行 +func (ctx *authContext) stepLocalhostNoAuth() authResult { + if ctx.isLocalhostConn && !ctx.hasAnyAuth { + return authGrant(RoleAdministrator) + } + return authContinue() +} + +// stepSessionAccessCode 通过会话中的访问授权码 +func (ctx *authContext) stepSessionAccessCode() authResult { + if ctx.workspace.AccessAuthCode == Conf.AccessAuthCode && "" != Conf.AccessAuthCode { + return authGrant(RoleAdministrator) + } + return authContinue() +} + +// stepOIDCSession 通过 OIDC 会话 +func (ctx *authContext) stepOIDCSession() authResult { + if OIDCIsValid(Conf.OIDC, ctx.workspace) { + return authGrant(RoleAdministrator) + } + return authContinue() +} + +// stepBasicAuth 使用 BasicAuth 校验访问授权码 +func (ctx *authContext) stepBasicAuth() authResult { + if username, password, ok := ctx.ginCtx.Request.BasicAuth(); ok { + if util.WorkspaceName == username && Conf.AccessAuthCode == password { + return authGrant(RoleAdministrator) } + } + return authContinue() +} - c.Set(RoleContextKey, RoleAdministrator) - c.Next() - return +// stepLocalhostWhitelist 本机特定路径直接赋予管理员(满足后续权限校验) +func (ctx *authContext) stepLocalhostWhitelist() authResult { + if !ctx.isLocalhostConn { + return authContinue() + } + + reqURI := ctx.ginCtx.Request.RequestURI + switch { + case strings.HasPrefix(reqURI, "/assets/") || strings.HasPrefix(reqURI, "/export/"): + return authGrant(RoleAdministrator) + case strings.HasPrefix(reqURI, "/api/system/exit"): + return authGrant(RoleAdministrator) + case strings.HasPrefix(reqURI, "/api/system/getNetwork") || strings.HasPrefix(reqURI, "/api/system/getWorkspaceInfo"): + return authGrant(RoleAdministrator) + case strings.HasPrefix(reqURI, "/api/sync/performSync"): + if util.ContainerIOS == util.Container || util.ContainerAndroid == util.Container || util.ContainerHarmony == util.Container { + return authGrant(RoleAdministrator) + } } + return authContinue() +} +// stepStaticWhitelist 放行静态资源 +func (ctx *authContext) stepStaticWhitelist() authResult { + reqURI := ctx.ginCtx.Request.RequestURI // 放过 /appearance/ 等(不要扩大到 /stage/ 否则鉴权会有问题) - if strings.HasPrefix(c.Request.RequestURI, "/appearance/") || - strings.HasPrefix(c.Request.RequestURI, "/stage/build/export/") || - strings.HasPrefix(c.Request.RequestURI, "/stage/protyle/") { - c.Next() - return + if strings.HasPrefix(reqURI, "/appearance/") || + strings.HasPrefix(reqURI, "/stage/build/export/") || + strings.HasPrefix(reqURI, "/stage/protyle/") { + return authPass() } + return authContinue() +} - // 放过来自本机的某些请求 - if localhost { - if strings.HasPrefix(c.Request.RequestURI, "/assets/") || strings.HasPrefix(c.Request.RequestURI, "/export/") { - c.Set(RoleContextKey, RoleAdministrator) - c.Next() - return +// stepFailWebDAV 确保 WebDAV 返回 Basic 401 +func (ctx *authContext) stepFailWebDAV() authResult { + reqURI := ctx.ginCtx.Request.RequestURI + if strings.HasPrefix(reqURI, "/webdav") || + strings.HasPrefix(reqURI, "/caldav") || + strings.HasPrefix(reqURI, "/carddav") { + return authHeaderStatus(BasicAuthHeaderKey, BasicAuthHeaderValue, http.StatusUnauthorized) + } + return authContinue() +} + +// stepFail 兜底处理:浏览器/客户端重定向,其余 401 +func (ctx *authContext) stepFail() authResult { + logging.LogWarnf("auth failed [ip=%s, path=%s]", util.GetRemoteAddr(ctx.ginCtx.Request), ctx.ginCtx.Request.URL.Path) + userAgentHeader := ctx.ginCtx.GetHeader("User-Agent") + if strings.HasPrefix(userAgentHeader, "SiYuan/") || strings.HasPrefix(userAgentHeader, "Mozilla/") { + if http.MethodGet != ctx.ginCtx.Request.Method || ctx.ginCtx.IsWebsocket() { + return authUnauthorized(Conf.Language(156)) } - if strings.HasPrefix(c.Request.RequestURI, "/api/system/exit") { - c.Set(RoleContextKey, RoleAdministrator) - c.Next() - return + return ctx.authRedirectToCheckAuth() + } + + return authUnauthorized("Auth failed [session]") +} + +func handleAuthResult(c *gin.Context, res authResult) bool { + switch res.action { + case authActionGrant: + c.Set(RoleContextKey, res.role) + c.Next() + return true + case authActionPass: + c.Next() + return true + case authActionRedirect: + target := res.redirectTo + if "" == target { + target = "/" } - if strings.HasPrefix(c.Request.RequestURI, "/api/system/getNetwork") || strings.HasPrefix(c.Request.RequestURI, "/api/system/getWorkspaceInfo") { - c.Set(RoleContextKey, RoleAdministrator) - c.Next() - return + c.Redirect(http.StatusFound, target) + c.Abort() + return true + case authActionDeny: + status := res.status + if 0 == status { + status = http.StatusUnauthorized } - if strings.HasPrefix(c.Request.RequestURI, "/api/sync/performSync") { - if util.ContainerIOS == util.Container || util.ContainerAndroid == util.Container || util.ContainerHarmony == util.Container { - c.Set(RoleContextKey, RoleAdministrator) - c.Next() - return - } + if nil != res.payload { + c.JSON(status, res.payload) + c.Abort() + return true + } + c.AbortWithStatus(status) + return true + case authActionHeaderStatus: + if "" != res.headerKey { + c.Header(res.headerKey, res.headerValue) } + status := res.status + if 0 == status { + status = http.StatusUnauthorized + } + c.AbortWithStatus(status) + return true + default: + return false } +} - // 通过 Cookie - session := util.GetSession(c) - workspaceSession := util.GetWorkspaceSession(session) - if workspaceSession.AccessAuthCode == Conf.AccessAuthCode { - c.Set(RoleContextKey, RoleAdministrator) - c.Next() - return +// CheckAuth 鉴权逻辑 +func CheckAuth(c *gin.Context) { + ctx := newAuthContext(c) + steps := []authStep{ + {name: "existing-role", handler: (*authContext).stepExistingRole}, + + {name: "skip-auth", handler: (*authContext).stepSkipAuth}, + + // API Token 认证 + {name: "authorization-header-token", handler: (*authContext).stepAuthorizationHeaderToken}, + {name: "query-token", handler: (*authContext).stepQueryToken}, + + // 提前放行鉴权相关页面,避免被 auth local guard阻断 + {name: "auth-page-whitelist", handler: (*authContext).stepAuthPageWhitelist}, + + {name: "auth-local-guard", handler: (*authContext).stepAuthLocalGuard}, + + // stepLocalhostNoAuth 务必放在 stepAuthLocalGuard 之后,以防 非Local的远程请求无认证配置时 放行 + // 并且 localGuard 检查是否是 本地请求 的逻辑更严格 + {name: "localhost-no-auth", handler: (*authContext).stepLocalhostNoAuth}, + {name: "session-access-code", handler: (*authContext).stepSessionAccessCode}, + {name: "session-oidc", handler: (*authContext).stepOIDCSession}, + {name: "basic-auth", handler: (*authContext).stepBasicAuth}, + + // 放行特定路径 + {name: "localhost-whitelist", handler: (*authContext).stepLocalhostWhitelist}, + {name: "static-whitelist", handler: (*authContext).stepStaticWhitelist}, + + // 错误处理 + {name: "webdav-auth-fail", handler: (*authContext).stepFailWebDAV}, + {name: "fail", handler: (*authContext).stepFail}, } - // 通过 BasicAuth (header: Authorization) - if username, password, ok := c.Request.BasicAuth(); ok { - // 使用访问授权码作为密码 - if util.WorkspaceName == username && Conf.AccessAuthCode == password { - c.Set(RoleContextKey, RoleAdministrator) - c.Next() + for _, step := range steps { + if handled := handleAuthResult(c, step.handler(ctx)); handled { return } } - // WebDAV BasicAuth Authenticate - if strings.HasPrefix(c.Request.RequestURI, "/webdav") || - strings.HasPrefix(c.Request.RequestURI, "/caldav") || - strings.HasPrefix(c.Request.RequestURI, "/carddav") { - c.Header(BasicAuthHeaderKey, BasicAuthHeaderValue) - c.AbortWithStatus(http.StatusUnauthorized) - return - } + // 不应该到达这里 + logging.LogErrorf("auth logic error") + c.AbortWithStatus(http.StatusUnauthorized) +} - // 跳过访问授权页 - if "/check-auth" == c.Request.URL.Path { - c.Next() - return +func handleAuthResultWebsocket(res authResult, pass *bool) bool { + switch res.action { + case authActionGrant: + *pass = true + return true + case authActionPass: + *pass = true + return true + case authActionRedirect: + panic("websocket auth cannot redirect") + case authActionDeny: + *pass = false + return true + case authActionHeaderStatus: + panic("websocket auth cannot return header status") + default: + return false } +} - if workspaceSession.AccessAuthCode != Conf.AccessAuthCode { - userAgentHeader := c.GetHeader("User-Agent") - if strings.HasPrefix(userAgentHeader, "SiYuan/") || strings.HasPrefix(userAgentHeader, "Mozilla/") { - if "GET" != c.Request.Method || c.IsWebsocket() { - c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": Conf.Language(156)}) - c.Abort() - return - } +// CheckWebsocketAuth WebSocket 鉴权逻辑 +func CheckWebsocketAuth(c *gin.Context) (pass bool) { + ctx := newAuthContext(c) + steps := []authStep{ + {name: "existing-role", handler: (*authContext).stepExistingRole}, - location := url.URL{} - queryParams := url.Values{} - queryParams.Set("to", c.Request.URL.String()) - location.RawQuery = queryParams.Encode() - location.Path = "/check-auth" + //{name: "skip-auth", handler: (*authContext).stepSkipAuth}, - c.Redirect(http.StatusFound, location.String()) - c.Abort() + // 提前放行鉴权相关页面,避免被 auth local guard阻断 + {name: "auth-page-whitelist", handler: (*authContext).stepAuthPageWhitelist}, + + {name: "auth-local-guard", handler: (*authContext).stepAuthLocalGuard}, + + // stepLocalhostNoAuth 务必放在 stepAuthLocalGuard 之后,以防 非Local的远程请求无认证配置时 放行 + // 并且 localGuard 检查是否是 本地请求 的逻辑更严格 + {name: "localhost-no-auth", handler: (*authContext).stepLocalhostNoAuth}, + {name: "session-access-code", handler: (*authContext).stepSessionAccessCode}, + {name: "session-oidc", handler: (*authContext).stepOIDCSession}, + } + + for _, step := range steps { + if handled := handleAuthResultWebsocket(step.handler(ctx), &pass); handled { return } - - c.JSON(http.StatusUnauthorized, map[string]interface{}{"code": -1, "msg": "Auth failed [session]"}) - c.Abort() - return } - c.Set(RoleContextKey, RoleAdministrator) - c.Next() + logging.LogWarnf("closed an unauthenticated session [%s]", util.GetRemoteAddr(c.Request)) + return false } func CheckAdminRole(c *gin.Context) { diff --git a/kernel/server/serve.go b/kernel/server/serve.go index e55ee92d2c9..c9ee7f84d79 100644 --- a/kernel/server/serve.go +++ b/kernel/server/serve.go @@ -168,6 +168,7 @@ func Serve(fastMode bool, cookieKey string) { serveSnippets(ginServer) serveRepoDiff(ginServer) serveCheckAuth(ginServer) + serveOIDC(ginServer) serveFixedStaticFiles(ginServer) api.ServeAPI(ginServer) @@ -416,6 +417,12 @@ func serveCheckAuth(ginServer *gin.Engine) { ginServer.GET("/check-auth", serveAuthPage) } +func serveOIDC(ginServer *gin.Engine) { + ginServer.GET("/auth/oidc/login", model.OIDCLogin) + ginServer.GET("/auth/oidc/callback", model.OIDCCallback) + ginServer.GET("/auth/oidc/check", model.OIDCCheck) +} + func serveAuthPage(c *gin.Context) { data, err := os.ReadFile(filepath.Join(util.WorkingDir, "stage/auth.html")) if err != nil { @@ -450,6 +457,9 @@ func serveAuthPage(c *gin.Context) { keymapHideWindow = "⌥M" } } + oidcEnabled := model.OIDCIsEnabled(model.Conf.OIDC) + oidcProviderName := model.OIDCProviderLabel(model.Conf.OIDC) + authError := strings.TrimSpace(c.Query("error")) model := map[string]interface{}{ "l0": model.Conf.Language(173), "l1": model.Conf.Language(174), @@ -462,6 +472,9 @@ func serveAuthPage(c *gin.Context) { "l8": model.Conf.Language(95), "l9": model.Conf.Language(83), "l10": model.Conf.Language(257), + "l11": model.Conf.Language(289), + "l12": model.Conf.Language(290), + "l13": model.Conf.Language(291), "appearanceMode": model.Conf.Appearance.Mode, "appearanceModeOS": model.Conf.Appearance.ModeOS, "workspace": util.WorkspaceName, @@ -469,6 +482,10 @@ func serveAuthPage(c *gin.Context) { "keymapGeneralToggleWin": keymapHideWindow, "trayMenuLangs": util.TrayMenuLangs[util.Lang], "workspaceDir": util.WorkspaceDir, + "oidcEnabled": oidcEnabled, + "oidcProviderName": oidcProviderName, + "hasAccessAuthCode": "" != model.Conf.AccessAuthCode, + "authError": authError, } buf := &bytes.Buffer{} if err = tpl.Execute(buf, model); err != nil { @@ -576,7 +593,12 @@ func serveWebSocket(ginServer *gin.Engine) { util.WebSocketServer.Config.MaxMessageSize = 1024 * 1024 * 8 ginServer.GET("/ws", func(c *gin.Context) { - if err := util.WebSocketServer.HandleRequest(c.Writer, c.Request); err != nil { + ctxKey := make(map[string]any) + // Websocket 前端 API 因安全设计原因,无法读取到握手时的 HTTP 信息 + // 因此即使鉴权失败,也必须先允许握手成功,再拒绝。 + ctxKey["auth"] = model.CheckWebsocketAuth(c) + + if err := util.WebSocketServer.HandleRequestWithKeys(c.Writer, c.Request, ctxKey); err != nil { logging.LogErrorf("handle command failed: %s", err) } }) @@ -587,50 +609,9 @@ func serveWebSocket(ginServer *gin.Engine) { util.WebSocketServer.HandleConnect(func(s *melody.Session) { //logging.LogInfof("ws check auth for [%s]", s.Request.RequestURI) - authOk := true - - if "" != model.Conf.AccessAuthCode { - session, err := sessionStore.Get(s.Request, "siyuan") - if err != nil { - authOk = false - logging.LogErrorf("get cookie failed: %s", err) - } else { - val := session.Values["data"] - if nil == val { - authOk = false - } else { - sess := &util.SessionData{} - err = gulu.JSON.UnmarshalJSON([]byte(val.(string)), sess) - if err != nil { - authOk = false - logging.LogErrorf("unmarshal cookie failed: %s", err) - } else { - workspaceSess := util.GetWorkspaceSession(sess) - authOk = workspaceSess.AccessAuthCode == model.Conf.AccessAuthCode - } - } - } - } - - // REF: https://github.com/siyuan-note/siyuan/issues/11364 - if !authOk { - if token := model.ParseXAuthToken(s.Request); token != nil { - authOk = token.Valid && model.IsValidRole(model.GetClaimRole(model.GetTokenClaims(token)), []model.Role{ - model.RoleAdministrator, - model.RoleEditor, - model.RoleReader, - }) - } - } - - if !authOk { - // 用于授权页保持连接,避免非常驻内存内核自动退出 https://github.com/siyuan-note/insider/issues/1099 - authOk = strings.Contains(s.Request.RequestURI, "/ws?app=siyuan&id=auth") - } - if !authOk { + if authOk, ok := s.Keys["auth"].(bool); !ok || !authOk { s.CloseWithMsg([]byte(" unauthenticated")) - logging.LogWarnf("closed an unauthenticated session [%s]", util.GetRemoteAddr(s.Request)) return } diff --git a/kernel/util/path.go b/kernel/util/path.go index e2f39a6ef7d..70bbca806c6 100644 --- a/kernel/util/path.go +++ b/kernel/util/path.go @@ -20,11 +20,13 @@ import ( "bytes" "io/fs" "net" + "net/url" "os" "path" "path/filepath" "runtime" "sort" + "strconv" "strings" "time" @@ -409,3 +411,40 @@ func IsPartitionRootPath(path string) bool { return cleanPath == "/" } } + +// SanitizeRedirectPath sanitizes the given redirect path to prevent open redirect vulnerabilities. +// it ensures that the path is relative and does not contain any malicious components. +func SanitizeRedirectPath(dest string) string { + if "" == dest { + return "/" + } + parsed, err := url.Parse(dest) + if err != nil || parsed.IsAbs() || "" != parsed.Host { + return "/" + } + if strings.HasPrefix(parsed.Path, "//") { + return "/" + } + if "" == parsed.Path { + parsed.Path = "/" + } + if !strings.HasPrefix(parsed.Path, "/") { + return "/" + } + parsed.Scheme = "" + parsed.Host = "" + parsed.User = nil + return parsed.String() +} + +func ParseBoolQuery(val string) bool { + val = strings.TrimSpace(val) + if "" == val { + return false + } + if "1" == val { + return true + } + parsed, err := strconv.ParseBool(val) + return err == nil && parsed +} diff --git a/kernel/util/session.go b/kernel/util/session.go index 32ff0b44542..a24430a613c 100644 --- a/kernel/util/session.go +++ b/kernel/util/session.go @@ -37,6 +37,13 @@ type SessionData struct { type WorkspaceSession struct { AccessAuthCode string Captcha string + OIDC *OIDCSession +} + +type OIDCSession struct { + ProviderID string + ProviderHash string + FilterHash string } func (sd *SessionData) Clear(c *gin.Context) { @@ -78,7 +85,6 @@ func GetSession(c *gin.Context) *SessionData { } func GetWorkspaceSession(session *SessionData) (ret *WorkspaceSession) { - ret = &WorkspaceSession{} if nil == session.Workspaces { session.Workspaces = map[string]*WorkspaceSession{} } @@ -87,6 +93,9 @@ func GetWorkspaceSession(session *SessionData) (ret *WorkspaceSession) { ret = &WorkspaceSession{} session.Workspaces[WorkspaceDir] = ret } + if nil == ret.OIDC { + ret.OIDC = &OIDCSession{} + } return } diff --git a/kernel/util/working.go b/kernel/util/working.go index 5244e9c4bec..8372f4e8207 100644 --- a/kernel/util/working.go +++ b/kernel/util/working.go @@ -53,24 +53,43 @@ const ( SIYUAN_ACCESS_AUTH_CODE = "SIYUAN_ACCESS_AUTH_CODE" SIYUAN_WORKSPACE = "SIYUAN_WORKSPACE_PATH" SIYUAN_LANG = "SIYUAN_LANG" -) -var ( - RunInContainer = false // 是否运行在容器中 - SiyuanAccessAuthCodeBypass = false // 是否跳过空访问授权码检查 - SiyuanAccessAuthCodeViaEnvvar = "" // Fallback auth code via env var (SIYUAN_ACCESS_AUTH_CODE) + ContainerStd = "std" // 桌面端 + ContainerDocker = "docker" // Docker 容器端 + ContainerAndroid = "android" // Android 端 + ContainerIOS = "ios" // iOS 端 + ContainerHarmony = "harmony" // 鸿蒙端 + + LocalHost = "127.0.0.1" // 伺服地址 + FixedPort = "6806" // 固定端口 ) -func initEnvVars() { - RunInContainer = isRunningInDockerContainer() - var err error - if SiyuanAccessAuthCodeBypass, err = strconv.ParseBool(os.Getenv("SIYUAN_ACCESS_AUTH_CODE_BYPASS")); err != nil { - SiyuanAccessAuthCodeBypass = false - } - SiyuanAccessAuthCodeViaEnvvar = os.Getenv("SIYUAN_ACCESS_AUTH_CODE") +type AuthCLIArgs struct { + AccessCode string + AccessCodeSet bool + AccessAuthBypass bool + AccessAuthBypassSet bool + OIDCProvider string + OIDCProviderSet bool + OIDCProviders string + OIDCProvidersSet bool + OIDCFilters string + OIDCFiltersSet bool } var ( + ServerURL *url.URL // 内核服务 URL + ServerPort = "0" // HTTP/WebSocket 端口,0 为使用随机端口 + + ReadOnly bool + Lang = "" + + AuthCLI = AuthCLIArgs{} + + Container string // docker, android, ios, harmony, std + RunInContainer = false // 是否运行在容器中 + ISMicrosoftStore bool // 桌面端是否是微软商店版 + bootProgress = atomic.Int32{} // 启动进度,从 0 到 100 bootDetails string // 启动细节描述 HttpServer *http.Server // HTTP 伺服器实例 @@ -90,8 +109,48 @@ func coalesceToEnvVar(fromCLI *string, envVarName string) *string { return fromCLI } +type stringFlag struct { + set bool + val string +} + +func (f *stringFlag) Set(s string) error { + f.val = s + f.set = true + return nil +} + +func (f *stringFlag) String() string { + return f.val +} + +type boolFlag struct { + set bool + val bool +} + +func (f *boolFlag) Set(s string) error { + if "" == s { + s = "true" + } + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + f.val = v + f.set = true + return nil +} + +func (f *boolFlag) String() string { + return strconv.FormatBool(f.val) +} + +func (f *boolFlag) IsBoolFlag() bool { + return true +} + func Boot() { - initEnvVars() IncBootProgress(3, "Booting kernel...") rand.Seed(time.Now().UTC().UnixNano()) initMime() @@ -101,18 +160,49 @@ func Boot() { wdPath := flag.String("wd", WorkingDir, "working directory of SiYuan") port := flag.String("port", "0", "port of the HTTP server") readOnly := flag.String("readonly", "false", "read-only mode") - accessAuthCode := flag.String("accessAuthCode", "", "access auth code") ssl := flag.Bool("ssl", false, "for https and wss") lang := flag.String("lang", "", "ar_SA/de_DE/en_US/es_ES/fr_FR/he_IL/it_IT/ja_JP/ko_KR/pl_PL/pt_BR/ru_RU/tr_TR/zh_CHT/zh_CN") mode := flag.String("mode", "prod", "dev/prod") + var accessAuthCode stringFlag + flag.Var(&accessAuthCode, "access-auth-code", "access auth code") + var accessAuthBypass boolFlag + flag.Var(&accessAuthBypass, "access-auth-bypass", "bypass all access authentication and security checks (not recommended)") + var oidcProvider stringFlag + flag.Var(&oidcProvider, "oidc-provider", "OIDC provider id") + var oidcProviders stringFlag + flag.Var(&oidcProviders, "oidc-providers", "OIDC providers configuration (JSON)") + var oidcFilters stringFlag + flag.Var(&oidcFilters, "oidc-filters", "OIDC filter configuration (JSON)") flag.Parse() // Fallback to env vars if commandline args are not set // valid only for CLI args that default to "", as the // others have explicit (sane) defaults workspacePath = coalesceToEnvVar(workspacePath, SIYUAN_WORKSPACE) - accessAuthCode = coalesceToEnvVar(accessAuthCode, SIYUAN_ACCESS_AUTH_CODE) lang = coalesceToEnvVar(lang, SIYUAN_LANG) + RunInContainer = isRunningInDockerContainer() + if v := strings.TrimSpace(os.Getenv(SIYUAN_ACCESS_AUTH_CODE)); "" != v { + accessAuthCode.val = v + accessAuthCode.set = true + } + if v := strings.TrimSpace(os.Getenv("SIYUAN_OIDC_PROVIDER")); "" != v { + oidcProvider.val = v + oidcProvider.set = true + } + if v := strings.TrimSpace(os.Getenv("SIYUAN_OIDC_PROVIDERS")); "" != v { + oidcProviders.val = v + oidcProviders.set = true + } + if v := strings.TrimSpace(os.Getenv("SIYUAN_OIDC_FILTERS")); "" != v { + oidcFilters.val = v + oidcFilters.set = true + } + if v := strings.TrimSpace(os.Getenv("SIYUAN_ACCESS_AUTH_BYPASS")); "" != v { + if parsed, err := strconv.ParseBool(v); err == nil { + accessAuthBypass.val = parsed + accessAuthBypass.set = true + } + } if "" != *wdPath { WorkingDir = *wdPath @@ -123,30 +213,23 @@ func Boot() { Mode = *mode ServerPort = *port ReadOnly, _ = strconv.ParseBool(*readOnly) - AccessAuthCode = *accessAuthCode - AccessAuthCode = strings.TrimSpace(AccessAuthCode) - AccessAuthCode = RemoveInvalid(AccessAuthCode) + + AuthCLI = AuthCLIArgs{ + AccessCode: RemoveInvalid(strings.TrimSpace(accessAuthCode.val)), + AccessCodeSet: accessAuthCode.set, + AccessAuthBypass: accessAuthBypass.val, + AccessAuthBypassSet: accessAuthBypass.set, + OIDCProvider: strings.TrimSpace(oidcProvider.val), + OIDCProviderSet: oidcProvider.set, + OIDCProviders: strings.TrimSpace(oidcProviders.val), + OIDCProvidersSet: oidcProviders.set, + OIDCFilters: strings.TrimSpace(oidcFilters.val), + OIDCFiltersSet: oidcFilters.set, + } + Container = ContainerStd if RunInContainer { Container = ContainerDocker - if "" == AccessAuthCode { // Still empty? - interruptBoot := true - - // Set the env `SIYUAN_ACCESS_AUTH_CODE_BYPASS=true` to skip checking empty access auth code https://github.com/siyuan-note/siyuan/issues/9709 - if SiyuanAccessAuthCodeBypass { - interruptBoot = false - fmt.Println("bypass access auth code check since the env [SIYUAN_ACCESS_AUTH_CODE_BYPASS] is set to [true]") - } - - if interruptBoot { - // The access authorization code command line parameter must be set when deploying via Docker https://github.com/siyuan-note/siyuan/issues/9328 - fmt.Printf("the access authorization code command line parameter (--accessAuthCode) must be set when deploying via Docker\n") - fmt.Printf("or you can set the SIYUAN_ACCESS_AUTH_CODE env var") - os.Exit(logging.ExitCodeSecurityRisk) - } - } - } - if ContainerStd != Container { ServerPort = FixedPort } @@ -387,29 +470,6 @@ func WriteWorkspacePaths(workspacePaths []string) (err error) { return } -var ( - ServerURL *url.URL // 内核服务 URL - ServerPort = "0" // HTTP/WebSocket 端口,0 为使用随机端口 - - ReadOnly bool - AccessAuthCode string - Lang = "" - - Container string // docker, android, ios, harmony, std - ISMicrosoftStore bool // 桌面端是否是微软商店版 -) - -const ( - ContainerStd = "std" // 桌面端 - ContainerDocker = "docker" // Docker 容器端 - ContainerAndroid = "android" // Android 端 - ContainerIOS = "ios" // iOS 端 - ContainerHarmony = "harmony" // 鸿蒙端 - - LocalHost = "127.0.0.1" // 伺服地址 - FixedPort = "6806" // 固定端口 -) - func initPathDir() { if err := os.MkdirAll(ConfDir, 0755); err != nil && !os.IsExist(err) { logging.LogFatalf(logging.ExitCodeInitWorkspaceErr, "create conf folder [%s] failed: %s", ConfDir, err)