Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sftp-server): public key login #7668

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ var db *gorm.DB

func Init(d *gorm.DB) {
db = d
err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem))
err := AutoMigrate(new(model.Storage), new(model.User), new(model.Meta), new(model.SettingItem), new(model.SearchNode), new(model.TaskItem), new(model.SSHPublicKey))
if err != nil {
log.Fatalf("failed migrate database: %s", err.Error())
}
Expand Down
57 changes: 57 additions & 0 deletions internal/db/sshkey.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package db

import (
"github.com/alist-org/alist/v3/internal/model"
"github.com/pkg/errors"
)

func GetSSHPublicKeyByUserId(userId uint, pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) {
keyDB := db.Model(&model.SSHPublicKey{})
query := model.SSHPublicKey{UserId: userId}
if err := keyDB.Where(query).Count(&count).Error; err != nil {
return nil, 0, errors.Wrapf(err, "failed get user's keys count")
}
if err := keyDB.Where(query).Order(columnName("id")).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&keys).Error; err != nil {
return nil, 0, errors.Wrapf(err, "failed get find user's keys")
}
return keys, count, nil
}

func GetSSHPublicKeyById(id uint) (*model.SSHPublicKey, error) {
var k model.SSHPublicKey
if err := db.First(&k, id).Error; err != nil {
return nil, errors.Wrapf(err, "failed get old key")
}
return &k, nil
}

func GetSSHPublicKeyByUserTitle(userId uint, title string) (*model.SSHPublicKey, error) {
key := model.SSHPublicKey{UserId: userId, Title: title}
if err := db.Where(key).First(&key).Error; err != nil {
return nil, errors.Wrapf(err, "failed find key with title of user")
}
return &key, nil
}

func CreateSSHPublicKey(k *model.SSHPublicKey) error {
return errors.WithStack(db.Create(k).Error)
}

func UpdateSSHPublicKey(k *model.SSHPublicKey) error {
return errors.WithStack(db.Save(k).Error)
}

func GetSSHPublicKeys(pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) {
keyDB := db.Model(&model.SSHPublicKey{})
if err := keyDB.Count(&count).Error; err != nil {
return nil, 0, errors.Wrapf(err, "failed get keys count")
}
if err := keyDB.Order(columnName("id")).Offset((pageIndex - 1) * pageSize).Limit(pageSize).Find(&keys).Error; err != nil {
return nil, 0, errors.Wrapf(err, "failed get find keys")
}
return keys, count, nil
}

func DeleteSSHPublicKeyById(id uint) error {
return errors.WithStack(db.Delete(&model.SSHPublicKey{}, id).Error)
}
28 changes: 28 additions & 0 deletions internal/model/sshkey.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package model

import (
"golang.org/x/crypto/ssh"
"time"
)

type SSHPublicKey struct {
ID uint `json:"id" gorm:"primaryKey"`
UserId uint `json:"-"`
Title string `json:"title"`
Fingerprint string `json:"fingerprint"`
KeyStr string `gorm:"type:text" json:"-"`
AddedTime time.Time `json:"added_time"`
LastUsedTime time.Time `json:"last_used_time"`
}

func (k *SSHPublicKey) GetKey() (ssh.PublicKey, error) {
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k.KeyStr))
if err != nil {
return nil, err
}
return pubKey, nil
}

func (k *SSHPublicKey) UpdateLastUsedTime() {
k.LastUsedTime = time.Now()
}
48 changes: 48 additions & 0 deletions internal/op/sshkey.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package op

import (
"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/model"
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"time"
)

func CreateSSHPublicKey(k *model.SSHPublicKey) (error, bool) {
_, err := db.GetSSHPublicKeyByUserTitle(k.UserId, k.Title)
if err == nil {
return errors.New("key with the same title already exists"), true
}
pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(k.KeyStr))
if err != nil {
return err, false
}
k.KeyStr = string(pubKey.Marshal())
k.Fingerprint = ssh.FingerprintSHA256(pubKey)
k.AddedTime = time.Now()
k.LastUsedTime = k.AddedTime
return db.CreateSSHPublicKey(k), true
}

func GetSSHPublicKeyByUserId(userId uint, pageIndex, pageSize int) (keys []model.SSHPublicKey, count int64, err error) {
return db.GetSSHPublicKeyByUserId(userId, pageIndex, pageSize)
}

func GetSSHPublicKeyByIdAndUserId(id uint, userId uint) (*model.SSHPublicKey, error) {
key, err := db.GetSSHPublicKeyById(id)
if err != nil {
return nil, err
}
if key.UserId != userId {
return nil, errors.Wrapf(err, "failed get old key")
}
return key, nil
}

func UpdateSSHPublicKey(k *model.SSHPublicKey) error {
return db.UpdateSSHPublicKey(k)
}

func DeleteSSHPublicKeyById(keyId uint) error {
return db.DeleteSSHPublicKeyById(keyId)
}
124 changes: 124 additions & 0 deletions server/handles/sshkey.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package handles

import (
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/server/common"
"github.com/gin-gonic/gin"
"strconv"
)

type SSHKeyAddReq struct {
Title string `json:"title" binding:"required"`
Key string `json:"key" binding:"required"`
}

func AddMyPublicKey(c *gin.Context) {
userObj, ok := c.Value("user").(*model.User)
if !ok || userObj.IsGuest() {
common.ErrorStrResp(c, "user invalid", 401)
return
}
var req SSHKeyAddReq
if err := c.ShouldBind(&req); err != nil {
common.ErrorStrResp(c, "request invalid", 400)
return
}
if req.Title == "" {
common.ErrorStrResp(c, "request invalid", 400)
return
}
key := &model.SSHPublicKey{
Title: req.Title,
KeyStr: req.Key,
UserId: userObj.ID,
}
err, parsed := op.CreateSSHPublicKey(key)
if !parsed {
common.ErrorStrResp(c, "provided key invalid", 400)
return
} else if err != nil {
common.ErrorResp(c, err, 500, true)
return
}
common.SuccessResp(c)
}

