Skip to content

Commit

Permalink
feat: added state management for workers and an API route to query wo…
Browse files Browse the repository at this point in the history
…rkers state - untested
  • Loading branch information
equals215 committed Jun 24, 2024
1 parent ae4878c commit 62a8c11
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 41 deletions.
4 changes: 2 additions & 2 deletions cmd/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/internetarchive/Zeno/internal/pkg/frontier"
"github.com/internetarchive/Zeno/internal/pkg/utils"
"github.com/paulbellamy/ratecounter"
"github.com/remeh/sizedwaitgroup"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -50,7 +49,8 @@ func InitCrawlWithCMD(flags config.Flags) *crawl.Crawl {
c.JobPath = path.Join("jobs", flags.Job)

c.Workers = flags.Workers
c.WorkerPool = sizedwaitgroup.New(c.Workers)
c.WorkerPool = make([]*crawl.Worker, 0)
c.WorkerStopTimeout = time.Second * 60 // Placeholder for WorkerStopTimeout
c.MaxConcurrentAssets = flags.MaxConcurrentAssets

c.Seencheck = flags.Seencheck
Expand Down
30 changes: 30 additions & 0 deletions internal/pkg/crawl/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package crawl

import (
"os"
"strconv"
"time"

"github.com/gin-contrib/pprof"
Expand All @@ -11,6 +12,17 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
)

type APIWorkersState struct {
Workers []*APIWorkerState `json:"workers"`
}

type APIWorkerState struct {
WorkerID int `json:"worker_id"`
Status string `json:"status"`
LastError string `json:"last_error"`
Locked bool `json:"locked"`
}

func (crawl *Crawl) startAPI() {
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = logInfo.Out
Expand Down Expand Up @@ -56,6 +68,24 @@ func (crawl *Crawl) startAPI() {
r.GET("/metrics", gin.WrapH(promhttp.Handler()))
}

r.GET("/workers", func(c *gin.Context) {
workersState := crawl.GetWorkerState(-1)
c.JSON(200, workersState)
})

r.GET("/workers/:worker_id", func(c *gin.Context) {
workerID := c.Param("worker_id")
workerIDInt, err := strconv.Atoi(workerID)
if err != nil {
c.JSON(404, gin.H{
"error": "Worker not found",
})
return
}
workersState := crawl.GetWorkerState(workerIDInt)
c.JSON(200, workersState)
})

err := r.Run(":" + crawl.APIPort)
if err != nil {
logError.Fatalf("unable to start API: %s", err.Error())
Expand Down
41 changes: 21 additions & 20 deletions internal/pkg/crawl/capture.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func (c *Crawl) captureAsset(item *frontier.Item, cookies []*http.Cookie) error
}

// Capture capture the URL and return the outlinks
func (c *Crawl) Capture(item *frontier.Item) {
func (c *Crawl) Capture(item *frontier.Item) error {
var (
resp *http.Response
waitGroup sync.WaitGroup
Expand All @@ -237,7 +237,7 @@ func (c *Crawl) Capture(item *frontier.Item) {
req, err := http.NewRequest("GET", utils.URLToString(item.URL), nil)
if err != nil {
logError.WithFields(c.genLogFields(err, item.URL, nil)).Error("error while preparing GET request")
return
return err
}

if item.Hop > 0 && item.ParentItem != nil {
Expand Down Expand Up @@ -307,14 +307,14 @@ func (c *Crawl) Capture(item *frontier.Item) {
// Execute request
resp, err = c.executeGET(item, req, false)
if err != nil && err.Error() == "URL from redirection has already been seen" {
return
return err
} else if err != nil && err.Error() == "URL is being rate limited, sending back to HQ" {
c.HQProducerChannel <- frontier.NewItem(item.URL, item.ParentItem, item.Type, item.Hop, "", true)
logError.WithFields(c.genLogFields(err, item.URL, nil)).Error("URL is being rate limited, sending back to HQ")
return
return err
} else if err != nil {
logError.WithFields(c.genLogFields(err, item.URL, nil)).Error("error while executing GET request")
return
return err
}
defer resp.Body.Close()

Expand All @@ -335,41 +335,41 @@ func (c *Crawl) Capture(item *frontier.Item) {
base, err := url.Parse(utils.URLToString(resp.Request.URL))
if err != nil {
logError.WithFields(c.genLogFields(err, item.URL, nil)).Error("error while parsing base URL")
return
return err
}

// If the response is a JSON document, we want to scrape it for links
if strings.Contains(resp.Header.Get("Content-Type"), "json") {
jsonBody, err := io.ReadAll(resp.Body)
if err != nil {
logError.WithFields(c.genLogFields(err, item.URL, nil)).Error("error while reading JSON body")
return
return err
}

outlinksFromJSON, err := getURLsFromJSON(string(jsonBody))
if err != nil {
logError.WithFields(c.genLogFields(err, item.URL, nil)).Error("error while getting URLs from JSON")
return
return err
}

waitGroup.Add(1)
go c.queueOutlinks(utils.MakeAbsolute(item.URL, utils.StringSliceToURLSlice(outlinksFromJSON)), item, &waitGroup)

return
return err
}

// If the response is an XML document, we want to scrape it for links
if strings.Contains(resp.Header.Get("Content-Type"), "xml") {
xmlBody, err := io.ReadAll(resp.Body)
if err != nil {
logError.WithFields(c.genLogFields(err, item.URL, nil)).Error("error while reading XML body")
return
return err
}

mv, err := mxj.NewMapXml(xmlBody)
if err != nil {
logError.WithFields(c.genLogFields(err, item.URL, nil)).Error("error while parsing XML body")
return
return err
}

for _, value := range mv.LeafValues() {
Expand All @@ -390,14 +390,14 @@ func (c *Crawl) Capture(item *frontier.Item) {
logError.WithFields(c.genLogFields(err, item.URL, nil)).Error("error while reading response body")
}

return
return err
}

// Turn the response into a doc that we will scrape for outlinks and assets.
doc, err := goquery.NewDocumentFromReader(resp.Body)
if err != nil {
logError.WithFields(c.genLogFields(err, item.URL, nil)).Error("error while creating goquery document")
return
return err
}

// Execute site-specific code on the document
Expand All @@ -406,7 +406,7 @@ func (c *Crawl) Capture(item *frontier.Item) {
cfstreamURLs, err := cloudflarestream.GetJSFiles(doc, base, *c.Client)
if err != nil {
logError.WithFields(c.genLogFields(err, item.URL, nil)).Error("error while getting JS files from cloudflarestream")
return
return err
}

// Seencheck the URLs we captured, we ignore the returned value here
Expand Down Expand Up @@ -464,26 +464,26 @@ func (c *Crawl) Capture(item *frontier.Item) {
outlinks, err := extractOutlinks(base, doc)
if err != nil {
logError.WithFields(c.genLogFields(err, item.URL, nil)).Error("error while extracting outlinks")
return
return err
}

waitGroup.Add(1)
go c.queueOutlinks(outlinks, item, &waitGroup)

if c.DisableAssetsCapture {
return
return err
}

// Extract and capture assets
assets, err := c.extractAssets(base, item, doc)
if err != nil {
logError.WithFields(c.genLogFields(err, item.URL, nil)).Error("error while extracting assets")
return
return err
}

// If we didn't find any assets, let's stop here
if len(assets) == 0 {
return
return err
}

// If --local-seencheck is enabled, then we check if the assets are in the
Expand All @@ -502,7 +502,7 @@ func (c *Crawl) Capture(item *frontier.Item) {
}

if len(seencheckedBatch) == 0 {
return
return err
}

assets = seencheckedBatch
Expand All @@ -522,7 +522,7 @@ func (c *Crawl) Capture(item *frontier.Item) {
}

if len(assets) == 0 {
return
return err
}
}

Expand Down Expand Up @@ -584,6 +584,7 @@ func (c *Crawl) Capture(item *frontier.Item) {
}

swg.Wait()
return err
}

func getURLsFromJSON(jsonString string) ([]string, error) {
Expand Down
101 changes: 97 additions & 4 deletions internal/pkg/crawl/crawl.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package crawl

import (
"fmt"
"net/http"
"sync"
"time"
Expand All @@ -11,7 +12,6 @@ import (
"github.com/internetarchive/Zeno/internal/pkg/utils"
"github.com/paulbellamy/ratecounter"
"github.com/prometheus/client_golang/prometheus"
"github.com/remeh/sizedwaitgroup"
"github.com/sirupsen/logrus"
"github.com/telanflow/cookiejar"
"mvdan.cc/xurls/v2"
Expand Down Expand Up @@ -42,8 +42,13 @@ type Crawl struct {
// Frontier
Frontier *frontier.Frontier

// Worker pool
WorkerMutex sync.RWMutex
WorkerPool []*Worker
WorkerStopSignal chan bool
WorkerStopTimeout time.Duration

// Crawl settings
WorkerPool sizedwaitgroup.SizedWaitGroup
MaxConcurrentAssets int
Client *warc.CustomHTTPClient
ClientProxied *warc.CustomHTTPClient
Expand Down Expand Up @@ -287,9 +292,11 @@ func (c *Crawl) Start() (err error) {

// Fire up the desired amount of workers
for i := 0; i < c.Workers; i++ {
c.WorkerPool.Add()
go c.Worker()
worker := newWorker(c)
c.WorkerPool = append(c.WorkerPool, worker)
go worker.Run()
}
go c.WorkerWatcher()

// Start the process responsible for printing live stats on the standard output
if c.LiveStats {
Expand Down Expand Up @@ -335,3 +342,89 @@ func (c *Crawl) Start() (err error) {

return
}

func (c *Crawl) WorkerWatcher() {
var toEnd = false

for {
select {
// Stop the workers
case <-c.WorkerStopSignal:
for _, worker := range c.WorkerPool {
worker.doneSignal <- true
}
toEnd = true

// Check for finished workers and remove them from the pool
// End the watcher if a stop signal was received and all workers are completed
default:
c.WorkerMutex.RLock()
for i, worker := range c.WorkerPool {
if worker.state.status == completed {
// Remove the worker from the pool
c.WorkerMutex.RUnlock()
c.WorkerMutex.Lock()
c.WorkerPool = append(c.WorkerPool[:i], c.WorkerPool[i+1:]...)
c.WorkerMutex.Unlock()
}
}

if toEnd && len(c.WorkerPool) == 0 {
return // All workers are completed
}
}
}
}

func (c *Crawl) EnsureWorkersFinished() bool {
var workerPoolLen int
var timer = time.NewTimer(c.WorkerStopTimeout)

for {
c.WorkerMutex.RLock()
workerPoolLen = len(c.WorkerPool)
if workerPoolLen == 0 {
c.WorkerMutex.RUnlock()
return true
}
c.WorkerMutex.RUnlock()
select {
case <-timer.C:
c.Logger.Warning(fmt.Sprintf("[WORKERS] Timeout reached. %d workers still running", workerPoolLen))
return false
default:
c.Logger.Warning(fmt.Sprintf("[WORKERS] Waiting for %d workers to finish", workerPoolLen))
time.Sleep(time.Second * 5)
}
}
}

// GetWorkerState returns the state of a worker given its index in the worker pool
// if the provided index is -1 then the state of all workers is returned
func (c *Crawl) GetWorkerState(index int) interface{} {
if index == -1 {
var workersStatus = new(APIWorkersState)
for i, worker := range c.WorkerPool {
workersStatus.Workers = append(workersStatus.Workers, _getWorkerState(worker, i))
}
return workersStatus
}
if index >= len(c.WorkerPool) {
return nil
}
return _getWorkerState(c.WorkerPool[index], index)
}

func _getWorkerState(worker *Worker, index int) *APIWorkerState {
isLocked := false
if worker.TryLock() {
isLocked = true
worker.Unlock()
}
return &APIWorkerState{
WorkerID: index,
Status: worker.state.status.String(),
LastError: worker.state.lastError.Error(),
Locked: isLocked,
}
}
3 changes: 2 additions & 1 deletion internal/pkg/crawl/finish.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func (crawl *Crawl) catchFinish() {
}

func (crawl *Crawl) finish() {
crawl.WorkerStopSignal <- true
crawl.Finished.Set(true)

// First we wait for the queue reader to finish its current work,
Expand All @@ -45,7 +46,7 @@ func (crawl *Crawl) finish() {
close(crawl.Frontier.PullChan)

crawl.Logger.Warning("[WORKERS] Waiting for workers to finish")
crawl.WorkerPool.Wait()
crawl.EnsureWorkersFinished()
crawl.Logger.Warning("[WORKERS] All workers finished")

// When all workers are finished, we can safely close the HQ related channels
Expand Down
Loading

0 comments on commit 62a8c11

Please sign in to comment.