diff --git a/common/config/config.go b/common/config/config.go index 11da0b967d..43f5686273 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -35,6 +35,7 @@ var PasswordLoginEnabled = true var PasswordRegisterEnabled = true var EmailVerificationEnabled = false var GitHubOAuthEnabled = false +var OidcEnabled = false var WeChatAuthEnabled = false var TurnstileCheckEnabled = false var RegisterEnabled = true @@ -70,6 +71,13 @@ var GitHubClientSecret = "" var LarkClientId = "" var LarkClientSecret = "" +var OidcClientId = "" +var OidcClientSecret = "" +var OidcWellKnown = "" +var OidcAuthorizationEndpoint = "" +var OidcTokenEndpoint = "" +var OidcUserinfoEndpoint = "" + var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" diff --git a/controller/auth/oidc.go b/controller/auth/oidc.go new file mode 100644 index 0000000000..7b4ad4b9ee --- /dev/null +++ b/controller/auth/oidc.go @@ -0,0 +1,225 @@ +package auth + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/model" + "net/http" + "strconv" + "time" +) + +type OidcResponse struct { + AccessToken string `json:"access_token"` + IDToken string `json:"id_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + Scope string `json:"scope"` +} + +type OidcUser struct { + OpenID string `json:"sub"` + Email string `json:"email"` + Name string `json:"name"` + PreferredUsername string `json:"preferred_username"` + Picture string `json:"picture"` +} + +func getOidcUserInfoByCode(code string) (*OidcUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + values := map[string]string{ + "client_id": config.OidcClientId, + "client_secret": config.OidcClientSecret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": fmt.Sprintf("%s/oauth/oidc", config.ServerAddress), + } + jsonData, err := json.Marshal(values) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", config.OidcTokenEndpoint, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") + } + defer res.Body.Close() + var oidcResponse OidcResponse + err = json.NewDecoder(res.Body).Decode(&oidcResponse) + if err != nil { + return nil, err + } + req, err = http.NewRequest("GET", config.OidcUserinfoEndpoint, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+oidcResponse.AccessToken) + res2, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至 OIDC 服务器,请稍后重试!") + } + var oidcUser OidcUser + err = json.NewDecoder(res2.Body).Decode(&oidcUser) + if err != nil { + return nil, err + } + return &oidcUser, nil +} + +func OidcAuth(c *gin.Context) { + session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } + username := session.Get("username") + if username != nil { + OidcBind(c) + return + } + if !config.OidcEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 OIDC 登录以及注册", + }) + return + } + code := c.Query("code") + oidcUser, err := getOidcUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + OidcId: oidcUser.OpenID, + } + if model.IsOidcIdAlreadyTaken(user.OidcId) { + err := user.FillUserByOidcId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if config.RegisterEnabled { + user.Email = oidcUser.Email + if oidcUser.PreferredUsername != "" { + user.Username = oidcUser.PreferredUsername + } else { + user.Username = "oidc_" + strconv.Itoa(model.GetMaxUserId()+1) + } + if oidcUser.Name != "" { + user.DisplayName = oidcUser.Name + } else { + user.DisplayName = "OIDC User" + } + err := user.Insert(0) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != model.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + controller.SetupLogin(&user, c) +} + +func OidcBind(c *gin.Context) { + if !config.OidcEnabled { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员未开启通过 OIDC 登录以及注册", + }) + return + } + code := c.Query("code") + oidcUser, err := getOidcUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + OidcId: oidcUser.OpenID, + } + if model.IsOidcIdAlreadyTaken(user.OidcId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该 OIDC 账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.OidcId = oidcUser.OpenID + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/misc.go b/controller/misc.go index 2928b8fb33..ae90087017 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -18,24 +18,30 @@ func GetStatus(c *gin.Context) { "success": true, "message": "", "data": gin.H{ - "version": common.Version, - "start_time": common.StartTime, - "email_verification": config.EmailVerificationEnabled, - "github_oauth": config.GitHubOAuthEnabled, - "github_client_id": config.GitHubClientId, - "lark_client_id": config.LarkClientId, - "system_name": config.SystemName, - "logo": config.Logo, - "footer_html": config.Footer, - "wechat_qrcode": config.WeChatAccountQRCodeImageURL, - "wechat_login": config.WeChatAuthEnabled, - "server_address": config.ServerAddress, - "turnstile_check": config.TurnstileCheckEnabled, - "turnstile_site_key": config.TurnstileSiteKey, - "top_up_link": config.TopUpLink, - "chat_link": config.ChatLink, - "quota_per_unit": config.QuotaPerUnit, - "display_in_currency": config.DisplayInCurrencyEnabled, + "version": common.Version, + "start_time": common.StartTime, + "email_verification": config.EmailVerificationEnabled, + "github_oauth": config.GitHubOAuthEnabled, + "github_client_id": config.GitHubClientId, + "lark_client_id": config.LarkClientId, + "system_name": config.SystemName, + "logo": config.Logo, + "footer_html": config.Footer, + "wechat_qrcode": config.WeChatAccountQRCodeImageURL, + "wechat_login": config.WeChatAuthEnabled, + "server_address": config.ServerAddress, + "turnstile_check": config.TurnstileCheckEnabled, + "turnstile_site_key": config.TurnstileSiteKey, + "top_up_link": config.TopUpLink, + "chat_link": config.ChatLink, + "quota_per_unit": config.QuotaPerUnit, + "display_in_currency": config.DisplayInCurrencyEnabled, + "oidc": config.OidcEnabled, + "oidc_client_id": config.OidcClientId, + "oidc_well_known": config.OidcWellKnown, + "oidc_authorization_endpoint": config.OidcAuthorizationEndpoint, + "oidc_token_endpoint": config.OidcTokenEndpoint, + "oidc_userinfo_endpoint": config.OidcUserinfoEndpoint, }, }) return diff --git a/model/option.go b/model/option.go index bed8d4c37d..8fd30aee2a 100644 --- a/model/option.go +++ b/model/option.go @@ -28,6 +28,7 @@ func InitOptionMap() { config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) + config.OptionMap["OidcEnabled"] = strconv.FormatBool(config.OidcEnabled) config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) @@ -130,6 +131,8 @@ func updateOptionMap(key string, value string) (err error) { config.EmailVerificationEnabled = boolValue case "GitHubOAuthEnabled": config.GitHubOAuthEnabled = boolValue + case "OidcEnabled": + config.OidcEnabled = boolValue case "WeChatAuthEnabled": config.WeChatAuthEnabled = boolValue case "TurnstileCheckEnabled": @@ -176,6 +179,18 @@ func updateOptionMap(key string, value string) (err error) { config.LarkClientId = value case "LarkClientSecret": config.LarkClientSecret = value + case "OidcClientId": + config.OidcClientId = value + case "OidcClientSecret": + config.OidcClientSecret = value + case "OidcWellKnown": + config.OidcWellKnown = value + case "OidcAuthorizationEndpoint": + config.OidcAuthorizationEndpoint = value + case "OidcTokenEndpoint": + config.OidcTokenEndpoint = value + case "OidcUserinfoEndpoint": + config.OidcUserinfoEndpoint = value case "Footer": config.Footer = value case "SystemName": diff --git a/model/user.go b/model/user.go index 924d72f940..a964a0d7d2 100644 --- a/model/user.go +++ b/model/user.go @@ -39,6 +39,7 @@ type User struct { GitHubId string `json:"github_id" gorm:"column:github_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` LarkId string `json:"lark_id" gorm:"column:lark_id;index"` + OidcId string `json:"oidc_id" gorm:"column:oidc_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management Quota int64 `json:"quota" gorm:"bigint;default:0"` @@ -245,6 +246,14 @@ func (user *User) FillUserByLarkId() error { return nil } +func (user *User) FillUserByOidcId() error { + if user.OidcId == "" { + return errors.New("oidc id 为空!") + } + DB.Where(User{OidcId: user.OidcId}).First(user) + return nil +} + func (user *User) FillUserByWeChatId() error { if user.WeChatId == "" { return errors.New("WeChat id 为空!") @@ -277,6 +286,10 @@ func IsLarkIdAlreadyTaken(githubId string) bool { return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 } +func IsOidcIdAlreadyTaken(oidcId string) bool { + return DB.Where("oidc_id = ?", oidcId).Find(&User{}).RowsAffected == 1 +} + func IsUsernameAlreadyTaken(username string) bool { return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 } diff --git a/router/api.go b/router/api.go index d2ada4ebda..6d00c6eaa1 100644 --- a/router/api.go +++ b/router/api.go @@ -23,6 +23,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth) + apiRouter.GET("/oauth/oidc", middleware.CriticalRateLimit(), auth.OidcAuth) apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth) apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode) apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth) diff --git a/web/berry/src/assets/images/icons/lark.svg b/web/berry/src/assets/images/icons/lark.svg index 239e1bef65..79688e2aae 100644 --- a/web/berry/src/assets/images/icons/lark.svg +++ b/web/berry/src/assets/images/icons/lark.svg @@ -1 +1,5 @@ - \ No newline at end of file + + + diff --git a/web/berry/src/assets/images/icons/oidc.svg b/web/berry/src/assets/images/icons/oidc.svg new file mode 100644 index 0000000000..96e01f814d --- /dev/null +++ b/web/berry/src/assets/images/icons/oidc.svg @@ -0,0 +1,7 @@ + + + + diff --git a/web/berry/src/config.js b/web/berry/src/config.js index eeeda99a28..8c1faf9bb6 100644 --- a/web/berry/src/config.js +++ b/web/berry/src/config.js @@ -22,7 +22,12 @@ const config = { turnstile_site_key: '', version: '', wechat_login: false, - wechat_qrcode: '' + wechat_qrcode: '', + oidc: false, + oidc_client_id: '', + oidc_authorization_endpoint: '', + oidc_token_endpoint: '', + oidc_userinfo_endpoint: '', } }; diff --git a/web/berry/src/hooks/useLogin.js b/web/berry/src/hooks/useLogin.js index 39d8b40741..6d89727d8e 100644 --- a/web/berry/src/hooks/useLogin.js +++ b/web/berry/src/hooks/useLogin.js @@ -70,6 +70,28 @@ const useLogin = () => { } }; + const oidcLogin = async (code, state) => { + try { + const res = await API.get(`/api/oauth/oidc?code=${code}&state=${state}`); + const { success, message, data } = res.data; + if (success) { + if (message === 'bind') { + showSuccess('绑定成功!'); + navigate('/panel'); + } else { + dispatch({ type: LOGIN, payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/panel'); + } + } + return { success, message }; + } catch (err) { + // 请求失败,设置错误信息 + return { success: false, message: '' }; + } + } + const wechatLogin = async (code) => { try { const res = await API.get(`/api/oauth/wechat?code=${code}`); @@ -94,7 +116,7 @@ const useLogin = () => { navigate('/'); }; - return { login, logout, githubLogin, wechatLogin, larkLogin }; + return { login, logout, githubLogin, wechatLogin, larkLogin,oidcLogin }; }; export default useLogin; diff --git a/web/berry/src/routes/OtherRoutes.js b/web/berry/src/routes/OtherRoutes.js index 58c0b660e6..a4bdb5d304 100644 --- a/web/berry/src/routes/OtherRoutes.js +++ b/web/berry/src/routes/OtherRoutes.js @@ -9,6 +9,7 @@ const AuthLogin = Loadable(lazy(() => import('views/Authentication/Auth/Login')) const AuthRegister = Loadable(lazy(() => import('views/Authentication/Auth/Register'))); const GitHubOAuth = Loadable(lazy(() => import('views/Authentication/Auth/GitHubOAuth'))); const LarkOAuth = Loadable(lazy(() => import('views/Authentication/Auth/LarkOAuth'))); +const OidcOAuth = Loadable(lazy(() => import('views/Authentication/Auth/OidcOAuth'))); const ForgetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ForgetPassword'))); const ResetPassword = Loadable(lazy(() => import('views/Authentication/Auth/ResetPassword'))); const Home = Loadable(lazy(() => import('views/Home'))); @@ -53,6 +54,10 @@ const OtherRoutes = { path: '/oauth/lark', element: }, + { + path: 'oauth/oidc', + element: + }, { path: '/404', element: diff --git a/web/berry/src/utils/common.js b/web/berry/src/utils/common.js index d74d032e58..f9c2896cda 100644 --- a/web/berry/src/utils/common.js +++ b/web/berry/src/utils/common.js @@ -98,6 +98,21 @@ export async function onLarkOAuthClicked(lark_client_id) { window.open(`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`); } +export async function onOidcClicked(auth_url, client_id, openInNewTab = false) { + const state = await getOAuthState(); + if (!state) return; + const redirect_uri = `${window.location.origin}/oauth/oidc`; + const response_type = "code"; + const scope = "openid profile email"; + const url = `${auth_url}?client_id=${client_id}&redirect_uri=${redirect_uri}&response_type=${response_type}&scope=${scope}&state=${state}`; + if (openInNewTab) { + window.open(url); + } else + { + window.location.href = url; + } +} + export function isAdmin() { let user = localStorage.getItem('user'); if (!user) return false; diff --git a/web/berry/src/views/Authentication/Auth/OidcOAuth.js b/web/berry/src/views/Authentication/Auth/OidcOAuth.js new file mode 100644 index 0000000000..55d9372d15 --- /dev/null +++ b/web/berry/src/views/Authentication/Auth/OidcOAuth.js @@ -0,0 +1,94 @@ +import { Link, useNavigate, useSearchParams } from 'react-router-dom'; +import React, { useEffect, useState } from 'react'; +import { showError } from 'utils/common'; +import useLogin from 'hooks/useLogin'; + +// material-ui +import { useTheme } from '@mui/material/styles'; +import { Grid, Stack, Typography, useMediaQuery, CircularProgress } from '@mui/material'; + +// project imports +import AuthWrapper from '../AuthWrapper'; +import AuthCardWrapper from '../AuthCardWrapper'; +import Logo from 'ui-component/Logo'; + +// assets + +// ================================|| AUTH3 - LOGIN ||================================ // + +const OidcOAuth = () => { + const theme = useTheme(); + const matchDownSM = useMediaQuery(theme.breakpoints.down('md')); + + const [searchParams] = useSearchParams(); + const [prompt, setPrompt] = useState('处理中...'); + const { oidcLogin } = useLogin(); + + let navigate = useNavigate(); + + const sendCode = async (code, state, count) => { + const { success, message } = await oidcLogin(code, state); + if (!success) { + if (message) { + showError(message); + } + if (count === 0) { + setPrompt(`操作失败,重定向至登录界面中...`); + await new Promise((resolve) => setTimeout(resolve, 2000)); + navigate('/login'); + return; + } + count++; + setPrompt(`出现错误,第 ${count} 次重试中...`); + await new Promise((resolve) => setTimeout(resolve, 2000)); + await sendCode(code, state, count); + } + }; + + useEffect(() => { + let code = searchParams.get('code'); + let state = searchParams.get('state'); + sendCode(code, state, 0).then(); + }, []); + + return ( + + + + + + + + + + + + + + + + + + OIDC 登录 + + + + + + + + + {prompt} + + + + + + + + + + ); +}; + +export default OidcOAuth; diff --git a/web/berry/src/views/Authentication/AuthForms/AuthLogin.js b/web/berry/src/views/Authentication/AuthForms/AuthLogin.js index bc7a35c0af..7efd036233 100644 --- a/web/berry/src/views/Authentication/AuthForms/AuthLogin.js +++ b/web/berry/src/views/Authentication/AuthForms/AuthLogin.js @@ -36,7 +36,8 @@ import VisibilityOff from '@mui/icons-material/VisibilityOff'; import Github from 'assets/images/icons/github.svg'; import Wechat from 'assets/images/icons/wechat.svg'; import Lark from 'assets/images/icons/lark.svg'; -import { onGitHubOAuthClicked, onLarkOAuthClicked } from 'utils/common'; +import OIDC from 'assets/images/icons/oidc.svg'; +import { onGitHubOAuthClicked, onLarkOAuthClicked, onOidcClicked } from 'utils/common'; // ============================|| FIREBASE - LOGIN ||============================ // @@ -50,7 +51,7 @@ const LoginForm = ({ ...others }) => { // const [checked, setChecked] = useState(true); let tripartiteLogin = false; - if (siteInfo.github_oauth || siteInfo.wechat_login || siteInfo.lark_client_id) { + if (siteInfo.github_oauth || siteInfo.wechat_login || siteInfo.lark_client_id || siteInfo.oidc) { tripartiteLogin = true; } @@ -145,6 +146,29 @@ const LoginForm = ({ ...others }) => { )} + {siteInfo.oidc && ( + + + + + + )} 8) { + oidc_id = inputs.oidc_id.slice(0, 6) + '...' + inputs.oidc_id.slice(-6); + } + return oidc_id; + } + return ( <> @@ -141,6 +151,9 @@ export default function Profile() { + @@ -216,6 +229,13 @@ export default function Profile() { )} + {status.oidc && !inputs.oidc_id && ( + + + + )} + + + +