diff --git a/backend/pkg/api/data_access/notifications.go b/backend/pkg/api/data_access/notifications.go index fc91128ad..781414571 100644 --- a/backend/pkg/api/data_access/notifications.go +++ b/backend/pkg/api/data_access/notifications.go @@ -285,7 +285,6 @@ func (d *DataAccessService) GetDashboardNotifications(ctx context.Context, userI )). Where( goqu.Ex{"uvd.user_id": userId}, - goqu.L("uvd.network = ANY(?)", pq.Array(chainIds)), ). GroupBy( goqu.I("uvdnh.epoch"), @@ -295,6 +294,12 @@ func (d *DataAccessService) GetDashboardNotifications(ctx context.Context, userI goqu.I("uvdg.name"), ) + if chainIds != nil { + vdbQuery = vdbQuery.Where( + goqu.L("uvd.network = ANY(?)", pq.Array(chainIds)), + ) + } + // TODO account dashboards /*adbQuery := goqu.Dialect("postgres"). From(goqu.T("adb_notifications_history").As("anh")). diff --git a/backend/pkg/api/handlers/auth.go b/backend/pkg/api/handlers/auth.go index 777b84ca9..3f3c2f3b3 100644 --- a/backend/pkg/api/handlers/auth.go +++ b/backend/pkg/api/handlers/auth.go @@ -3,14 +3,10 @@ package handlers import ( "cmp" "context" - "crypto/sha256" - "encoding/binary" - "encoding/hex" "errors" "fmt" "html" "net/http" - "strconv" "strings" "time" @@ -41,7 +37,7 @@ const authEmailExpireTime = time.Minute * 30 type ctxKey string const ctxUserIdKey ctxKey = "user_id" -const ctxIsMockEnabledKey ctxKey = "is_mock_enabled" +const ctxIsMockedKey ctxKey = "is_mocked" var errBadCredentials = newUnauthorizedErr("invalid email or password") @@ -86,7 +82,7 @@ func (h *HandlerService) purgeAllSessionsForUser(ctx context.Context, userId uin // TODO move to service? func (h *HandlerService) sendConfirmationEmail(ctx context.Context, userId uint64, email string) error { // 1. check last confirmation time to enforce ratelimit - lastTs, err := h.dai.GetEmailConfirmationTime(ctx, userId) + lastTs, err := h.daService.GetEmailConfirmationTime(ctx, userId) if err != nil { return errors.New("error getting confirmation-ts") } @@ -96,7 +92,7 @@ func (h *HandlerService) sendConfirmationEmail(ctx context.Context, userId uint6 // 2. update confirmation hash (before sending so there's no hash mismatch on failure) confirmationHash := utils.RandomString(40) - err = h.dai.UpdateEmailConfirmationHash(ctx, userId, email, confirmationHash) + err = h.daService.UpdateEmailConfirmationHash(ctx, userId, email, confirmationHash) if err != nil { return errors.New("error updating confirmation hash") } @@ -117,7 +113,7 @@ Best regards, } // 4. update confirmation time (only after mail was sent) - err = h.dai.UpdateEmailConfirmationTime(ctx, userId) + err = h.daService.UpdateEmailConfirmationTime(ctx, userId) if err != nil { // shouldn't present this as error to user, confirmation works fine log.Error(err, "error updating email confirmation time, rate limiting won't be enforced", 0, nil) @@ -129,7 +125,7 @@ Best regards, func (h *HandlerService) sendPasswordResetEmail(ctx context.Context, userId uint64, email string) error { // 0. check if password resets are allowed // (can be forbidden by admin (not yet in v2)) - passwordResetAllowed, err := h.dai.IsPasswordResetAllowed(ctx, userId) + passwordResetAllowed, err := h.daService.IsPasswordResetAllowed(ctx, userId) if err != nil { return err } @@ -138,7 +134,7 @@ func (h *HandlerService) sendPasswordResetEmail(ctx context.Context, userId uint } // 1. check last confirmation time to enforce ratelimit - lastTs, err := h.dai.GetPasswordResetTime(ctx, userId) + lastTs, err := h.daService.GetPasswordResetTime(ctx, userId) if err != nil { return errors.New("error getting confirmation-ts") } @@ -148,7 +144,7 @@ func (h *HandlerService) sendPasswordResetEmail(ctx context.Context, userId uint // 2. update reset hash (before sending so there's no hash mismatch on failure) resetHash := utils.RandomString(40) - err = h.dai.UpdatePasswordResetHash(ctx, userId, resetHash) + err = h.daService.UpdatePasswordResetHash(ctx, userId, resetHash) if err != nil { return errors.New("error updating confirmation hash") } @@ -169,7 +165,7 @@ Best regards, } // 4. update reset time (only after mail was sent) - err = h.dai.UpdatePasswordResetTime(ctx, userId) + err = h.daService.UpdatePasswordResetTime(ctx, userId) if err != nil { // shouldn't present this as error to user, reset works fine log.Error(err, "error updating password reset time, rate limiting won't be enforced", 0, nil) @@ -198,7 +194,7 @@ func (h *HandlerService) GetUserIdByApiKey(r *http.Request) (uint64, error) { if apiKey == "" { return 0, newUnauthorizedErr("missing api key") } - userId, err := h.dai.GetUserIdByApiKey(r.Context(), apiKey) + userId, err := h.daService.GetUserIdByApiKey(r.Context(), apiKey) if errors.Is(err, dataaccess.ErrNotFound) { err = newUnauthorizedErr("api key not found") } @@ -247,7 +243,7 @@ func (h *HandlerService) InternalPostUsers(w http.ResponseWriter, r *http.Reques return } - _, err := h.dai.GetUserByEmail(r.Context(), email) + _, err := h.daService.GetUserByEmail(r.Context(), email) if !errors.Is(err, dataaccess.ErrNotFound) { if err == nil { returnConflict(w, r, errors.New("email already registered")) @@ -270,7 +266,7 @@ func (h *HandlerService) InternalPostUsers(w http.ResponseWriter, r *http.Reques } // add user - userId, err := h.dai.CreateUser(r.Context(), email, string(passwordHash)) + userId, err := h.daService.CreateUser(r.Context(), email, string(passwordHash)) if err != nil { handleErr(w, r, err) return @@ -295,12 +291,12 @@ func (h *HandlerService) InternalPostUserConfirm(w http.ResponseWriter, r *http. return } - userId, err := h.dai.GetUserIdByConfirmationHash(r.Context(), confirmationHash) + userId, err := h.daService.GetUserIdByConfirmationHash(r.Context(), confirmationHash) if err != nil { handleErr(w, r, err) return } - confirmationTime, err := h.dai.GetEmailConfirmationTime(r.Context(), userId) + confirmationTime, err := h.daService.GetEmailConfirmationTime(r.Context(), userId) if err != nil { handleErr(w, r, err) return @@ -310,7 +306,7 @@ func (h *HandlerService) InternalPostUserConfirm(w http.ResponseWriter, r *http. return } - err = h.dai.UpdateUserEmail(r.Context(), userId) + err = h.daService.UpdateUserEmail(r.Context(), userId) if err != nil { handleErr(w, r, err) return @@ -342,7 +338,7 @@ func (h *HandlerService) InternalPostUserPasswordReset(w http.ResponseWriter, r return } - userId, err := h.dai.GetUserByEmail(r.Context(), email) + userId, err := h.daService.GetUserByEmail(r.Context(), email) if err != nil { if err == dataaccess.ErrNotFound { // don't leak if email is registered @@ -380,12 +376,12 @@ func (h *HandlerService) InternalPostUserPasswordResetHash(w http.ResponseWriter } // check token validity - userId, err := h.dai.GetUserIdByResetHash(r.Context(), resetToken) + userId, err := h.daService.GetUserIdByResetHash(r.Context(), resetToken) if err != nil { handleErr(w, r, err) return } - resetTime, err := h.dai.GetPasswordResetTime(r.Context(), userId) + resetTime, err := h.daService.GetPasswordResetTime(r.Context(), userId) if err != nil { handleErr(w, r, err) return @@ -401,20 +397,20 @@ func (h *HandlerService) InternalPostUserPasswordResetHash(w http.ResponseWriter handleErr(w, r, errors.New("error hashing password")) return } - err = h.dai.UpdateUserPassword(r.Context(), userId, string(passwordHash)) + err = h.daService.UpdateUserPassword(r.Context(), userId, string(passwordHash)) if err != nil { handleErr(w, r, err) return } // if email is not confirmed, confirm since they clicked a link emailed to them - userInfo, err := h.dai.GetUserCredentialInfo(r.Context(), userId) + userInfo, err := h.daService.GetUserCredentialInfo(r.Context(), userId) if err != nil { handleErr(w, r, err) return } if !userInfo.EmailConfirmed { - err = h.dai.UpdateUserEmail(r.Context(), userId) + err = h.daService.UpdateUserEmail(r.Context(), userId) if err != nil { handleErr(w, r, err) return @@ -449,7 +445,7 @@ func (h *HandlerService) InternalPostLogin(w http.ResponseWriter, r *http.Reques } // fetch user - userId, err := h.dai.GetUserByEmail(r.Context(), email) + userId, err := h.daService.GetUserByEmail(r.Context(), email) if err != nil { if errors.Is(err, dataaccess.ErrNotFound) { err = errBadCredentials @@ -457,7 +453,7 @@ func (h *HandlerService) InternalPostLogin(w http.ResponseWriter, r *http.Reques handleErr(w, r, err) return } - user, err := h.dai.GetUserCredentialInfo(r.Context(), userId) + user, err := h.daService.GetUserCredentialInfo(r.Context(), userId) if err != nil { if errors.Is(err, dataaccess.ErrNotFound) { err = errBadCredentials @@ -532,7 +528,7 @@ func (h *HandlerService) InternalPostMobileAuthorize(w http.ResponseWriter, r *h } // check if oauth app exists to validate whether redirect uri is valid - appInfo, err := h.dai.GetAppDataFromRedirectUri(req.RedirectURI) + appInfo, err := h.daService.GetAppDataFromRedirectUri(req.RedirectURI) if err != nil { callback := req.RedirectURI + "?error=invalid_request&error_description=missing_redirect_uri" + state http.Redirect(w, r, callback, http.StatusSeeOther) @@ -549,7 +545,7 @@ func (h *HandlerService) InternalPostMobileAuthorize(w http.ResponseWriter, r *h session := h.scs.Token(r.Context()) sanitizedDeviceName := html.EscapeString(clientName) - err = h.dai.AddUserDevice(userInfo.Id, utils.HashAndEncode(session+session), clientID, sanitizedDeviceName, appInfo.ID) + err = h.daService.AddUserDevice(userInfo.Id, utils.HashAndEncode(session+session), clientID, sanitizedDeviceName, appInfo.ID) if err != nil { log.Warnf("Error adding user device: %v", err) callback := req.RedirectURI + "?error=invalid_request&error_description=server_error" + state @@ -589,7 +585,7 @@ func (h *HandlerService) InternalPostMobileEquivalentExchange(w http.ResponseWri } // Get user info - user, err := h.dai.GetUserCredentialInfo(r.Context(), userID) + user, err := h.daService.GetUserCredentialInfo(r.Context(), userID) if err != nil { if errors.Is(err, dataaccess.ErrNotFound) { err = errBadCredentials @@ -612,7 +608,7 @@ func (h *HandlerService) InternalPostMobileEquivalentExchange(w http.ResponseWri // invalidate old refresh token and replace with hashed session id sanitizedDeviceName := html.EscapeString(req.DeviceName) - err = h.dai.MigrateMobileSession(refreshTokenHashed, utils.HashAndEncode(session+session), req.DeviceID, sanitizedDeviceName) // salted with session + err = h.daService.MigrateMobileSession(refreshTokenHashed, utils.HashAndEncode(session+session), req.DeviceID, sanitizedDeviceName) // salted with session if err != nil { handleErr(w, r, err) return @@ -653,7 +649,7 @@ func (h *HandlerService) InternalPostUsersMeNotificationSettingsPairedDevicesTok return } - err = h.dai.AddMobileNotificationToken(user.Id, deviceID, req.Token) + err = h.daService.AddMobileNotificationToken(user.Id, deviceID, req.Token) if err != nil { handleErr(w, r, err) return @@ -693,7 +689,7 @@ func (h *HandlerService) InternalHandleMobilePurchase(w http.ResponseWriter, r * return } - subscriptionCount, err := h.dai.GetAppSubscriptionCount(user.Id) + subscriptionCount, err := h.daService.GetAppSubscriptionCount(user.Id) if err != nil { handleErr(w, r, err) return @@ -724,7 +720,7 @@ func (h *HandlerService) InternalHandleMobilePurchase(w http.ResponseWriter, r * } } - err = h.dai.AddMobilePurchase(nil, user.Id, req, validationResult, "") + err = h.daService.AddMobilePurchase(nil, user.Id, req, validationResult, "") if err != nil { handleErr(w, r, err) return @@ -755,7 +751,7 @@ func (h *HandlerService) InternalDeleteUser(w http.ResponseWriter, r *http.Reque } // TODO allow if user has any subsciptions etc? - err = h.dai.RemoveUser(r.Context(), user.Id) + err = h.daService.RemoveUser(r.Context(), user.Id) if err != nil { handleErr(w, r, err) return @@ -777,7 +773,7 @@ func (h *HandlerService) InternalPostUserEmail(w http.ResponseWriter, r *http.Re handleErr(w, r, err) return } - userInfo, err := h.dai.GetUserCredentialInfo(r.Context(), user.Id) + userInfo, err := h.daService.GetUserCredentialInfo(r.Context(), user.Id) if err != nil { handleErr(w, r, err) return @@ -809,7 +805,7 @@ func (h *HandlerService) InternalPostUserEmail(w http.ResponseWriter, r *http.Re return } - _, err = h.dai.GetUserByEmail(r.Context(), newEmail) + _, err = h.daService.GetUserByEmail(r.Context(), newEmail) if !errors.Is(err, dataaccess.ErrNotFound) { if err == nil { handleErr(w, r, newConflictErr("email already registered")) @@ -856,7 +852,7 @@ func (h *HandlerService) InternalPutUserPassword(w http.ResponseWriter, r *http. return } // user doesn't contain password, fetch from db - userData, err := h.dai.GetUserCredentialInfo(r.Context(), user.Id) + userData, err := h.daService.GetUserCredentialInfo(r.Context(), user.Id) if err != nil { handleErr(w, r, err) return @@ -892,7 +888,7 @@ func (h *HandlerService) InternalPutUserPassword(w http.ResponseWriter, r *http. } // change password - err = h.dai.UpdateUserPassword(r.Context(), user.Id, string(passwordHash)) + err = h.daService.UpdateUserPassword(r.Context(), user.Id, string(passwordHash)) if err != nil { handleErr(w, r, err) return @@ -906,187 +902,3 @@ func (h *HandlerService) InternalPutUserPassword(w http.ResponseWriter, r *http. returnNoContent(w, r) } - -// Middlewares - -func hashUint64(data uint64) [32]byte { - // Convert uint64 to a byte slice - buf := make([]byte, 8) - binary.LittleEndian.PutUint64(buf, data) - - // Compute SHA-256 hash - hash := sha256.Sum256(buf) - return hash -} - -func checkHash(data uint64, hashStr string) bool { - // Decode the hexadecimal string into a byte slice - hashToCheck, err := hex.DecodeString(hashStr) - if err != nil { - return false - } - - // Hash the uint64 value - computedHash := hashUint64(data) - - // Compare the computed hash with the provided hash - return string(computedHash[:]) == string(hashToCheck) -} - -// returns a middleware that stores user id in context, using the provided function -func StoreUserIdMiddleware(next http.Handler, userIdFunc func(r *http.Request) (uint64, error)) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - userId, err := userIdFunc(r) - if err != nil { - if errors.Is(err, errUnauthorized) { - // if next handler requires authentication, it should return 'unauthorized' itself - next.ServeHTTP(w, r) - } else { - handleErr(w, r, err) - } - return - } - - // if user id matches a given hash, allow access without checking dashboard access and return mock data - // TODO: move to config, exposing this in source code is a minor security risk for now - validHashes := []string{ - "2cab06069254b5555b617efa1d17f0748324270bb587b73422e6840d59ff322c", - "fc624cf355b84bc583661552982894621568b59c0a1c92ab0c1e03ed3bbf649b", - "03e7fb02cbc33eb45e98ab50b4bcad7fc338e5edfb5eca33ad9eb7d13d4ff106", - } - for _, hash := range validHashes { - if checkHash(userId, hash) { - ctx := r.Context() - ctx = context.WithValue(ctx, ctxIsMockEnabledKey, true) - r = r.WithContext(ctx) - } - } - ctx := r.Context() - ctx = context.WithValue(ctx, ctxUserIdKey, userId) - r = r.WithContext(ctx) - next.ServeHTTP(w, r) - }) -} - -func (h *HandlerService) StoreUserIdBySessionMiddleware(next http.Handler) http.Handler { - return StoreUserIdMiddleware(next, func(r *http.Request) (uint64, error) { - return h.GetUserIdBySession(r) - }) -} - -func (h *HandlerService) StoreUserIdByApiKeyMiddleware(next http.Handler) http.Handler { - return StoreUserIdMiddleware(next, func(r *http.Request) (uint64, error) { - return h.GetUserIdByApiKey(r) - }) -} - -// returns a middleware that checks if user has access to dashboard when a primary id is used -func (h *HandlerService) VDBAuthMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // if mock data is used, no need to check access - if isMockEnabled, ok := r.Context().Value(ctxIsMockEnabledKey).(bool); ok && isMockEnabled { - next.ServeHTTP(w, r) - return - } - var err error - dashboardId, err := strconv.ParseUint(mux.Vars(r)["dashboard_id"], 10, 64) - if err != nil { - // if primary id is not used, no need to check access - next.ServeHTTP(w, r) - return - } - // primary id is used -> user needs to have access to dashboard - - userId, err := GetUserIdByContext(r) - if err != nil { - handleErr(w, r, err) - return - } - - // store user id in context - ctx := r.Context() - ctx = context.WithValue(ctx, ctxUserIdKey, userId) - r = r.WithContext(ctx) - - dashboardUser, err := h.dai.GetValidatorDashboardUser(r.Context(), types.VDBIdPrimary(dashboardId)) - if err != nil { - handleErr(w, r, err) - return - } - - if dashboardUser.UserId != userId { - // user does not have access to dashboard - // the proper error would be 403 Forbidden, but we don't want to leak information so we return 404 Not Found - handleErr(w, r, newNotFoundErr("dashboard with id %v not found", dashboardId)) - return - } - - next.ServeHTTP(w, r) - }) -} - -// Common middleware logic for checking user premium perks -func (h *HandlerService) PremiumPerkCheckMiddleware(next http.Handler, hasRequiredPerk func(premiumPerks types.PremiumPerks) bool) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // get user id from context - userId, err := GetUserIdByContext(r) - if err != nil { - handleErr(w, r, err) - return - } - - // get user info - userInfo, err := h.dai.GetUserInfo(r.Context(), userId) - if err != nil { - handleErr(w, r, err) - return - } - - // check if user has the required premium perk - if !hasRequiredPerk(userInfo.PremiumPerks) { - handleErr(w, r, newForbiddenErr("users premium perks do not allow usage of this endpoint")) - return - } - - next.ServeHTTP(w, r) - }) -} - -// Middleware for managing dashboards via API -func (h *HandlerService) ManageDashboardsViaApiCheckMiddleware(next http.Handler) http.Handler { - return h.PremiumPerkCheckMiddleware(next, func(premiumPerks types.PremiumPerks) bool { - return premiumPerks.ManageDashboardViaApi - }) -} - -// Middleware for managing notifications via API -func (h *HandlerService) ManageNotificationsViaApiCheckMiddleware(next http.Handler) http.Handler { - return h.PremiumPerkCheckMiddleware(next, func(premiumPerks types.PremiumPerks) bool { - return premiumPerks.ConfigureNotificationsViaApi - }) -} - -// middleware check to return if specified dashboard is not archived (and accessible) -func (h *HandlerService) VDBArchivedCheckMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - dashboardId, err := h.handleDashboardId(r.Context(), mux.Vars(r)["dashboard_id"]) - if err != nil { - handleErr(w, r, err) - return - } - if len(dashboardId.Validators) > 0 { - next.ServeHTTP(w, r) - return - } - dashboard, err := h.dai.GetValidatorDashboardInfo(r.Context(), dashboardId.Id) - if err != nil { - handleErr(w, r, err) - return - } - if dashboard.IsArchived { - handleErr(w, r, newForbiddenErr("dashboard with id %v is archived", dashboardId)) - return - } - next.ServeHTTP(w, r) - }) -} diff --git a/backend/pkg/api/handlers/backward_compat.go b/backend/pkg/api/handlers/backward_compat.go index f02161a00..e9d823c69 100644 --- a/backend/pkg/api/handlers/backward_compat.go +++ b/backend/pkg/api/handlers/backward_compat.go @@ -43,7 +43,7 @@ func (h *HandlerService) getTokenByRefresh(r *http.Request, refreshToken string) log.Infof("refresh token: %v, claims: %v, hashed refresh: %v", refreshToken, unsafeClaims, refreshTokenHashed) // confirm all claims via db lookup and refreshtoken check - userID, err := h.dai.GetUserIdByRefreshToken(unsafeClaims.UserID, unsafeClaims.AppID, unsafeClaims.DeviceID, refreshTokenHashed) + userID, err := h.daService.GetUserIdByRefreshToken(unsafeClaims.UserID, unsafeClaims.AppID, unsafeClaims.DeviceID, refreshTokenHashed) if err != nil { if err == sql.ErrNoRows { return 0, "", dataaccess.ErrNotFound diff --git a/backend/pkg/api/handlers/handler_service.go b/backend/pkg/api/handlers/handler_service.go new file mode 100644 index 000000000..4e7435351 --- /dev/null +++ b/backend/pkg/api/handlers/handler_service.go @@ -0,0 +1,541 @@ +package handlers + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + + "github.com/gobitfly/beaconchain/pkg/commons/log" + "github.com/invopop/jsonschema" + + "github.com/alexedwards/scs/v2" + dataaccess "github.com/gobitfly/beaconchain/pkg/api/data_access" + "github.com/gobitfly/beaconchain/pkg/api/enums" + "github.com/gobitfly/beaconchain/pkg/api/services" + types "github.com/gobitfly/beaconchain/pkg/api/types" +) + +type HandlerService struct { + daService dataaccess.DataAccessor + daDummy dataaccess.DataAccessor + scs *scs.SessionManager + isPostMachineMetricsEnabled bool // if more config options are needed, consider having the whole config in here +} + +func NewHandlerService(dataAccessor dataaccess.DataAccessor, dummy dataaccess.DataAccessor, sessionManager *scs.SessionManager, enablePostMachineMetrics bool) *HandlerService { + if allNetworks == nil { + networks, err := dataAccessor.GetAllNetworks() + if err != nil { + log.Fatal(err, "error getting networks for handler", 0, nil) + } + allNetworks = networks + } + + return &HandlerService{ + daService: dataAccessor, + daDummy: dummy, + scs: sessionManager, + isPostMachineMetricsEnabled: enablePostMachineMetrics, + } +} + +// getDataAccessor returns the correct data accessor based on the request context. +// if the request is mocked, the data access dummy is returned; otherwise the data access service. +// should only be used if getting mocked data for the endpoint is appropriate +func (h *HandlerService) getDataAccessor(r *http.Request) dataaccess.DataAccessor { + if isMocked(r) { + return h.daDummy + } + return h.daService +} + +// all networks available in the system, filled on startup in NewHandlerService +var allNetworks []types.NetworkInfo + +// -------------------------------------- +// errors + +var ( + errMsgParsingId = errors.New("error parsing parameter 'dashboard_id'") + errBadRequest = errors.New("bad request") + errInternalServer = errors.New("internal server error") + errUnauthorized = errors.New("unauthorized") + errForbidden = errors.New("forbidden") + errConflict = errors.New("conflict") + errTooManyRequests = errors.New("too many requests") + errGone = errors.New("gone") +) + +// -------------------------------------- +// utility functions + +type validatorSet struct { + Indexes []types.VDBValidator + PublicKeys []string +} + +// parseDashboardId is a helper function to validate the string dashboard id param. +func parseDashboardId(id string) (interface{}, error) { + var v validationError + if reInteger.MatchString(id) { + // given id is a normal id + id := v.checkUint(id, "dashboard_id") + if v.hasErrors() { + return nil, v + } + return types.VDBIdPrimary(id), nil + } + if reValidatorDashboardPublicId.MatchString(id) { + // given id is a public id + return types.VDBIdPublic(id), nil + } + // given id must be an encoded set of validators + decodedId, err := base64.RawURLEncoding.DecodeString(id) + if err != nil { + return nil, newBadRequestErr("given value '%s' is not a valid dashboard id", id) + } + indexes, publicKeys := v.checkValidatorList(string(decodedId), forbidEmpty) + if v.hasErrors() { + return nil, newBadRequestErr("given value '%s' is not a valid dashboard id", id) + } + return validatorSet{Indexes: indexes, PublicKeys: publicKeys}, nil +} + +// getDashboardId is a helper function to convert the dashboard id param to a VDBId. +// precondition: dashboardIdParam must be a valid dashboard id and either a primary id, public id, or list of validators. +func (h *HandlerService) getDashboardId(ctx context.Context, dashboardIdParam interface{}) (*types.VDBId, error) { + switch dashboardId := dashboardIdParam.(type) { + case types.VDBIdPrimary: + return &types.VDBId{Id: dashboardId, Validators: nil}, nil + case types.VDBIdPublic: + dashboardInfo, err := h.daService.GetValidatorDashboardPublicId(ctx, dashboardId) + if err != nil { + return nil, err + } + return &types.VDBId{Id: types.VDBIdPrimary(dashboardInfo.DashboardId), Validators: nil, AggregateGroups: !dashboardInfo.ShareSettings.ShareGroups}, nil + case validatorSet: + validators, err := h.daService.GetValidatorsFromSlices(dashboardId.Indexes, dashboardId.PublicKeys) + if err != nil { + return nil, err + } + if len(validators) == 0 { + return nil, newNotFoundErr("no validators found for given id") + } + if len(validators) > maxValidatorsInList { + return nil, newBadRequestErr("too many validators in list, maximum is %d", maxValidatorsInList) + } + return &types.VDBId{Validators: validators}, nil + } + return nil, errMsgParsingId +} + +// handleDashboardId is a helper function to both validate the dashboard id param and convert it to a VDBId. +// it should be used as the last validation step for all internal dashboard GET-handlers. +// Modifying handlers (POST, PUT, DELETE) should only accept primary dashboard ids and just use checkPrimaryDashboardId. +func (h *HandlerService) handleDashboardId(ctx context.Context, param string) (*types.VDBId, error) { + // validate dashboard id param + dashboardIdParam, err := parseDashboardId(param) + if err != nil { + return nil, err + } + // convert to VDBId + dashboardId, err := h.getDashboardId(ctx, dashboardIdParam) + if err != nil { + return nil, err + } + + return dashboardId, nil +} + +const chartDatapointLimit uint64 = 200 + +type ChartTimeDashboardLimits struct { + MinAllowedTs uint64 + LatestExportedTs uint64 + MaxAllowedInterval uint64 +} + +// helper function to retrieve allowed chart timestamp boundaries according to the users premium perks at the current point in time +func (h *HandlerService) getCurrentChartTimeLimitsForDashboard(ctx context.Context, dashboardId *types.VDBId, aggregation enums.ChartAggregation) (ChartTimeDashboardLimits, error) { + limits := ChartTimeDashboardLimits{} + var err error + premiumPerks, err := h.getDashboardPremiumPerks(ctx, *dashboardId) + if err != nil { + return limits, err + } + + maxAge := getMaxChartAge(aggregation, premiumPerks.ChartHistorySeconds) // can be max int for unlimited, always check for underflows + if maxAge == 0 { + return limits, newConflictErr("requested aggregation is not available for dashboard owner's premium subscription") + } + limits.LatestExportedTs, err = h.daService.GetLatestExportedChartTs(ctx, aggregation) + if err != nil { + return limits, err + } + limits.MinAllowedTs = limits.LatestExportedTs - min(maxAge, limits.LatestExportedTs) // min to prevent underflow + secondsPerEpoch := uint64(12 * 32) // TODO: fetch dashboards chain id and use correct value for network once available + limits.MaxAllowedInterval = chartDatapointLimit*uint64(aggregation.Duration(secondsPerEpoch).Seconds()) - 1 // -1 to make sure we don't go over the limit + + return limits, nil +} + +// getDashboardPremiumPerks gets the premium perks of the dashboard OWNER or if it's a guest dashboard, it returns free tier premium perks +func (h *HandlerService) getDashboardPremiumPerks(ctx context.Context, id types.VDBId) (*types.PremiumPerks, error) { + // for guest dashboards, return free tier perks + if id.Validators != nil { + perk, err := h.daService.GetFreeTierPerks(ctx) + if err != nil { + return nil, err + } + return perk, nil + } + // could be made into a single query if needed + dashboardUser, err := h.daService.GetValidatorDashboardUser(ctx, id.Id) + if err != nil { + return nil, err + } + userInfo, err := h.daService.GetUserInfo(ctx, dashboardUser.UserId) + if err != nil { + return nil, err + } + + return &userInfo.PremiumPerks, nil +} + +// getMaxChartAge returns the maximum age of a chart in seconds based on the given aggregation type and premium perks +func getMaxChartAge(aggregation enums.ChartAggregation, perkSeconds types.ChartHistorySeconds) uint64 { + aggregations := enums.ChartAggregations + switch aggregation { + case aggregations.Epoch: + return perkSeconds.Epoch + case aggregations.Hourly: + return perkSeconds.Hourly + case aggregations.Daily: + return perkSeconds.Daily + case aggregations.Weekly: + return perkSeconds.Weekly + default: + return 0 + } +} + +func isUserAdmin(user *types.UserInfo) bool { + if user == nil { + return false + } + return user.UserGroup == types.UserGroupAdmin +} + +// -------------------------------------- +// Response handling + +func writeResponse(w http.ResponseWriter, r *http.Request, statusCode int, response interface{}) { + w.Header().Set("Content-Type", "application/json") + if response == nil { + w.WriteHeader(statusCode) + return + } + jsonData, err := json.Marshal(response) + if err != nil { + logApiError(r, fmt.Errorf("error encoding json data: %w", err), 0, + log.Fields{ + "data": fmt.Sprintf("%+v", response), + }) + w.WriteHeader(http.StatusInternalServerError) + response = types.ApiErrorResponse{ + Error: "error encoding json data", + } + if err = json.NewEncoder(w).Encode(response); err != nil { + // there seems to be an error with the lib + logApiError(r, fmt.Errorf("error encoding error response after failed encoding: %w", err), 0) + } + return + } + w.WriteHeader(statusCode) + if _, err = w.Write(jsonData); err != nil { + // already returned wrong status code to user, can't prevent that + logApiError(r, fmt.Errorf("error writing response data: %w", err), 0) + } +} + +func returnError(w http.ResponseWriter, r *http.Request, code int, err error) { + response := types.ApiErrorResponse{ + Error: err.Error(), + } + writeResponse(w, r, code, response) +} + +func returnOk(w http.ResponseWriter, r *http.Request, data interface{}) { + writeResponse(w, r, http.StatusOK, data) +} + +func returnCreated(w http.ResponseWriter, r *http.Request, data interface{}) { + writeResponse(w, r, http.StatusCreated, data) +} + +func returnNoContent(w http.ResponseWriter, r *http.Request) { + writeResponse(w, r, http.StatusNoContent, nil) +} + +// Errors + +func returnBadRequest(w http.ResponseWriter, r *http.Request, err error) { + returnError(w, r, http.StatusBadRequest, err) +} + +func returnUnauthorized(w http.ResponseWriter, r *http.Request, err error) { + returnError(w, r, http.StatusUnauthorized, err) +} + +func returnNotFound(w http.ResponseWriter, r *http.Request, err error) { + returnError(w, r, http.StatusNotFound, err) +} + +func returnConflict(w http.ResponseWriter, r *http.Request, err error) { + returnError(w, r, http.StatusConflict, err) +} + +func returnForbidden(w http.ResponseWriter, r *http.Request, err error) { + returnError(w, r, http.StatusForbidden, err) +} + +func returnTooManyRequests(w http.ResponseWriter, r *http.Request, err error) { + returnError(w, r, http.StatusTooManyRequests, err) +} + +func returnGone(w http.ResponseWriter, r *http.Request, err error) { + returnError(w, r, http.StatusGone, err) +} + +const maxBodySize = 10 * 1024 + +func logApiError(r *http.Request, err error, callerSkip int, additionalInfos ...log.Fields) { + body, _ := io.ReadAll(io.LimitReader(r.Body, maxBodySize)) + requestFields := log.Fields{ + "request_endpoint": r.Method + " " + r.URL.Path, + "request_query": r.URL.RawQuery, + "request_body": string(body), + } + log.Error(err, "error handling request", callerSkip+1, append(additionalInfos, requestFields)...) +} + +func handleErr(w http.ResponseWriter, r *http.Request, err error) { + _, isValidationError := err.(validationError) + switch { + case isValidationError, errors.Is(err, errBadRequest): + returnBadRequest(w, r, err) + case errors.Is(err, dataaccess.ErrNotFound): + returnNotFound(w, r, err) + case errors.Is(err, errUnauthorized): + returnUnauthorized(w, r, err) + case errors.Is(err, errForbidden): + returnForbidden(w, r, err) + case errors.Is(err, errConflict): + returnConflict(w, r, err) + case errors.Is(err, services.ErrWaiting): + returnError(w, r, http.StatusServiceUnavailable, err) + case errors.Is(err, errTooManyRequests): + returnTooManyRequests(w, r, err) + case errors.Is(err, errGone): + returnGone(w, r, err) + case errors.Is(err, context.Canceled): + if r.Context().Err() != context.Canceled { // only return error if the request context was canceled + logApiError(r, err, 1) + returnError(w, r, http.StatusInternalServerError, err) + } + default: + logApiError(r, err, 1) + // TODO: don't return the error message to the user in production + returnError(w, r, http.StatusInternalServerError, err) + } +} + +// -------------------------------------- +// Error Helpers + +func errWithMsg(err error, format string, args ...interface{}) error { + return fmt.Errorf("%w: %s", err, fmt.Sprintf(format, args...)) +} + +//nolint:nolintlint +//nolint:unparam +func newBadRequestErr(format string, args ...interface{}) error { + return errWithMsg(errBadRequest, format, args...) +} + +//nolint:unparam +func newInternalServerErr(format string, args ...interface{}) error { + return errWithMsg(errInternalServer, format, args...) +} + +//nolint:unparam +func newUnauthorizedErr(format string, args ...interface{}) error { + return errWithMsg(errUnauthorized, format, args...) +} + +func newForbiddenErr(format string, args ...interface{}) error { + return errWithMsg(errForbidden, format, args...) +} + +//nolint:unparam +func newConflictErr(format string, args ...interface{}) error { + return errWithMsg(errConflict, format, args...) +} + +//nolint:nolintlint +//nolint:unparam +func newNotFoundErr(format string, args ...interface{}) error { + return errWithMsg(dataaccess.ErrNotFound, format, args...) +} + +func newTooManyRequestsErr(format string, args ...interface{}) error { + return errWithMsg(errTooManyRequests, format, args...) +} + +func newGoneErr(format string, args ...interface{}) error { + return errWithMsg(errGone, format, args...) +} + +// -------------------------------------- +// misc. helper functions + +// maps different types of validator dashboard summary validators to a common format +func mapVDBIndices(indices interface{}) ([]types.VDBSummaryValidatorsData, error) { + if indices == nil { + return nil, errors.New("no data found when mapping") + } + + switch v := indices.(type) { + case *types.VDBGeneralSummaryValidators: + // deposited, online, offline, slashing, slashed, exited, withdrawn, pending, exiting, withdrawing + return []types.VDBSummaryValidatorsData{ + mapUintSlice("deposited", v.Deposited), + mapUintSlice("online", v.Online), + mapUintSlice("offline", v.Offline), + mapUintSlice("slashing", v.Slashing), + mapUintSlice("slashed", v.Slashed), + mapUintSlice("exited", v.Exited), + mapUintSlice("withdrawn", v.Withdrawn), + mapIndexTimestampSlice("pending", v.Pending), + mapIndexTimestampSlice("exiting", v.Exiting), + mapIndexTimestampSlice("withdrawing", v.Withdrawing), + }, nil + + case *types.VDBSyncSummaryValidators: + return []types.VDBSummaryValidatorsData{ + mapUintSlice("sync_current", v.Current), + mapUintSlice("sync_upcoming", v.Upcoming), + mapSlice("sync_past", v.Past, + func(v types.VDBValidatorSyncPast) (uint64, []uint64) { return v.Index, []uint64{v.Count} }, + ), + }, nil + + case *types.VDBSlashingsSummaryValidators: + return []types.VDBSummaryValidatorsData{ + mapSlice("got_slashed", v.GotSlashed, + func(v types.VDBValidatorGotSlashed) (uint64, []uint64) { return v.Index, []uint64{v.SlashedBy} }, + ), + mapSlice("has_slashed", v.HasSlashed, + func(v types.VDBValidatorHasSlashed) (uint64, []uint64) { return v.Index, v.SlashedIndices }, + ), + }, nil + + case *types.VDBProposalSummaryValidators: + return []types.VDBSummaryValidatorsData{ + mapIndexBlocksSlice("proposal_proposed", v.Proposed), + mapIndexBlocksSlice("proposal_missed", v.Missed), + }, nil + + default: + return nil, fmt.Errorf("unsupported indices type") + } +} + +// maps different types of validator dashboard summary validators to a common format +func mapSlice[T any](category string, validators []T, getIndexAndDutyObjects func(validator T) (index uint64, dutyObjects []uint64)) types.VDBSummaryValidatorsData { + validatorsData := make([]types.VDBSummaryValidator, len(validators)) + for i, validator := range validators { + index, dutyObjects := getIndexAndDutyObjects(validator) + validatorsData[i] = types.VDBSummaryValidator{Index: index, DutyObjects: dutyObjects} + } + return types.VDBSummaryValidatorsData{ + Category: category, + Validators: validatorsData, + } +} +func mapUintSlice(category string, validators []uint64) types.VDBSummaryValidatorsData { + return mapSlice(category, validators, + func(v uint64) (uint64, []uint64) { return v, nil }, + ) +} + +func mapIndexTimestampSlice(category string, validators []types.IndexTimestamp) types.VDBSummaryValidatorsData { + return mapSlice(category, validators, + func(v types.IndexTimestamp) (uint64, []uint64) { return v.Index, []uint64{v.Timestamp} }, + ) +} + +func mapIndexBlocksSlice(category string, validators []types.IndexBlocks) types.VDBSummaryValidatorsData { + return mapSlice(category, validators, + func(v types.IndexBlocks) (uint64, []uint64) { return v.Index, v.Blocks }, + ) +} + +// -------------------------------------- +// intOrString is a custom type that can be unmarshalled from either an int or a string (strings will also be parsed to int if possible). +// if unmarshaling throws no errors one of the two fields will be set, the other will be nil. +type intOrString struct { + intValue *uint64 + strValue *string +} + +func (v *intOrString) UnmarshalJSON(data []byte) error { + // Attempt to unmarshal as uint64 first + var intValue uint64 + if err := json.Unmarshal(data, &intValue); err == nil { + v.intValue = &intValue + return nil + } + + // If unmarshalling as uint64 fails, try to unmarshal as string + var strValue string + if err := json.Unmarshal(data, &strValue); err == nil { + if parsedInt, err := strconv.ParseUint(strValue, 10, 64); err == nil { + v.intValue = &parsedInt + } else { + v.strValue = &strValue + } + return nil + } + + // If both unmarshalling attempts fail, return an error + return fmt.Errorf("failed to unmarshal intOrString from json: %s", string(data)) +} + +func (v intOrString) String() string { + if v.intValue != nil { + return strconv.FormatUint(*v.intValue, 10) + } + if v.strValue != nil { + return *v.strValue + } + return "" +} + +func (intOrString) JSONSchema() *jsonschema.Schema { + return &jsonschema.Schema{ + OneOf: []*jsonschema.Schema{ + {Type: "string"}, {Type: "integer"}, + }, + } +} + +func isMocked(r *http.Request) bool { + isMocked, ok := r.Context().Value(ctxIsMockedKey).(bool) + return ok && isMocked +} diff --git a/backend/pkg/api/handlers/common.go b/backend/pkg/api/handlers/input_validation.go similarity index 51% rename from backend/pkg/api/handlers/common.go rename to backend/pkg/api/handlers/input_validation.go index 053426d46..16478449a 100644 --- a/backend/pkg/api/handlers/common.go +++ b/backend/pkg/api/handlers/input_validation.go @@ -3,10 +3,7 @@ package handlers import ( "bytes" "cmp" - "context" - "encoding/base64" "encoding/json" - "errors" "fmt" "io" "net/http" @@ -16,46 +13,14 @@ import ( "strings" "github.com/ethereum/go-ethereum/common/hexutil" - "github.com/gobitfly/beaconchain/pkg/commons/log" + "github.com/gobitfly/beaconchain/pkg/api/enums" + "github.com/gobitfly/beaconchain/pkg/api/types" "github.com/gorilla/mux" "github.com/invopop/jsonschema" "github.com/shopspring/decimal" "github.com/xeipuuv/gojsonschema" - - "github.com/alexedwards/scs/v2" - dataaccess "github.com/gobitfly/beaconchain/pkg/api/data_access" - "github.com/gobitfly/beaconchain/pkg/api/enums" - "github.com/gobitfly/beaconchain/pkg/api/services" - types "github.com/gobitfly/beaconchain/pkg/api/types" ) -type HandlerService struct { - dai dataaccess.DataAccessor - dummy dataaccess.DataAccessor - scs *scs.SessionManager - isPostMachineMetricsEnabled bool // if more config options are needed, consider having the whole config in here -} - -func NewHandlerService(dataAccessor dataaccess.DataAccessor, dummy dataaccess.DataAccessor, sessionManager *scs.SessionManager, enablePostMachineMetrics bool) *HandlerService { - if allNetworks == nil { - networks, err := dataAccessor.GetAllNetworks() - if err != nil { - log.Fatal(err, "error getting networks for handler", 0, nil) - } - allNetworks = networks - } - - return &HandlerService{ - dai: dataAccessor, - dummy: dummy, - scs: sessionManager, - isPostMachineMetricsEnabled: enablePostMachineMetrics, - } -} - -// all networks available in the system, filled on startup in NewHandlerService -var allNetworks []types.NetworkInfo - // -------------------------------------- var ( @@ -93,23 +58,6 @@ const ( MaxArchivedDashboardsCount = 10 ) -var ( - errMsgParsingId = errors.New("error parsing parameter 'dashboard_id'") - errBadRequest = errors.New("bad request") - errInternalServer = errors.New("internal server error") - errUnauthorized = errors.New("unauthorized") - errForbidden = errors.New("forbidden") - errConflict = errors.New("conflict") - errTooManyRequests = errors.New("too many requests") - errGone = errors.New("gone") -) - -type Paging struct { - cursor string - limit uint64 - search string -} - // All changes to common functions MUST NOT break any public handler behavior (not in effect yet) // -------------------------------------- @@ -146,6 +94,8 @@ func (v *validationError) hasErrors() bool { return v != nil && len(*v) > 0 } +// -------------------------------------- + func (v *validationError) checkRegex(regex *regexp.Regexp, param, paramName string) string { if !regex.MatchString(param) { v.add(paramName, fmt.Sprintf(`given value '%s' has incorrect format`, param)) @@ -307,143 +257,10 @@ func (v *validationError) checkAdConfigurationKeys(keysString string) []string { return keys } -type validatorSet struct { - Indexes []types.VDBValidator - PublicKeys []string -} - -// parseDashboardId is a helper function to validate the string dashboard id param. -func parseDashboardId(id string) (interface{}, error) { - var v validationError - if reInteger.MatchString(id) { - // given id is a normal id - id := v.checkUint(id, "dashboard_id") - if v.hasErrors() { - return nil, v - } - return types.VDBIdPrimary(id), nil - } - if reValidatorDashboardPublicId.MatchString(id) { - // given id is a public id - return types.VDBIdPublic(id), nil - } - // given id must be an encoded set of validators - decodedId, err := base64.RawURLEncoding.DecodeString(id) - if err != nil { - return nil, newBadRequestErr("given value '%s' is not a valid dashboard id", id) - } - indexes, publicKeys := v.checkValidatorList(string(decodedId), forbidEmpty) - if v.hasErrors() { - return nil, newBadRequestErr("given value '%s' is not a valid dashboard id", id) - } - return validatorSet{Indexes: indexes, PublicKeys: publicKeys}, nil -} - -// getDashboardId is a helper function to convert the dashboard id param to a VDBId. -// precondition: dashboardIdParam must be a valid dashboard id and either a primary id, public id, or list of validators. -func (h *HandlerService) getDashboardId(ctx context.Context, dashboardIdParam interface{}) (*types.VDBId, error) { - switch dashboardId := dashboardIdParam.(type) { - case types.VDBIdPrimary: - return &types.VDBId{Id: dashboardId, Validators: nil}, nil - case types.VDBIdPublic: - dashboardInfo, err := h.dai.GetValidatorDashboardPublicId(ctx, dashboardId) - if err != nil { - return nil, err - } - return &types.VDBId{Id: types.VDBIdPrimary(dashboardInfo.DashboardId), Validators: nil, AggregateGroups: !dashboardInfo.ShareSettings.ShareGroups}, nil - case validatorSet: - validators, err := h.dai.GetValidatorsFromSlices(dashboardId.Indexes, dashboardId.PublicKeys) - if err != nil { - return nil, err - } - if len(validators) == 0 { - return nil, newNotFoundErr("no validators found for given id") - } - if len(validators) > maxValidatorsInList { - return nil, newBadRequestErr("too many validators in list, maximum is %d", maxValidatorsInList) - } - return &types.VDBId{Validators: validators}, nil - } - return nil, errMsgParsingId -} - -// handleDashboardId is a helper function to both validate the dashboard id param and convert it to a VDBId. -// it should be used as the last validation step for all internal dashboard GET-handlers. -// Modifying handlers (POST, PUT, DELETE) should only accept primary dashboard ids and just use checkPrimaryDashboardId. -func (h *HandlerService) handleDashboardId(ctx context.Context, param string) (*types.VDBId, error) { - // validate dashboard id param - dashboardIdParam, err := parseDashboardId(param) - if err != nil { - return nil, err - } - // convert to VDBId - dashboardId, err := h.getDashboardId(ctx, dashboardIdParam) - if err != nil { - return nil, err - } - - return dashboardId, nil -} - -const chartDatapointLimit uint64 = 200 - -type ChartTimeDashboardLimits struct { - MinAllowedTs uint64 - LatestExportedTs uint64 - MaxAllowedInterval uint64 -} - -// helper function to retrieve allowed chart timestamp boundaries according to the users premium perks at the current point in time -func (h *HandlerService) getCurrentChartTimeLimitsForDashboard(ctx context.Context, dashboardId *types.VDBId, aggregation enums.ChartAggregation) (ChartTimeDashboardLimits, error) { - limits := ChartTimeDashboardLimits{} - var err error - premiumPerks, err := h.getDashboardPremiumPerks(ctx, *dashboardId) - if err != nil { - return limits, err - } - - maxAge := getMaxChartAge(aggregation, premiumPerks.ChartHistorySeconds) // can be max int for unlimited, always check for underflows - if maxAge == 0 { - return limits, newConflictErr("requested aggregation is not available for dashboard owner's premium subscription") - } - limits.LatestExportedTs, err = h.dai.GetLatestExportedChartTs(ctx, aggregation) - if err != nil { - return limits, err - } - limits.MinAllowedTs = limits.LatestExportedTs - min(maxAge, limits.LatestExportedTs) // min to prevent underflow - secondsPerEpoch := uint64(12 * 32) // TODO: fetch dashboards chain id and use correct value for network once available - limits.MaxAllowedInterval = chartDatapointLimit*uint64(aggregation.Duration(secondsPerEpoch).Seconds()) - 1 // -1 to make sure we don't go over the limit - - return limits, nil -} - func (v *validationError) checkPrimaryDashboardId(param string) types.VDBIdPrimary { return types.VDBIdPrimary(v.checkUint(param, "dashboard_id")) } -// getDashboardPremiumPerks gets the premium perks of the dashboard OWNER or if it's a guest dashboard, it returns free tier premium perks -func (h *HandlerService) getDashboardPremiumPerks(ctx context.Context, id types.VDBId) (*types.PremiumPerks, error) { - // for guest dashboards, return free tier perks - if id.Validators != nil { - perk, err := h.dai.GetFreeTierPerks(ctx) - if err != nil { - return nil, err - } - return perk, nil - } - // could be made into a single query if needed - dashboardUser, err := h.dai.GetValidatorDashboardUser(ctx, id.Id) - if err != nil { - return nil, err - } - userInfo, err := h.dai.GetUserInfo(ctx, dashboardUser.UserId) - if err != nil { - return nil, err - } - - return &userInfo.PremiumPerks, nil -} - // helper function to unify handling of block detail request validation func (h *HandlerService) validateBlockRequest(r *http.Request, paramName string) (uint64, uint64, error) { var v validationError @@ -454,9 +271,9 @@ func (h *HandlerService) validateBlockRequest(r *http.Request, paramName string) // possibly add other values like "genesis", "finalized", hardforks etc. later case "latest": if paramName == "block" { - value, err = h.dai.GetLatestBlock() + value, err = h.daService.GetLatestBlock() } else if paramName == "slot" { - value, err = h.dai.GetLatestSlot() + value, err = h.daService.GetLatestSlot() } if err != nil { return 0, 0, err @@ -484,7 +301,6 @@ func (v *validationError) checkExistingGroupId(param string) uint64 { return v.checkUint(param, "group_id") } -//nolint:unparam func splitParameters(params string, delim rune) []string { // This splits the string by delim and removes empty strings f := func(c rune) bool { @@ -531,6 +347,12 @@ func (v *validationError) checkUintMinMax(param string, min uint64, max uint64, return checkMinMax(v, v.checkUint(param, paramName), min, max, paramName) } +type Paging struct { + cursor string + limit uint64 + search string +} + func (v *validationError) checkPagingParams(q url.Values) Paging { paging := Paging{ cursor: q.Get("cursor"), @@ -579,7 +401,7 @@ func checkSort[T enums.EnumFactory[T]](v *validationError, sortString string) *t if sortString == "" { return &types.Sort[T]{Column: c, Desc: defaultDesc} } - sortSplit := strings.Split(sortString, ":") + sortSplit := splitParameters(sortString, ':') if len(sortSplit) > 2 { v.add("sort", fmt.Sprintf("given value '%s' for parameter 'sort' is not valid, expected format is '[:(asc|desc)]'", sortString)) return nil @@ -679,11 +501,7 @@ func (v *validationError) checkNetworkParameter(param string) uint64 { return v.checkNetwork(intOrString{strValue: ¶m}) } -//nolint:unused func (v *validationError) checkNetworksParameter(param string) []uint64 { - if param == "" { - v.add("networks", "list of networks must not be empty") - } var chainIds []uint64 for _, network := range splitParameters(param, ',') { chainIds = append(chainIds, v.checkNetworkParameter(network)) @@ -738,334 +556,3 @@ func (v *validationError) checkTimestamps(r *http.Request, chartLimits ChartTime return afterTs, beforeTs } } - -// getMaxChartAge returns the maximum age of a chart in seconds based on the given aggregation type and premium perks -func getMaxChartAge(aggregation enums.ChartAggregation, perkSeconds types.ChartHistorySeconds) uint64 { - aggregations := enums.ChartAggregations - switch aggregation { - case aggregations.Epoch: - return perkSeconds.Epoch - case aggregations.Hourly: - return perkSeconds.Hourly - case aggregations.Daily: - return perkSeconds.Daily - case aggregations.Weekly: - return perkSeconds.Weekly - default: - return 0 - } -} - -func isUserAdmin(user *types.UserInfo) bool { - if user == nil { - return false - } - return user.UserGroup == types.UserGroupAdmin -} - -// -------------------------------------- -// Response handling - -func writeResponse(w http.ResponseWriter, r *http.Request, statusCode int, response interface{}) { - w.Header().Set("Content-Type", "application/json") - if response == nil { - w.WriteHeader(statusCode) - return - } - jsonData, err := json.Marshal(response) - if err != nil { - logApiError(r, fmt.Errorf("error encoding json data: %w", err), 0, - log.Fields{ - "data": fmt.Sprintf("%+v", response), - }) - w.WriteHeader(http.StatusInternalServerError) - response = types.ApiErrorResponse{ - Error: "error encoding json data", - } - if err = json.NewEncoder(w).Encode(response); err != nil { - // there seems to be an error with the lib - logApiError(r, fmt.Errorf("error encoding error response after failed encoding: %w", err), 0) - } - return - } - w.WriteHeader(statusCode) - if _, err = w.Write(jsonData); err != nil { - // already returned wrong status code to user, can't prevent that - logApiError(r, fmt.Errorf("error writing response data: %w", err), 0) - } -} - -func returnError(w http.ResponseWriter, r *http.Request, code int, err error) { - response := types.ApiErrorResponse{ - Error: err.Error(), - } - writeResponse(w, r, code, response) -} - -func returnOk(w http.ResponseWriter, r *http.Request, data interface{}) { - writeResponse(w, r, http.StatusOK, data) -} - -func returnCreated(w http.ResponseWriter, r *http.Request, data interface{}) { - writeResponse(w, r, http.StatusCreated, data) -} - -func returnNoContent(w http.ResponseWriter, r *http.Request) { - writeResponse(w, r, http.StatusNoContent, nil) -} - -// Errors - -func returnBadRequest(w http.ResponseWriter, r *http.Request, err error) { - returnError(w, r, http.StatusBadRequest, err) -} - -func returnUnauthorized(w http.ResponseWriter, r *http.Request, err error) { - returnError(w, r, http.StatusUnauthorized, err) -} - -func returnNotFound(w http.ResponseWriter, r *http.Request, err error) { - returnError(w, r, http.StatusNotFound, err) -} - -func returnConflict(w http.ResponseWriter, r *http.Request, err error) { - returnError(w, r, http.StatusConflict, err) -} - -func returnForbidden(w http.ResponseWriter, r *http.Request, err error) { - returnError(w, r, http.StatusForbidden, err) -} - -func returnTooManyRequests(w http.ResponseWriter, r *http.Request, err error) { - returnError(w, r, http.StatusTooManyRequests, err) -} - -func returnGone(w http.ResponseWriter, r *http.Request, err error) { - returnError(w, r, http.StatusGone, err) -} - -const maxBodySize = 10 * 1024 - -func logApiError(r *http.Request, err error, callerSkip int, additionalInfos ...log.Fields) { - body, _ := io.ReadAll(io.LimitReader(r.Body, maxBodySize)) - requestFields := log.Fields{ - "request_endpoint": r.Method + " " + r.URL.Path, - "request_query": r.URL.RawQuery, - "request_body": string(body), - } - log.Error(err, "error handling request", callerSkip+1, append(additionalInfos, requestFields)...) -} - -func handleErr(w http.ResponseWriter, r *http.Request, err error) { - _, isValidationError := err.(validationError) - switch { - case isValidationError, errors.Is(err, errBadRequest): - returnBadRequest(w, r, err) - case errors.Is(err, dataaccess.ErrNotFound): - returnNotFound(w, r, err) - case errors.Is(err, errUnauthorized): - returnUnauthorized(w, r, err) - case errors.Is(err, errForbidden): - returnForbidden(w, r, err) - case errors.Is(err, errConflict): - returnConflict(w, r, err) - case errors.Is(err, services.ErrWaiting): - returnError(w, r, http.StatusServiceUnavailable, err) - case errors.Is(err, errTooManyRequests): - returnTooManyRequests(w, r, err) - case errors.Is(err, errGone): - returnGone(w, r, err) - default: - logApiError(r, err, 1) - // TODO: don't return the error message to the user in production - returnError(w, r, http.StatusInternalServerError, err) - } -} - -// -------------------------------------- -// Error Helpers - -func errWithMsg(err error, format string, args ...interface{}) error { - return fmt.Errorf("%w: %s", err, fmt.Sprintf(format, args...)) -} - -//nolint:nolintlint -//nolint:unparam -func newBadRequestErr(format string, args ...interface{}) error { - return errWithMsg(errBadRequest, format, args...) -} - -//nolint:unparam -func newInternalServerErr(format string, args ...interface{}) error { - return errWithMsg(errInternalServer, format, args...) -} - -//nolint:unparam -func newUnauthorizedErr(format string, args ...interface{}) error { - return errWithMsg(errUnauthorized, format, args...) -} - -func newForbiddenErr(format string, args ...interface{}) error { - return errWithMsg(errForbidden, format, args...) -} - -//nolint:unparam -func newConflictErr(format string, args ...interface{}) error { - return errWithMsg(errConflict, format, args...) -} - -//nolint:nolintlint -//nolint:unparam -func newNotFoundErr(format string, args ...interface{}) error { - return errWithMsg(dataaccess.ErrNotFound, format, args...) -} - -func newTooManyRequestsErr(format string, args ...interface{}) error { - return errWithMsg(errTooManyRequests, format, args...) -} - -func newGoneErr(format string, args ...interface{}) error { - return errWithMsg(errGone, format, args...) -} - -// -------------------------------------- -// misc. helper functions - -// maps different types of validator dashboard summary validators to a common format -func mapVDBIndices(indices interface{}) ([]types.VDBSummaryValidatorsData, error) { - if indices == nil { - return nil, errors.New("no data found when mapping") - } - - switch v := indices.(type) { - case *types.VDBGeneralSummaryValidators: - // deposited, online, offline, slashing, slashed, exited, withdrawn, pending, exiting, withdrawing - return []types.VDBSummaryValidatorsData{ - mapUintSlice("deposited", v.Deposited), - mapUintSlice("online", v.Online), - mapUintSlice("offline", v.Offline), - mapUintSlice("slashing", v.Slashing), - mapUintSlice("slashed", v.Slashed), - mapUintSlice("exited", v.Exited), - mapUintSlice("withdrawn", v.Withdrawn), - mapIndexTimestampSlice("pending", v.Pending), - mapIndexTimestampSlice("exiting", v.Exiting), - mapIndexTimestampSlice("withdrawing", v.Withdrawing), - }, nil - - case *types.VDBSyncSummaryValidators: - return []types.VDBSummaryValidatorsData{ - mapUintSlice("sync_current", v.Current), - mapUintSlice("sync_upcoming", v.Upcoming), - mapSlice("sync_past", v.Past, - func(v types.VDBValidatorSyncPast) (uint64, []uint64) { return v.Index, []uint64{v.Count} }, - ), - }, nil - - case *types.VDBSlashingsSummaryValidators: - return []types.VDBSummaryValidatorsData{ - mapSlice("got_slashed", v.GotSlashed, - func(v types.VDBValidatorGotSlashed) (uint64, []uint64) { return v.Index, []uint64{v.SlashedBy} }, - ), - mapSlice("has_slashed", v.HasSlashed, - func(v types.VDBValidatorHasSlashed) (uint64, []uint64) { return v.Index, v.SlashedIndices }, - ), - }, nil - - case *types.VDBProposalSummaryValidators: - return []types.VDBSummaryValidatorsData{ - mapIndexBlocksSlice("proposal_proposed", v.Proposed), - mapIndexBlocksSlice("proposal_missed", v.Missed), - }, nil - - default: - return nil, fmt.Errorf("unsupported indices type") - } -} - -// maps different types of validator dashboard summary validators to a common format -func mapSlice[T any](category string, validators []T, getIndexAndDutyObjects func(validator T) (index uint64, dutyObjects []uint64)) types.VDBSummaryValidatorsData { - validatorsData := make([]types.VDBSummaryValidator, len(validators)) - for i, validator := range validators { - index, dutyObjects := getIndexAndDutyObjects(validator) - validatorsData[i] = types.VDBSummaryValidator{Index: index, DutyObjects: dutyObjects} - } - return types.VDBSummaryValidatorsData{ - Category: category, - Validators: validatorsData, - } -} -func mapUintSlice(category string, validators []uint64) types.VDBSummaryValidatorsData { - return mapSlice(category, validators, - func(v uint64) (uint64, []uint64) { return v, nil }, - ) -} - -func mapIndexTimestampSlice(category string, validators []types.IndexTimestamp) types.VDBSummaryValidatorsData { - return mapSlice(category, validators, - func(v types.IndexTimestamp) (uint64, []uint64) { return v.Index, []uint64{v.Timestamp} }, - ) -} - -func mapIndexBlocksSlice(category string, validators []types.IndexBlocks) types.VDBSummaryValidatorsData { - return mapSlice(category, validators, - func(v types.IndexBlocks) (uint64, []uint64) { return v.Index, v.Blocks }, - ) -} - -// -------------------------------------- -// intOrString is a custom type that can be unmarshalled from either an int or a string (strings will also be parsed to int if possible). -// if unmarshaling throws no errors one of the two fields will be set, the other will be nil. -type intOrString struct { - intValue *uint64 - strValue *string -} - -func (v *intOrString) UnmarshalJSON(data []byte) error { - // Attempt to unmarshal as uint64 first - var intValue uint64 - if err := json.Unmarshal(data, &intValue); err == nil { - v.intValue = &intValue - return nil - } - - // If unmarshalling as uint64 fails, try to unmarshal as string - var strValue string - if err := json.Unmarshal(data, &strValue); err == nil { - if parsedInt, err := strconv.ParseUint(strValue, 10, 64); err == nil { - v.intValue = &parsedInt - } else { - v.strValue = &strValue - } - return nil - } - - // If both unmarshalling attempts fail, return an error - return fmt.Errorf("failed to unmarshal intOrString from json: %s", string(data)) -} - -func (v intOrString) String() string { - if v.intValue != nil { - return strconv.FormatUint(*v.intValue, 10) - } - if v.strValue != nil { - return *v.strValue - } - return "" -} - -func (intOrString) JSONSchema() *jsonschema.Schema { - return &jsonschema.Schema{ - OneOf: []*jsonschema.Schema{ - {Type: "string"}, {Type: "integer"}, - }, - } -} - -func isMockEnabled(r *http.Request) bool { - isMockEnabled, ok := r.Context().Value(ctxIsMockEnabledKey).(bool) - if !ok { - return false - } - return isMockEnabled -} diff --git a/backend/pkg/api/handlers/internal.go b/backend/pkg/api/handlers/internal.go index 41f549b8f..3093352f9 100644 --- a/backend/pkg/api/handlers/internal.go +++ b/backend/pkg/api/handlers/internal.go @@ -14,7 +14,7 @@ import ( // Premium Plans func (h *HandlerService) InternalGetProductSummary(w http.ResponseWriter, r *http.Request) { - data, err := h.dai.GetProductSummary(r.Context()) + data, err := h.daService.GetProductSummary(r.Context()) if err != nil { handleErr(w, r, err) return @@ -29,7 +29,7 @@ func (h *HandlerService) InternalGetProductSummary(w http.ResponseWriter, r *htt // API Ratelimit Weights func (h *HandlerService) InternalGetRatelimitWeights(w http.ResponseWriter, r *http.Request) { - data, err := h.dai.GetApiWeights(r.Context()) + data, err := h.daService.GetApiWeights(r.Context()) if err != nil { handleErr(w, r, err) return @@ -44,19 +44,19 @@ func (h *HandlerService) InternalGetRatelimitWeights(w http.ResponseWriter, r *h // Latest State func (h *HandlerService) InternalGetLatestState(w http.ResponseWriter, r *http.Request) { - latestSlot, err := h.dai.GetLatestSlot() + latestSlot, err := h.daService.GetLatestSlot() if err != nil { handleErr(w, r, err) return } - finalizedEpoch, err := h.dai.GetLatestFinalizedEpoch() + finalizedEpoch, err := h.daService.GetLatestFinalizedEpoch() if err != nil { handleErr(w, r, err) return } - exchangeRates, err := h.dai.GetLatestExchangeRates() + exchangeRates, err := h.daService.GetLatestExchangeRates() if err != nil { handleErr(w, r, err) return @@ -74,7 +74,7 @@ func (h *HandlerService) InternalGetLatestState(w http.ResponseWriter, r *http.R } func (h *HandlerService) InternalGetRocketPool(w http.ResponseWriter, r *http.Request) { - data, err := h.dai.GetRocketPoolOverview(r.Context()) + data, err := h.daService.GetRocketPoolOverview(r.Context()) if err != nil { handleErr(w, r, err) return @@ -126,7 +126,7 @@ func (h *HandlerService) InternalPostAdConfigurations(w http.ResponseWriter, r * return } - err = h.dai.CreateAdConfiguration(r.Context(), key, req.JQuerySelector, insertMode, req.RefreshInterval, req.ForAllUsers, req.BannerId, req.HtmlContent, req.Enabled) + err = h.daService.CreateAdConfiguration(r.Context(), key, req.JQuerySelector, insertMode, req.RefreshInterval, req.ForAllUsers, req.BannerId, req.HtmlContent, req.Enabled) if err != nil { handleErr(w, r, err) return @@ -156,7 +156,7 @@ func (h *HandlerService) InternalGetAdConfigurations(w http.ResponseWriter, r *h return } - data, err := h.dai.GetAdConfigurations(r.Context(), keys) + data, err := h.daService.GetAdConfigurations(r.Context(), keys) if err != nil { handleErr(w, r, err) return @@ -202,7 +202,7 @@ func (h *HandlerService) InternalPutAdConfiguration(w http.ResponseWriter, r *ht return } - err = h.dai.UpdateAdConfiguration(r.Context(), key, req.JQuerySelector, insertMode, req.RefreshInterval, req.ForAllUsers, req.BannerId, req.HtmlContent, req.Enabled) + err = h.daService.UpdateAdConfiguration(r.Context(), key, req.JQuerySelector, insertMode, req.RefreshInterval, req.ForAllUsers, req.BannerId, req.HtmlContent, req.Enabled) if err != nil { handleErr(w, r, err) return @@ -232,7 +232,7 @@ func (h *HandlerService) InternalDeleteAdConfiguration(w http.ResponseWriter, r return } - err = h.dai.RemoveAdConfiguration(r.Context(), key) + err = h.daService.RemoveAdConfiguration(r.Context(), key) if err != nil { handleErr(w, r, err) return @@ -251,7 +251,7 @@ func (h *HandlerService) InternalGetUserInfo(w http.ResponseWriter, r *http.Requ handleErr(w, r, err) return } - userInfo, err := h.dai.GetUserInfo(r.Context(), user.Id) + userInfo, err := h.daService.GetUserInfo(r.Context(), user.Id) if err != nil { handleErr(w, r, err) return @@ -383,7 +383,7 @@ func (h *HandlerService) InternalGetValidatorDashboardMobileValidators(w http.Re handleErr(w, r, v) return } - data, paging, err := h.dai.GetValidatorDashboardMobileValidators(r.Context(), *dashboardId, period, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) + data, paging, err := h.daService.GetValidatorDashboardMobileValidators(r.Context(), *dashboardId, period, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) if err != nil { handleErr(w, r, err) return @@ -512,7 +512,7 @@ func (h *HandlerService) InternalGetValidatorDashboardMobileWidget(w http.Respon handleErr(w, r, err) return } - userInfo, err := h.dai.GetUserInfo(r.Context(), userId) + userInfo, err := h.daService.GetUserInfo(r.Context(), userId) if err != nil { handleErr(w, r, err) return @@ -521,7 +521,7 @@ func (h *HandlerService) InternalGetValidatorDashboardMobileWidget(w http.Respon returnForbidden(w, r, errors.New("user does not have access to mobile app widget")) return } - data, err := h.dai.GetValidatorDashboardMobileWidget(r.Context(), dashboardId) + data, err := h.daService.GetValidatorDashboardMobileWidget(r.Context(), dashboardId) if err != nil { handleErr(w, r, err) return @@ -545,7 +545,7 @@ func (h *HandlerService) InternalGetMobileLatestBundle(w http.ResponseWriter, r handleErr(w, r, v) return } - stats, err := h.dai.GetLatestBundleForNativeVersion(r.Context(), nativeVersion) + stats, err := h.daService.GetLatestBundleForNativeVersion(r.Context(), nativeVersion) if err != nil { handleErr(w, r, err) return @@ -570,7 +570,7 @@ func (h *HandlerService) InternalPostMobileBundleDeliveries(w http.ResponseWrite handleErr(w, r, v) return } - err := h.dai.IncrementBundleDeliveryCount(r.Context(), bundleVersion) + err := h.daService.IncrementBundleDeliveryCount(r.Context(), bundleVersion) if err != nil { handleErr(w, r, err) return @@ -671,7 +671,7 @@ func (h *HandlerService) InternalGetBlock(w http.ResponseWriter, r *http.Request return } - data, err := h.dai.GetBlock(r.Context(), chainId, block) + data, err := h.daService.GetBlock(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -690,7 +690,7 @@ func (h *HandlerService) InternalGetBlockOverview(w http.ResponseWriter, r *http return } - data, err := h.dai.GetBlockOverview(r.Context(), chainId, block) + data, err := h.daService.GetBlockOverview(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -708,7 +708,7 @@ func (h *HandlerService) InternalGetBlockTransactions(w http.ResponseWriter, r * return } - data, err := h.dai.GetBlockTransactions(r.Context(), chainId, block) + data, err := h.daService.GetBlockTransactions(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -727,7 +727,7 @@ func (h *HandlerService) InternalGetBlockVotes(w http.ResponseWriter, r *http.Re return } - data, err := h.dai.GetBlockVotes(r.Context(), chainId, block) + data, err := h.daService.GetBlockVotes(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -746,7 +746,7 @@ func (h *HandlerService) InternalGetBlockAttestations(w http.ResponseWriter, r * return } - data, err := h.dai.GetBlockAttestations(r.Context(), chainId, block) + data, err := h.daService.GetBlockAttestations(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -765,7 +765,7 @@ func (h *HandlerService) InternalGetBlockWithdrawals(w http.ResponseWriter, r *h return } - data, err := h.dai.GetBlockWithdrawals(r.Context(), chainId, block) + data, err := h.daService.GetBlockWithdrawals(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -784,7 +784,7 @@ func (h *HandlerService) InternalGetBlockBlsChanges(w http.ResponseWriter, r *ht return } - data, err := h.dai.GetBlockBlsChanges(r.Context(), chainId, block) + data, err := h.daService.GetBlockBlsChanges(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -803,7 +803,7 @@ func (h *HandlerService) InternalGetBlockVoluntaryExits(w http.ResponseWriter, r return } - data, err := h.dai.GetBlockVoluntaryExits(r.Context(), chainId, block) + data, err := h.daService.GetBlockVoluntaryExits(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -822,7 +822,7 @@ func (h *HandlerService) InternalGetBlockBlobs(w http.ResponseWriter, r *http.Re return } - data, err := h.dai.GetBlockBlobs(r.Context(), chainId, block) + data, err := h.daService.GetBlockBlobs(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -844,7 +844,7 @@ func (h *HandlerService) InternalGetSlot(w http.ResponseWriter, r *http.Request) return } - data, err := h.dai.GetSlot(r.Context(), chainId, block) + data, err := h.daService.GetSlot(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -863,7 +863,7 @@ func (h *HandlerService) InternalGetSlotOverview(w http.ResponseWriter, r *http. return } - data, err := h.dai.GetSlotOverview(r.Context(), chainId, block) + data, err := h.daService.GetSlotOverview(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -881,7 +881,7 @@ func (h *HandlerService) InternalGetSlotTransactions(w http.ResponseWriter, r *h return } - data, err := h.dai.GetSlotTransactions(r.Context(), chainId, block) + data, err := h.daService.GetSlotTransactions(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -900,7 +900,7 @@ func (h *HandlerService) InternalGetSlotVotes(w http.ResponseWriter, r *http.Req return } - data, err := h.dai.GetSlotVotes(r.Context(), chainId, block) + data, err := h.daService.GetSlotVotes(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -919,7 +919,7 @@ func (h *HandlerService) InternalGetSlotAttestations(w http.ResponseWriter, r *h return } - data, err := h.dai.GetSlotAttestations(r.Context(), chainId, block) + data, err := h.daService.GetSlotAttestations(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -938,7 +938,7 @@ func (h *HandlerService) InternalGetSlotWithdrawals(w http.ResponseWriter, r *ht return } - data, err := h.dai.GetSlotWithdrawals(r.Context(), chainId, block) + data, err := h.daService.GetSlotWithdrawals(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -957,7 +957,7 @@ func (h *HandlerService) InternalGetSlotBlsChanges(w http.ResponseWriter, r *htt return } - data, err := h.dai.GetSlotBlsChanges(r.Context(), chainId, block) + data, err := h.daService.GetSlotBlsChanges(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -976,7 +976,7 @@ func (h *HandlerService) InternalGetSlotVoluntaryExits(w http.ResponseWriter, r return } - data, err := h.dai.GetSlotVoluntaryExits(r.Context(), chainId, block) + data, err := h.daService.GetSlotVoluntaryExits(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return @@ -995,7 +995,7 @@ func (h *HandlerService) InternalGetSlotBlobs(w http.ResponseWriter, r *http.Req return } - data, err := h.dai.GetSlotBlobs(r.Context(), chainId, block) + data, err := h.daService.GetSlotBlobs(r.Context(), chainId, block) if err != nil { handleErr(w, r, err) return diff --git a/backend/pkg/api/handlers/machine_metrics.go b/backend/pkg/api/handlers/machine_metrics.go index 9ac1dd0cd..9bef843d5 100644 --- a/backend/pkg/api/handlers/machine_metrics.go +++ b/backend/pkg/api/handlers/machine_metrics.go @@ -30,7 +30,7 @@ func (h *HandlerService) PublicGetUserMachineMetrics(w http.ResponseWriter, r *h return } - userInfo, err := h.dai.GetUserInfo(r.Context(), userId) + userInfo, err := h.daService.GetUserInfo(r.Context(), userId) if err != nil { handleErr(w, r, err) return @@ -51,7 +51,7 @@ func (h *HandlerService) PublicGetUserMachineMetrics(w http.ResponseWriter, r *h offset = 0 } - data, err := h.dai.GetUserMachineMetrics(r.Context(), userId, int(limit), int(offset)) + data, err := h.daService.GetUserMachineMetrics(r.Context(), userId, int(limit), int(offset)) if err != nil { handleErr(w, r, err) return @@ -77,13 +77,13 @@ func (h *HandlerService) LegacyPostUserMachineMetrics(w http.ResponseWriter, r * return } - userID, err := h.dai.GetUserIdByApiKey(r.Context(), apiKey) + userID, err := h.daService.GetUserIdByApiKey(r.Context(), apiKey) if err != nil { returnBadRequest(w, r, fmt.Errorf("no user found with api key")) return } - userInfo, err := h.dai.GetUserInfo(r.Context(), userID) + userInfo, err := h.daService.GetUserInfo(r.Context(), userID) if err != nil { handleErr(w, r, err) return @@ -200,7 +200,7 @@ func (h *HandlerService) internal_processMachine(context context.Context, machin } } - return h.dai.PostUserMachineMetrics(context, userInfo.Id, machine, parsedMeta.Process, data) + return h.daService.PostUserMachineMetrics(context, userInfo.Id, machine, parsedMeta.Process, data) } func DecodeMapStructure(input interface{}, output interface{}) error { diff --git a/backend/pkg/api/handlers/middlewares.go b/backend/pkg/api/handlers/middlewares.go new file mode 100644 index 000000000..54238cc59 --- /dev/null +++ b/backend/pkg/api/handlers/middlewares.go @@ -0,0 +1,200 @@ +package handlers + +import ( + "context" + "errors" + "net/http" + "slices" + "strconv" + + "github.com/gobitfly/beaconchain/pkg/api/types" + "github.com/gorilla/mux" +) + +// Middlewares + +// middleware that stores user id in context, using the provided function +func StoreUserIdMiddleware(next http.Handler, userIdFunc func(r *http.Request) (uint64, error)) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + userId, err := userIdFunc(r) + if err != nil { + if errors.Is(err, errUnauthorized) { + // if next handler requires authentication, it should return 'unauthorized' itself + next.ServeHTTP(w, r) + } else { + handleErr(w, r, err) + } + return + } + + // store user id in context + ctx := r.Context() + ctx = context.WithValue(ctx, ctxUserIdKey, userId) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + }) +} + +// middleware that stores user id in context, using the session to get the user id +func (h *HandlerService) StoreUserIdBySessionMiddleware(next http.Handler) http.Handler { + return StoreUserIdMiddleware(next, func(r *http.Request) (uint64, error) { + return h.GetUserIdBySession(r) + }) +} + +// middleware that stores user id in context, using the api key to get the user id +func (h *HandlerService) StoreUserIdByApiKeyMiddleware(next http.Handler) http.Handler { + return StoreUserIdMiddleware(next, func(r *http.Request) (uint64, error) { + return h.GetUserIdByApiKey(r) + }) +} + +// middleware that checks if user has access to dashboard when a primary id is used +func (h *HandlerService) VDBAuthMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // if mock data is used, no need to check access + if isMockEnabled, ok := r.Context().Value(ctxIsMockedKey).(bool); ok && isMockEnabled { + next.ServeHTTP(w, r) + return + } + var err error + dashboardId, err := strconv.ParseUint(mux.Vars(r)["dashboard_id"], 10, 64) + if err != nil { + // if primary id is not used, no need to check access + next.ServeHTTP(w, r) + return + } + // primary id is used -> user needs to have access to dashboard + + userId, err := GetUserIdByContext(r) + if err != nil { + handleErr(w, r, err) + return + } + + // store user id in context + ctx := r.Context() + ctx = context.WithValue(ctx, ctxUserIdKey, userId) + r = r.WithContext(ctx) + + dashboardUser, err := h.daService.GetValidatorDashboardUser(r.Context(), types.VDBIdPrimary(dashboardId)) + if err != nil { + handleErr(w, r, err) + return + } + + if dashboardUser.UserId != userId { + // user does not have access to dashboard + // the proper error would be 403 Forbidden, but we don't want to leak information so we return 404 Not Found + handleErr(w, r, newNotFoundErr("dashboard with id %v not found", dashboardId)) + return + } + + next.ServeHTTP(w, r) + }) +} + +// Common middleware logic for checking user premium perks +func (h *HandlerService) PremiumPerkCheckMiddleware(next http.Handler, hasRequiredPerk func(premiumPerks types.PremiumPerks) bool) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // get user id from context + userId, err := GetUserIdByContext(r) + if err != nil { + handleErr(w, r, err) + return + } + + // get user info + userInfo, err := h.daService.GetUserInfo(r.Context(), userId) + if err != nil { + handleErr(w, r, err) + return + } + + // check if user has the required premium perk + if !hasRequiredPerk(userInfo.PremiumPerks) { + handleErr(w, r, newForbiddenErr("users premium perks do not allow usage of this endpoint")) + return + } + + next.ServeHTTP(w, r) + }) +} + +// Middleware for managing dashboards via API +func (h *HandlerService) ManageDashboardsViaApiCheckMiddleware(next http.Handler) http.Handler { + return h.PremiumPerkCheckMiddleware(next, func(premiumPerks types.PremiumPerks) bool { + return premiumPerks.ManageDashboardViaApi + }) +} + +// Middleware for managing notifications via API +func (h *HandlerService) ManageNotificationsViaApiCheckMiddleware(next http.Handler) http.Handler { + return h.PremiumPerkCheckMiddleware(next, func(premiumPerks types.PremiumPerks) bool { + return premiumPerks.ConfigureNotificationsViaApi + }) +} + +// middleware check to return if specified dashboard is not archived (and accessible) +func (h *HandlerService) VDBArchivedCheckMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isMockEnabled, ok := r.Context().Value(ctxIsMockedKey).(bool); ok && isMockEnabled { + next.ServeHTTP(w, r) + return + } + dashboardId, err := h.handleDashboardId(r.Context(), mux.Vars(r)["dashboard_id"]) + if err != nil { + handleErr(w, r, err) + return + } + if len(dashboardId.Validators) > 0 { + next.ServeHTTP(w, r) + return + } + dashboard, err := h.daService.GetValidatorDashboardInfo(r.Context(), dashboardId.Id) + if err != nil { + handleErr(w, r, err) + return + } + if dashboard.IsArchived { + handleErr(w, r, newForbiddenErr("dashboard with id %v is archived", dashboardId)) + return + } + next.ServeHTTP(w, r) + }) +} + +// middleware that checks for `is_mocked` query param and stores it in the request context. +// should bypass auth checks if the flag is set and cause handlers to return mocked data. +// only allowed for users in the admin or dev group. +// note that mocked data is only returned by handlers that check for it. +func (h *HandlerService) StoreIsMockedFlagMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + isMocked, _ := strconv.ParseBool(r.URL.Query().Get("is_mocked")) + if !isMocked { + next.ServeHTTP(w, r) + return + } + // fetch user group + userId, err := h.GetUserIdBySession(r) + if err != nil { + handleErr(w, r, err) + return + } + userCredentials, err := h.daService.GetUserInfo(r.Context(), userId) + if err != nil { + handleErr(w, r, err) + return + } + allowedGroups := []string{types.UserGroupAdmin, types.UserGroupDev} + if !slices.Contains(allowedGroups, userCredentials.UserGroup) { + handleErr(w, r, newForbiddenErr("user is not allowed to use mock data")) + return + } + // store isMocked flag in context + ctx := r.Context() + ctx = context.WithValue(ctx, ctxIsMockedKey, true) + r = r.WithContext(ctx) + next.ServeHTTP(w, r) + }) +} diff --git a/backend/pkg/api/handlers/public.go b/backend/pkg/api/handlers/public.go index 53b84e270..98d1b8b55 100644 --- a/backend/pkg/api/handlers/public.go +++ b/backend/pkg/api/handlers/public.go @@ -47,7 +47,7 @@ func (h *HandlerService) PublicGetHealthz(w http.ResponseWriter, r *http.Request } ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second) defer cancel() - data := h.dai.GetHealthz(ctx, showAll) + data := h.getDataAccessor(r).GetHealthz(ctx, showAll) responseCode := http.StatusOK if data.TotalOkPercentage != 1 { @@ -74,7 +74,7 @@ func (h *HandlerService) PublicGetUserDashboards(w http.ResponseWriter, r *http. handleErr(w, r, err) return } - data, err := h.dai.GetUserDashboards(r.Context(), userId) + data, err := h.getDataAccessor(r).GetUserDashboards(r.Context(), userId) if err != nil { handleErr(w, r, err) return @@ -177,12 +177,12 @@ func (h *HandlerService) PublicPostValidatorDashboards(w http.ResponseWriter, r return } - userInfo, err := h.dai.GetUserInfo(r.Context(), userId) + userInfo, err := h.getDataAccessor(r).GetUserInfo(r.Context(), userId) if err != nil { handleErr(w, r, err) return } - dashboardCount, err := h.dai.GetUserValidatorDashboardCount(r.Context(), userId, true) + dashboardCount, err := h.getDataAccessor(r).GetUserValidatorDashboardCount(r.Context(), userId, true) if err != nil { handleErr(w, r, err) return @@ -192,7 +192,7 @@ func (h *HandlerService) PublicPostValidatorDashboards(w http.ResponseWriter, r return } - data, err := h.dai.CreateValidatorDashboard(r.Context(), userId, name, chainId) + data, err := h.getDataAccessor(r).CreateValidatorDashboard(r.Context(), userId, name, chainId) if err != nil { handleErr(w, r, err) return @@ -232,10 +232,10 @@ func (h *HandlerService) PublicGetValidatorDashboard(w http.ResponseWriter, r *h // set name depending on dashboard id var name string if reInteger.MatchString(dashboardIdParam) { - name, err = h.dai.GetValidatorDashboardName(r.Context(), dashboardId.Id) + name, err = h.getDataAccessor(r).GetValidatorDashboardName(r.Context(), dashboardId.Id) } else if reValidatorDashboardPublicId.MatchString(dashboardIdParam) { var publicIdInfo *types.VDBPublicId - publicIdInfo, err = h.dai.GetValidatorDashboardPublicId(r.Context(), types.VDBIdPublic(dashboardIdParam)) + publicIdInfo, err = h.getDataAccessor(r).GetValidatorDashboardPublicId(r.Context(), types.VDBIdPublic(dashboardIdParam)) name = publicIdInfo.Name } if err != nil { @@ -249,7 +249,7 @@ func (h *HandlerService) PublicGetValidatorDashboard(w http.ResponseWriter, r *h handleErr(w, r, err) return } - data, err := h.dai.GetValidatorDashboardOverview(r.Context(), *dashboardId, protocolModes) + data, err := h.getDataAccessor(r).GetValidatorDashboardOverview(r.Context(), *dashboardId, protocolModes) if err != nil { handleErr(w, r, err) return @@ -281,7 +281,7 @@ func (h *HandlerService) PublicDeleteValidatorDashboard(w http.ResponseWriter, r handleErr(w, r, v) return } - err := h.dai.RemoveValidatorDashboard(r.Context(), dashboardId) + err := h.getDataAccessor(r).RemoveValidatorDashboard(r.Context(), dashboardId) if err != nil { handleErr(w, r, err) return @@ -317,7 +317,7 @@ func (h *HandlerService) PublicPutValidatorDashboardName(w http.ResponseWriter, handleErr(w, r, v) return } - data, err := h.dai.UpdateValidatorDashboardName(r.Context(), dashboardId, name) + data, err := h.getDataAccessor(r).UpdateValidatorDashboardName(r.Context(), dashboardId, name) if err != nil { handleErr(w, r, err) return @@ -364,12 +364,12 @@ func (h *HandlerService) PublicPostValidatorDashboardGroups(w http.ResponseWrite handleErr(w, r, err) return } - userInfo, err := h.dai.GetUserInfo(ctx, userId) + userInfo, err := h.getDataAccessor(r).GetUserInfo(ctx, userId) if err != nil { handleErr(w, r, err) return } - groupCount, err := h.dai.GetValidatorDashboardGroupCount(ctx, dashboardId) + groupCount, err := h.getDataAccessor(r).GetValidatorDashboardGroupCount(ctx, dashboardId) if err != nil { handleErr(w, r, err) return @@ -379,7 +379,7 @@ func (h *HandlerService) PublicPostValidatorDashboardGroups(w http.ResponseWrite return } - data, err := h.dai.CreateValidatorDashboardGroup(ctx, dashboardId, name) + data, err := h.getDataAccessor(r).CreateValidatorDashboardGroup(ctx, dashboardId, name) if err != nil { handleErr(w, r, err) return @@ -423,7 +423,7 @@ func (h *HandlerService) PublicPutValidatorDashboardGroups(w http.ResponseWriter handleErr(w, r, v) return } - groupExists, err := h.dai.GetValidatorDashboardGroupExists(r.Context(), dashboardId, groupId) + groupExists, err := h.getDataAccessor(r).GetValidatorDashboardGroupExists(r.Context(), dashboardId, groupId) if err != nil { handleErr(w, r, err) return @@ -432,7 +432,7 @@ func (h *HandlerService) PublicPutValidatorDashboardGroups(w http.ResponseWriter returnNotFound(w, r, errors.New("group not found")) return } - data, err := h.dai.UpdateValidatorDashboardGroup(r.Context(), dashboardId, groupId, name) + data, err := h.getDataAccessor(r).UpdateValidatorDashboardGroup(r.Context(), dashboardId, groupId, name) if err != nil { handleErr(w, r, err) return @@ -470,7 +470,7 @@ func (h *HandlerService) PublicDeleteValidatorDashboardGroup(w http.ResponseWrit returnBadRequest(w, r, errors.New("cannot delete default group")) return } - groupExists, err := h.dai.GetValidatorDashboardGroupExists(r.Context(), dashboardId, groupId) + groupExists, err := h.getDataAccessor(r).GetValidatorDashboardGroupExists(r.Context(), dashboardId, groupId) if err != nil { handleErr(w, r, err) return @@ -479,7 +479,7 @@ func (h *HandlerService) PublicDeleteValidatorDashboardGroup(w http.ResponseWrit returnNotFound(w, r, errors.New("group not found")) return } - err = h.dai.RemoveValidatorDashboardGroup(r.Context(), dashboardId, groupId) + err = h.getDataAccessor(r).RemoveValidatorDashboardGroup(r.Context(), dashboardId, groupId) if err != nil { handleErr(w, r, err) return @@ -544,7 +544,7 @@ func (h *HandlerService) PublicPostValidatorDashboardValidators(w http.ResponseW } ctx := r.Context() - groupExists, err := h.dai.GetValidatorDashboardGroupExists(ctx, dashboardId, groupId) + groupExists, err := h.getDataAccessor(r).GetValidatorDashboardGroupExists(ctx, dashboardId, groupId) if err != nil { handleErr(w, r, err) return @@ -558,7 +558,7 @@ func (h *HandlerService) PublicPostValidatorDashboardValidators(w http.ResponseW handleErr(w, r, err) return } - userInfo, err := h.dai.GetUserInfo(ctx, userId) + userInfo, err := h.getDataAccessor(r).GetUserInfo(ctx, userId) if err != nil { handleErr(w, r, err) return @@ -568,7 +568,7 @@ func (h *HandlerService) PublicPostValidatorDashboardValidators(w http.ResponseW return } dashboardLimit := userInfo.PremiumPerks.ValidatorsPerDashboard - existingValidatorCount, err := h.dai.GetValidatorDashboardValidatorsCount(ctx, dashboardId) + existingValidatorCount, err := h.getDataAccessor(r).GetValidatorDashboardValidatorsCount(ctx, dashboardId) if err != nil { handleErr(w, r, err) return @@ -589,7 +589,7 @@ func (h *HandlerService) PublicPostValidatorDashboardValidators(w http.ResponseW handleErr(w, r, v) return } - validators, err := h.dai.GetValidatorsFromSlices(indices, pubkeys) + validators, err := h.getDataAccessor(r).GetValidatorsFromSlices(indices, pubkeys) if err != nil { handleErr(w, r, err) return @@ -597,7 +597,7 @@ func (h *HandlerService) PublicPostValidatorDashboardValidators(w http.ResponseW if len(validators) > int(limit) { validators = validators[:limit] } - data, dataErr = h.dai.AddValidatorDashboardValidators(ctx, dashboardId, groupId, validators) + data, dataErr = h.getDataAccessor(r).AddValidatorDashboardValidators(ctx, dashboardId, groupId, validators) case req.DepositAddress != "": depositAddress := v.checkRegex(reEthereumAddress, req.DepositAddress, "deposit_address") @@ -605,7 +605,7 @@ func (h *HandlerService) PublicPostValidatorDashboardValidators(w http.ResponseW handleErr(w, r, v) return } - data, dataErr = h.dai.AddValidatorDashboardValidatorsByDepositAddress(ctx, dashboardId, groupId, depositAddress, limit) + data, dataErr = h.getDataAccessor(r).AddValidatorDashboardValidatorsByDepositAddress(ctx, dashboardId, groupId, depositAddress, limit) case req.WithdrawalAddress != "": withdrawalAddress := v.checkRegex(reWithdrawalCredential, req.WithdrawalAddress, "withdrawal_address") @@ -613,7 +613,7 @@ func (h *HandlerService) PublicPostValidatorDashboardValidators(w http.ResponseW handleErr(w, r, v) return } - data, dataErr = h.dai.AddValidatorDashboardValidatorsByWithdrawalAddress(ctx, dashboardId, groupId, withdrawalAddress, limit) + data, dataErr = h.getDataAccessor(r).AddValidatorDashboardValidatorsByWithdrawalAddress(ctx, dashboardId, groupId, withdrawalAddress, limit) case req.Graffiti != "": graffiti := v.checkRegex(reGraffiti, req.Graffiti, "graffiti") @@ -621,7 +621,7 @@ func (h *HandlerService) PublicPostValidatorDashboardValidators(w http.ResponseW handleErr(w, r, v) return } - data, dataErr = h.dai.AddValidatorDashboardValidatorsByGraffiti(ctx, dashboardId, groupId, graffiti, limit) + data, dataErr = h.getDataAccessor(r).AddValidatorDashboardValidatorsByGraffiti(ctx, dashboardId, groupId, graffiti, limit) } if dataErr != nil { @@ -663,7 +663,7 @@ func (h *HandlerService) PublicGetValidatorDashboardValidators(w http.ResponseWr handleErr(w, r, v) return } - data, paging, err := h.dai.GetValidatorDashboardValidators(r.Context(), *dashboardId, groupId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) + data, paging, err := h.getDataAccessor(r).GetValidatorDashboardValidators(r.Context(), *dashboardId, groupId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) if err != nil { handleErr(w, r, err) return @@ -703,12 +703,12 @@ func (h *HandlerService) PublicDeleteValidatorDashboardValidators(w http.Respons handleErr(w, r, v) return } - validators, err := h.dai.GetValidatorsFromSlices(indices, publicKeys) + validators, err := h.getDataAccessor(r).GetValidatorsFromSlices(indices, publicKeys) if err != nil { handleErr(w, r, err) return } - err = h.dai.RemoveValidatorDashboardValidators(r.Context(), dashboardId, validators) + err = h.getDataAccessor(r).RemoveValidatorDashboardValidators(r.Context(), dashboardId, validators) if err != nil { handleErr(w, r, err) return @@ -749,7 +749,7 @@ func (h *HandlerService) PublicPostValidatorDashboardPublicIds(w http.ResponseWr handleErr(w, r, v) return } - publicIdCount, err := h.dai.GetValidatorDashboardPublicIdCount(r.Context(), dashboardId) + publicIdCount, err := h.getDataAccessor(r).GetValidatorDashboardPublicIdCount(r.Context(), dashboardId) if err != nil { handleErr(w, r, err) return @@ -759,7 +759,7 @@ func (h *HandlerService) PublicPostValidatorDashboardPublicIds(w http.ResponseWr return } - data, err := h.dai.CreateValidatorDashboardPublicId(r.Context(), dashboardId, name, req.ShareSettings.ShareGroups) + data, err := h.getDataAccessor(r).CreateValidatorDashboardPublicId(r.Context(), dashboardId, name, req.ShareSettings.ShareGroups) if err != nil { handleErr(w, r, err) return @@ -805,7 +805,7 @@ func (h *HandlerService) PublicPutValidatorDashboardPublicId(w http.ResponseWrit handleErr(w, r, v) return } - fetchedId, err := h.dai.GetValidatorDashboardIdByPublicId(r.Context(), publicDashboardId) + fetchedId, err := h.getDataAccessor(r).GetValidatorDashboardIdByPublicId(r.Context(), publicDashboardId) if err != nil { handleErr(w, r, err) return @@ -815,7 +815,7 @@ func (h *HandlerService) PublicPutValidatorDashboardPublicId(w http.ResponseWrit return } - data, err := h.dai.UpdateValidatorDashboardPublicId(r.Context(), publicDashboardId, name, req.ShareSettings.ShareGroups) + data, err := h.getDataAccessor(r).UpdateValidatorDashboardPublicId(r.Context(), publicDashboardId, name, req.ShareSettings.ShareGroups) if err != nil { handleErr(w, r, err) return @@ -847,7 +847,7 @@ func (h *HandlerService) PublicDeleteValidatorDashboardPublicId(w http.ResponseW handleErr(w, r, v) return } - fetchedId, err := h.dai.GetValidatorDashboardIdByPublicId(r.Context(), publicDashboardId) + fetchedId, err := h.getDataAccessor(r).GetValidatorDashboardIdByPublicId(r.Context(), publicDashboardId) if err != nil { handleErr(w, r, err) return @@ -857,7 +857,7 @@ func (h *HandlerService) PublicDeleteValidatorDashboardPublicId(w http.ResponseW return } - err = h.dai.RemoveValidatorDashboardPublicId(r.Context(), publicDashboardId) + err = h.getDataAccessor(r).RemoveValidatorDashboardPublicId(r.Context(), publicDashboardId) if err != nil { handleErr(w, r, err) return @@ -896,7 +896,7 @@ func (h *HandlerService) PublicPutValidatorDashboardArchiving(w http.ResponseWri } // check conditions for changing archival status - dashboardInfo, err := h.dai.GetValidatorDashboardInfo(r.Context(), dashboardId) + dashboardInfo, err := h.getDataAccessor(r).GetValidatorDashboardInfo(r.Context(), dashboardId) if err != nil { handleErr(w, r, err) return @@ -914,13 +914,13 @@ func (h *HandlerService) PublicPutValidatorDashboardArchiving(w http.ResponseWri handleErr(w, r, err) return } - dashboardCount, err := h.dai.GetUserValidatorDashboardCount(r.Context(), userId, !req.IsArchived) + dashboardCount, err := h.getDataAccessor(r).GetUserValidatorDashboardCount(r.Context(), userId, !req.IsArchived) if err != nil { handleErr(w, r, err) return } - userInfo, err := h.dai.GetUserInfo(r.Context(), userId) + userInfo, err := h.getDataAccessor(r).GetUserInfo(r.Context(), userId) if err != nil { handleErr(w, r, err) return @@ -950,7 +950,7 @@ func (h *HandlerService) PublicPutValidatorDashboardArchiving(w http.ResponseWri archivedReason = &enums.VDBArchivedReasons.User } - data, err := h.dai.UpdateValidatorDashboardArchiving(r.Context(), dashboardId, archivedReason) + data, err := h.getDataAccessor(r).UpdateValidatorDashboardArchiving(r.Context(), dashboardId, archivedReason) if err != nil { handleErr(w, r, err) return @@ -984,7 +984,7 @@ func (h *HandlerService) PublicGetValidatorDashboardSlotViz(w http.ResponseWrite handleErr(w, r, v) return } - data, err := h.dai.GetValidatorDashboardSlotViz(r.Context(), *dashboardId, groupIds) + data, err := h.getDataAccessor(r).GetValidatorDashboardSlotViz(r.Context(), *dashboardId, groupIds) if err != nil { handleErr(w, r, err) return @@ -1029,7 +1029,7 @@ func (h *HandlerService) PublicGetValidatorDashboardSummary(w http.ResponseWrite return } - data, paging, err := h.dai.GetValidatorDashboardSummary(r.Context(), *dashboardId, period, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit, protocolModes) + data, paging, err := h.getDataAccessor(r).GetValidatorDashboardSummary(r.Context(), *dashboardId, period, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit, protocolModes) if err != nil { handleErr(w, r, err) return @@ -1074,7 +1074,7 @@ func (h *HandlerService) PublicGetValidatorDashboardGroupSummary(w http.Response return } - data, err := h.dai.GetValidatorDashboardGroupSummary(r.Context(), *dashboardId, groupId, period, protocolModes) + data, err := h.getDataAccessor(r).GetValidatorDashboardGroupSummary(r.Context(), *dashboardId, groupId, period, protocolModes) if err != nil { handleErr(w, r, err) return @@ -1127,7 +1127,7 @@ func (h *HandlerService) PublicGetValidatorDashboardSummaryChart(w http.Response return } - data, err := h.dai.GetValidatorDashboardSummaryChart(ctx, *dashboardId, groupIds, efficiencyType, aggregation, afterTs, beforeTs) + data, err := h.getDataAccessor(r).GetValidatorDashboardSummaryChart(ctx, *dashboardId, groupIds, efficiencyType, aggregation, afterTs, beforeTs) if err != nil { handleErr(w, r, err) return @@ -1171,13 +1171,13 @@ func (h *HandlerService) PublicGetValidatorDashboardSummaryValidators(w http.Res duties := enums.ValidatorDuties switch duty { case duties.None: - indices, err = h.dai.GetValidatorDashboardSummaryValidators(r.Context(), *dashboardId, groupId) + indices, err = h.getDataAccessor(r).GetValidatorDashboardSummaryValidators(r.Context(), *dashboardId, groupId) case duties.Sync: - indices, err = h.dai.GetValidatorDashboardSyncSummaryValidators(r.Context(), *dashboardId, groupId, period) + indices, err = h.getDataAccessor(r).GetValidatorDashboardSyncSummaryValidators(r.Context(), *dashboardId, groupId, period) case duties.Slashed: - indices, err = h.dai.GetValidatorDashboardSlashingsSummaryValidators(r.Context(), *dashboardId, groupId, period) + indices, err = h.getDataAccessor(r).GetValidatorDashboardSlashingsSummaryValidators(r.Context(), *dashboardId, groupId, period) case duties.Proposal: - indices, err = h.dai.GetValidatorDashboardProposalSummaryValidators(r.Context(), *dashboardId, groupId, period) + indices, err = h.getDataAccessor(r).GetValidatorDashboardProposalSummaryValidators(r.Context(), *dashboardId, groupId, period) } if err != nil { handleErr(w, r, err) @@ -1227,7 +1227,7 @@ func (h *HandlerService) PublicGetValidatorDashboardRewards(w http.ResponseWrite return } - data, paging, err := h.dai.GetValidatorDashboardRewards(r.Context(), *dashboardId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit, protocolModes) + data, paging, err := h.getDataAccessor(r).GetValidatorDashboardRewards(r.Context(), *dashboardId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit, protocolModes) if err != nil { handleErr(w, r, err) return @@ -1268,7 +1268,7 @@ func (h *HandlerService) PublicGetValidatorDashboardGroupRewards(w http.Response return } - data, err := h.dai.GetValidatorDashboardGroupRewards(r.Context(), *dashboardId, groupId, epoch, protocolModes) + data, err := h.getDataAccessor(r).GetValidatorDashboardGroupRewards(r.Context(), *dashboardId, groupId, epoch, protocolModes) if err != nil { handleErr(w, r, err) return @@ -1304,7 +1304,7 @@ func (h *HandlerService) PublicGetValidatorDashboardRewardsChart(w http.Response return } - data, err := h.dai.GetValidatorDashboardRewardsChart(r.Context(), *dashboardId, protocolModes) + data, err := h.getDataAccessor(r).GetValidatorDashboardRewardsChart(r.Context(), *dashboardId, protocolModes) if err != nil { handleErr(w, r, err) return @@ -1350,7 +1350,7 @@ func (h *HandlerService) PublicGetValidatorDashboardDuties(w http.ResponseWriter return } - data, paging, err := h.dai.GetValidatorDashboardDuties(r.Context(), *dashboardId, epoch, groupId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit, protocolModes) + data, paging, err := h.getDataAccessor(r).GetValidatorDashboardDuties(r.Context(), *dashboardId, epoch, groupId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit, protocolModes) if err != nil { handleErr(w, r, err) return @@ -1392,7 +1392,7 @@ func (h *HandlerService) PublicGetValidatorDashboardBlocks(w http.ResponseWriter return } - data, paging, err := h.dai.GetValidatorDashboardBlocks(r.Context(), *dashboardId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit, protocolModes) + data, paging, err := h.getDataAccessor(r).GetValidatorDashboardBlocks(r.Context(), *dashboardId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit, protocolModes) if err != nil { handleErr(w, r, err) return @@ -1442,7 +1442,7 @@ func (h *HandlerService) PublicGetValidatorDashboardHeatmap(w http.ResponseWrite return } - data, err := h.dai.GetValidatorDashboardHeatmap(r.Context(), *dashboardId, protocolModes, aggregation, afterTs, beforeTs) + data, err := h.getDataAccessor(r).GetValidatorDashboardHeatmap(r.Context(), *dashboardId, protocolModes, aggregation, afterTs, beforeTs) if err != nil { handleErr(w, r, err) return @@ -1492,7 +1492,7 @@ func (h *HandlerService) PublicGetValidatorDashboardGroupHeatmap(w http.Response return } - data, err := h.dai.GetValidatorDashboardGroupHeatmap(r.Context(), *dashboardId, groupId, protocolModes, aggregation, requestedTimestamp) + data, err := h.getDataAccessor(r).GetValidatorDashboardGroupHeatmap(r.Context(), *dashboardId, groupId, protocolModes, aggregation, requestedTimestamp) if err != nil { handleErr(w, r, err) return @@ -1527,7 +1527,7 @@ func (h *HandlerService) PublicGetValidatorDashboardExecutionLayerDeposits(w htt return } - data, paging, err := h.dai.GetValidatorDashboardElDeposits(r.Context(), *dashboardId, pagingParams.cursor, pagingParams.limit) + data, paging, err := h.getDataAccessor(r).GetValidatorDashboardElDeposits(r.Context(), *dashboardId, pagingParams.cursor, pagingParams.limit) if err != nil { handleErr(w, r, err) return @@ -1563,7 +1563,7 @@ func (h *HandlerService) PublicGetValidatorDashboardConsensusLayerDeposits(w htt return } - data, paging, err := h.dai.GetValidatorDashboardClDeposits(r.Context(), *dashboardId, pagingParams.cursor, pagingParams.limit) + data, paging, err := h.getDataAccessor(r).GetValidatorDashboardClDeposits(r.Context(), *dashboardId, pagingParams.cursor, pagingParams.limit) if err != nil { handleErr(w, r, err) return @@ -1592,7 +1592,7 @@ func (h *HandlerService) PublicGetValidatorDashboardTotalConsensusLayerDeposits( handleErr(w, r, err) return } - data, err := h.dai.GetValidatorDashboardTotalClDeposits(r.Context(), *dashboardId) + data, err := h.getDataAccessor(r).GetValidatorDashboardTotalClDeposits(r.Context(), *dashboardId) if err != nil { handleErr(w, r, err) return @@ -1620,7 +1620,7 @@ func (h *HandlerService) PublicGetValidatorDashboardTotalExecutionLayerDeposits( handleErr(w, r, err) return } - data, err := h.dai.GetValidatorDashboardTotalElDeposits(r.Context(), *dashboardId) + data, err := h.getDataAccessor(r).GetValidatorDashboardTotalElDeposits(r.Context(), *dashboardId) if err != nil { handleErr(w, r, err) return @@ -1662,7 +1662,7 @@ func (h *HandlerService) PublicGetValidatorDashboardWithdrawals(w http.ResponseW return } - data, paging, err := h.dai.GetValidatorDashboardWithdrawals(r.Context(), *dashboardId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit, protocolModes) + data, paging, err := h.getDataAccessor(r).GetValidatorDashboardWithdrawals(r.Context(), *dashboardId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit, protocolModes) if err != nil { handleErr(w, r, err) return @@ -1699,7 +1699,7 @@ func (h *HandlerService) PublicGetValidatorDashboardTotalWithdrawals(w http.Resp return } - data, err := h.dai.GetValidatorDashboardTotalWithdrawals(r.Context(), *dashboardId, pagingParams.search, protocolModes) + data, err := h.getDataAccessor(r).GetValidatorDashboardTotalWithdrawals(r.Context(), *dashboardId, pagingParams.search, protocolModes) if err != nil { handleErr(w, r, err) return @@ -1739,7 +1739,7 @@ func (h *HandlerService) PublicGetValidatorDashboardRocketPool(w http.ResponseWr return } - data, paging, err := h.dai.GetValidatorDashboardRocketPool(r.Context(), *dashboardId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) + data, paging, err := h.getDataAccessor(r).GetValidatorDashboardRocketPool(r.Context(), *dashboardId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) if err != nil { handleErr(w, r, err) return @@ -1774,7 +1774,7 @@ func (h *HandlerService) PublicGetValidatorDashboardTotalRocketPool(w http.Respo return } - data, err := h.dai.GetValidatorDashboardTotalRocketPool(r.Context(), *dashboardId, pagingParams.search) + data, err := h.getDataAccessor(r).GetValidatorDashboardTotalRocketPool(r.Context(), *dashboardId, pagingParams.search) if err != nil { handleErr(w, r, err) return @@ -1817,7 +1817,7 @@ func (h *HandlerService) PublicGetValidatorDashboardRocketPoolMinipools(w http.R return } - data, paging, err := h.dai.GetValidatorDashboardRocketPoolMinipools(r.Context(), *dashboardId, nodeAddress, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) + data, paging, err := h.getDataAccessor(r).GetValidatorDashboardRocketPoolMinipools(r.Context(), *dashboardId, nodeAddress, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) if err != nil { handleErr(w, r, err) return @@ -1847,7 +1847,7 @@ func (h *HandlerService) PublicGetUserNotifications(w http.ResponseWriter, r *ht handleErr(w, r, err) return } - data, err := h.dai.GetNotificationOverview(r.Context(), userId) + data, err := h.getDataAccessor(r).GetNotificationOverview(r.Context(), userId) if err != nil { handleErr(w, r, err) return @@ -1882,20 +1882,13 @@ func (h *HandlerService) PublicGetUserNotificationDashboards(w http.ResponseWrit q := r.URL.Query() pagingParams := v.checkPagingParams(q) sort := checkSort[enums.NotificationDashboardsColumn](&v, q.Get("sort")) - chainId := v.checkNetworkParameter(q.Get("network")) - chainIds := []uint64{chainId} - // TODO replace with "networks" once multiple networks are supported - //chainIds := v.checkNetworksParameter(q.Get("networks")) + chainIds := v.checkNetworksParameter(q.Get("networks")) if v.hasErrors() { handleErr(w, r, v) return } - dataAccessor := h.dai - if isMockEnabled(r) { - dataAccessor = h.dummy - } - data, paging, err := dataAccessor.GetDashboardNotifications(r.Context(), userId, chainIds, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) + data, paging, err := h.getDataAccessor(r).GetDashboardNotifications(r.Context(), userId, chainIds, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) if err != nil { handleErr(w, r, err) return @@ -1931,11 +1924,7 @@ func (h *HandlerService) PublicGetUserNotificationsValidatorDashboard(w http.Res handleErr(w, r, v) return } - dataAccessor := h.dai - if isMockEnabled(r) { - dataAccessor = h.dummy - } - data, err := dataAccessor.GetValidatorDashboardNotificationDetails(r.Context(), dashboardId, groupId, epoch, search) + data, err := h.getDataAccessor(r).GetValidatorDashboardNotificationDetails(r.Context(), dashboardId, groupId, epoch, search) if err != nil { handleErr(w, r, err) return @@ -1970,7 +1959,7 @@ func (h *HandlerService) PublicGetUserNotificationsAccountDashboard(w http.Respo handleErr(w, r, v) return } - data, err := h.dai.GetAccountDashboardNotificationDetails(r.Context(), dashboardId, groupId, epoch, search) + data, err := h.getDataAccessor(r).GetAccountDashboardNotificationDetails(r.Context(), dashboardId, groupId, epoch, search) if err != nil { handleErr(w, r, err) return @@ -2008,7 +1997,7 @@ func (h *HandlerService) PublicGetUserNotificationMachines(w http.ResponseWriter handleErr(w, r, v) return } - data, paging, err := h.dai.GetMachineNotifications(r.Context(), userId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) + data, paging, err := h.getDataAccessor(r).GetMachineNotifications(r.Context(), userId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) if err != nil { handleErr(w, r, err) return @@ -2047,7 +2036,7 @@ func (h *HandlerService) PublicGetUserNotificationClients(w http.ResponseWriter, handleErr(w, r, v) return } - data, paging, err := h.dai.GetClientNotifications(r.Context(), userId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) + data, paging, err := h.getDataAccessor(r).GetClientNotifications(r.Context(), userId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) if err != nil { handleErr(w, r, err) return @@ -2086,7 +2075,7 @@ func (h *HandlerService) PublicGetUserNotificationRocketPool(w http.ResponseWrit handleErr(w, r, v) return } - data, paging, err := h.dai.GetRocketPoolNotifications(r.Context(), userId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) + data, paging, err := h.getDataAccessor(r).GetRocketPoolNotifications(r.Context(), userId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) if err != nil { handleErr(w, r, err) return @@ -2124,7 +2113,7 @@ func (h *HandlerService) PublicGetUserNotificationNetworks(w http.ResponseWriter handleErr(w, r, v) return } - data, paging, err := h.dai.GetNetworkNotifications(r.Context(), userId, pagingParams.cursor, *sort, pagingParams.limit) + data, paging, err := h.getDataAccessor(r).GetNetworkNotifications(r.Context(), userId, pagingParams.cursor, *sort, pagingParams.limit) if err != nil { handleErr(w, r, err) return @@ -2152,19 +2141,19 @@ func (h *HandlerService) PublicGetUserNotificationSettings(w http.ResponseWriter handleErr(w, r, err) return } - data, err := h.dai.GetNotificationSettings(r.Context(), userId) + data, err := h.getDataAccessor(r).GetNotificationSettings(r.Context(), userId) if err != nil { handleErr(w, r, err) return } // check premium perks - userInfo, err := h.dai.GetUserInfo(r.Context(), userId) + userInfo, err := h.getDataAccessor(r).GetUserInfo(r.Context(), userId) if err != nil { handleErr(w, r, err) return } - defaultSettings, err := h.dai.GetNotificationSettingsDefaultValues(r.Context()) + defaultSettings, err := h.getDataAccessor(r).GetNotificationSettingsDefaultValues(r.Context()) if err != nil { handleErr(w, r, err) return @@ -2227,12 +2216,12 @@ func (h *HandlerService) PublicPutUserNotificationSettingsGeneral(w http.Respons } // check premium perks - userInfo, err := h.dai.GetUserInfo(r.Context(), userId) + userInfo, err := h.getDataAccessor(r).GetUserInfo(r.Context(), userId) if err != nil { handleErr(w, r, err) return } - defaultSettings, err := h.dai.GetNotificationSettingsDefaultValues(r.Context()) + defaultSettings, err := h.getDataAccessor(r).GetNotificationSettingsDefaultValues(r.Context()) if err != nil { handleErr(w, r, err) return @@ -2248,7 +2237,7 @@ func (h *HandlerService) PublicPutUserNotificationSettingsGeneral(w http.Respons return } - err = h.dai.UpdateNotificationSettingsGeneral(r.Context(), userId, req) + err = h.getDataAccessor(r).UpdateNotificationSettingsGeneral(r.Context(), userId, req) if err != nil { handleErr(w, r, err) return @@ -2313,7 +2302,7 @@ func (h *HandlerService) PublicPutUserNotificationSettingsNetworks(w http.Respon IsNewRewardRoundSubscribed: req.IsNewRewardRoundSubscribed, } - err = h.dai.UpdateNotificationSettingsNetworks(r.Context(), userId, chainId, settings) + err = h.getDataAccessor(r).UpdateNotificationSettingsNetworks(r.Context(), userId, chainId, settings) if err != nil { handleErr(w, r, err) return @@ -2362,7 +2351,7 @@ func (h *HandlerService) PublicPutUserNotificationSettingsPairedDevices(w http.R handleErr(w, r, v) return } - err = h.dai.UpdateNotificationSettingsPairedDevice(r.Context(), userId, pairedDeviceId, name, req.IsNotificationsEnabled) + err = h.getDataAccessor(r).UpdateNotificationSettingsPairedDevice(r.Context(), userId, pairedDeviceId, name, req.IsNotificationsEnabled) if err != nil { handleErr(w, r, err) return @@ -2402,7 +2391,7 @@ func (h *HandlerService) PublicDeleteUserNotificationSettingsPairedDevices(w htt handleErr(w, r, v) return } - err = h.dai.DeleteNotificationSettingsPairedDevice(r.Context(), userId, pairedDeviceId) + err = h.getDataAccessor(r).DeleteNotificationSettingsPairedDevice(r.Context(), userId, pairedDeviceId) if err != nil { handleErr(w, r, err) return @@ -2442,7 +2431,7 @@ func (h *HandlerService) PublicPutUserNotificationSettingsClient(w http.Response handleErr(w, r, v) return } - data, err := h.dai.UpdateNotificationSettingsClients(r.Context(), userId, clientId, req.IsSubscribed) + data, err := h.getDataAccessor(r).UpdateNotificationSettingsClients(r.Context(), userId, clientId, req.IsSubscribed) if err != nil { handleErr(w, r, err) return @@ -2480,19 +2469,19 @@ func (h *HandlerService) PublicGetUserNotificationSettingsDashboards(w http.Resp handleErr(w, r, v) return } - data, paging, err := h.dai.GetNotificationSettingsDashboards(r.Context(), userId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) + data, paging, err := h.getDataAccessor(r).GetNotificationSettingsDashboards(r.Context(), userId, pagingParams.cursor, *sort, pagingParams.search, pagingParams.limit) if err != nil { handleErr(w, r, err) return } // if users premium perks do not allow subscriptions, set them to false in the response // TODO: once stripe payments run in v2, this should be removed and the notification settings should be updated upon a tier change instead - userInfo, err := h.dai.GetUserInfo(r.Context(), userId) + userInfo, err := h.getDataAccessor(r).GetUserInfo(r.Context(), userId) if err != nil { handleErr(w, r, err) return } - defaultSettings, err := h.dai.GetNotificationSettingsDefaultValues(r.Context()) + defaultSettings, err := h.getDataAccessor(r).GetNotificationSettingsDefaultValues(r.Context()) if err != nil { handleErr(w, r, err) return @@ -2559,7 +2548,7 @@ func (h *HandlerService) PublicPutUserNotificationSettingsValidatorDashboard(w h handleErr(w, r, v) return } - userInfo, err := h.dai.GetUserInfo(r.Context(), userId) + userInfo, err := h.getDataAccessor(r).GetUserInfo(r.Context(), userId) if err != nil { handleErr(w, r, err) return @@ -2573,7 +2562,7 @@ func (h *HandlerService) PublicPutUserNotificationSettingsValidatorDashboard(w h return } - err = h.dai.UpdateNotificationSettingsValidatorDashboard(r.Context(), userId, dashboardId, groupId, req) + err = h.getDataAccessor(r).UpdateNotificationSettingsValidatorDashboard(r.Context(), userId, dashboardId, groupId, req) if err != nil { handleErr(w, r, err) return @@ -2646,7 +2635,7 @@ func (h *HandlerService) PublicPutUserNotificationSettingsAccountDashboard(w htt IsERC721TokenTransfersSubscribed: req.IsERC721TokenTransfersSubscribed, IsERC1155TokenTransfersSubscribed: req.IsERC1155TokenTransfersSubscribed, } - err = h.dai.UpdateNotificationSettingsAccountDashboard(r.Context(), userId, dashboardId, groupId, settings) + err = h.getDataAccessor(r).UpdateNotificationSettingsAccountDashboard(r.Context(), userId, dashboardId, groupId, settings) if err != nil { handleErr(w, r, err) return diff --git a/backend/pkg/api/handlers/search_handlers.go b/backend/pkg/api/handlers/search_handlers.go index 3b8f058a1..49d43acdd 100644 --- a/backend/pkg/api/handlers/search_handlers.go +++ b/backend/pkg/api/handlers/search_handlers.go @@ -120,208 +120,163 @@ func (h *HandlerService) InternalPostSearch(w http.ResponseWriter, r *http.Reque // Search Helper Functions func (h *HandlerService) handleSearch(ctx context.Context, input string, searchType searchTypeKey, chainId uint64) (*types.SearchResult, error) { - select { - case <-ctx.Done(): - return nil, nil + switch searchType { + case validatorByIndex: + return h.handleSearchValidatorByIndex(ctx, input, chainId) + case validatorByPublicKey: + return h.handleSearchValidatorByPublicKey(ctx, input, chainId) + case validatorsByDepositAddress: + return h.handleSearchValidatorsByDepositAddress(ctx, input, chainId) + case validatorsByDepositEnsName: + return h.handleSearchValidatorsByDepositEnsName(ctx, input, chainId) + case validatorsByWithdrawalCredential: + return h.handleSearchValidatorsByWithdrawalCredential(ctx, input, chainId) + case validatorsByWithdrawalAddress: + return h.handleSearchValidatorsByWithdrawalAddress(ctx, input, chainId) + case validatorsByWithdrawalEns: + return h.handleSearchValidatorsByWithdrawalEnsName(ctx, input, chainId) + case validatorsByGraffiti: + return h.handleSearchValidatorsByGraffiti(ctx, input, chainId) default: - switch searchType { - case validatorByIndex: - return h.handleSearchValidatorByIndex(ctx, input, chainId) - case validatorByPublicKey: - return h.handleSearchValidatorByPublicKey(ctx, input, chainId) - case validatorsByDepositAddress: - return h.handleSearchValidatorsByDepositAddress(ctx, input, chainId) - case validatorsByDepositEnsName: - return h.handleSearchValidatorsByDepositEnsName(ctx, input, chainId) - case validatorsByWithdrawalCredential: - return h.handleSearchValidatorsByWithdrawalCredential(ctx, input, chainId) - case validatorsByWithdrawalAddress: - return h.handleSearchValidatorsByWithdrawalAddress(ctx, input, chainId) - case validatorsByWithdrawalEns: - return h.handleSearchValidatorsByWithdrawalEnsName(ctx, input, chainId) - case validatorsByGraffiti: - return h.handleSearchValidatorsByGraffiti(ctx, input, chainId) - default: - return nil, errors.New("invalid search type") - } + return nil, errors.New("invalid search type") } } func (h *HandlerService) handleSearchValidatorByIndex(ctx context.Context, input string, chainId uint64) (*types.SearchResult, error) { - select { - case <-ctx.Done(): - return nil, nil - default: - index, err := strconv.ParseUint(input, 10, 64) - if err != nil { - // input should've been checked by the regex before, this should never happen - return nil, err - } - result, err := h.dai.GetSearchValidatorByIndex(ctx, chainId, index) - if err != nil { - return nil, err - } - - return &types.SearchResult{ - Type: string(validatorByIndex), - ChainId: chainId, - HashValue: "0x" + hex.EncodeToString(result.PublicKey), - NumValue: &result.Index, - }, nil + index, err := strconv.ParseUint(input, 10, 64) + if err != nil { + // input should've been checked by the regex before, this should never happen + return nil, err + } + result, err := h.daService.GetSearchValidatorByIndex(ctx, chainId, index) + if err != nil { + return nil, err } + + return &types.SearchResult{ + Type: string(validatorByIndex), + ChainId: chainId, + HashValue: "0x" + hex.EncodeToString(result.PublicKey), + NumValue: &result.Index, + }, nil } func (h *HandlerService) handleSearchValidatorByPublicKey(ctx context.Context, input string, chainId uint64) (*types.SearchResult, error) { - select { - case <-ctx.Done(): - return nil, nil - default: - publicKey, err := hex.DecodeString(strings.TrimPrefix(input, "0x")) - if err != nil { - // input should've been checked by the regex before, this should never happen - return nil, err - } - result, err := h.dai.GetSearchValidatorByPublicKey(ctx, chainId, publicKey) - if err != nil { - return nil, err - } - - return &types.SearchResult{ - Type: string(validatorByPublicKey), - ChainId: chainId, - HashValue: "0x" + hex.EncodeToString(result.PublicKey), - NumValue: &result.Index, - }, nil + publicKey, err := hex.DecodeString(strings.TrimPrefix(input, "0x")) + if err != nil { + // input should've been checked by the regex before, this should never happen + return nil, err } + result, err := h.daService.GetSearchValidatorByPublicKey(ctx, chainId, publicKey) + if err != nil { + return nil, err + } + + return &types.SearchResult{ + Type: string(validatorByPublicKey), + ChainId: chainId, + HashValue: "0x" + hex.EncodeToString(result.PublicKey), + NumValue: &result.Index, + }, nil } func (h *HandlerService) handleSearchValidatorsByDepositAddress(ctx context.Context, input string, chainId uint64) (*types.SearchResult, error) { - select { - case <-ctx.Done(): - return nil, nil - default: - address, err := hex.DecodeString(strings.TrimPrefix(input, "0x")) - if err != nil { - return nil, err - } - result, err := h.dai.GetSearchValidatorsByDepositAddress(ctx, chainId, address) - if err != nil { - return nil, err - } - - return &types.SearchResult{ - Type: string(validatorsByDepositAddress), - ChainId: chainId, - HashValue: "0x" + hex.EncodeToString(result.Address), - NumValue: &result.Count, - }, nil + address, err := hex.DecodeString(strings.TrimPrefix(input, "0x")) + if err != nil { + return nil, err + } + result, err := h.daService.GetSearchValidatorsByDepositAddress(ctx, chainId, address) + if err != nil { + return nil, err } + + return &types.SearchResult{ + Type: string(validatorsByDepositAddress), + ChainId: chainId, + HashValue: "0x" + hex.EncodeToString(result.Address), + NumValue: &result.Count, + }, nil } func (h *HandlerService) handleSearchValidatorsByDepositEnsName(ctx context.Context, input string, chainId uint64) (*types.SearchResult, error) { - select { - case <-ctx.Done(): - return nil, nil - default: - result, err := h.dai.GetSearchValidatorsByDepositEnsName(ctx, chainId, input) - if err != nil { - return nil, err - } - - return &types.SearchResult{ - Type: string(validatorsByDepositEnsName), - ChainId: chainId, - StrValue: result.EnsName, - HashValue: "0x" + hex.EncodeToString(result.Address), - NumValue: &result.Count, - }, nil + result, err := h.daService.GetSearchValidatorsByDepositEnsName(ctx, chainId, input) + if err != nil { + return nil, err } + + return &types.SearchResult{ + Type: string(validatorsByDepositEnsName), + ChainId: chainId, + StrValue: result.EnsName, + HashValue: "0x" + hex.EncodeToString(result.Address), + NumValue: &result.Count, + }, nil } func (h *HandlerService) handleSearchValidatorsByWithdrawalCredential(ctx context.Context, input string, chainId uint64) (*types.SearchResult, error) { - select { - case <-ctx.Done(): - return nil, nil - default: - withdrawalCredential, err := hex.DecodeString(strings.TrimPrefix(input, "0x")) - if err != nil { - return nil, err - } - result, err := h.dai.GetSearchValidatorsByWithdrawalCredential(ctx, chainId, withdrawalCredential) - if err != nil { - return nil, err - } - - return &types.SearchResult{ - Type: string(validatorsByWithdrawalCredential), - ChainId: chainId, - HashValue: "0x" + hex.EncodeToString(result.WithdrawalCredential), - NumValue: &result.Count, - }, nil + withdrawalCredential, err := hex.DecodeString(strings.TrimPrefix(input, "0x")) + if err != nil { + return nil, err } + result, err := h.daService.GetSearchValidatorsByWithdrawalCredential(ctx, chainId, withdrawalCredential) + if err != nil { + return nil, err + } + + return &types.SearchResult{ + Type: string(validatorsByWithdrawalCredential), + ChainId: chainId, + HashValue: "0x" + hex.EncodeToString(result.WithdrawalCredential), + NumValue: &result.Count, + }, nil } func (h *HandlerService) handleSearchValidatorsByWithdrawalAddress(ctx context.Context, input string, chainId uint64) (*types.SearchResult, error) { - select { - case <-ctx.Done(): - return nil, nil - default: - withdrawalString := "010000000000000000000000" + strings.TrimPrefix(input, "0x") - withdrawalCredential, err := hex.DecodeString(withdrawalString) - if err != nil { - return nil, err - } - result, err := h.dai.GetSearchValidatorsByWithdrawalCredential(ctx, chainId, withdrawalCredential) - if err != nil { - return nil, err - } - - return &types.SearchResult{ - Type: string(validatorsByWithdrawalAddress), - ChainId: chainId, - HashValue: "0x" + hex.EncodeToString(result.WithdrawalCredential), - NumValue: &result.Count, - }, nil + withdrawalString := "010000000000000000000000" + strings.TrimPrefix(input, "0x") + withdrawalCredential, err := hex.DecodeString(withdrawalString) + if err != nil { + return nil, err + } + result, err := h.daService.GetSearchValidatorsByWithdrawalCredential(ctx, chainId, withdrawalCredential) + if err != nil { + return nil, err } + + return &types.SearchResult{ + Type: string(validatorsByWithdrawalAddress), + ChainId: chainId, + HashValue: "0x" + hex.EncodeToString(result.WithdrawalCredential), + NumValue: &result.Count, + }, nil } func (h *HandlerService) handleSearchValidatorsByWithdrawalEnsName(ctx context.Context, input string, chainId uint64) (*types.SearchResult, error) { - select { - case <-ctx.Done(): - return nil, nil - default: - result, err := h.dai.GetSearchValidatorsByWithdrawalEnsName(ctx, chainId, input) - if err != nil { - return nil, err - } - - return &types.SearchResult{ - Type: string(validatorsByWithdrawalEns), - ChainId: chainId, - StrValue: result.EnsName, - HashValue: "0x" + hex.EncodeToString(result.Address), - NumValue: &result.Count, - }, nil + result, err := h.daService.GetSearchValidatorsByWithdrawalEnsName(ctx, chainId, input) + if err != nil { + return nil, err } + + return &types.SearchResult{ + Type: string(validatorsByWithdrawalEns), + ChainId: chainId, + StrValue: result.EnsName, + HashValue: "0x" + hex.EncodeToString(result.Address), + NumValue: &result.Count, + }, nil } func (h *HandlerService) handleSearchValidatorsByGraffiti(ctx context.Context, input string, chainId uint64) (*types.SearchResult, error) { - select { - case <-ctx.Done(): - return nil, nil - default: - result, err := h.dai.GetSearchValidatorsByGraffiti(ctx, chainId, input) - if err != nil { - return nil, err - } - - return &types.SearchResult{ - Type: string(validatorsByGraffiti), - ChainId: chainId, - StrValue: result.Graffiti, - NumValue: &result.Count, - }, nil + result, err := h.daService.GetSearchValidatorsByGraffiti(ctx, chainId, input) + if err != nil { + return nil, err } + + return &types.SearchResult{ + Type: string(validatorsByGraffiti), + ChainId: chainId, + StrValue: result.Graffiti, + NumValue: &result.Count, + }, nil } // -------------------------------------- diff --git a/backend/pkg/api/router.go b/backend/pkg/api/router.go index 36beea764..946fca724 100644 --- a/backend/pkg/api/router.go +++ b/backend/pkg/api/router.go @@ -39,6 +39,11 @@ func NewApiRouter(dataAccessor dataaccess.DataAccessor, dummy dataaccess.DataAcc publicRouter.Use(handlerService.StoreUserIdByApiKeyMiddleware) internalRouter.Use(handlerService.StoreUserIdBySessionMiddleware) + if cfg.DeploymentType != "production" { + publicRouter.Use(handlerService.StoreIsMockedFlagMiddleware) + internalRouter.Use(handlerService.StoreIsMockedFlagMiddleware) + } + addRoutes(handlerService, publicRouter, internalRouter, cfg) addLegacyRoutes(handlerService, legacyRouter) diff --git a/backend/pkg/api/types/user.go b/backend/pkg/api/types/user.go index 47dd0b3fd..d31ac8e0e 100644 --- a/backend/pkg/api/types/user.go +++ b/backend/pkg/api/types/user.go @@ -1,6 +1,7 @@ package types const UserGroupAdmin = "ADMIN" +const UserGroupDev = "DEV" type UserInfo struct { Id uint64 `json:"id"` diff --git a/backend/pkg/exporter/modules/relays.go b/backend/pkg/exporter/modules/relays.go index 205586792..8c0049e98 100644 --- a/backend/pkg/exporter/modules/relays.go +++ b/backend/pkg/exporter/modules/relays.go @@ -107,9 +107,10 @@ func fetchDeliveredPayloads(r types.Relay, offset uint64) ([]BidTrace, error) { if offset != 0 { url += fmt.Sprintf("&cursor=%v", offset) } - - //nolint:gosec - resp, err := http.Get(url) + client := &http.Client{ + Timeout: time.Second * 30, + } + resp, err := client.Get(url) if err != nil { log.Error(err, "error retrieving delivered payloads", 0, map[string]interface{}{"relay": r.ID}) diff --git a/frontend/components/notifications/NotificationsDashboardDialogEntity.vue b/frontend/components/notifications/NotificationsDashboardDialogEntity.vue index c60a8ec19..8fa24bd43 100644 --- a/frontend/components/notifications/NotificationsDashboardDialogEntity.vue +++ b/frontend/components/notifications/NotificationsDashboardDialogEntity.vue @@ -6,7 +6,6 @@ import { faCube, faFileSignature, faGlobe, - faMoneyBill, faPowerOff, faRocket, faUserSlash, @@ -89,30 +88,6 @@ defineEmits<{ (e: 'filter-changed', value: string): void }>() - - - - - () > {{ proposal.index }} - - - - - - () > {{ upcomingProposal.index }} -