diff --git a/routes/index_test.go b/routes/index_test.go index 1dae41c0c..4768f6f8f 100644 --- a/routes/index_test.go +++ b/routes/index_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "net/url" + "os" "strings" "sync" "testing" @@ -15,6 +16,7 @@ import ( "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" "github.com/rs/cors" + "github.com/stakwork/sphinx-tribes/auth" "github.com/stakwork/sphinx-tribes/config" "github.com/stakwork/sphinx-tribes/utils" "github.com/stretchr/testify/assert" @@ -1248,3 +1250,341 @@ func TestInternalServerErrorHandler(t *testing.T) { }) } } + +func TestNewRouter(t *testing.T) { + t.Run("Basic Router Initialization", func(t *testing.T) { + server := NewRouter() + assert.NotNil(t, server) + assert.NotNil(t, server.Handler) + + assert.Equal(t, ":5002", server.Addr) + }) + + t.Run("Custom Port Configuration", func(t *testing.T) { + originalPort := os.Getenv("PORT") + os.Setenv("PORT", "8080") + defer os.Setenv("PORT", originalPort) + + server := NewRouter() + assert.Equal(t, ":8080", server.Addr) + }) + + t.Run("Route Mounting Verification", func(t *testing.T) { + server := NewRouter() + router := server.Handler.(*chi.Mux) + + routes := []string{ + "/tribes", + "/bots", + "/bot", + "/people", + "/person", + "/connectioncodes", + "/github_issue", + "/gobounties", + "/workspaces", + "/metrics", + "/features", + "/workflows", + "/bounties/ticket", + "/hivechat", + "/test", + "/feature-flags", + "/snippet", + "/activities", + } + + for _, route := range routes { + req := httptest.NewRequest("GET", route, nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + assert.NotEqual(t, http.StatusNotFound, rr.Code, "Route %s should exist", route) + } + }) + + t.Run("Public Endpoints Group", func(t *testing.T) { + server := NewRouter() + router := server.Handler.(*chi.Mux) + + publicEndpoints := []struct { + method string + path string + }{ + {"GET", "/tribe_by_feed"}, + {"GET", "/leaderboard/test-uuid"}, + {"GET", "/tribe_by_un/test-name"}, + {"GET", "/tribes_by_owner/test-pubkey"}, + {"GET", "/search/bots/test-query"}, + {"GET", "/podcast"}, + {"GET", "/feed"}, + {"GET", "/search_podcasts"}, + {"GET", "/search_podcast_episodes"}, + {"GET", "/search_youtube"}, + {"GET", "/search_youtube_videos"}, + {"GET", "/youtube_videos"}, + {"GET", "/admin_pubkeys"}, + {"GET", "/ask"}, + {"GET", "/poll/test-challenge"}, + {"POST", "/save"}, + {"GET", "/save/test-key"}, + {"GET", "/migrate_bounties"}, + {"GET", "/websocket"}, + } + + for _, endpoint := range publicEndpoints { + req := httptest.NewRequest(endpoint.method, endpoint.path, nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + assert.NotEqual(t, http.StatusNotFound, rr.Code, + "Endpoint %s %s should exist", endpoint.method, endpoint.path) + } + }) + + t.Run("Protected Endpoints Group", func(t *testing.T) { + r := chi.NewRouter() + r.Use(auth.PubKeyContext) + + r.Post("/channel", FeatureMockHandler(t, http.StatusUnauthorized)) + r.Post("/leaderboard/{tribe_uuid}", FeatureMockHandler(t, http.StatusUnauthorized)) + r.Put("/leaderboard/{tribe_uuid}", FeatureMockHandler(t, http.StatusUnauthorized)) + r.Put("/tribe", FeatureMockHandler(t, http.StatusUnauthorized)) + r.Put("/tribestats", FeatureMockHandler(t, http.StatusUnauthorized)) + r.Delete("/tribe/{uuid}", FeatureMockHandler(t, http.StatusUnauthorized)) + r.Put("/tribeactivity/{uuid}", FeatureMockHandler(t, http.StatusUnauthorized)) + r.Put("/tribepreview/{uuid}", FeatureMockHandler(t, http.StatusUnauthorized)) + r.Post("/verify/{challenge}", FeatureMockHandler(t, http.StatusUnauthorized)) + r.Post("/badges", FeatureMockHandler(t, http.StatusUnauthorized)) + r.Delete("/channel/{id}", FeatureMockHandler(t, http.StatusUnauthorized)) + r.Delete("/ticket/{pubKey}/{created}", FeatureMockHandler(t, http.StatusUnauthorized)) + r.Get("/poll/invoice/{paymentRequest}", FeatureMockHandler(t, http.StatusUnauthorized)) + r.Post("/meme_upload", FeatureMockHandler(t, http.StatusUnauthorized)) + r.Get("/admin/auth", FeatureMockHandler(t, http.StatusUnauthorized)) + + protectedEndpoints := []struct { + name string + method string + path string + expectedStatus int + }{ + { + name: "Test Protected POST /channel Route", + method: http.MethodPost, + path: "/channel", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Test Protected POST /leaderboard/{tribe_uuid} Route", + method: http.MethodPost, + path: "/leaderboard/test-uuid", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Test Protected PUT /leaderboard/{tribe_uuid} Route", + method: http.MethodPut, + path: "/leaderboard/test-uuid", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Test Protected PUT /tribe Route", + method: http.MethodPut, + path: "/tribe", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Test Protected PUT /tribestats Route", + method: http.MethodPut, + path: "/tribestats", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Test Protected DELETE /tribe/{uuid} Route", + method: http.MethodDelete, + path: "/tribe/test-uuid", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Test Protected PUT /tribeactivity/{uuid} Route", + method: http.MethodPut, + path: "/tribeactivity/test-uuid", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Test Protected PUT /tribepreview/{uuid} Route", + method: http.MethodPut, + path: "/tribepreview/test-uuid", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Test Protected POST /verify/{challenge} Route", + method: http.MethodPost, + path: "/verify/test-challenge", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Test Protected POST /badges Route", + method: http.MethodPost, + path: "/badges", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Test Protected DELETE /channel/{id} Route", + method: http.MethodDelete, + path: "/channel/test-id", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Test Protected DELETE /ticket/{pubKey}/{created} Route", + method: http.MethodDelete, + path: "/ticket/test-pubkey/test-created", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Test Protected GET /poll/invoice/{paymentRequest} Route", + method: http.MethodGet, + path: "/poll/invoice/test-request", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Test Protected POST /meme_upload Route", + method: http.MethodPost, + path: "/meme_upload", + expectedStatus: http.StatusUnauthorized, + }, + { + name: "Test Protected GET /admin/auth Route", + method: http.MethodGet, + path: "/admin/auth", + expectedStatus: http.StatusUnauthorized, + }, + } + + for _, tc := range protectedEndpoints { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(tc.method, tc.path, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + assert.Equal(t, tc.expectedStatus, w.Code, "Handler returned wrong status code") + }) + } + }) + + t.Run("Authentication Endpoints Group", func(t *testing.T) { + server := NewRouter() + router := server.Handler.(*chi.Mux) + + authEndpoints := []struct { + method string + path string + }{ + {"GET", "/lnauth_login"}, + {"GET", "/lnauth"}, + {"GET", "/refresh_jwt"}, + {"POST", "/invoices"}, + {"POST", "/budgetinvoices"}, + } + + for _, endpoint := range authEndpoints { + req := httptest.NewRequest(endpoint.method, endpoint.path, nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + assert.NotEqual(t, http.StatusNotFound, rr.Code, + "Auth endpoint %s %s should exist", endpoint.method, endpoint.path) + } + }) + + t.Run("Timeout Middleware", func(t *testing.T) { + router := chi.NewRouter() + router.Use(middleware.Timeout(10 * time.Millisecond)) + + router.Get("/slow", func(w http.ResponseWriter, r *http.Request) { + timer := time.NewTimer(100 * time.Millisecond) + defer timer.Stop() + + select { + case <-r.Context().Done(): + w.WriteHeader(http.StatusServiceUnavailable) + return + case <-timer.C: + w.WriteHeader(http.StatusOK) + } + }) + + req := httptest.NewRequest("GET", "/slow", nil) + w := httptest.NewRecorder() + + done := make(chan bool) + go func() { + router.ServeHTTP(w, req) + done <- true + }() + + select { + case <-done: + assert.Equal(t, http.StatusServiceUnavailable, w.Code, "Should timeout and return 503") + case <-time.After(62 * time.Second): + t.Fatal("Request did not timeout as expected") + } + }) + + t.Run("Error Handler", func(t *testing.T) { + server := NewRouter() + router := server.Handler.(*chi.Mux) + + router.Get("/panic", func(w http.ResponseWriter, r *http.Request) { + panic("test panic") + }) + + req := httptest.NewRequest("GET", "/panic", nil) + rr := httptest.NewRecorder() + + router.ServeHTTP(rr, req) + assert.Equal(t, http.StatusInternalServerError, rr.Code) + }) + + t.Run("Feature Flag Middleware", func(t *testing.T) { + server := NewRouter() + router := server.Handler.(*chi.Mux) + + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + + router.ServeHTTP(rr, req) + assert.NotEqual(t, http.StatusNotFound, rr.Code) + }) + + t.Run("Request ID Generation", func(t *testing.T) { + router := chi.NewRouter() + router.Use(middleware.RequestID) + + router.Get("/test-id", func(w http.ResponseWriter, r *http.Request) { + reqID := middleware.GetReqID(r.Context()) + assert.NotEmpty(t, reqID, "Request ID should be present") + w.Header().Set("X-Request-ID", reqID) + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("GET", "/test-id", nil) + rr := httptest.NewRecorder() + + router.ServeHTTP(rr, req) + assert.Equal(t, http.StatusOK, rr.Code, "Handler should execute successfully") + assert.NotEmpty(t, rr.Header().Get("X-Request-ID"), "Request ID should be present in response header") +}) + +t.Run("Logger Middleware", func(t *testing.T) { + router := chi.NewRouter() + router.Use(middleware.Logger) + + router.Get("/test-log", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + req := httptest.NewRequest("GET", "/test-log", nil) + w := httptest.NewRecorder() + + router.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code, "Handler should execute successfully") +}) + +}