func ListMyPublicKey(c *gin.Context) {
userObj, ok := c.Value("user").(*model.User)
if !ok || userObj.IsGuest() {
common.ErrorStrResp(c, "user invalid", 401)
return
}
list(c, userObj)
}

func DeleteMyPublicKey(c *gin.Context) {
userObj, ok := c.Value("user").(*model.User)
if !ok || userObj.IsGuest() {
common.ErrorStrResp(c, "user invalid", 401)
return
}
keyId, err := strconv.Atoi(c.Query("id"))
if err != nil {
common.ErrorStrResp(c, "id format invalid", 400)
return
}
key, err := op.GetSSHPublicKeyByIdAndUserId(uint(keyId), userObj.ID)
if err != nil {
common.ErrorStrResp(c, "failed to get public key", 404)
return
}
err = op.DeleteSSHPublicKeyById(key.ID)
if err != nil {
common.ErrorResp(c, err, 500, true)
return
}
common.SuccessResp(c)
}

func ListPublicKeys(c *gin.Context) {
userId, err := strconv.Atoi(c.Query("uid"))
if err != nil {
common.ErrorStrResp(c, "user id format invalid", 400)
return
}
userObj, err := op.GetUserById(uint(userId))
if err != nil {
common.ErrorStrResp(c, "user invalid", 404)
return
}
list(c, userObj)
}

func DeletePublicKey(c *gin.Context) {
keyId, err := strconv.Atoi(c.Query("id"))
if err != nil {
common.ErrorStrResp(c, "id format invalid", 400)
return
}
err = op.DeleteSSHPublicKeyById(uint(keyId))
if err != nil {
common.ErrorResp(c, err, 500, true)
return
}
common.SuccessResp(c)
}

func list(c *gin.Context, userObj *model.User) {
var req model.PageReq
if err := c.ShouldBind(&req); err != nil {
common.ErrorResp(c, err, 400)
return
}
req.Validate()
keys, total, err := op.GetSSHPublicKeyByUserId(userObj.ID, req.Page, req.PerPage)
if err != nil {
common.ErrorResp(c, err, 500, true)
return
}
common.SuccessResp(c, common.PageResp{
Content: keys,
Total: total,
})
}
5 changes: 5 additions & 0 deletions server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ func Init(e *gin.Engine) {
api.POST("/auth/login/ldap", handles.LoginLdap)
auth.GET("/me", handles.CurrentUser)
auth.POST("/me/update", handles.UpdateCurrent)
auth.GET("/me/sshkey/list", handles.ListMyPublicKey)
auth.POST("/me/sshkey/add", handles.AddMyPublicKey)
auth.POST("/me/sshkey/delete", handles.DeleteMyPublicKey)
auth.POST("/auth/2fa/generate", handles.Generate2FA)
auth.POST("/auth/2fa/verify", handles.Verify2FA)
auth.GET("/auth/logout", handles.LogOut)
Expand Down Expand Up @@ -102,6 +105,8 @@ func admin(g *gin.RouterGroup) {
user.POST("/cancel_2fa", handles.Cancel2FAById)
user.POST("/delete", handles.DeleteUser)
user.POST("/del_cache", handles.DelUserCache)
user.GET("/sshkey/list", handles.ListPublicKeys)
user.POST("/sshkey/delete", handles.DeletePublicKey)

storage := g.Group("/storage")
storage.GET("/list", handles.ListStorages)
Expand Down
27 changes: 26 additions & 1 deletion server/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/pkg/errors"
"golang.org/x/crypto/ssh"
"net/http"
"time"
)

type SftpDriver struct {
Expand All @@ -35,6 +36,7 @@ func (d *SftpDriver) GetConfig() *sftpd.Config {
NoClientAuth: true,
NoClientAuthCallback: d.NoClientAuth,
PasswordCallback: d.PasswordAuth,
PublicKeyCallback: d.PublicKeyAuth,
AuthLogCallback: d.AuthLogCallback,
BannerCallback: d.GetBanner,
}
Expand Down Expand Up @@ -85,14 +87,37 @@ func (d *SftpDriver) PasswordAuth(conn ssh.ConnMetadata, password []byte) (*ssh.
if err != nil {
return nil, err
}
if userObj.Disabled || !userObj.CanFTPAccess() {
return nil, errors.New("user is not allowed to access via SFTP")
}
passHash := model.StaticHash(string(password))
if err = userObj.ValidatePwdStaticHash(passHash); err != nil {
return nil, err
}
return nil, nil
}

func (d *SftpDriver) PublicKeyAuth(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
userObj, err := op.GetUserByName(conn.User())
if err != nil {
return nil, err
}
if userObj.Disabled || !userObj.CanFTPAccess() {
return nil, errors.New("user is not allowed to access via SFTP")
}
return nil, nil
keys, _, err := op.GetSSHPublicKeyByUserId(userObj.ID, 1, -1)
if err != nil {
return nil, err
}
marshal := string(key.Marshal())
for _, sk := range keys {
if marshal == sk.KeyStr {
sk.LastUsedTime = time.Now()
_ = op.UpdateSSHPublicKey(&sk)
return nil, nil
}
}
return nil, errors.New("public key refused")
}

func (d *SftpDriver) AuthLogCallback(conn ssh.ConnMetadata, method string, err error) {
Expand Down
Loading