Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: implement context based progress tracker #741

Merged
merged 13 commits into from
Jul 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion base/parallel/parallel.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func DynamicParallel(nJobs int, jobsAlloc *task.JobsAllocator, worker func(worke
// consumer
for {
exit := atomic.NewBool(true)
numJobs := jobsAlloc.AvailableJobs(nil)
numJobs := jobsAlloc.AvailableJobs()
var wg sync.WaitGroup
wg.Add(numJobs)
errs := make([]error, nJobs)
Expand Down
174 changes: 174 additions & 0 deletions base/progress/progress.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
// Copyright 2023 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package progress

import (
"context"
"sort"
"sync"
"time"

"github.com/google/uuid"
"modernc.org/mathutil"
)

type spanKeyType string

var spanKeyName = spanKeyType(uuid.New().String())

type Status string

const (
StatusPending Status = "Pending"
StatusComplete Status = "Complete"
StatusRunning Status = "Running"
StatusSuspended Status = "Suspended"
StatusFailed Status = "Failed"
)

type Tracer struct {
name string
spans sync.Map
}

func NewTracer(name string) *Tracer {
return &Tracer{name: name}
}

// Start creates a root span.
func (t *Tracer) Start(ctx context.Context, name string, total int) (context.Context, *Span) {
span := &Span{
name: name,
status: StatusRunning,
total: total,
start: time.Now(),
}
t.spans.Store(name, span)
return context.WithValue(ctx, spanKeyName, span), span
}

func (t *Tracer) List() []Progress {
var progress []Progress
t.spans.Range(func(key, value interface{}) bool {
span := value.(*Span)
progress = append(progress, span.Progress())
return true
})
// sort by start time
sort.Slice(progress, func(i, j int) bool {
return progress[i].StartTime.Before(progress[j].StartTime)
})
return progress
}

type Span struct {
name string
status Status
total int
count int
err string
start time.Time
finish time.Time
children sync.Map
}

func (s *Span) Add(n int) {
s.count = mathutil.Min(s.count+n, s.total)
}

func (s *Span) End() {
s.status = StatusComplete
s.count = s.total
s.finish = time.Now()
}

func (s *Span) Error(err error) {
s.err = err.Error()
}

func (s *Span) Count() int {
return s.count
}

func (s *Span) Progress() Progress {
// find running children
var children []Progress
s.children.Range(func(key, value interface{}) bool {
child := value.(*Span)
progress := child.Progress()
if progress.Status == StatusRunning {
children = append(children, progress)
}
return true
})
// leaf node
if len(children) == 0 {
return Progress{
Name: s.name,
Status: s.status,
Error: s.err,
Count: s.count,
Total: s.total,
StartTime: s.start,
FinishTime: s.finish,
}
}
// non-leaf node
childTotal := children[0].Total
parentTotal := s.total * childTotal
parentCount := s.count * childTotal
for _, child := range children {
parentCount += childTotal * child.Count / child.Total
}
return Progress{
Name: s.name,
Status: s.status,
Error: s.err,
Count: parentCount,
Total: parentTotal,
StartTime: s.start,
FinishTime: s.finish,
}
}

func Start(ctx context.Context, name string, total int) (context.Context, *Span) {
childSpan := &Span{
name: name,
status: StatusRunning,
total: total,
count: 0,
start: time.Now(),
}
if ctx == nil {
return nil, childSpan
}
span, ok := (ctx).Value(spanKeyName).(*Span)
if !ok {
return nil, childSpan
}
span.children.Store(name, childSpan)
return context.WithValue(ctx, spanKeyName, childSpan), childSpan
}

type Progress struct {
Tracer string
Name string
Status Status
Error string
Count int
Total int
StartTime time.Time
FinishTime time.Time
}
34 changes: 34 additions & 0 deletions base/progress/progress_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright 2023 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package progress

import (
"testing"

"github.com/stretchr/testify/suite"
)

type ProgressTestSuite struct {
suite.Suite
tracer Tracer
}

func (suite *ProgressTestSuite) SetupTest() {
suite.tracer = Tracer{}
}

func TestProgressTestSuite(t *testing.T) {
suite.Run(t, new(ProgressTestSuite))
}
4 changes: 3 additions & 1 deletion base/search/bruteforce.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package search

import (
"context"

"github.com/zhenghaoz/gorse/base/heap"
)

Expand All @@ -26,7 +28,7 @@ type Bruteforce struct {
}

// Build a vector index on data.
func (b *Bruteforce) Build() {}
func (b *Bruteforce) Build(_ context.Context) {}

// NewBruteforce creates a Bruteforce vector index.
func NewBruteforce(vectors []Vector) *Bruteforce {
Expand Down
23 changes: 12 additions & 11 deletions base/search/hnsw.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package search

import (
"context"
"math/rand"
"runtime"
"sync"
Expand All @@ -26,7 +27,7 @@ import (
"github.com/zhenghaoz/gorse/base/heap"
"github.com/zhenghaoz/gorse/base/log"
"github.com/zhenghaoz/gorse/base/parallel"
"github.com/zhenghaoz/gorse/base/task"
"github.com/zhenghaoz/gorse/base/progress"
"go.uber.org/zap"
"modernc.org/mathutil"
)
Expand All @@ -49,7 +50,6 @@ type HNSW struct {
maxConnection0 int
efConstruction int
numJobs int
task *task.SubTask
}

// HNSWConfig is the configuration function for HNSW.
Expand Down Expand Up @@ -123,24 +123,26 @@ func (h *HNSW) knnSearch(q Vector, k, ef int) *heap.PriorityQueue {
}

// Build a vector index on data.
func (h *HNSW) Build() {
func (h *HNSW) Build(ctx context.Context) {
completed := make(chan struct{}, h.numJobs)
go func() {
defer base.CheckPanic()
completedCount, previousCount := 0, 0
ticker := time.NewTicker(10 * time.Second)
_, span := progress.Start(ctx, "HNSW.Build", len(h.vectors))
for {
select {
case _, ok := <-completed:
if !ok {
span.End()
return
}
completedCount++
case <-ticker.C:
throughput := completedCount - previousCount
previousCount = completedCount
h.task.Add(throughput * len(h.vectors))
if throughput > 0 {
span.Add(throughput)
log.Logger().Info("building index",
zap.Int("n_indexed_vectors", completedCount),
zap.Int("n_vectors", len(h.vectors)),
Expand Down Expand Up @@ -320,7 +322,7 @@ func NewHNSWBuilder(data []Vector, k, numJobs int) *HNSWBuilder {
rng: base.NewRandomGenerator(0),
numJobs: numJobs,
}
b.bruteForce.Build()
b.bruteForce.Build(context.Background())
return b
}

Expand Down Expand Up @@ -361,20 +363,19 @@ func (b *HNSWBuilder) evaluate(idx *HNSW, prune0 bool) float32 {
return result / count
}

func (b *HNSWBuilder) Build(recall float32, trials int, prune0 bool, t *task.Task) (idx *HNSW, score float32) {
buildTask := t.SubTask(EstimateHNSWBuilderComplexity(len(b.data), trials))
defer buildTask.Finish()
func (b *HNSWBuilder) Build(ctx context.Context, recall float32, trials int, prune0 bool) (idx *HNSW, score float32) {
ef := 1 << int(math32.Ceil(math32.Log2(float32(b.k))))
newCtx, span := progress.Start(ctx, "HNSWBuilder.Build", trials)
defer span.End()
for i := 0; i < trials; i++ {
start := time.Now()
idx = NewHNSW(b.data,
SetEFConstruction(ef),
SetHNSWNumJobs(b.numJobs))
idx.task = buildTask
idx.Build()
idx.Build(newCtx)
buildTime := time.Since(start)
score = b.evaluate(idx, prune0)
idx.task.Add(b.testSize * len(b.data))
span.Add(1)
log.Logger().Info("try to build vector index",
zap.String("index_type", "HNSW"),
zap.Int("ef_construction", ef),
Expand Down
8 changes: 5 additions & 3 deletions base/search/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
package search

import (
"context"
"fmt"
"reflect"
"sort"

"github.com/chewxy/math32"
"github.com/zhenghaoz/gorse/base/floats"
"github.com/zhenghaoz/gorse/base/log"
"go.uber.org/zap"
"modernc.org/sortutil"
"reflect"
"sort"
)

type Vector interface {
Expand Down Expand Up @@ -182,7 +184,7 @@ func (v *DictionaryCentroidVector) Distance(vector Vector) float32 {
}

type VectorIndex interface {
Build()
Build(ctx context.Context)
Search(q Vector, n int, prune0 bool) ([]int32, []float32)
MultiSearch(q Vector, terms []string, n int, prune0 bool) (map[string][]int32, map[string][]float32)
}
Loading
Loading