From 247c38ee27cfeb2d1a3d4a33913f7b85dbfa9d68 Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Wed, 26 Jul 2023 21:11:00 +0800 Subject: [PATCH 01/13] refactor: implement context based progress tracker --- base/progress/progress.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 base/progress/progress.go diff --git a/base/progress/progress.go b/base/progress/progress.go new file mode 100644 index 000000000..6bf4128cd --- /dev/null +++ b/base/progress/progress.go @@ -0,0 +1,15 @@ +// Copyright 2023 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package progress From 7f60b80811598d89ecfeaadafefd883e27e01183 Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Wed, 26 Jul 2023 21:41:19 +0800 Subject: [PATCH 02/13] add test --- base/progress/progress.go | 73 ++++++++++++++++++++++++++++++++++ base/progress/progress_test.go | 34 ++++++++++++++++ 2 files changed, 107 insertions(+) create mode 100644 base/progress/progress_test.go diff --git a/base/progress/progress.go b/base/progress/progress.go index 6bf4128cd..97685ace2 100644 --- a/base/progress/progress.go +++ b/base/progress/progress.go @@ -13,3 +13,76 @@ // limitations under the License. package progress + +import ( + "context" + "sync" + + "github.com/google/uuid" +) + +var spanKey = uuid.New().String() + +type Tracer struct { + spans sync.Map +} + +// Start creates a root span. +func (t *Tracer) Start(ctx context.Context, name string, total int64) (context.Context, *Span) { + span := &Span{name: name, total: total} + t.spans.Store(name, span) + return context.WithValue(ctx, spanKey, span), span +} + +func (t *Tracer) List() []Progress { + var progress []Progress + t.spans.Range(func(key, value interface{}) bool { + span := value.(*Span) + progress = append(progress, Progress{ + Name: span.name, + Total: span.total, + Count: span.count, + }) + return true + }) + return progress +} + +type Span struct { + name string + total int64 + count int64 + err error + children sync.Map +} + +func (s *Span) Add(n int64) { + s.count += n +} + +func (s *Span) End() { + s.count = s.total +} + +func (s *Span) Error(err error) { + s.err = err +} + +func Start(ctx context.Context, name string, total int64) (context.Context, *Span) { + childSpan := &Span{name: name, total: total} + if ctx == nil { + return nil, childSpan + } + span, ok := (ctx).Value(spanKey).(*Span) + if !ok { + return nil, childSpan + } + span.children.Store(name, childSpan) + return context.WithValue(ctx, spanKey, childSpan), childSpan +} + +type Progress struct { + Name string + Total int64 + Count int64 +} diff --git a/base/progress/progress_test.go b/base/progress/progress_test.go new file mode 100644 index 000000000..6e5b87b93 --- /dev/null +++ b/base/progress/progress_test.go @@ -0,0 +1,34 @@ +// Copyright 2023 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package progress + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type ProgressTestSuite struct { + suite.Suite + tracer Tracer +} + +func (suite *ProgressTestSuite) SetupTest() { + suite.tracer = Tracer{} +} + +func TestProgressTestSuite(t *testing.T) { + suite.Run(t, new(ProgressTestSuite)) +} From 1207db767cf2ca7adc1a8e8ed1cade15fe8249f2 Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Thu, 27 Jul 2023 20:36:41 +0800 Subject: [PATCH 03/13] Remove monitor --- base/task/monitor.go | 256 -------------------------------------- base/task/monitor_test.go | 87 ------------- 2 files changed, 343 deletions(-) delete mode 100644 base/task/monitor.go delete mode 100644 base/task/monitor_test.go diff --git a/base/task/monitor.go b/base/task/monitor.go deleted file mode 100644 index 6c70a89e5..000000000 --- a/base/task/monitor.go +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright 2021 gorse Project Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package task - -import ( - "sort" - "strings" - "sync" - "time" - - mapset "github.com/deckarep/golang-set/v2" -) - -type Status string - -const ( - StatusPending Status = "Pending" - StatusComplete Status = "Complete" - StatusRunning Status = "Running" - StatusSuspended Status = "Suspended" - StatusFailed Status = "Failed" -) - -// Task progress information. -type Task struct { - Name string - Status Status - Done int - Total int - StartTime time.Time - FinishTime time.Time - Error string -} - -func NewTask(name string, total int) *Task { - return &Task{ - Name: name, - Status: StatusRunning, - Done: 0, - Total: total, - StartTime: time.Now(), - FinishTime: time.Time{}, - } -} - -func (t *Task) Update(done int) { - t.Done = done - t.Status = StatusRunning -} - -func (t *Task) Add(done int) { - if t != nil { - t.Status = StatusRunning - t.Done += done - } -} - -func (t *Task) Finish() { - if t != nil { - t.Status = StatusComplete - t.Done = t.Total - t.FinishTime = time.Now() - } -} - -func (t *Task) Suspend(flag bool) { - if t != nil { - if flag { - t.Status = StatusSuspended - } else { - t.Status = StatusRunning - } - } -} - -func (t *Task) Fail(err string) { - if t != nil { - t.Error = err - t.Status = StatusFailed - } -} - -func (t *Task) SubTask(done int) *SubTask { - if t == nil { - return nil - } - return &SubTask{ - Parent: t, - Start: t.Done, - End: t.Done + done, - } -} - -type SubTask struct { - Start int - End int - Parent *Task -} - -func (t *SubTask) Add(done int) { - if t != nil { - t.Parent.Add(done) - } -} - -func (t *SubTask) Finish() { - if t != nil { - t.Parent.Update(t.End) - } -} - -// Monitor monitors the progress of all tasks. -type Monitor struct { - TaskLock sync.Mutex - Tasks map[string]*Task -} - -// NewTaskMonitor creates a Monitor and add pending tasks. -func NewTaskMonitor() *Monitor { - return &Monitor{ - Tasks: make(map[string]*Task), - } -} - -func (tm *Monitor) GetTask(name string) *Task { - tm.TaskLock.Lock() - defer tm.TaskLock.Unlock() - return tm.Tasks[name] -} - -// Pending a task. -func (tm *Monitor) Pending(name string) { - tm.TaskLock.Lock() - defer tm.TaskLock.Unlock() - tm.Tasks[name] = &Task{ - Name: name, - Status: StatusPending, - } -} - -// Start a task. -func (tm *Monitor) Start(name string, total int) *Task { - tm.TaskLock.Lock() - defer tm.TaskLock.Unlock() - t := NewTask(name, total) - tm.Tasks[name] = t - return t -} - -// Finish a task. -func (tm *Monitor) Finish(name string) { - tm.TaskLock.Lock() - defer tm.TaskLock.Unlock() - task, exist := tm.Tasks[name] - if exist { - task.Finish() - } -} - -// Update the progress of a task. -func (tm *Monitor) Update(name string, done int) { - tm.TaskLock.Lock() - defer tm.TaskLock.Unlock() - tm.Tasks[name].Update(done) -} - -// Add the progress of a task. -func (tm *Monitor) Add(name string, done int) { - tm.TaskLock.Lock() - defer tm.TaskLock.Unlock() - tm.Tasks[name].Add(done) -} - -// Suspend a task. -func (tm *Monitor) Suspend(name string, flag bool) { - tm.TaskLock.Lock() - defer tm.TaskLock.Unlock() - tm.Tasks[name].Suspend(flag) -} - -func (tm *Monitor) Fail(name, err string) { - tm.TaskLock.Lock() - defer tm.TaskLock.Unlock() - tm.Tasks[name].Fail(err) -} - -// List all tasks and remove tasks from disconnected workers. -func (tm *Monitor) List(workers ...string) []Task { - tm.TaskLock.Lock() - defer tm.TaskLock.Unlock() - workerSet := mapset.NewSet(workers...) - var task []Task - for name, t := range tm.Tasks { - // remove tasks from disconnected workers - if strings.Contains(name, "[") && strings.Contains(name, "]") { - begin := strings.Index(name, "[") + 1 - end := strings.Index(name, "]") - if !workerSet.Contains(name[begin:end]) { - delete(tm.Tasks, name) - continue - } - } - task = append(task, *t) - } - sort.Sort(Tasks(task)) - return task -} - -// Get the progress of a task. -func (tm *Monitor) Get(name string) int { - tm.TaskLock.Lock() - defer tm.TaskLock.Unlock() - task, exist := tm.Tasks[name] - if exist { - return task.Done - } - return 0 -} - -// Tasks is used to sort []Task. -type Tasks []Task - -// Len is used to sort []Task. -func (t Tasks) Len() int { - return len(t) -} - -// Swap is used to sort []Task. -func (t Tasks) Swap(i, j int) { - t[i], t[j] = t[j], t[i] -} - -// Less is used to sort []Task. -func (t Tasks) Less(i, j int) bool { - if t[i].Status != StatusPending && t[j].Status == StatusPending { - return true - } else if t[i].Status == StatusPending && t[j].Status != StatusPending { - return false - } else if t[i].Status == StatusPending && t[j].Status == StatusPending { - return t[i].Name < t[j].Name - } else { - return t[i].StartTime.Before(t[j].StartTime) - } -} diff --git a/base/task/monitor_test.go b/base/task/monitor_test.go deleted file mode 100644 index d0f07d8b5..000000000 --- a/base/task/monitor_test.go +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2021 gorse Project Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package task - -import ( - "github.com/samber/lo" - "github.com/stretchr/testify/assert" - "testing" -) - -func TestTaskMonitor(t *testing.T) { - taskMonitor := NewTaskMonitor() - taskMonitor.Start("a", 100) - assert.Equal(t, 0, taskMonitor.Get("a")) - assert.Equal(t, "a", taskMonitor.Tasks["a"].Name) - assert.Equal(t, 100, taskMonitor.Tasks["a"].Total) - assert.Equal(t, 0, taskMonitor.Tasks["a"].Done) - assert.Equal(t, StatusRunning, taskMonitor.Tasks["a"].Status) - - taskMonitor.Update("a", 50) - assert.Equal(t, 50, taskMonitor.Get("a")) - assert.Equal(t, "a", taskMonitor.Tasks["a"].Name) - assert.Equal(t, 100, taskMonitor.Tasks["a"].Total) - assert.Equal(t, 50, taskMonitor.Tasks["a"].Done) - assert.Equal(t, StatusRunning, taskMonitor.Tasks["a"].Status) - taskMonitor.Suspend("a", true) - assert.Equal(t, StatusSuspended, taskMonitor.Tasks["a"].Status) - taskMonitor.Suspend("a", false) - assert.Equal(t, StatusRunning, taskMonitor.Tasks["a"].Status) - - taskMonitor.Add("a", 30) - assert.Equal(t, 80, taskMonitor.Get("a")) - assert.Equal(t, "a", taskMonitor.Tasks["a"].Name) - assert.Equal(t, 100, taskMonitor.Tasks["a"].Total) - assert.Equal(t, 80, taskMonitor.Tasks["a"].Done) - assert.Equal(t, StatusRunning, taskMonitor.Tasks["a"].Status) - - taskMonitor.Finish("a") - assert.Equal(t, 100, taskMonitor.Get("a")) - assert.Equal(t, "a", taskMonitor.Tasks["a"].Name) - assert.Equal(t, 100, taskMonitor.Tasks["a"].Total) - assert.Equal(t, 100, taskMonitor.Tasks["a"].Done) - assert.Equal(t, StatusComplete, taskMonitor.Tasks["a"].Status) - - taskMonitor.Start("b", 100) - taskMonitor.Fail("b", "error") - assert.Equal(t, "error", taskMonitor.Tasks["b"].Error) - assert.Equal(t, StatusFailed, taskMonitor.Tasks["b"].Status) - - taskMonitor.Pending("c") - assert.Equal(t, StatusPending, taskMonitor.Tasks["c"].Status) - - taskMonitor.Start("d [a]", 100) - taskMonitor.Start("d [b]", 100) - tasks := taskMonitor.List("a") - assert.ElementsMatch(t, lo.ToSlicePtr(tasks), []*Task{ - taskMonitor.GetTask("a"), - taskMonitor.GetTask("b"), - taskMonitor.GetTask("c"), - taskMonitor.GetTask("d [a]"), - }) -} - -func TestSubTask(t *testing.T) { - task := NewTask("a", 100) - task.Add(10) - assert.Equal(t, 10, task.Done) - s := task.SubTask(80) - assert.Equal(t, 10, s.Start) - assert.Equal(t, 90, s.End) - s.Add(20) - assert.Equal(t, 30, task.Done) - s.Finish() - assert.Equal(t, 90, task.Done) -} From 3d920c7144708e3d87b07e74d7950589c7521b76 Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Thu, 27 Jul 2023 21:19:36 +0800 Subject: [PATCH 04/13] Fix click package --- model/click/deepfm.go | 12 ++++++------ model/click/deepfm_test.go | 3 ++- model/click/model.go | 23 ++++++++--------------- model/click/model_test.go | 14 +++++--------- model/click/search.go | 36 ++++++++++++++++++++---------------- model/click/search_test.go | 23 +++++++++-------------- model/ranking/model.go | 22 +++++++--------------- 7 files changed, 57 insertions(+), 76 deletions(-) diff --git a/model/click/deepfm.go b/model/click/deepfm.go index 115c211e9..ba4f93d12 100644 --- a/model/click/deepfm.go +++ b/model/click/deepfm.go @@ -15,6 +15,7 @@ package click import ( + "context" "fmt" "io" "time" @@ -24,6 +25,7 @@ import ( "github.com/samber/lo" "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/base/log" + "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/model" "go.uber.org/zap" "gorgonia.org/gorgonia" @@ -162,7 +164,7 @@ func (fm *DeepFM) BatchPredict(x []lo.Tuple2[[]int32, []float32]) []float32 { return predictions[:len(x)] } -func (fm *DeepFM) Fit(trainSet *Dataset, testSet *Dataset, config *FitConfig) Score { +func (fm *DeepFM) Fit(ctx context.Context, trainSet *Dataset, testSet *Dataset, config *FitConfig) Score { fm.Init(trainSet) evalStart := time.Now() score := EvaluateClassification(fm, testSet) @@ -185,6 +187,7 @@ func (fm *DeepFM) Fit(trainSet *Dataset, testSet *Dataset, config *FitConfig) Sc gorgonia.WithL2Reg(float64(fm.reg)), gorgonia.WithLearnRate(float64(fm.lr))) + _, span := progress.Start(ctx, "DeepFM.Fit", fm.nEpochs*trainSet.Count()) for epoch := 1; epoch <= fm.nEpochs; epoch++ { fitStart := time.Now() cost := float32(0) @@ -200,6 +203,7 @@ func (fm *DeepFM) Fit(trainSet *Dataset, testSet *Dataset, config *FitConfig) Sc cost += fm.cost.Value().Data().(float32) lo.Must0(solver.Step(gorgonia.NodesToValueGrads(fm.learnables))) fm.vm.Reset() + span.Add(mathutil.Min(fm.batchSize, trainSet.Count()-i)) } fitTime := time.Since(fitStart) @@ -220,8 +224,8 @@ func (fm *DeepFM) Fit(trainSet *Dataset, testSet *Dataset, config *FitConfig) Sc break } } - config.Task.Add(1) } + span.End() return score } @@ -261,10 +265,6 @@ func (fm *DeepFM) Bytes() int { return 0 } -func (fm *DeepFM) Complexity() int { - return 0 -} - func (fm *DeepFM) forward(batchSize int) { // input nodes fm.embeddingV = gorgonia.NodeFromAny(fm.g, diff --git a/model/click/deepfm_test.go b/model/click/deepfm_test.go index 9cc9082e1..5321bfa17 100644 --- a/model/click/deepfm_test.go +++ b/model/click/deepfm_test.go @@ -15,6 +15,7 @@ package click import ( + "context" "testing" "github.com/stretchr/testify/assert" @@ -32,6 +33,6 @@ func TestDeepFM_Classification_Frappe(t *testing.T) { model.Reg: 0.0001, }) fitConfig := newFitConfigWithTestTracker(20) - score := m.Fit(train, test, fitConfig) + score := m.Fit(context.Background(), train, test, fitConfig) assert.InDelta(t, 0.9271656, score.Accuracy, classificationDelta) } diff --git a/model/click/model.go b/model/click/model.go index 4abf6e784..0b43e20ce 100644 --- a/model/click/model.go +++ b/model/click/model.go @@ -15,6 +15,7 @@ package click import ( + "context" "encoding/binary" "fmt" "io" @@ -31,6 +32,7 @@ import ( "github.com/zhenghaoz/gorse/base/floats" "github.com/zhenghaoz/gorse/base/log" "github.com/zhenghaoz/gorse/base/parallel" + "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/model" "go.uber.org/zap" @@ -94,7 +96,6 @@ func (score Score) BetterThan(s Score) bool { type FitConfig struct { *task.JobsAllocator Verbose int - Task *task.Task } func NewFitConfig() *FitConfig { @@ -113,11 +114,6 @@ func (config *FitConfig) SetJobsAllocator(allocator *task.JobsAllocator) *FitCon return config } -func (config *FitConfig) SetTask(t *task.Task) *FitConfig { - config.Task = t - return config -} - func (config *FitConfig) LoadDefaultIfNil() *FitConfig { if config == nil { return NewFitConfig() @@ -129,10 +125,9 @@ type FactorizationMachine interface { model.Model Predict(userId, itemId string, userFeatures, itemFeatures []Feature) float32 InternalPredict(x []int32, values []float32) float32 - Fit(trainSet *Dataset, testSet *Dataset, config *FitConfig) Score + Fit(ctx context.Context, trainSet *Dataset, testSet *Dataset, config *FitConfig) Score Marshal(w io.Writer) error Bytes() int - Complexity() int } type BatchInference interface { @@ -271,7 +266,7 @@ func (fm *FM) InternalPredict(features []int32, values []float32) float32 { } // Fit trains the model. Its task complexity is O(fm.nEpochs). -func (fm *FM) Fit(trainSet, testSet *Dataset, config *FitConfig) Score { +func (fm *FM) Fit(ctx context.Context, trainSet, testSet *Dataset, config *FitConfig) Score { config = config.LoadDefaultIfNil() log.Logger().Info("fit FM", zap.Int("train_size", trainSet.Count()), @@ -313,6 +308,7 @@ func (fm *FM) Fit(trainSet, testSet *Dataset, config *FitConfig) Score { log.Logger().Debug(fmt.Sprintf("fit fm %v/%v", 0, fm.nEpochs), fields...) snapshots.AddSnapshot(score, fm.V, fm.W, fm.B) + _, span := progress.Start(ctx, "FM.Fit", fm.nEpochs) for epoch := 1; epoch <= fm.nEpochs; epoch++ { for i := 0; i < trainSet.Target.Len(); i++ { fm.MinTarget = math32.Min(fm.MinTarget, trainSet.Target.Get(i)) @@ -320,7 +316,7 @@ func (fm *FM) Fit(trainSet, testSet *Dataset, config *FitConfig) Score { } fitStart := time.Now() cost := float32(0) - _ = parallel.BatchParallel(trainSet.Count(), config.AvailableJobs(config.Task), 128, func(workerId, beginJobId, endJobId int) error { + _ = parallel.BatchParallel(trainSet.Count(), config.AvailableJobs(), 128, func(workerId, beginJobId, endJobId int) error { for i := beginJobId; i < endJobId; i++ { features, values, target := trainSet.Get(i) prediction := fm.internalPredictImpl(features, values) @@ -442,8 +438,9 @@ func (fm *FM) Fit(trainSet, testSet *Dataset, config *FitConfig) Score { } snapshots.AddSnapshot(score, fm.V, fm.W, fm.B) } - config.Task.Add(1) + span.Add(1) } + span.End() // restore best snapshot fm.V = snapshots.BestWeights[0].([][]float32) fm.W = snapshots.BestWeights[1].([]float32) @@ -539,10 +536,6 @@ func (fm *FM) Bytes() int { return int(bytes) + fm.Index.Bytes() } -func (fm *FM) Complexity() int { - return fm.nEpochs -} - func MarshalModel(w io.Writer, m FactorizationMachine) error { return m.Marshal(w) } diff --git a/model/click/model_test.go b/model/click/model_test.go index f3c8bd44b..e9d2a39c6 100644 --- a/model/click/model_test.go +++ b/model/click/model_test.go @@ -15,6 +15,7 @@ package click import ( "bytes" + "context" "testing" "github.com/stretchr/testify/assert" @@ -28,11 +29,9 @@ const ( ) func newFitConfigWithTestTracker(numEpoch int) *FitConfig { - t := task.NewTask("test", numEpoch) cfg := NewFitConfig(). SetVerbose(1). - SetJobsAllocator(task.NewConstantJobsAllocator(1)). - SetTask(t) + SetJobsAllocator(task.NewConstantJobsAllocator(1)) return cfg } @@ -54,9 +53,8 @@ func TestFM_Classification_Frappe(t *testing.T) { model.Optimizer: optimizer, }) fitConfig := newFitConfigWithTestTracker(20) - score := m.Fit(train, test, fitConfig) + score := m.Fit(context.Background(), train, test, fitConfig) assert.InDelta(t, 0.91684, score.Accuracy, classificationDelta) - assert.Equal(t, m.Complexity(), fitConfig.Task.Done) }) } } @@ -94,9 +92,8 @@ func TestFM_Regression_Criteo(t *testing.T) { model.Reg: 0.0001, }) fitConfig := newFitConfigWithTestTracker(20) - score := m.Fit(train, test, fitConfig) + score := m.Fit(context.Background(), train, test, fitConfig) assert.InDelta(t, 0.839194, score.RMSE, regressionDelta) - assert.Equal(t, m.Complexity(), fitConfig.Task.Done) // test prediction assert.Equal(t, m.InternalPredict([]int32{1, 2, 3, 4, 5, 6}, []float32{1, 1, 0.3, 0.4, 0.5, 0.6}), @@ -119,9 +116,8 @@ func TestFM_Regression_Criteo(t *testing.T) { m = tmp.(*FM) m.nEpochs = 1 fitConfig = newFitConfigWithTestTracker(1) - scoreInc := m.Fit(train, test, fitConfig) + scoreInc := m.Fit(context.Background(), train, test, fitConfig) assert.InDelta(t, 0.839194, scoreInc.RMSE, regressionDelta) - assert.Equal(t, m.Complexity(), fitConfig.Task.Done) // test clear m.Clear() diff --git a/model/click/search.go b/model/click/search.go index da4a673bc..bd6d79a12 100644 --- a/model/click/search.go +++ b/model/click/search.go @@ -15,12 +15,14 @@ package click import ( + "context" "fmt" "sync" "time" "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/base/log" + "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/model" "go.uber.org/zap" @@ -37,31 +39,30 @@ type ParamsSearchResult struct { } // GridSearchCV finds the best parameters for a model. -func GridSearchCV(estimator FactorizationMachine, trainSet *Dataset, testSet *Dataset, paramGrid model.ParamsGrid, +func GridSearchCV(ctx context.Context, estimator FactorizationMachine, trainSet *Dataset, testSet *Dataset, paramGrid model.ParamsGrid, _ int64, fitConfig *FitConfig) ParamsSearchResult { // Retrieve parameter names and length paramNames := make([]model.ParamName, 0, len(paramGrid)) - count := 1 + total := 1 for paramName, values := range paramGrid { paramNames = append(paramNames, paramName) - count *= len(values) + total *= len(values) } // Construct DFS procedure results := ParamsSearchResult{ - Scores: make([]Score, 0, count), - Params: make([]model.Params, 0, count), + Scores: make([]Score, 0, total), + Params: make([]model.Params, 0, total), } var dfs func(deep int, params model.Params) - progress := 0 + newCtx, span := progress.Start(ctx, "GridSearchCV", total) dfs = func(deep int, params model.Params) { if deep == len(paramNames) { - progress++ - log.Logger().Info(fmt.Sprintf("grid search %v/%v", progress, count), + log.Logger().Info(fmt.Sprintf("grid search %v/%v", span.Count(), total), zap.Any("params", params)) // Cross validate estimator.Clear() estimator.SetParams(estimator.GetParams().Overwrite(params)) - score := estimator.Fit(trainSet, testSet, fitConfig) + score := estimator.Fit(newCtx, trainSet, testSet, fitConfig) // Create GridSearch result results.Scores = append(results.Scores, score) results.Params = append(results.Params, params.Copy()) @@ -71,6 +72,7 @@ func GridSearchCV(estimator FactorizationMachine, trainSet *Dataset, testSet *Da results.BestIndex = len(results.Params) - 1 results.BestModel = Clone(estimator) } + span.Add(1) } else { paramName := paramNames[deep] values := paramGrid[paramName] @@ -82,21 +84,23 @@ func GridSearchCV(estimator FactorizationMachine, trainSet *Dataset, testSet *Da } params := make(map[model.ParamName]interface{}) dfs(0, params) + span.End() return results } // RandomSearchCV searches hyper-parameters by random. -func RandomSearchCV(estimator FactorizationMachine, trainSet *Dataset, testSet *Dataset, paramGrid model.ParamsGrid, +func RandomSearchCV(ctx context.Context, estimator FactorizationMachine, trainSet *Dataset, testSet *Dataset, paramGrid model.ParamsGrid, numTrials int, seed int64, fitConfig *FitConfig) ParamsSearchResult { // if the number of combination is less than number of trials, use grid search if paramGrid.NumCombinations() <= numTrials { - return GridSearchCV(estimator, trainSet, testSet, paramGrid, seed, fitConfig) + return GridSearchCV(ctx, estimator, trainSet, testSet, paramGrid, seed, fitConfig) } rng := base.NewRandomGenerator(seed) results := ParamsSearchResult{ Scores: make([]Score, 0, numTrials), Params: make([]model.Params, 0, numTrials), } + newCtx, span := progress.Start(ctx, "RandomSearchCV", numTrials) for i := 1; i <= numTrials; i++ { // Make parameters params := model.Params{} @@ -109,7 +113,7 @@ func RandomSearchCV(estimator FactorizationMachine, trainSet *Dataset, testSet * zap.Any("params", params)) estimator.Clear() estimator.SetParams(estimator.GetParams().Overwrite(params)) - score := estimator.Fit(trainSet, testSet, fitConfig) + score := estimator.Fit(newCtx, trainSet, testSet, fitConfig) results.Scores = append(results.Scores, score) results.Params = append(results.Params, params.Copy()) if len(results.Scores) == 0 || score.BetterThan(results.BestScore) { @@ -118,7 +122,9 @@ func RandomSearchCV(estimator FactorizationMachine, trainSet *Dataset, testSet * results.BestIndex = len(results.Params) - 1 results.BestModel = Clone(estimator) } + span.Add(1) } + span.End() return results } @@ -156,7 +162,7 @@ func (searcher *ModelSearcher) Complexity() int { return searcher.numTrials * searcher.numEpochs } -func (searcher *ModelSearcher) Fit(trainSet, valSet *Dataset, t *task.Task, j *task.JobsAllocator) error { +func (searcher *ModelSearcher) Fit(ctx context.Context, trainSet, valSet *Dataset, j *task.JobsAllocator) error { log.Logger().Info("click model search", zap.Int("n_users", trainSet.UserCount()), zap.Int("n_items", trainSet.ItemCount()), @@ -166,9 +172,7 @@ func (searcher *ModelSearcher) Fit(trainSet, valSet *Dataset, t *task.Task, j *t // Random search grid := searcher.model.GetParamsGrid(searcher.searchSize) - r := RandomSearchCV(searcher.model, trainSet, valSet, grid, searcher.numTrials, 0, NewFitConfig(). - SetJobsAllocator(j). - SetTask(t)) + r := RandomSearchCV(ctx, searcher.model, trainSet, valSet, grid, searcher.numTrials, 0, NewFitConfig().SetJobsAllocator(j)) searcher.bestMutex.Lock() defer searcher.bestMutex.Unlock() searcher.bestModel = r.BestModel diff --git a/model/click/search_test.go b/model/click/search_test.go index 524cac57e..77f4a996a 100644 --- a/model/click/search_test.go +++ b/model/click/search_test.go @@ -14,12 +14,14 @@ package click import ( + "context" + "io" + "testing" + "github.com/stretchr/testify/assert" "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/model" - "io" - "testing" ) // NewMapIndexDataset creates a data set. @@ -57,12 +59,11 @@ func (m *mockFactorizationMachineForSearch) GetItemIndex() base.Index { panic("don't call me") } -func (m *mockFactorizationMachineForSearch) Fit(_, _ *Dataset, cfg *FitConfig) Score { +func (m *mockFactorizationMachineForSearch) Fit(_ context.Context, _, _ *Dataset, cfg *FitConfig) Score { score := float32(0) score += m.Params.GetFloat32(model.NFactors, 0.0) score += m.Params.GetFloat32(model.InitMean, 0.0) score += m.Params.GetFloat32(model.InitStdDev, 0.0) - cfg.Task.Add(m.Params.GetInt(model.NEpochs, 0)) return Score{Task: FMClassification, AUC: score} } @@ -87,17 +88,15 @@ func (m *mockFactorizationMachineForSearch) GetParamsGrid(_ bool) model.ParamsGr } func newFitConfigForSearch() *FitConfig { - t := task.NewTask("test", 0) return &FitConfig{ Verbose: 1, - Task: t, } } func TestGridSearchCV(t *testing.T) { m := &mockFactorizationMachineForSearch{} fitConfig := newFitConfigForSearch() - r := GridSearchCV(m, nil, nil, m.GetParamsGrid(false), 0, fitConfig) + r := GridSearchCV(context.Background(), m, nil, nil, m.GetParamsGrid(false), 0, fitConfig) assert.Equal(t, float32(12), r.BestScore.AUC) assert.Equal(t, model.Params{ model.NFactors: 4, @@ -109,7 +108,7 @@ func TestGridSearchCV(t *testing.T) { func TestRandomSearchCV(t *testing.T) { m := &mockFactorizationMachineForSearch{} fitConfig := newFitConfigForSearch() - r := RandomSearchCV(m, nil, nil, m.GetParamsGrid(false), 63, 0, fitConfig) + r := RandomSearchCV(context.Background(), m, nil, nil, m.GetParamsGrid(false), 63, 0, fitConfig) assert.Equal(t, float32(12), r.BestScore.AUC) assert.Equal(t, model.Params{ model.NFactors: 4, @@ -121,8 +120,7 @@ func TestRandomSearchCV(t *testing.T) { func TestModelSearcher_RandomSearch(t *testing.T) { searcher := NewModelSearcher(2, 63, false) searcher.model = &mockFactorizationMachineForSearch{model.BaseModel{Params: model.Params{model.NEpochs: 2}}} - tk := task.NewTask("test", searcher.Complexity()) - err := searcher.Fit(NewMapIndexDataset(), NewMapIndexDataset(), tk, task.NewConstantJobsAllocator(1)) + err := searcher.Fit(context.Background(), NewMapIndexDataset(), NewMapIndexDataset(), task.NewConstantJobsAllocator(1)) assert.NoError(t, err) m, score := searcher.GetBestModel() assert.Equal(t, float32(12), score.AUC) @@ -132,14 +130,12 @@ func TestModelSearcher_RandomSearch(t *testing.T) { model.InitMean: 4, model.InitStdDev: 4, }, m.GetParams()) - assert.Equal(t, searcher.Complexity(), tk.Done) } func TestModelSearcher_GridSearch(t *testing.T) { searcher := NewModelSearcher(2, 64, false) searcher.model = &mockFactorizationMachineForSearch{model.BaseModel{Params: model.Params{model.NEpochs: 2}}} - tk := task.NewTask("test", searcher.Complexity()) - err := searcher.Fit(NewMapIndexDataset(), NewMapIndexDataset(), tk, task.NewConstantJobsAllocator(1)) + err := searcher.Fit(context.Background(), NewMapIndexDataset(), NewMapIndexDataset(), task.NewConstantJobsAllocator(1)) assert.NoError(t, err) m, score := searcher.GetBestModel() assert.Equal(t, float32(12), score.AUC) @@ -149,5 +145,4 @@ func TestModelSearcher_GridSearch(t *testing.T) { model.InitMean: 4, model.InitStdDev: 4, }, m.GetParams()) - assert.Equal(t, searcher.Complexity(), tk.Done) } diff --git a/model/ranking/model.go b/model/ranking/model.go index 022d85180..ec3a59b21 100644 --- a/model/ranking/model.go +++ b/model/ranking/model.go @@ -47,7 +47,6 @@ type FitConfig struct { Verbose int Candidates int TopK int - Task *task.Task } func NewFitConfig() *FitConfig { @@ -68,11 +67,6 @@ func (config *FitConfig) SetJobsAllocator(allocator *task.JobsAllocator) *FitCon return config } -func (config *FitConfig) SetTask(t *task.Task) *FitConfig { - config.Task = t - return config -} - func (config *FitConfig) LoadDefaultIfNil() *FitConfig { if config == nil { return NewFitConfig() @@ -421,7 +415,7 @@ func (bpr *BPR) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { } snapshots := SnapshotManger{} evalStart := time.Now() - scores := Evaluate(bpr, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(config.Task), NDCG, Precision, Recall) + scores := Evaluate(bpr, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(), NDCG, Precision, Recall) evalTime := time.Since(evalStart) log.Logger().Debug(fmt.Sprintf("fit bpr %v/%v", 0, bpr.nEpochs), zap.String("eval_time", evalTime.String()), @@ -433,7 +427,7 @@ func (bpr *BPR) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { for epoch := 1; epoch <= bpr.nEpochs; epoch++ { fitStart := time.Now() // Training epoch - numJobs := config.AvailableJobs(config.Task) + numJobs := config.AvailableJobs() cost := make([]float32, numJobs) _ = parallel.Parallel(trainSet.Count(), numJobs, func(workerId, _ int) error { // Select a user @@ -482,7 +476,7 @@ func (bpr *BPR) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { // Cross validation if epoch%config.Verbose == 0 || epoch == bpr.nEpochs { evalStart = time.Now() - scores = Evaluate(bpr, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(config.Task), NDCG, Precision, Recall) + scores = Evaluate(bpr, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(), NDCG, Precision, Recall) evalTime = time.Since(evalStart) log.Logger().Debug(fmt.Sprintf("fit bpr %v/%v", epoch, bpr.nEpochs), zap.String("fit_time", fitTime.String()), @@ -492,7 +486,6 @@ func (bpr *BPR) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), scores[2])) snapshots.AddSnapshot(Score{NDCG: scores[0], Precision: scores[1], Recall: scores[2]}, bpr.UserFactor, bpr.ItemFactor) } - config.Task.Add(1) } // restore best snapshot bpr.UserFactor = snapshots.BestWeights[0].([][]float32) @@ -739,7 +732,7 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { // evaluate initial model snapshots := SnapshotManger{} evalStart := time.Now() - scores := Evaluate(ccd, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(config.Task), NDCG, Precision, Recall) + scores := Evaluate(ccd, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(), NDCG, Precision, Recall) evalTime := time.Since(evalStart) log.Logger().Debug(fmt.Sprintf("fit ccd %v/%v", 0, ccd.nEpochs), zap.String("eval_time", evalTime.String()), @@ -761,7 +754,7 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { } } } - _ = parallel.Parallel(trainSet.UserCount(), config.AvailableJobs(config.Task), func(workerId, userIndex int) error { + _ = parallel.Parallel(trainSet.UserCount(), config.AvailableJobs(), func(workerId, userIndex int) error { userFeedback := trainSet.UserFeedback[userIndex] for _, i := range userFeedback { userPredictions[workerId][i] = ccd.InternalPredict(int32(userIndex), i) @@ -802,7 +795,7 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { } } } - _ = parallel.Parallel(trainSet.ItemCount(), config.AvailableJobs(config.Task), func(workerId, itemIndex int) error { + _ = parallel.Parallel(trainSet.ItemCount(), config.AvailableJobs(), func(workerId, itemIndex int) error { itemFeedback := trainSet.ItemFeedback[itemIndex] for _, u := range itemFeedback { itemPredictions[workerId][u] = ccd.InternalPredict(u, int32(itemIndex)) @@ -835,7 +828,7 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { // Cross validation if ep%config.Verbose == 0 || ep == ccd.nEpochs { evalStart = time.Now() - scores = Evaluate(ccd, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(config.Task), NDCG, Precision, Recall) + scores = Evaluate(ccd, valSet, trainSet, config.TopK, config.Candidates, config.AvailableJobs(), NDCG, Precision, Recall) evalTime = time.Since(evalStart) log.Logger().Debug(fmt.Sprintf("fit ccd %v/%v", ep, ccd.nEpochs), zap.String("fit_time", fitTime.String()), @@ -845,7 +838,6 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), scores[2])) snapshots.AddSnapshot(Score{NDCG: scores[0], Precision: scores[1], Recall: scores[2]}, ccd.UserFactor, ccd.ItemFactor) } - config.Task.Add(1) } // restore best snapshot ccd.UserFactor = snapshots.BestWeights[0].([][]float32) From 2d4cf31a3d7b9b12e15911e6456932eb358e5def Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Thu, 27 Jul 2023 21:44:20 +0800 Subject: [PATCH 05/13] Fix ranking package --- model/ranking/evaluator_test.go | 7 ++---- model/ranking/model.go | 26 ++++++++++---------- model/ranking/model_test.go | 23 ++++++++---------- model/ranking/search.go | 42 ++++++++++++++++----------------- model/ranking/search_test.go | 23 +++++++----------- 5 files changed, 54 insertions(+), 67 deletions(-) diff --git a/model/ranking/evaluator_test.go b/model/ranking/evaluator_test.go index 6166692e5..927138ac8 100644 --- a/model/ranking/evaluator_test.go +++ b/model/ranking/evaluator_test.go @@ -14,6 +14,7 @@ package ranking import ( + "context" "io" "strconv" "testing" @@ -68,10 +69,6 @@ type mockMatrixFactorizationForEval struct { negative []mapset.Set[int32] } -func (m *mockMatrixFactorizationForEval) Complexity() int { - panic("implement me") -} - func (m *mockMatrixFactorizationForEval) Bytes() int { panic("implement me") } @@ -112,7 +109,7 @@ func (m *mockMatrixFactorizationForEval) GetItemIndex() base.Index { panic("don't call me") } -func (m *mockMatrixFactorizationForEval) Fit(_, _ *DataSet, _ *FitConfig) Score { +func (m *mockMatrixFactorizationForEval) Fit(_ context.Context, _, _ *DataSet, _ *FitConfig) Score { panic("don't call me") } diff --git a/model/ranking/model.go b/model/ranking/model.go index ec3a59b21..43685b77b 100644 --- a/model/ranking/model.go +++ b/model/ranking/model.go @@ -15,6 +15,7 @@ package ranking import ( + "context" "fmt" "io" "reflect" @@ -31,6 +32,7 @@ import ( "github.com/zhenghaoz/gorse/base/floats" "github.com/zhenghaoz/gorse/base/log" "github.com/zhenghaoz/gorse/base/parallel" + "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/model" "go.uber.org/zap" @@ -77,7 +79,7 @@ func (config *FitConfig) LoadDefaultIfNil() *FitConfig { type Model interface { model.Model // Fit a model with a train set and parameters. - Fit(trainSet *DataSet, validateSet *DataSet, config *FitConfig) Score + Fit(ctx context.Context, trainSet *DataSet, validateSet *DataSet, config *FitConfig) Score // GetItemIndex returns item index. GetItemIndex() base.Index // Marshal model into byte stream. @@ -110,8 +112,6 @@ type MatrixFactorization interface { Unmarshal(r io.Reader) error // Bytes returns used memory. Bytes() int - // Complexity returns the complexity of the model. - Complexity() int } type BaseMatrixFactorization struct { @@ -335,10 +335,6 @@ func (bpr *BPR) GetItemFactor(itemIndex int32) []float32 { return bpr.ItemFactor[itemIndex] } -func (bpr *BPR) Complexity() int { - return bpr.nEpochs -} - // SetParams sets hyper-parameters of the BPR model. func (bpr *BPR) SetParams(params model.Params) { bpr.BaseMatrixFactorization.SetParams(params) @@ -387,7 +383,7 @@ func (bpr *BPR) InternalPredict(userIndex, itemIndex int32) float32 { } // Fit the BPR model. Its task complexity is O(bpr.nEpochs). -func (bpr *BPR) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { +func (bpr *BPR) Fit(ctx context.Context, trainSet, valSet *DataSet, config *FitConfig) Score { config = config.LoadDefaultIfNil() log.Logger().Info("fit bpr", zap.Int("train_set_size", trainSet.Count()), @@ -424,6 +420,7 @@ func (bpr *BPR) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), scores[2])) snapshots.AddSnapshot(Score{NDCG: scores[0], Precision: scores[1], Recall: scores[2]}, bpr.UserFactor, bpr.ItemFactor) // Training + _, span := progress.Start(ctx, "BPR.Fit", bpr.nEpochs) for epoch := 1; epoch <= bpr.nEpochs; epoch++ { fitStart := time.Now() // Training epoch @@ -486,7 +483,9 @@ func (bpr *BPR) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), scores[2])) snapshots.AddSnapshot(Score{NDCG: scores[0], Precision: scores[1], Recall: scores[2]}, bpr.UserFactor, bpr.ItemFactor) } + span.Add(1) } + span.End() // restore best snapshot bpr.UserFactor = snapshots.BestWeights[0].([][]float32) bpr.ItemFactor = snapshots.BestWeights[1].([][]float32) @@ -613,10 +612,6 @@ func (ccd *CCD) GetItemFactor(itemIndex int32) []float32 { return ccd.ItemFactor[itemIndex] } -func (ccd *CCD) Complexity() int { - return ccd.nEpochs -} - // SetParams sets hyper-parameters for the ALS model. func (ccd *CCD) SetParams(params model.Params) { ccd.BaseMatrixFactorization.SetParams(params) @@ -708,7 +703,7 @@ func (ccd *CCD) Init(trainSet *DataSet) { } // Fit the CCD model. Its task complexity is O(ccd.nEpochs). -func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { +func (ccd *CCD) Fit(ctx context.Context, trainSet, valSet *DataSet, config *FitConfig) Score { config = config.LoadDefaultIfNil() log.Logger().Info("fit ccd", zap.Int("train_set_size", trainSet.Count()), @@ -740,6 +735,8 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { zap.Float32(fmt.Sprintf("Precision@%v", config.TopK), scores[1]), zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), scores[2])) snapshots.AddSnapshot(Score{NDCG: scores[0], Precision: scores[1], Recall: scores[2]}, ccd.UserFactor, ccd.ItemFactor) + + _, span := progress.Start(ctx, "CCD.Fit", ccd.nEpochs) for ep := 1; ep <= ccd.nEpochs; ep++ { fitStart := time.Now() // Update user factors @@ -838,7 +835,10 @@ func (ccd *CCD) Fit(trainSet, valSet *DataSet, config *FitConfig) Score { zap.Float32(fmt.Sprintf("Recall@%v", config.TopK), scores[2])) snapshots.AddSnapshot(Score{NDCG: scores[0], Precision: scores[1], Recall: scores[2]}, ccd.UserFactor, ccd.ItemFactor) } + span.Add(1) } + span.End() + // restore best snapshot ccd.UserFactor = snapshots.BestWeights[0].([][]float32) ccd.ItemFactor = snapshots.BestWeights[1].([][]float32) diff --git a/model/ranking/model_test.go b/model/ranking/model_test.go index 0508429d8..915526e4e 100644 --- a/model/ranking/model_test.go +++ b/model/ranking/model_test.go @@ -15,13 +15,15 @@ package ranking import ( "bytes" + "context" + "math" + "runtime" + "testing" + "github.com/stretchr/testify/assert" "github.com/zhenghaoz/gorse/base/floats" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/model" - "math" - "runtime" - "testing" ) const ( @@ -30,8 +32,7 @@ const ( ) func newFitConfig(numEpoch int) *FitConfig { - t := task.NewTask("test", numEpoch) - cfg := NewFitConfig().SetVerbose(1).SetJobsAllocator(task.NewConstantJobsAllocator(runtime.NumCPU())).SetTask(t) + cfg := NewFitConfig().SetVerbose(1).SetJobsAllocator(task.NewConstantJobsAllocator(runtime.NumCPU())) return cfg } @@ -50,11 +51,10 @@ func TestBPR_MovieLens(t *testing.T) { model.InitStdDev: 0.001, }) fitConfig := newFitConfig(30) - score := m.Fit(trainSet, testSet, fitConfig) + score := m.Fit(context.Background(), trainSet, testSet, fitConfig) assert.InDelta(t, 0.36, score.NDCG, benchDelta) assert.Equal(t, trainSet.UserIndex, m.GetUserIndex()) assert.Equal(t, testSet.ItemIndex, m.GetItemIndex()) - assert.Equal(t, m.Complexity(), fitConfig.Task.Done) // test predict assert.Equal(t, m.Predict("1", "1"), m.InternalPredict(1, 1)) @@ -77,9 +77,8 @@ func TestBPR_MovieLens(t *testing.T) { m = tmp.(*BPR) m.nEpochs = 1 fitConfig = newFitConfig(1) - scoreInc := m.Fit(trainSet, testSet, fitConfig) + scoreInc := m.Fit(context.Background(), trainSet, testSet, fitConfig) assert.InDelta(t, score.NDCG, scoreInc.NDCG, incrDelta) - assert.Equal(t, m.Complexity(), fitConfig.Task.Done) // test clear m.Clear() @@ -111,9 +110,8 @@ func TestCCD_MovieLens(t *testing.T) { model.Alpha: 0.05, }) fitConfig := newFitConfig(30) - score := m.Fit(trainSet, testSet, fitConfig) + score := m.Fit(context.Background(), trainSet, testSet, fitConfig) assert.InDelta(t, 0.36, score.NDCG, benchDelta) - assert.Equal(t, m.Complexity(), fitConfig.Task.Done) // test predict assert.Equal(t, m.Predict("1", "1"), m.InternalPredict(1, 1)) @@ -128,9 +126,8 @@ func TestCCD_MovieLens(t *testing.T) { m = tmp.(*CCD) m.nEpochs = 1 fitConfig = newFitConfig(1) - scoreInc := m.Fit(trainSet, testSet, fitConfig) + scoreInc := m.Fit(context.Background(), trainSet, testSet, fitConfig) assert.InDelta(t, score.NDCG, scoreInc.NDCG, incrDelta) - assert.Equal(t, m.Complexity(), fitConfig.Task.Done) // test clear m.Clear() diff --git a/model/ranking/search.go b/model/ranking/search.go index e1f65c2c7..404632d07 100644 --- a/model/ranking/search.go +++ b/model/ranking/search.go @@ -15,12 +15,14 @@ package ranking import ( + "context" "fmt" "sync" "time" "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/base/log" + "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/model" "go.uber.org/zap" @@ -47,31 +49,30 @@ func (r *ParamsSearchResult) AddScore(params model.Params, score Score) { } // GridSearchCV finds the best parameters for a model. -func GridSearchCV(estimator MatrixFactorization, trainSet *DataSet, testSet *DataSet, paramGrid model.ParamsGrid, +func GridSearchCV(ctx context.Context, estimator MatrixFactorization, trainSet *DataSet, testSet *DataSet, paramGrid model.ParamsGrid, _ int64, fitConfig *FitConfig) ParamsSearchResult { // Retrieve parameter names and length paramNames := make([]model.ParamName, 0, len(paramGrid)) - count := 1 + total := 1 for paramName, values := range paramGrid { paramNames = append(paramNames, paramName) - count *= len(values) + total *= len(values) } // Construct DFS procedure results := ParamsSearchResult{ - Scores: make([]Score, 0, count), - Params: make([]model.Params, 0, count), + Scores: make([]Score, 0, total), + Params: make([]model.Params, 0, total), } var dfs func(deep int, params model.Params) - progress := 0 + newCtx, span := progress.Start(ctx, "GridSearchCV", total) dfs = func(deep int, params model.Params) { if deep == len(paramNames) { - progress++ - log.Logger().Info(fmt.Sprintf("grid search (%v/%v)", progress, count), + log.Logger().Info(fmt.Sprintf("grid search (%v/%v)", span.Count(), total), zap.Any("params", params)) // Cross validate estimator.Clear() estimator.SetParams(estimator.GetParams().Overwrite(params)) - score := estimator.Fit(trainSet, testSet, fitConfig) + score := estimator.Fit(newCtx, trainSet, testSet, fitConfig) // Create GridSearch result results.Scores = append(results.Scores, score) results.Params = append(results.Params, params.Copy()) @@ -81,6 +82,7 @@ func GridSearchCV(estimator MatrixFactorization, trainSet *DataSet, testSet *Dat results.BestParams = params.Copy() results.BestIndex = len(results.Params) - 1 } + span.Add(1) } else { paramName := paramNames[deep] values := paramGrid[paramName] @@ -92,21 +94,23 @@ func GridSearchCV(estimator MatrixFactorization, trainSet *DataSet, testSet *Dat } params := make(map[model.ParamName]interface{}) dfs(0, params) + span.End() return results } // RandomSearchCV searches hyper-parameters by random. -func RandomSearchCV(estimator MatrixFactorization, trainSet *DataSet, testSet *DataSet, paramGrid model.ParamsGrid, +func RandomSearchCV(ctx context.Context, estimator MatrixFactorization, trainSet *DataSet, testSet *DataSet, paramGrid model.ParamsGrid, numTrials int, seed int64, fitConfig *FitConfig) ParamsSearchResult { // if the number of combination is less than number of trials, use grid search if paramGrid.NumCombinations() < numTrials { - return GridSearchCV(estimator, trainSet, testSet, paramGrid, seed, fitConfig) + return GridSearchCV(ctx, estimator, trainSet, testSet, paramGrid, seed, fitConfig) } rng := base.NewRandomGenerator(seed) results := ParamsSearchResult{ Scores: make([]Score, 0, numTrials), Params: make([]model.Params, 0, numTrials), } + newCtx, span := progress.Start(ctx, "RandomSearchCV", numTrials) for i := 1; i <= numTrials; i++ { // Make parameters params := model.Params{} @@ -119,7 +123,7 @@ func RandomSearchCV(estimator MatrixFactorization, trainSet *DataSet, testSet *D zap.Any("params", params)) estimator.Clear() estimator.SetParams(estimator.GetParams().Overwrite(params)) - score := estimator.Fit(trainSet, testSet, fitConfig) + score := estimator.Fit(newCtx, trainSet, testSet, fitConfig) results.Scores = append(results.Scores, score) results.Params = append(results.Params, params.Copy()) if len(results.Scores) == 0 || score.NDCG > results.BestScore.NDCG { @@ -128,7 +132,9 @@ func RandomSearchCV(estimator MatrixFactorization, trainSet *DataSet, testSet *D results.BestParams = params.Copy() results.BestIndex = len(results.Params) - 1 } + span.Add(1) } + span.End() return results } @@ -165,20 +171,14 @@ func (searcher *ModelSearcher) GetBestModel() (string, MatrixFactorization, Scor return searcher.bestModelName, searcher.bestModel, searcher.bestScore } -func (searcher *ModelSearcher) Complexity() int { - return len(searcher.models) * searcher.numEpochs * searcher.numTrials -} - -func (searcher *ModelSearcher) Fit(trainSet, valSet *DataSet, t *task.Task, j *task.JobsAllocator) error { +func (searcher *ModelSearcher) Fit(ctx context.Context, trainSet, valSet *DataSet, j *task.JobsAllocator) error { log.Logger().Info("ranking model search", zap.Int("n_users", trainSet.UserCount()), zap.Int("n_items", trainSet.ItemCount())) startTime := time.Now() for _, m := range searcher.models { - r := RandomSearchCV(m, trainSet, valSet, m.GetParamsGrid(searcher.searchSize), searcher.numTrials, 0, - NewFitConfig(). - SetJobsAllocator(j). - SetTask(t)) + r := RandomSearchCV(ctx, m, trainSet, valSet, m.GetParamsGrid(searcher.searchSize), searcher.numTrials, 0, + NewFitConfig().SetJobsAllocator(j)) searcher.bestMutex.Lock() if searcher.bestModel == nil || r.BestScore.NDCG > searcher.bestScore.NDCG { searcher.bestModel = r.BestModel diff --git a/model/ranking/search_test.go b/model/ranking/search_test.go index 71141b4e8..169fa3774 100644 --- a/model/ranking/search_test.go +++ b/model/ranking/search_test.go @@ -14,12 +14,14 @@ package ranking import ( + "context" + "io" + "testing" + "github.com/stretchr/testify/assert" "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/model" - "io" - "testing" ) type mockMatrixFactorizationForSearch struct { @@ -30,10 +32,6 @@ func newMockMatrixFactorizationForSearch(numEpoch int) *mockMatrixFactorizationF return &mockMatrixFactorizationForSearch{model.BaseModel{Params: model.Params{model.NEpochs: numEpoch}}} } -func (m *mockMatrixFactorizationForSearch) Complexity() int { - panic("implement me") -} - func (m *mockMatrixFactorizationForSearch) Bytes() int { panic("implement me") } @@ -74,12 +72,11 @@ func (m *mockMatrixFactorizationForSearch) GetItemIndex() base.Index { panic("don't call me") } -func (m *mockMatrixFactorizationForSearch) Fit(_, _ *DataSet, cfg *FitConfig) Score { +func (m *mockMatrixFactorizationForSearch) Fit(_ context.Context, _, _ *DataSet, cfg *FitConfig) Score { score := float32(0) score += m.Params.GetFloat32(model.NFactors, 0.0) score += m.Params.GetFloat32(model.InitMean, 0.0) score += m.Params.GetFloat32(model.InitStdDev, 0.0) - cfg.Task.Add(m.Params.GetInt(model.NEpochs, 0)) return Score{NDCG: score} } @@ -104,17 +101,15 @@ func (m *mockMatrixFactorizationForSearch) GetParamsGrid(_ bool) model.ParamsGri } func newFitConfigForSearch() *FitConfig { - t := task.NewTask("test", 100) return &FitConfig{ Verbose: 1, - Task: t, } } func TestGridSearchCV(t *testing.T) { m := &mockMatrixFactorizationForSearch{} fitConfig := newFitConfigForSearch() - r := GridSearchCV(m, nil, nil, m.GetParamsGrid(false), 0, fitConfig) + r := GridSearchCV(context.Background(), m, nil, nil, m.GetParamsGrid(false), 0, fitConfig) assert.Equal(t, float32(12), r.BestScore.NDCG) assert.Equal(t, model.Params{ model.NFactors: 4, @@ -126,7 +121,7 @@ func TestGridSearchCV(t *testing.T) { func TestRandomSearchCV(t *testing.T) { m := &mockMatrixFactorizationForSearch{} fitConfig := newFitConfigForSearch() - r := RandomSearchCV(m, nil, nil, m.GetParamsGrid(false), 63, 0, fitConfig) + r := RandomSearchCV(context.Background(), m, nil, nil, m.GetParamsGrid(false), 63, 0, fitConfig) assert.Equal(t, float32(12), r.BestScore.NDCG) assert.Equal(t, model.Params{ model.NFactors: 4, @@ -138,8 +133,7 @@ func TestRandomSearchCV(t *testing.T) { func TestModelSearcher(t *testing.T) { searcher := NewModelSearcher(2, 63, false) searcher.models = []MatrixFactorization{newMockMatrixFactorizationForSearch(2)} - tk := task.NewTask("test", searcher.Complexity()) - err := searcher.Fit(NewMapIndexDataset(), NewMapIndexDataset(), tk, task.NewConstantJobsAllocator(1)) + err := searcher.Fit(context.Background(), NewMapIndexDataset(), NewMapIndexDataset(), task.NewConstantJobsAllocator(1)) assert.NoError(t, err) _, m, score := searcher.GetBestModel() assert.Equal(t, float32(12), score.NDCG) @@ -149,5 +143,4 @@ func TestModelSearcher(t *testing.T) { model.InitMean: 4, model.InitStdDev: 4, }, m.GetParams()) - assert.Equal(t, searcher.Complexity(), tk.Done) } From 6381a3802f61ccddb7ea68b2bd462e1663e02505 Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Fri, 28 Jul 2023 19:43:54 +0800 Subject: [PATCH 06/13] Fix search --- base/search/bruteforce.go | 4 +++- base/search/hnsw.go | 23 ++++++++++++----------- base/search/index.go | 8 +++++--- base/search/index_test.go | 18 ++++++++---------- base/search/ivf.go | 24 ++++++++---------------- base/search/ivf_test.go | 2 +- base/task/schedule.go | 10 ++++------ base/task/schedule_test.go | 38 +++++++++++++++++++------------------- 8 files changed, 60 insertions(+), 67 deletions(-) diff --git a/base/search/bruteforce.go b/base/search/bruteforce.go index 73bfd9c59..849c23b2d 100644 --- a/base/search/bruteforce.go +++ b/base/search/bruteforce.go @@ -15,6 +15,8 @@ package search import ( + "context" + "github.com/zhenghaoz/gorse/base/heap" ) @@ -26,7 +28,7 @@ type Bruteforce struct { } // Build a vector index on data. -func (b *Bruteforce) Build() {} +func (b *Bruteforce) Build(_ context.Context) {} // NewBruteforce creates a Bruteforce vector index. func NewBruteforce(vectors []Vector) *Bruteforce { diff --git a/base/search/hnsw.go b/base/search/hnsw.go index 6ff736c15..597b17a43 100644 --- a/base/search/hnsw.go +++ b/base/search/hnsw.go @@ -15,6 +15,7 @@ package search import ( + "context" "math/rand" "runtime" "sync" @@ -26,7 +27,7 @@ import ( "github.com/zhenghaoz/gorse/base/heap" "github.com/zhenghaoz/gorse/base/log" "github.com/zhenghaoz/gorse/base/parallel" - "github.com/zhenghaoz/gorse/base/task" + "github.com/zhenghaoz/gorse/base/progress" "go.uber.org/zap" "modernc.org/mathutil" ) @@ -49,7 +50,6 @@ type HNSW struct { maxConnection0 int efConstruction int numJobs int - task *task.SubTask } // HNSWConfig is the configuration function for HNSW. @@ -123,24 +123,26 @@ func (h *HNSW) knnSearch(q Vector, k, ef int) *heap.PriorityQueue { } // Build a vector index on data. -func (h *HNSW) Build() { +func (h *HNSW) Build(ctx context.Context) { completed := make(chan struct{}, h.numJobs) go func() { defer base.CheckPanic() completedCount, previousCount := 0, 0 ticker := time.NewTicker(10 * time.Second) + _, span := progress.Start(ctx, "HNSW.Build", len(h.vectors)) for { select { case _, ok := <-completed: if !ok { + span.End() return } completedCount++ case <-ticker.C: throughput := completedCount - previousCount previousCount = completedCount - h.task.Add(throughput * len(h.vectors)) if throughput > 0 { + span.Add(throughput) log.Logger().Info("building index", zap.Int("n_indexed_vectors", completedCount), zap.Int("n_vectors", len(h.vectors)), @@ -320,7 +322,7 @@ func NewHNSWBuilder(data []Vector, k, numJobs int) *HNSWBuilder { rng: base.NewRandomGenerator(0), numJobs: numJobs, } - b.bruteForce.Build() + b.bruteForce.Build(context.Background()) return b } @@ -361,20 +363,19 @@ func (b *HNSWBuilder) evaluate(idx *HNSW, prune0 bool) float32 { return result / count } -func (b *HNSWBuilder) Build(recall float32, trials int, prune0 bool, t *task.Task) (idx *HNSW, score float32) { - buildTask := t.SubTask(EstimateHNSWBuilderComplexity(len(b.data), trials)) - defer buildTask.Finish() +func (b *HNSWBuilder) Build(ctx context.Context, recall float32, trials int, prune0 bool) (idx *HNSW, score float32) { ef := 1 << int(math32.Ceil(math32.Log2(float32(b.k)))) + newCtx, span := progress.Start(ctx, "HNSWBuilder.Build", trials) + defer span.End() for i := 0; i < trials; i++ { start := time.Now() idx = NewHNSW(b.data, SetEFConstruction(ef), SetHNSWNumJobs(b.numJobs)) - idx.task = buildTask - idx.Build() + idx.Build(newCtx) buildTime := time.Since(start) score = b.evaluate(idx, prune0) - idx.task.Add(b.testSize * len(b.data)) + span.Add(1) log.Logger().Info("try to build vector index", zap.String("index_type", "HNSW"), zap.Int("ef_construction", ef), diff --git a/base/search/index.go b/base/search/index.go index 1255e1f00..12e69390b 100644 --- a/base/search/index.go +++ b/base/search/index.go @@ -15,14 +15,16 @@ package search import ( + "context" "fmt" + "reflect" + "sort" + "github.com/chewxy/math32" "github.com/zhenghaoz/gorse/base/floats" "github.com/zhenghaoz/gorse/base/log" "go.uber.org/zap" "modernc.org/sortutil" - "reflect" - "sort" ) type Vector interface { @@ -182,7 +184,7 @@ func (v *DictionaryCentroidVector) Distance(vector Vector) float32 { } type VectorIndex interface { - Build() + Build(ctx context.Context) Search(q Vector, n int, prune0 bool) ([]int32, []float32) MultiSearch(q Vector, terms []string, n int, prune0 bool) (map[string][]int32, map[string][]float32) } diff --git a/base/search/index_test.go b/base/search/index_test.go index 88e1fbcf2..556e890af 100644 --- a/base/search/index_test.go +++ b/base/search/index_test.go @@ -15,13 +15,15 @@ package search import ( + "context" + "math/big" + "runtime" + "testing" + "github.com/stretchr/testify/assert" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/model" "github.com/zhenghaoz/gorse/model/ranking" - "math/big" - "runtime" - "testing" ) func TestHNSW_InnerProduct(t *testing.T) { @@ -37,7 +39,7 @@ func TestHNSW_InnerProduct(t *testing.T) { model.InitStdDev: 0.001, }) fitConfig := ranking.NewFitConfig().SetVerbose(1).SetJobsAllocator(task.NewConstantJobsAllocator(runtime.NumCPU())) - m.Fit(trainSet, testSet, fitConfig) + m.Fit(context.Background(), trainSet, testSet, fitConfig) var vectors []Vector for i, itemFactor := range m.ItemFactor { var terms []string @@ -49,12 +51,10 @@ func TestHNSW_InnerProduct(t *testing.T) { // build vector index builder := NewHNSWBuilder(vectors, 10, runtime.NumCPU()) - tk := task.NewTask("test", EstimateHNSWBuilderComplexity(len(vectors), 5)) - idx, recall := builder.Build(0.9, 5, false, tk) + idx, recall := builder.Build(context.Background(), 0.9, 5, false) assert.Greater(t, recall, float32(0.9)) recall = builder.evaluateTermSearch(idx, true, "prime") assert.Greater(t, recall, float32(0.8)) - assert.Equal(t, EstimateHNSWBuilderComplexity(len(vectors), 5), tk.Done) } func TestIVF_Cosine(t *testing.T) { @@ -76,10 +76,8 @@ func TestIVF_Cosine(t *testing.T) { // build vector index builder := NewIVFBuilder(vectors, 10) - tk := task.NewTask("test", EstimateIVFBuilderComplexity(len(vectors), 5)) - idx, recall := builder.Build(0.9, 5, true, tk) + idx, recall := builder.Build(0.9, 5, true) assert.Greater(t, recall, float32(0.9)) recall = builder.evaluateTermSearch(idx, true, "prime") assert.Greater(t, recall, float32(0.8)) - assert.Equal(t, EstimateIVFBuilderComplexity(len(vectors), 5), tk.Done) } diff --git a/base/search/ivf.go b/base/search/ivf.go index 12bb20355..f362e3f14 100644 --- a/base/search/ivf.go +++ b/base/search/ivf.go @@ -15,6 +15,7 @@ package search import ( + "context" "math" "math/rand" "sync" @@ -47,8 +48,6 @@ type IVF struct { maxIter int numProbe int jobsAlloc *task.JobsAllocator - - task *task.SubTask } type IVFConfig func(ivf *IVF) @@ -179,7 +178,7 @@ func (idx *IVF) MultiSearch(q Vector, terms []string, n int, prune0 bool) (value return } -func (idx *IVF) Build() { +func (idx *IVF) Build(_ context.Context) { if idx.k > len(idx.data) { panic("the size of the observations set must greater than or equal to k") } else if len(idx.data) == 0 { @@ -206,7 +205,7 @@ func (idx *IVF) Build() { // reassign clusters nextClusters := make([]ivfCluster, idx.k) - _ = parallel.Parallel(len(idx.data), idx.jobsAlloc.AvailableJobs(idx.task.Parent), func(_, i int) error { + _ = parallel.Parallel(len(idx.data), idx.jobsAlloc.AvailableJobs(), func(_, i int) error { if !idx.data[i].IsHidden() { nextCluster, nextDistance := -1, float32(math32.MaxFloat32) for c := range clusters { @@ -237,8 +236,6 @@ func (idx *IVF) Build() { nextClusters[c].centroid = idx.data[0].Centroid(idx.data, nextClusters[c].observations) } clusters = nextClusters - - idx.task.Add(len(idx.data) * int(math.Sqrt(float64(len(idx.data))))) } } @@ -260,7 +257,7 @@ func NewIVFBuilder(data []Vector, k int, configs ...IVFConfig) *IVFBuilder { rng: base.NewRandomGenerator(0), configs: configs, } - b.bruteForce.Build() + b.bruteForce.Build(context.Background()) return b } @@ -269,7 +266,7 @@ func (b *IVFBuilder) evaluate(idx *IVF, prune0 bool) float32 { samples := b.rng.Sample(0, len(b.data), testSize) var result, count float32 var mu sync.Mutex - _ = parallel.Parallel(len(samples), idx.jobsAlloc.AvailableJobs(idx.task.Parent), func(_, i int) error { + _ = parallel.Parallel(len(samples), idx.jobsAlloc.AvailableJobs(), func(_, i int) error { sample := samples[i] expected, _ := b.bruteForce.Search(b.data[sample], b.k, prune0) if len(expected) > 0 { @@ -287,20 +284,15 @@ func (b *IVFBuilder) evaluate(idx *IVF, prune0 bool) float32 { return result / count } -func (b *IVFBuilder) Build(recall float32, numEpoch int, prune0 bool, t *task.Task) (idx *IVF, score float32) { +func (b *IVFBuilder) Build(recall float32, numEpoch int, prune0 bool) (idx *IVF, score float32) { idx = NewIVF(b.data, b.configs...) - idx.task = t.SubTask(DefaultMaxIter * len(b.data) * int(math.Sqrt(float64(len(b.data))))) start := time.Now() - idx.Build() - idx.task.Finish() + idx.Build(context.Background()) - probeTask := t.SubTask(len(b.data) * DefaultTestSize * numEpoch) - defer probeTask.Finish() buildTime := time.Since(start) idx.numProbe = int(math32.Ceil(float32(b.k) / math32.Sqrt(float32(len(b.data))))) for i := 0; i < numEpoch; i++ { score = b.evaluate(idx, prune0) - probeTask.Add(len(b.data) * DefaultTestSize) log.Logger().Info("try to build vector index", zap.String("index_type", "IVF"), zap.Int("num_probe", idx.numProbe), @@ -320,7 +312,7 @@ func (b *IVFBuilder) evaluateTermSearch(idx *IVF, prune0 bool, term string) floa samples := b.rng.Sample(0, len(b.data), testSize) var result, count float32 var mu sync.Mutex - _ = parallel.Parallel(len(samples), idx.jobsAlloc.AvailableJobs(idx.task.Parent), func(_, i int) error { + _ = parallel.Parallel(len(samples), idx.jobsAlloc.AvailableJobs(), func(_, i int) error { sample := samples[i] expected, _ := b.bruteForce.MultiSearch(b.data[sample], []string{term}, b.k, prune0) if len(expected) > 0 { diff --git a/base/search/ivf_test.go b/base/search/ivf_test.go index da875d4e8..407c61594 100644 --- a/base/search/ivf_test.go +++ b/base/search/ivf_test.go @@ -31,7 +31,7 @@ func TestIVFConfig(t *testing.T) { assert.Equal(t, float32(0.123), ivf.errorRate) SetIVFJobsAllocator(task.NewConstantJobsAllocator(234))(ivf) - assert.Equal(t, 234, ivf.jobsAlloc.AvailableJobs(nil)) + assert.Equal(t, 234, ivf.jobsAlloc.AvailableJobs()) SetMaxIteration(345)(ivf) assert.Equal(t, 345, ivf.maxIter) diff --git a/base/task/schedule.go b/base/task/schedule.go index 31e225a74..7f0bbac27 100644 --- a/base/task/schedule.go +++ b/base/task/schedule.go @@ -45,13 +45,13 @@ func (allocator *JobsAllocator) MaxJobs() int { return allocator.numJobs } -func (allocator *JobsAllocator) AvailableJobs(tracker *Task) int { +func (allocator *JobsAllocator) AvailableJobs() int { if allocator == nil || allocator.numJobs < 1 { // Return 1 for invalid allocator return 1 } else if allocator.scheduler != nil { // Use jobs scheduler - return allocator.scheduler.allocateJobsForTask(allocator.taskName, true, tracker) + return allocator.scheduler.allocateJobsForTask(allocator.taskName, true) } return allocator.numJobs } @@ -59,7 +59,7 @@ func (allocator *JobsAllocator) AvailableJobs(tracker *Task) int { // Init jobs allocation. This method is used to request allocation of jobs for the first time. func (allocator *JobsAllocator) Init() { if allocator.scheduler != nil { - allocator.scheduler.allocateJobsForTask(allocator.taskName, true, nil) + allocator.scheduler.allocateJobsForTask(allocator.taskName, true) } } @@ -126,7 +126,7 @@ func (s *JobsScheduler) GetJobsAllocator(taskName string) *JobsAllocator { } } -func (s *JobsScheduler) allocateJobsForTask(taskName string, block bool, tracker *Task) int { +func (s *JobsScheduler) allocateJobsForTask(taskName string, block bool) int { // Find current task and return the jobs temporarily. s.L.Lock() currentTask, exist := s.tasks[taskName] @@ -142,14 +142,12 @@ func (s *JobsScheduler) allocateJobsForTask(taskName string, block bool, tracker for { s.allocateJobsForAll() if currentTask.jobs == 0 && block { - tracker.Suspend(true) if currentTask.previous > 0 { log.Logger().Debug("suspend task", zap.String("task", currentTask.name)) s.Broadcast() } s.Wait() } else { - tracker.Suspend(false) if currentTask.previous == 0 { log.Logger().Debug("resume task", zap.String("task", currentTask.name)) } diff --git a/base/task/schedule_test.go b/base/task/schedule_test.go index c3541bcfb..932ee2104 100644 --- a/base/task/schedule_test.go +++ b/base/task/schedule_test.go @@ -24,15 +24,15 @@ import ( func TestConstantJobsAllocator(t *testing.T) { allocator := NewConstantJobsAllocator(314) assert.Equal(t, 314, allocator.MaxJobs()) - assert.Equal(t, 314, allocator.AvailableJobs(nil)) + assert.Equal(t, 314, allocator.AvailableJobs()) allocator = NewConstantJobsAllocator(-1) assert.Equal(t, 1, allocator.MaxJobs()) - assert.Equal(t, 1, allocator.AvailableJobs(nil)) + assert.Equal(t, 1, allocator.AvailableJobs()) allocator = nil assert.Equal(t, 1, allocator.MaxJobs()) - assert.Equal(t, 1, allocator.AvailableJobs(nil)) + assert.Equal(t, 1, allocator.AvailableJobs()) } func TestDynamicJobsAllocator(t *testing.T) { @@ -44,11 +44,11 @@ func TestDynamicJobsAllocator(t *testing.T) { s.Register("e", 4, false) c := s.GetJobsAllocator("c") assert.Equal(t, 8, c.MaxJobs()) - assert.Equal(t, 3, c.AvailableJobs(nil)) + assert.Equal(t, 3, c.AvailableJobs()) b := s.GetJobsAllocator("b") - assert.Equal(t, 3, b.AvailableJobs(nil)) + assert.Equal(t, 3, b.AvailableJobs()) a := s.GetJobsAllocator("a") - assert.Equal(t, 2, a.AvailableJobs(nil)) + assert.Equal(t, 2, a.AvailableJobs()) barrier := make(chan struct{}) var wg sync.WaitGroup @@ -57,14 +57,14 @@ func TestDynamicJobsAllocator(t *testing.T) { defer wg.Done() barrier <- struct{}{} d := s.GetJobsAllocator("d") - assert.Equal(t, 4, d.AvailableJobs(nil)) + assert.Equal(t, 4, d.AvailableJobs()) }() go func() { defer wg.Done() barrier <- struct{}{} e := s.GetJobsAllocator("e") e.Init() - assert.Equal(t, 4, s.allocateJobsForTask("e", false, nil)) + assert.Equal(t, 4, s.allocateJobsForTask("e", false)) }() <-barrier @@ -83,27 +83,27 @@ func TestJobsScheduler(t *testing.T) { assert.True(t, s.Register("d", 4, false)) assert.True(t, s.Register("e", 4, false)) assert.False(t, s.Register("c", 1, true)) - assert.Equal(t, 3, s.allocateJobsForTask("c", false, nil)) - assert.Equal(t, 3, s.allocateJobsForTask("b", false, nil)) - assert.Equal(t, 2, s.allocateJobsForTask("a", false, nil)) - assert.Equal(t, 0, s.allocateJobsForTask("d", false, nil)) - assert.Equal(t, 0, s.allocateJobsForTask("e", false, nil)) + assert.Equal(t, 3, s.allocateJobsForTask("c", false)) + assert.Equal(t, 3, s.allocateJobsForTask("b", false)) + assert.Equal(t, 2, s.allocateJobsForTask("a", false)) + assert.Equal(t, 0, s.allocateJobsForTask("d", false)) + assert.Equal(t, 0, s.allocateJobsForTask("e", false)) // several tasks complete s.Unregister("b") s.Unregister("c") - assert.Equal(t, 8, s.allocateJobsForTask("a", false, nil)) + assert.Equal(t, 8, s.allocateJobsForTask("a", false)) // privileged tasks complete s.Unregister("a") - assert.Equal(t, 4, s.allocateJobsForTask("d", false, nil)) - assert.Equal(t, 4, s.allocateJobsForTask("e", false, nil)) + assert.Equal(t, 4, s.allocateJobsForTask("d", false)) + assert.Equal(t, 4, s.allocateJobsForTask("e", false)) // block privileged tasks if normal tasks are running s.Register("a", 1, true) s.Register("b", 2, true) s.Register("c", 3, true) - assert.Equal(t, 0, s.allocateJobsForTask("c", false, nil)) - assert.Equal(t, 0, s.allocateJobsForTask("b", false, nil)) - assert.Equal(t, 0, s.allocateJobsForTask("a", false, nil)) + assert.Equal(t, 0, s.allocateJobsForTask("c", false)) + assert.Equal(t, 0, s.allocateJobsForTask("b", false)) + assert.Equal(t, 0, s.allocateJobsForTask("a", false)) } From abad6f24a86c6288206046de5461887f520abe58 Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Fri, 28 Jul 2023 21:15:38 +0800 Subject: [PATCH 07/13] Update PB --- master/local_cache_test.go | 8 +- master/master.go | 18 +--- master/tasks_test.go | 145 ++++++++++++----------------- protocol/protocol.pb.go | 173 ++++++++++++++++++----------------- protocol/protocol.proto | 22 ++--- protocol/protocol_grpc.pb.go | 32 +++---- protocol/task.go | 19 ++-- protocol/task_test.go | 16 ++-- 8 files changed, 207 insertions(+), 226 deletions(-) diff --git a/master/local_cache_test.go b/master/local_cache_test.go index 3e10ded32..9f9307406 100644 --- a/master/local_cache_test.go +++ b/master/local_cache_test.go @@ -15,12 +15,14 @@ package master import ( + "context" + "testing" + "github.com/stretchr/testify/assert" "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/model" "github.com/zhenghaoz/gorse/model/click" "github.com/zhenghaoz/gorse/model/ranking" - "testing" ) func newRankingDataset() (*ranking.DataSet, *ranking.DataSet) { @@ -57,7 +59,7 @@ func TestLocalCache(t *testing.T) { // write and load trainSet, testSet := newRankingDataset() bpr := ranking.NewBPR(model.Params{model.NEpochs: 0}) - bpr.Fit(trainSet, testSet, nil) + bpr.Fit(context.Background(), trainSet, testSet, nil) cache.RankingModel = bpr cache.RankingModelName = "bpr" cache.RankingModelVersion = 123 @@ -65,7 +67,7 @@ func TestLocalCache(t *testing.T) { train, test := newClickDataset() fm := click.NewFM(click.FMClassification, model.Params{model.NEpochs: 0}) - fm.Fit(train, test, nil) + fm.Fit(context.Background(), train, test, nil) cache.ClickModel = fm cache.ClickModelVersion = 456 cache.ClickModelScore = click.Score{Precision: 1, RMSE: 100, Task: click.FMClassification} diff --git a/master/master.go b/master/master.go index a9551658d..afbba8baf 100644 --- a/master/master.go +++ b/master/master.go @@ -31,6 +31,7 @@ import ( "github.com/zhenghaoz/gorse/base/encoding" "github.com/zhenghaoz/gorse/base/log" "github.com/zhenghaoz/gorse/base/parallel" + "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/model/click" @@ -57,7 +58,7 @@ type Master struct { server.RestServer grpcServer *grpc.Server - taskMonitor *task.Monitor + tracer progress.Tracer jobsScheduler *task.JobsScheduler cacheFile string managedMode bool @@ -111,19 +112,11 @@ func NewMaster(cfg *config.Config, cacheFile string, managedMode bool) *Master { otel.SetTracerProvider(tp) otel.SetErrorHandler(log.GetErrorHandler()) otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})) - // create task monitor - taskMonitor := task.NewTaskMonitor() - for _, taskName := range []string{TaskLoadDataset, TaskFindItemNeighbors, TaskFindUserNeighbors, - TaskFitRankingModel, TaskFitClickModel, TaskSearchRankingModel, TaskSearchClickModel, - TaskCacheGarbageCollection} { - taskMonitor.Pending(taskName) - } return &Master{ nodesInfo: make(map[string]*Node), // create task monitor cacheFile: cacheFile, managedMode: managedMode, - taskMonitor: taskMonitor, jobsScheduler: task.NewJobsScheduler(cfg.Master.NumJobs), // default ranking model rankingModelName: "bpr", @@ -324,7 +317,7 @@ func (m *Master) RunPrivilegedTasksLoop() { j := m.jobsScheduler.GetJobsAllocator(task.name()) defer m.jobsScheduler.Unregister(task.name()) j.Init() - if err := task.run(j); err != nil { + if err := task.run(context.Background(), j); err != nil { log.Logger().Error("failed to run task", zap.String("task", task.name()), zap.Error(err)) return } @@ -362,9 +355,8 @@ func (m *Master) RunRagtagTasksLoop() { defer m.jobsScheduler.Unregister(task.name()) j := m.jobsScheduler.GetJobsAllocator(task.name()) j.Init() - if err = task.run(j); err != nil { + if err = task.run(context.Background(), j); err != nil { log.Logger().Error("failed to run task", zap.String("task", task.name()), zap.Error(err)) - m.taskMonitor.Fail(task.name(), err.Error()) } }(t) } @@ -434,7 +426,7 @@ func (m *Master) RunManagedTasksLoop() { defer m.jobsScheduler.Unregister(task.name()) defer wg.Done() j.Init() - if err := task.run(j); err != nil { + if err := task.run(context.Background(), j); err != nil { log.Logger().Error("failed to run task", zap.String("task", task.name()), zap.Error(err)) return } diff --git a/master/tasks_test.go b/master/tasks_test.go index 03f8d2b49..b572d125c 100644 --- a/master/tasks_test.go +++ b/master/tasks_test.go @@ -20,7 +20,6 @@ import ( "time" "github.com/samber/lo" - "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/storage/cache" "github.com/zhenghaoz/gorse/storage/data" @@ -34,16 +33,16 @@ func (s *MasterTestSuite) TestFindItemNeighborsBruteForce() { s.Config.Master.NumJobs = 4 // collect similar items := []data.Item{ - {"0", false, []string{"*"}, time.Now(), []string{"a", "b", "c", "d"}, ""}, - {"1", false, []string{"*"}, time.Now(), []string{}, ""}, - {"2", false, []string{"*"}, time.Now(), []string{"b", "c", "d"}, ""}, - {"3", false, nil, time.Now(), []string{}, ""}, - {"4", false, nil, time.Now(), []string{"b", "c"}, ""}, - {"5", false, []string{"*"}, time.Now(), []string{}, ""}, - {"6", false, []string{"*"}, time.Now(), []string{"c"}, ""}, - {"7", false, []string{"*"}, time.Now(), []string{}, ""}, - {"8", false, []string{"*"}, time.Now(), []string{"a", "b", "c", "d", "e"}, ""}, - {"9", false, nil, time.Now(), []string{}, ""}, + {ItemId: "0", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{"a", "b", "c", "d"}, Comment: ""}, + {ItemId: "1", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{}, Comment: ""}, + {ItemId: "2", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{"b", "c", "d"}, Comment: ""}, + {ItemId: "3", IsHidden: false, Categories: nil, Timestamp: time.Now(), Labels: []string{}, Comment: ""}, + {ItemId: "4", IsHidden: false, Categories: nil, Timestamp: time.Now(), Labels: []string{"b", "c"}, Comment: ""}, + {ItemId: "5", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{}, Comment: ""}, + {ItemId: "6", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{"c"}, Comment: ""}, + {ItemId: "7", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{}, Comment: ""}, + {ItemId: "8", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{"a", "b", "c", "d", "e"}, Comment: ""}, + {ItemId: "9", IsHidden: false, Categories: nil, Timestamp: time.Now(), Labels: []string{}, Comment: ""}, } feedbacks := make([]data.Feedback, 0) for i := 0; i < 10; i++ { @@ -88,12 +87,10 @@ func (s *MasterTestSuite) TestFindItemNeighborsBruteForce() { // similar items (common users) s.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeRelated neighborTask := NewFindItemNeighborsTask(&s.Master) - s.NoError(neighborTask.run(nil)) + s.NoError(neighborTask.run(context.Background(), nil)) similar, err := s.CacheClient.SearchDocuments(ctx, cache.ItemNeighbors, "9", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"7", "5", "3"}, cache.ConvertDocumentsToValues(similar)) - s.Equal(s.estimateFindItemNeighborsComplexity(dataset), s.taskMonitor.Tasks[TaskFindItemNeighbors].Done) - s.Equal(task.StatusComplete, s.taskMonitor.Tasks[TaskFindItemNeighbors].Status) // similar items in category (common users) similar, err = s.CacheClient.SearchDocuments(ctx, cache.ItemNeighbors, "9", []string{"*"}, 0, 100) s.NoError(err) @@ -104,12 +101,10 @@ func (s *MasterTestSuite) TestFindItemNeighborsBruteForce() { s.NoError(err) s.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeSimilar neighborTask = NewFindItemNeighborsTask(&s.Master) - s.NoError(neighborTask.run(nil)) + s.NoError(neighborTask.run(context.Background(), nil)) similar, err = s.CacheClient.SearchDocuments(ctx, cache.ItemNeighbors, "8", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"0", "2", "4"}, cache.ConvertDocumentsToValues(similar)) - s.Equal(s.estimateFindItemNeighborsComplexity(dataset), s.taskMonitor.Tasks[TaskFindItemNeighbors].Done) - s.Equal(task.StatusComplete, s.taskMonitor.Tasks[TaskFindItemNeighbors].Status) // similar items in category (common labels) similar, err = s.CacheClient.SearchDocuments(ctx, cache.ItemNeighbors, "8", []string{"*"}, 0, 100) s.NoError(err) @@ -122,15 +117,13 @@ func (s *MasterTestSuite) TestFindItemNeighborsBruteForce() { s.NoError(err) s.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeAuto neighborTask = NewFindItemNeighborsTask(&s.Master) - s.NoError(neighborTask.run(nil)) + s.NoError(neighborTask.run(context.Background(), nil)) similar, err = s.CacheClient.SearchDocuments(ctx, cache.ItemNeighbors, "8", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"0", "2", "4"}, cache.ConvertDocumentsToValues(similar)) similar, err = s.CacheClient.SearchDocuments(ctx, cache.ItemNeighbors, "9", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"7", "5", "3"}, cache.ConvertDocumentsToValues(similar)) - s.Equal(s.estimateFindItemNeighborsComplexity(dataset), s.taskMonitor.Tasks[TaskFindItemNeighbors].Done) - s.Equal(task.StatusComplete, s.taskMonitor.Tasks[TaskFindItemNeighbors].Status) } func (s *MasterTestSuite) TestFindItemNeighborsIVF() { @@ -145,16 +138,16 @@ func (s *MasterTestSuite) TestFindItemNeighborsIVF() { s.Config.Recommend.ItemNeighbors.IndexFitEpoch = 10 // collect similar items := []data.Item{ - {"0", false, []string{"*"}, time.Now(), []string{"a", "b", "c", "d"}, ""}, - {"1", false, []string{"*"}, time.Now(), []string{}, ""}, - {"2", false, []string{"*"}, time.Now(), []string{"b", "c", "d"}, ""}, - {"3", false, nil, time.Now(), []string{}, ""}, - {"4", false, nil, time.Now(), []string{"b", "c"}, ""}, - {"5", false, []string{"*"}, time.Now(), []string{}, ""}, - {"6", false, []string{"*"}, time.Now(), []string{"c"}, ""}, - {"7", false, []string{"*"}, time.Now(), []string{}, ""}, - {"8", false, []string{"*"}, time.Now(), []string{"a", "b", "c", "d", "e"}, ""}, - {"9", false, nil, time.Now(), []string{}, ""}, + {ItemId: "0", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{"a", "b", "c", "d"}, Comment: ""}, + {ItemId: "1", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{}, Comment: ""}, + {ItemId: "2", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{"b", "c", "d"}, Comment: ""}, + {ItemId: "3", IsHidden: false, Categories: nil, Timestamp: time.Now(), Labels: []string{}, Comment: ""}, + {ItemId: "4", IsHidden: false, Categories: nil, Timestamp: time.Now(), Labels: []string{"b", "c"}, Comment: ""}, + {ItemId: "5", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{}, Comment: ""}, + {ItemId: "6", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{"c"}, Comment: ""}, + {ItemId: "7", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{}, Comment: ""}, + {ItemId: "8", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{"a", "b", "c", "d", "e"}, Comment: ""}, + {ItemId: "9", IsHidden: false, Categories: nil, Timestamp: time.Now(), Labels: []string{}, Comment: ""}, } feedbacks := make([]data.Feedback, 0) for i := 0; i < 10; i++ { @@ -199,12 +192,10 @@ func (s *MasterTestSuite) TestFindItemNeighborsIVF() { // similar items (common users) s.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeRelated neighborTask := NewFindItemNeighborsTask(&s.Master) - s.NoError(neighborTask.run(nil)) + s.NoError(neighborTask.run(context.Background(), nil)) similar, err := s.CacheClient.SearchDocuments(ctx, cache.ItemNeighbors, "9", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"7", "5", "3"}, cache.ConvertDocumentsToValues(similar)) - s.Equal(s.estimateFindItemNeighborsComplexity(dataset), s.taskMonitor.Tasks[TaskFindItemNeighbors].Done) - s.Equal(task.StatusComplete, s.taskMonitor.Tasks[TaskFindItemNeighbors].Status) // similar items in category (common users) similar, err = s.CacheClient.SearchDocuments(ctx, cache.ItemNeighbors, "9", []string{"*"}, 0, 100) s.NoError(err) @@ -215,12 +206,10 @@ func (s *MasterTestSuite) TestFindItemNeighborsIVF() { s.NoError(err) s.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeSimilar neighborTask = NewFindItemNeighborsTask(&s.Master) - s.NoError(neighborTask.run(nil)) + s.NoError(neighborTask.run(context.Background(), nil)) similar, err = s.CacheClient.SearchDocuments(ctx, cache.ItemNeighbors, "8", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"0", "2", "4"}, cache.ConvertDocumentsToValues(similar)) - s.Equal(s.estimateFindItemNeighborsComplexity(dataset), s.taskMonitor.Tasks[TaskFindItemNeighbors].Done) - s.Equal(task.StatusComplete, s.taskMonitor.Tasks[TaskFindItemNeighbors].Status) // similar items in category (common labels) similar, err = s.CacheClient.SearchDocuments(ctx, cache.ItemNeighbors, "8", []string{"*"}, 0, 100) s.NoError(err) @@ -233,15 +222,13 @@ func (s *MasterTestSuite) TestFindItemNeighborsIVF() { s.NoError(err) s.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeAuto neighborTask = NewFindItemNeighborsTask(&s.Master) - s.NoError(neighborTask.run(nil)) + s.NoError(neighborTask.run(context.Background(), nil)) similar, err = s.CacheClient.SearchDocuments(ctx, cache.ItemNeighbors, "8", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"0", "2", "4"}, cache.ConvertDocumentsToValues(similar)) similar, err = s.CacheClient.SearchDocuments(ctx, cache.ItemNeighbors, "9", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"7", "5", "3"}, cache.ConvertDocumentsToValues(similar)) - s.Equal(s.estimateFindItemNeighborsComplexity(dataset), s.taskMonitor.Tasks[TaskFindItemNeighbors].Done) - s.Equal(task.StatusComplete, s.taskMonitor.Tasks[TaskFindItemNeighbors].Status) } func (s *MasterTestSuite) TestFindItemNeighborsIVF_ZeroIDF() { @@ -256,8 +243,8 @@ func (s *MasterTestSuite) TestFindItemNeighborsIVF_ZeroIDF() { // create dataset err := s.DataClient.BatchInsertItems(ctx, []data.Item{ - {"0", false, []string{"*"}, time.Now(), []string{"a", "a"}, ""}, - {"1", false, []string{"*"}, time.Now(), []string{"a", "a"}, ""}, + {ItemId: "0", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{"a", "a"}, Comment: ""}, + {ItemId: "1", IsHidden: false, Categories: []string{"*"}, Timestamp: time.Now(), Labels: []string{"a", "a"}, Comment: ""}, }) s.NoError(err) err = s.DataClient.BatchInsertFeedback(ctx, []data.Feedback{ @@ -272,7 +259,7 @@ func (s *MasterTestSuite) TestFindItemNeighborsIVF_ZeroIDF() { // similar items (common users) s.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeRelated neighborTask := NewFindItemNeighborsTask(&s.Master) - s.NoError(neighborTask.run(nil)) + s.NoError(neighborTask.run(context.Background(), nil)) similar, err := s.CacheClient.SearchDocuments(ctx, cache.ItemNeighbors, "0", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"1"}, cache.ConvertDocumentsToValues(similar)) @@ -280,7 +267,7 @@ func (s *MasterTestSuite) TestFindItemNeighborsIVF_ZeroIDF() { // similar items (common labels) s.Config.Recommend.ItemNeighbors.NeighborType = config.NeighborTypeSimilar neighborTask = NewFindItemNeighborsTask(&s.Master) - s.NoError(neighborTask.run(nil)) + s.NoError(neighborTask.run(context.Background(), nil)) similar, err = s.CacheClient.SearchDocuments(ctx, cache.ItemNeighbors, "0", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"1"}, cache.ConvertDocumentsToValues(similar)) @@ -294,16 +281,16 @@ func (s *MasterTestSuite) TestFindUserNeighborsBruteForce() { s.Config.Master.NumJobs = 4 // collect similar users := []data.User{ - {"0", []string{"a", "b", "c", "d"}, nil, ""}, - {"1", []string{}, nil, ""}, - {"2", []string{"b", "c", "d"}, nil, ""}, - {"3", []string{}, nil, ""}, - {"4", []string{"b", "c"}, nil, ""}, - {"5", []string{}, nil, ""}, - {"6", []string{"c"}, nil, ""}, - {"7", []string{}, nil, ""}, - {"8", []string{"a", "b", "c", "d", "e"}, nil, ""}, - {"9", []string{}, nil, ""}, + {UserId: "0", Labels: []string{"a", "b", "c", "d"}, Subscribe: nil, Comment: ""}, + {UserId: "1", Labels: []string{}, Subscribe: nil, Comment: ""}, + {UserId: "2", Labels: []string{"b", "c", "d"}, Subscribe: nil, Comment: ""}, + {UserId: "3", Labels: []string{}, Subscribe: nil, Comment: ""}, + {UserId: "4", Labels: []string{"b", "c"}, Subscribe: nil, Comment: ""}, + {UserId: "5", Labels: []string{}, Subscribe: nil, Comment: ""}, + {UserId: "6", Labels: []string{"c"}, Subscribe: nil, Comment: ""}, + {UserId: "7", Labels: []string{}, Subscribe: nil, Comment: ""}, + {UserId: "8", Labels: []string{"a", "b", "c", "d", "e"}, Subscribe: nil, Comment: ""}, + {UserId: "9", Labels: []string{}, Subscribe: nil, Comment: ""}, } feedbacks := make([]data.Feedback, 0) for i := 0; i < 10; i++ { @@ -332,24 +319,20 @@ func (s *MasterTestSuite) TestFindUserNeighborsBruteForce() { // similar items (common users) s.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeRelated neighborTask := NewFindUserNeighborsTask(&s.Master) - s.NoError(neighborTask.run(nil)) + s.NoError(neighborTask.run(context.Background(), nil)) similar, err := s.CacheClient.SearchDocuments(ctx, cache.UserNeighbors, "9", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"7", "5", "3"}, cache.ConvertDocumentsToValues(similar)) - s.Equal(s.estimateFindUserNeighborsComplexity(dataset), s.taskMonitor.Tasks[TaskFindUserNeighbors].Done) - s.Equal(task.StatusComplete, s.taskMonitor.Tasks[TaskFindUserNeighbors].Status) // similar items (common labels) err = s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyUserTime, "8"), time.Now())) s.NoError(err) s.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeSimilar neighborTask = NewFindUserNeighborsTask(&s.Master) - s.NoError(neighborTask.run(nil)) + s.NoError(neighborTask.run(context.Background(), nil)) similar, err = s.CacheClient.SearchDocuments(ctx, cache.UserNeighbors, "8", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"0", "2", "4"}, cache.ConvertDocumentsToValues(similar)) - s.Equal(s.estimateFindUserNeighborsComplexity(dataset), s.taskMonitor.Tasks[TaskFindUserNeighbors].Done) - s.Equal(task.StatusComplete, s.taskMonitor.Tasks[TaskFindUserNeighbors].Status) // similar items (auto) err = s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyUserTime, "8"), time.Now())) @@ -358,15 +341,13 @@ func (s *MasterTestSuite) TestFindUserNeighborsBruteForce() { s.NoError(err) s.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeAuto neighborTask = NewFindUserNeighborsTask(&s.Master) - s.NoError(neighborTask.run(nil)) + s.NoError(neighborTask.run(context.Background(), nil)) similar, err = s.CacheClient.SearchDocuments(ctx, cache.UserNeighbors, "8", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"0", "2", "4"}, cache.ConvertDocumentsToValues(similar)) similar, err = s.CacheClient.SearchDocuments(ctx, cache.UserNeighbors, "9", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"7", "5", "3"}, cache.ConvertDocumentsToValues(similar)) - s.Equal(s.estimateFindUserNeighborsComplexity(dataset), s.taskMonitor.Tasks[TaskFindUserNeighbors].Done) - s.Equal(task.StatusComplete, s.taskMonitor.Tasks[TaskFindUserNeighbors].Status) } func (s *MasterTestSuite) TestFindUserNeighborsIVF() { @@ -380,16 +361,16 @@ func (s *MasterTestSuite) TestFindUserNeighborsIVF() { s.Config.Recommend.UserNeighbors.IndexFitEpoch = 10 // collect similar users := []data.User{ - {"0", []string{"a", "b", "c", "d"}, nil, ""}, - {"1", []string{}, nil, ""}, - {"2", []string{"b", "c", "d"}, nil, ""}, - {"3", []string{}, nil, ""}, - {"4", []string{"b", "c"}, nil, ""}, - {"5", []string{}, nil, ""}, - {"6", []string{"c"}, nil, ""}, - {"7", []string{}, nil, ""}, - {"8", []string{"a", "b", "c", "d", "e"}, nil, ""}, - {"9", []string{}, nil, ""}, + {UserId: "0", Labels: []string{"a", "b", "c", "d"}, Subscribe: nil, Comment: ""}, + {UserId: "1", Labels: []string{}, Subscribe: nil, Comment: ""}, + {UserId: "2", Labels: []string{"b", "c", "d"}, Subscribe: nil, Comment: ""}, + {UserId: "3", Labels: []string{}, Subscribe: nil, Comment: ""}, + {UserId: "4", Labels: []string{"b", "c"}, Subscribe: nil, Comment: ""}, + {UserId: "5", Labels: []string{}, Subscribe: nil, Comment: ""}, + {UserId: "6", Labels: []string{"c"}, Subscribe: nil, Comment: ""}, + {UserId: "7", Labels: []string{}, Subscribe: nil, Comment: ""}, + {UserId: "8", Labels: []string{"a", "b", "c", "d", "e"}, Subscribe: nil, Comment: ""}, + {UserId: "9", Labels: []string{}, Subscribe: nil, Comment: ""}, } feedbacks := make([]data.Feedback, 0) for i := 0; i < 10; i++ { @@ -418,24 +399,20 @@ func (s *MasterTestSuite) TestFindUserNeighborsIVF() { // similar items (common users) s.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeRelated neighborTask := NewFindUserNeighborsTask(&s.Master) - s.NoError(neighborTask.run(nil)) + s.NoError(neighborTask.run(context.Background(), nil)) similar, err := s.CacheClient.SearchDocuments(ctx, cache.UserNeighbors, "9", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"7", "5", "3"}, cache.ConvertDocumentsToValues(similar)) - s.Equal(s.estimateFindUserNeighborsComplexity(dataset), s.taskMonitor.Tasks[TaskFindUserNeighbors].Done) - s.Equal(task.StatusComplete, s.taskMonitor.Tasks[TaskFindUserNeighbors].Status) // similar items (common labels) err = s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyUserTime, "8"), time.Now())) s.NoError(err) s.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeSimilar neighborTask = NewFindUserNeighborsTask(&s.Master) - s.NoError(neighborTask.run(nil)) + s.NoError(neighborTask.run(context.Background(), nil)) similar, err = s.CacheClient.SearchDocuments(ctx, cache.UserNeighbors, "8", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"0", "2", "4"}, cache.ConvertDocumentsToValues(similar)) - s.Equal(s.estimateFindUserNeighborsComplexity(dataset), s.taskMonitor.Tasks[TaskFindUserNeighbors].Done) - s.Equal(task.StatusComplete, s.taskMonitor.Tasks[TaskFindUserNeighbors].Status) // similar items (auto) err = s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyUserTime, "8"), time.Now())) @@ -444,15 +421,13 @@ func (s *MasterTestSuite) TestFindUserNeighborsIVF() { s.NoError(err) s.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeAuto neighborTask = NewFindUserNeighborsTask(&s.Master) - s.NoError(neighborTask.run(nil)) + s.NoError(neighborTask.run(context.Background(), nil)) similar, err = s.CacheClient.SearchDocuments(ctx, cache.UserNeighbors, "8", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"0", "2", "4"}, cache.ConvertDocumentsToValues(similar)) similar, err = s.CacheClient.SearchDocuments(ctx, cache.UserNeighbors, "9", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"7", "5", "3"}, cache.ConvertDocumentsToValues(similar)) - s.Equal(s.estimateFindUserNeighborsComplexity(dataset), s.taskMonitor.Tasks[TaskFindUserNeighbors].Done) - s.Equal(task.StatusComplete, s.taskMonitor.Tasks[TaskFindUserNeighbors].Status) } func (s *MasterTestSuite) TestFindUserNeighborsIVF_ZeroIDF() { @@ -467,8 +442,8 @@ func (s *MasterTestSuite) TestFindUserNeighborsIVF_ZeroIDF() { // create dataset err := s.DataClient.BatchInsertUsers(ctx, []data.User{ - {"0", []string{"a", "a"}, nil, ""}, - {"1", []string{"a", "a"}, nil, ""}, + {UserId: "0", Labels: []string{"a", "a"}, Subscribe: nil, Comment: ""}, + {UserId: "1", Labels: []string{"a", "a"}, Subscribe: nil, Comment: ""}, }) s.NoError(err) err = s.DataClient.BatchInsertFeedback(ctx, []data.Feedback{ @@ -483,7 +458,7 @@ func (s *MasterTestSuite) TestFindUserNeighborsIVF_ZeroIDF() { // similar users (common items) s.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeRelated neighborTask := NewFindUserNeighborsTask(&s.Master) - s.NoError(neighborTask.run(nil)) + s.NoError(neighborTask.run(context.Background(), nil)) similar, err := s.CacheClient.SearchDocuments(ctx, cache.UserNeighbors, "0", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"1"}, cache.ConvertDocumentsToValues(similar)) @@ -491,7 +466,7 @@ func (s *MasterTestSuite) TestFindUserNeighborsIVF_ZeroIDF() { // similar users (common labels) s.Config.Recommend.UserNeighbors.NeighborType = config.NeighborTypeSimilar neighborTask = NewFindUserNeighborsTask(&s.Master) - s.NoError(neighborTask.run(nil)) + s.NoError(neighborTask.run(context.Background(), nil)) similar, err = s.CacheClient.SearchDocuments(ctx, cache.UserNeighbors, "0", []string{""}, 0, 100) s.NoError(err) s.Equal([]string{"1"}, cache.ConvertDocumentsToValues(similar)) diff --git a/protocol/protocol.pb.go b/protocol/protocol.pb.go index 2ee984c40..2c228779c 100644 --- a/protocol/protocol.pb.go +++ b/protocol/protocol.pb.go @@ -14,8 +14,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 -// protoc v3.6.1 +// protoc-gen-go v1.28.1 +// protoc v3.12.4 // source: protocol.proto package protocol @@ -335,22 +335,23 @@ func (x *NodeInfo) GetBinaryVersion() string { return "" } -type PushTaskInfoRequest struct { +type PushProgressRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` - Status string `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` - Done int64 `protobuf:"varint,3,opt,name=done,proto3" json:"done,omitempty"` - Total int64 `protobuf:"varint,4,opt,name=total,proto3" json:"total,omitempty"` - StartTime int64 `protobuf:"varint,5,opt,name=start_time,json=startTime,proto3" json:"start_time,omitempty"` - FinishTime int64 `protobuf:"varint,6,opt,name=finish_time,json=finishTime,proto3" json:"finish_time,omitempty"` - Error string `protobuf:"bytes,7,opt,name=error,proto3" json:"error,omitempty"` + Tracer string `protobuf:"bytes,1,opt,name=tracer,proto3" json:"tracer,omitempty"` + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + Status string `protobuf:"bytes,3,opt,name=status,proto3" json:"status,omitempty"` + Error string `protobuf:"bytes,4,opt,name=error,proto3" json:"error,omitempty"` + Count int64 `protobuf:"varint,5,opt,name=count,proto3" json:"count,omitempty"` + Total int64 `protobuf:"varint,6,opt,name=total,proto3" json:"total,omitempty"` + StartTime int64 `protobuf:"varint,7,opt,name=start_time,json=startTime,proto3" json:"start_time,omitempty"` + FinishTime int64 `protobuf:"varint,8,opt,name=finish_time,json=finishTime,proto3" json:"finish_time,omitempty"` } -func (x *PushTaskInfoRequest) Reset() { - *x = PushTaskInfoRequest{} +func (x *PushProgressRequest) Reset() { + *x = PushProgressRequest{} if protoimpl.UnsafeEnabled { mi := &file_protocol_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -358,13 +359,13 @@ func (x *PushTaskInfoRequest) Reset() { } } -func (x *PushTaskInfoRequest) String() string { +func (x *PushProgressRequest) String() string { return protoimpl.X.MessageStringOf(x) } -func (*PushTaskInfoRequest) ProtoMessage() {} +func (*PushProgressRequest) ProtoMessage() {} -func (x *PushTaskInfoRequest) ProtoReflect() protoreflect.Message { +func (x *PushProgressRequest) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[4] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -376,68 +377,75 @@ func (x *PushTaskInfoRequest) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use PushTaskInfoRequest.ProtoReflect.Descriptor instead. -func (*PushTaskInfoRequest) Descriptor() ([]byte, []int) { +// Deprecated: Use PushProgressRequest.ProtoReflect.Descriptor instead. +func (*PushProgressRequest) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{4} } -func (x *PushTaskInfoRequest) GetName() string { +func (x *PushProgressRequest) GetTracer() string { + if x != nil { + return x.Tracer + } + return "" +} + +func (x *PushProgressRequest) GetName() string { if x != nil { return x.Name } return "" } -func (x *PushTaskInfoRequest) GetStatus() string { +func (x *PushProgressRequest) GetStatus() string { if x != nil { return x.Status } return "" } -func (x *PushTaskInfoRequest) GetDone() int64 { +func (x *PushProgressRequest) GetError() string { if x != nil { - return x.Done + return x.Error } - return 0 + return "" } -func (x *PushTaskInfoRequest) GetTotal() int64 { +func (x *PushProgressRequest) GetCount() int64 { if x != nil { - return x.Total + return x.Count } return 0 } -func (x *PushTaskInfoRequest) GetStartTime() int64 { +func (x *PushProgressRequest) GetTotal() int64 { if x != nil { - return x.StartTime + return x.Total } return 0 } -func (x *PushTaskInfoRequest) GetFinishTime() int64 { +func (x *PushProgressRequest) GetStartTime() int64 { if x != nil { - return x.FinishTime + return x.StartTime } return 0 } -func (x *PushTaskInfoRequest) GetError() string { +func (x *PushProgressRequest) GetFinishTime() int64 { if x != nil { - return x.Error + return x.FinishTime } - return "" + return 0 } -type PushTaskInfoResponse struct { +type PushProgressResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields } -func (x *PushTaskInfoResponse) Reset() { - *x = PushTaskInfoResponse{} +func (x *PushProgressResponse) Reset() { + *x = PushProgressResponse{} if protoimpl.UnsafeEnabled { mi := &file_protocol_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -445,13 +453,13 @@ func (x *PushTaskInfoResponse) Reset() { } } -func (x *PushTaskInfoResponse) String() string { +func (x *PushProgressResponse) String() string { return protoimpl.X.MessageStringOf(x) } -func (*PushTaskInfoResponse) ProtoMessage() {} +func (*PushProgressResponse) ProtoMessage() {} -func (x *PushTaskInfoResponse) ProtoReflect() protoreflect.Message { +func (x *PushProgressResponse) ProtoReflect() protoreflect.Message { mi := &file_protocol_proto_msgTypes[5] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -463,8 +471,8 @@ func (x *PushTaskInfoResponse) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -// Deprecated: Use PushTaskInfoResponse.ProtoReflect.Descriptor instead. -func (*PushTaskInfoResponse) Descriptor() ([]byte, []int) { +// Deprecated: Use PushProgressResponse.ProtoReflect.Descriptor instead. +func (*PushProgressResponse) Descriptor() ([]byte, []int) { return file_protocol_proto_rawDescGZIP(), []int{5} } @@ -499,45 +507,46 @@ var file_protocol_proto_rawDesc = []byte{ 0x70, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x68, 0x74, 0x74, 0x70, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x62, 0x69, - 0x6e, 0x61, 0x72, 0x79, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0xc1, 0x01, 0x0a, 0x13, - 0x50, 0x75, 0x73, 0x68, 0x54, 0x61, 0x73, 0x6b, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, - 0x73, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, - 0x12, 0x0a, 0x04, 0x64, 0x6f, 0x6e, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x04, 0x64, - 0x6f, 0x6e, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x18, 0x04, 0x20, 0x01, + 0x6e, 0x61, 0x72, 0x79, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0xdb, 0x01, 0x0a, 0x13, + 0x50, 0x75, 0x73, 0x68, 0x50, 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x74, 0x72, 0x61, 0x63, 0x65, 0x72, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, 0x72, 0x61, 0x63, 0x65, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x6e, + 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, + 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, + 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x14, 0x0a, + 0x05, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x63, 0x6f, + 0x75, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, - 0x72, 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x73, + 0x72, 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x73, 0x74, 0x61, 0x72, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x66, 0x69, 0x6e, 0x69, - 0x73, 0x68, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0a, 0x66, - 0x69, 0x6e, 0x69, 0x73, 0x68, 0x54, 0x69, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, - 0x6f, 0x72, 0x18, 0x07, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x22, - 0x16, 0x0a, 0x14, 0x50, 0x75, 0x73, 0x68, 0x54, 0x61, 0x73, 0x6b, 0x49, 0x6e, 0x66, 0x6f, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x2a, 0x3a, 0x0a, 0x08, 0x4e, 0x6f, 0x64, 0x65, 0x54, - 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x0a, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4e, 0x6f, 0x64, - 0x65, 0x10, 0x00, 0x12, 0x0e, 0x0a, 0x0a, 0x57, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x4e, 0x6f, 0x64, - 0x65, 0x10, 0x01, 0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x4e, 0x6f, 0x64, - 0x65, 0x10, 0x02, 0x32, 0x8c, 0x02, 0x0a, 0x06, 0x4d, 0x61, 0x73, 0x74, 0x65, 0x72, 0x12, 0x2f, - 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x1a, 0x0e, 0x2e, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x22, 0x00, 0x12, - 0x40, 0x0a, 0x0f, 0x47, 0x65, 0x74, 0x52, 0x61, 0x6e, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, - 0x65, 0x6c, 0x12, 0x15, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x56, 0x65, - 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x6e, 0x66, 0x6f, 0x1a, 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x46, 0x72, 0x61, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x22, 0x00, 0x30, - 0x01, 0x12, 0x3e, 0x0a, 0x0d, 0x47, 0x65, 0x74, 0x43, 0x6c, 0x69, 0x63, 0x6b, 0x4d, 0x6f, 0x64, - 0x65, 0x6c, 0x12, 0x15, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x56, 0x65, - 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x6e, 0x66, 0x6f, 0x1a, 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x46, 0x72, 0x61, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x22, 0x00, 0x30, - 0x01, 0x12, 0x4f, 0x0a, 0x0c, 0x50, 0x75, 0x73, 0x68, 0x54, 0x61, 0x73, 0x6b, 0x49, 0x6e, 0x66, - 0x6f, 0x12, 0x1d, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x50, 0x75, 0x73, - 0x68, 0x54, 0x61, 0x73, 0x6b, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x1e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x50, 0x75, 0x73, 0x68, - 0x54, 0x61, 0x73, 0x6b, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x42, 0x25, 0x5a, 0x23, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, - 0x2f, 0x7a, 0x68, 0x65, 0x6e, 0x67, 0x68, 0x61, 0x6f, 0x7a, 0x2f, 0x67, 0x6f, 0x72, 0x73, 0x65, - 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x33, + 0x73, 0x68, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0a, 0x66, + 0x69, 0x6e, 0x69, 0x73, 0x68, 0x54, 0x69, 0x6d, 0x65, 0x22, 0x16, 0x0a, 0x14, 0x50, 0x75, 0x73, + 0x68, 0x50, 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x2a, 0x3a, 0x0a, 0x08, 0x4e, 0x6f, 0x64, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, + 0x0a, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4e, 0x6f, 0x64, 0x65, 0x10, 0x00, 0x12, 0x0e, 0x0a, + 0x0a, 0x57, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x4e, 0x6f, 0x64, 0x65, 0x10, 0x01, 0x12, 0x0e, 0x0a, + 0x0a, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x10, 0x02, 0x32, 0x8c, 0x02, + 0x0a, 0x06, 0x4d, 0x61, 0x73, 0x74, 0x65, 0x72, 0x12, 0x2f, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, + 0x65, 0x74, 0x61, 0x12, 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4e, + 0x6f, 0x64, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x22, 0x00, 0x12, 0x40, 0x0a, 0x0f, 0x47, 0x65, 0x74, + 0x52, 0x61, 0x6e, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x15, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x49, + 0x6e, 0x66, 0x6f, 0x1a, 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x46, + 0x72, 0x61, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x22, 0x00, 0x30, 0x01, 0x12, 0x3e, 0x0a, 0x0d, 0x47, + 0x65, 0x74, 0x43, 0x6c, 0x69, 0x63, 0x6b, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x15, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x49, + 0x6e, 0x66, 0x6f, 0x1a, 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x46, + 0x72, 0x61, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x22, 0x00, 0x30, 0x01, 0x12, 0x4f, 0x0a, 0x0c, 0x50, + 0x75, 0x73, 0x68, 0x50, 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x12, 0x1d, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x50, 0x72, 0x6f, 0x67, 0x72, + 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x50, 0x72, 0x6f, 0x67, 0x72, 0x65, + 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x25, 0x5a, 0x23, + 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x7a, 0x68, 0x65, 0x6e, 0x67, + 0x68, 0x61, 0x6f, 0x7a, 0x2f, 0x67, 0x6f, 0x72, 0x73, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, + 0x63, 0x6f, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -560,19 +569,19 @@ var file_protocol_proto_goTypes = []interface{}{ (*Fragment)(nil), // 2: protocol.Fragment (*VersionInfo)(nil), // 3: protocol.VersionInfo (*NodeInfo)(nil), // 4: protocol.NodeInfo - (*PushTaskInfoRequest)(nil), // 5: protocol.PushTaskInfoRequest - (*PushTaskInfoResponse)(nil), // 6: protocol.PushTaskInfoResponse + (*PushProgressRequest)(nil), // 5: protocol.PushProgressRequest + (*PushProgressResponse)(nil), // 6: protocol.PushProgressResponse } var file_protocol_proto_depIdxs = []int32{ 0, // 0: protocol.NodeInfo.node_type:type_name -> protocol.NodeType 4, // 1: protocol.Master.GetMeta:input_type -> protocol.NodeInfo 3, // 2: protocol.Master.GetRankingModel:input_type -> protocol.VersionInfo 3, // 3: protocol.Master.GetClickModel:input_type -> protocol.VersionInfo - 5, // 4: protocol.Master.PushTaskInfo:input_type -> protocol.PushTaskInfoRequest + 5, // 4: protocol.Master.PushProgress:input_type -> protocol.PushProgressRequest 1, // 5: protocol.Master.GetMeta:output_type -> protocol.Meta 2, // 6: protocol.Master.GetRankingModel:output_type -> protocol.Fragment 2, // 7: protocol.Master.GetClickModel:output_type -> protocol.Fragment - 6, // 8: protocol.Master.PushTaskInfo:output_type -> protocol.PushTaskInfoResponse + 6, // 8: protocol.Master.PushProgress:output_type -> protocol.PushProgressResponse 5, // [5:9] is the sub-list for method output_type 1, // [1:5] is the sub-list for method input_type 1, // [1:1] is the sub-list for extension type_name @@ -635,7 +644,7 @@ func file_protocol_proto_init() { } } file_protocol_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PushTaskInfoRequest); i { + switch v := v.(*PushProgressRequest); i { case 0: return &v.state case 1: @@ -647,7 +656,7 @@ func file_protocol_proto_init() { } } file_protocol_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PushTaskInfoResponse); i { + switch v := v.(*PushProgressResponse); i { case 0: return &v.state case 1: diff --git a/protocol/protocol.proto b/protocol/protocol.proto index 20702f53e..1f1a03fb0 100644 --- a/protocol/protocol.proto +++ b/protocol/protocol.proto @@ -32,8 +32,7 @@ service Master { rpc GetRankingModel(VersionInfo) returns (stream Fragment) {} rpc GetClickModel(VersionInfo) returns (stream Fragment) {} - /* task management */ - rpc PushTaskInfo(PushTaskInfoRequest) returns (PushTaskInfoResponse) {} + rpc PushProgress(PushProgressRequest) returns (PushProgressResponse) {} } @@ -61,14 +60,15 @@ message NodeInfo { string binary_version = 4; } -message PushTaskInfoRequest { - string name = 1; - string status = 2; - int64 done = 3; - int64 total = 4; - int64 start_time = 5; - int64 finish_time = 6; - string error = 7; +message PushProgressRequest { + string tracer = 1; + string name = 2; + string status = 3; + string error = 4; + int64 count = 5; + int64 total = 6; + int64 start_time = 7; + int64 finish_time = 8; } -message PushTaskInfoResponse {} +message PushProgressResponse {} diff --git a/protocol/protocol_grpc.pb.go b/protocol/protocol_grpc.pb.go index 8da5e0dda..dbc48c9c1 100644 --- a/protocol/protocol_grpc.pb.go +++ b/protocol/protocol_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.2.0 -// - protoc v3.6.1 +// - protoc v3.12.4 // source: protocol.proto package protocol @@ -27,8 +27,7 @@ type MasterClient interface { // data distribute GetRankingModel(ctx context.Context, in *VersionInfo, opts ...grpc.CallOption) (Master_GetRankingModelClient, error) GetClickModel(ctx context.Context, in *VersionInfo, opts ...grpc.CallOption) (Master_GetClickModelClient, error) - // task management - PushTaskInfo(ctx context.Context, in *PushTaskInfoRequest, opts ...grpc.CallOption) (*PushTaskInfoResponse, error) + PushProgress(ctx context.Context, in *PushProgressRequest, opts ...grpc.CallOption) (*PushProgressResponse, error) } type masterClient struct { @@ -112,9 +111,9 @@ func (x *masterGetClickModelClient) Recv() (*Fragment, error) { return m, nil } -func (c *masterClient) PushTaskInfo(ctx context.Context, in *PushTaskInfoRequest, opts ...grpc.CallOption) (*PushTaskInfoResponse, error) { - out := new(PushTaskInfoResponse) - err := c.cc.Invoke(ctx, "/protocol.Master/PushTaskInfo", in, out, opts...) +func (c *masterClient) PushProgress(ctx context.Context, in *PushProgressRequest, opts ...grpc.CallOption) (*PushProgressResponse, error) { + out := new(PushProgressResponse) + err := c.cc.Invoke(ctx, "/protocol.Master/PushProgress", in, out, opts...) if err != nil { return nil, err } @@ -130,8 +129,7 @@ type MasterServer interface { // data distribute GetRankingModel(*VersionInfo, Master_GetRankingModelServer) error GetClickModel(*VersionInfo, Master_GetClickModelServer) error - // task management - PushTaskInfo(context.Context, *PushTaskInfoRequest) (*PushTaskInfoResponse, error) + PushProgress(context.Context, *PushProgressRequest) (*PushProgressResponse, error) mustEmbedUnimplementedMasterServer() } @@ -148,8 +146,8 @@ func (UnimplementedMasterServer) GetRankingModel(*VersionInfo, Master_GetRanking func (UnimplementedMasterServer) GetClickModel(*VersionInfo, Master_GetClickModelServer) error { return status.Errorf(codes.Unimplemented, "method GetClickModel not implemented") } -func (UnimplementedMasterServer) PushTaskInfo(context.Context, *PushTaskInfoRequest) (*PushTaskInfoResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method PushTaskInfo not implemented") +func (UnimplementedMasterServer) PushProgress(context.Context, *PushProgressRequest) (*PushProgressResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method PushProgress not implemented") } func (UnimplementedMasterServer) mustEmbedUnimplementedMasterServer() {} @@ -224,20 +222,20 @@ func (x *masterGetClickModelServer) Send(m *Fragment) error { return x.ServerStream.SendMsg(m) } -func _Master_PushTaskInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(PushTaskInfoRequest) +func _Master_PushProgress_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(PushProgressRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(MasterServer).PushTaskInfo(ctx, in) + return srv.(MasterServer).PushProgress(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/protocol.Master/PushTaskInfo", + FullMethod: "/protocol.Master/PushProgress", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(MasterServer).PushTaskInfo(ctx, req.(*PushTaskInfoRequest)) + return srv.(MasterServer).PushProgress(ctx, req.(*PushProgressRequest)) } return interceptor(ctx, in, info, handler) } @@ -254,8 +252,8 @@ var Master_ServiceDesc = grpc.ServiceDesc{ Handler: _Master_GetMeta_Handler, }, { - MethodName: "PushTaskInfo", - Handler: _Master_PushTaskInfo_Handler, + MethodName: "PushProgress", + Handler: _Master_PushProgress_Handler, }, }, Streams: []grpc.StreamDesc{ diff --git a/protocol/task.go b/protocol/task.go index 0dd7189be..903f09b0a 100644 --- a/protocol/task.go +++ b/protocol/task.go @@ -15,28 +15,31 @@ package protocol import ( - "github.com/zhenghaoz/gorse/base/task" "time" + + "github.com/zhenghaoz/gorse/base/progress" ) //go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative protocol.proto -func DecodeTask(in *PushTaskInfoRequest) *task.Task { - return &task.Task{ +func DecodeProgress(in *PushProgressRequest) *progress.Progress { + return &progress.Progress{ + Tracer: in.GetTracer(), Name: in.GetName(), - Status: task.Status(in.GetStatus()), - Done: int(in.GetDone()), + Status: progress.Status(in.GetStatus()), + Count: int(in.GetCount()), Total: int(in.GetTotal()), StartTime: time.UnixMilli(in.GetStartTime()), FinishTime: time.UnixMilli(in.GetFinishTime()), } } -func EncodeTask(t *task.Task) *PushTaskInfoRequest { - return &PushTaskInfoRequest{ +func EncodeProgress(t *progress.Progress) *PushProgressRequest { + return &PushProgressRequest{ + Tracer: t.Tracer, Name: t.Name, Status: string(t.Status), - Done: int64(t.Done), + Count: int64(t.Count), Total: int64(t.Total), StartTime: t.StartTime.UnixMilli(), FinishTime: t.FinishTime.UnixMilli(), diff --git a/protocol/task_test.go b/protocol/task_test.go index daa328736..9afb80448 100644 --- a/protocol/task_test.go +++ b/protocol/task_test.go @@ -15,21 +15,23 @@ package protocol import ( - "github.com/stretchr/testify/assert" - "github.com/zhenghaoz/gorse/base/task" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/zhenghaoz/gorse/base/progress" ) func TestEncodeDecode(t *testing.T) { - tk := &task.Task{ + tk := &progress.Progress{ + Tracer: "tracer", Name: "a", Total: 100, - Done: 50, - Status: task.StatusRunning, + Count: 50, + Status: progress.StatusRunning, StartTime: time.Date(2018, time.January, 1, 0, 0, 0, 0, time.Local), FinishTime: time.Date(2018, time.January, 2, 0, 0, 0, 0, time.Local), } - pb := EncodeTask(tk) - assert.Equal(t, tk, DecodeTask(pb)) + pb := EncodeProgress(tk) + assert.Equal(t, tk, DecodeProgress(pb)) } From 4023fa739664b5ed84812ce04d6ea381e1a14704 Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Fri, 28 Jul 2023 21:41:56 +0800 Subject: [PATCH 08/13] Update PB --- base/parallel/parallel.go | 2 +- protocol/protocol.pb.go | 864 ++++++++++++++--------------------- protocol/protocol.proto | 6 +- protocol/protocol_grpc.pb.go | 35 +- protocol/task.go | 44 +- protocol/task_test.go | 31 +- worker/worker_test.go | 8 +- 7 files changed, 418 insertions(+), 572 deletions(-) diff --git a/base/parallel/parallel.go b/base/parallel/parallel.go index 600cae641..6f967ce66 100644 --- a/base/parallel/parallel.go +++ b/base/parallel/parallel.go @@ -94,7 +94,7 @@ func DynamicParallel(nJobs int, jobsAlloc *task.JobsAllocator, worker func(worke // consumer for { exit := atomic.NewBool(true) - numJobs := jobsAlloc.AvailableJobs(nil) + numJobs := jobsAlloc.AvailableJobs() var wg sync.WaitGroup wg.Add(numJobs) errs := make([]error, nJobs) diff --git a/protocol/protocol.pb.go b/protocol/protocol.pb.go index 2c228779c..f14fd456e 100644 --- a/protocol/protocol.pb.go +++ b/protocol/protocol.pb.go @@ -1,38 +1,24 @@ -// Copyright 2020 gorse Project Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - // Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.28.1 -// protoc v3.12.4 // source: protocol.proto package protocol import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" + fmt "fmt" + proto "github.com/golang/protobuf/proto" + math "math" ) -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package type NodeType int32 @@ -42,649 +28,463 @@ const ( NodeType_ClientNode NodeType = 2 ) -// Enum value maps for NodeType. -var ( - NodeType_name = map[int32]string{ - 0: "ServerNode", - 1: "WorkerNode", - 2: "ClientNode", - } - NodeType_value = map[string]int32{ - "ServerNode": 0, - "WorkerNode": 1, - "ClientNode": 2, - } -) +var NodeType_name = map[int32]string{ + 0: "ServerNode", + 1: "WorkerNode", + 2: "ClientNode", +} -func (x NodeType) Enum() *NodeType { - p := new(NodeType) - *p = x - return p +var NodeType_value = map[string]int32{ + "ServerNode": 0, + "WorkerNode": 1, + "ClientNode": 2, } func (x NodeType) String() string { - return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) + return proto.EnumName(NodeType_name, int32(x)) } -func (NodeType) Descriptor() protoreflect.EnumDescriptor { - return file_protocol_proto_enumTypes[0].Descriptor() +func (NodeType) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_2bc2336598a3f7e0, []int{0} } -func (NodeType) Type() protoreflect.EnumType { - return &file_protocol_proto_enumTypes[0] +type Meta struct { + Config string `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` + RankingModelVersion int64 `protobuf:"varint,3,opt,name=ranking_model_version,json=rankingModelVersion,proto3" json:"ranking_model_version,omitempty"` + ClickModelVersion int64 `protobuf:"varint,4,opt,name=click_model_version,json=clickModelVersion,proto3" json:"click_model_version,omitempty"` + Me string `protobuf:"bytes,5,opt,name=me,proto3" json:"me,omitempty"` + Servers []string `protobuf:"bytes,6,rep,name=servers,proto3" json:"servers,omitempty"` + Workers []string `protobuf:"bytes,7,rep,name=workers,proto3" json:"workers,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Meta) Reset() { *m = Meta{} } +func (m *Meta) String() string { return proto.CompactTextString(m) } +func (*Meta) ProtoMessage() {} +func (*Meta) Descriptor() ([]byte, []int) { + return fileDescriptor_2bc2336598a3f7e0, []int{0} } -func (x NodeType) Number() protoreflect.EnumNumber { - return protoreflect.EnumNumber(x) +func (m *Meta) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Meta.Unmarshal(m, b) } - -// Deprecated: Use NodeType.Descriptor instead. -func (NodeType) EnumDescriptor() ([]byte, []int) { - return file_protocol_proto_rawDescGZIP(), []int{0} +func (m *Meta) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Meta.Marshal(b, m, deterministic) } - -type Meta struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Config string `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` - RankingModelVersion int64 `protobuf:"varint,3,opt,name=ranking_model_version,json=rankingModelVersion,proto3" json:"ranking_model_version,omitempty"` - ClickModelVersion int64 `protobuf:"varint,4,opt,name=click_model_version,json=clickModelVersion,proto3" json:"click_model_version,omitempty"` - Me string `protobuf:"bytes,5,opt,name=me,proto3" json:"me,omitempty"` - Servers []string `protobuf:"bytes,6,rep,name=servers,proto3" json:"servers,omitempty"` - Workers []string `protobuf:"bytes,7,rep,name=workers,proto3" json:"workers,omitempty"` -} - -func (x *Meta) Reset() { - *x = Meta{} - if protoimpl.UnsafeEnabled { - mi := &file_protocol_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } +func (m *Meta) XXX_Merge(src proto.Message) { + xxx_messageInfo_Meta.Merge(m, src) } - -func (x *Meta) String() string { - return protoimpl.X.MessageStringOf(x) +func (m *Meta) XXX_Size() int { + return xxx_messageInfo_Meta.Size(m) } - -func (*Meta) ProtoMessage() {} - -func (x *Meta) ProtoReflect() protoreflect.Message { - mi := &file_protocol_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) +func (m *Meta) XXX_DiscardUnknown() { + xxx_messageInfo_Meta.DiscardUnknown(m) } -// Deprecated: Use Meta.ProtoReflect.Descriptor instead. -func (*Meta) Descriptor() ([]byte, []int) { - return file_protocol_proto_rawDescGZIP(), []int{0} -} +var xxx_messageInfo_Meta proto.InternalMessageInfo -func (x *Meta) GetConfig() string { - if x != nil { - return x.Config +func (m *Meta) GetConfig() string { + if m != nil { + return m.Config } return "" } -func (x *Meta) GetRankingModelVersion() int64 { - if x != nil { - return x.RankingModelVersion +func (m *Meta) GetRankingModelVersion() int64 { + if m != nil { + return m.RankingModelVersion } return 0 } -func (x *Meta) GetClickModelVersion() int64 { - if x != nil { - return x.ClickModelVersion +func (m *Meta) GetClickModelVersion() int64 { + if m != nil { + return m.ClickModelVersion } return 0 } -func (x *Meta) GetMe() string { - if x != nil { - return x.Me +func (m *Meta) GetMe() string { + if m != nil { + return m.Me } return "" } -func (x *Meta) GetServers() []string { - if x != nil { - return x.Servers +func (m *Meta) GetServers() []string { + if m != nil { + return m.Servers } return nil } -func (x *Meta) GetWorkers() []string { - if x != nil { - return x.Workers +func (m *Meta) GetWorkers() []string { + if m != nil { + return m.Workers } return nil } type Fragment struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` } -func (x *Fragment) Reset() { - *x = Fragment{} - if protoimpl.UnsafeEnabled { - mi := &file_protocol_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } +func (m *Fragment) Reset() { *m = Fragment{} } +func (m *Fragment) String() string { return proto.CompactTextString(m) } +func (*Fragment) ProtoMessage() {} +func (*Fragment) Descriptor() ([]byte, []int) { + return fileDescriptor_2bc2336598a3f7e0, []int{1} } -func (x *Fragment) String() string { - return protoimpl.X.MessageStringOf(x) +func (m *Fragment) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Fragment.Unmarshal(m, b) } - -func (*Fragment) ProtoMessage() {} - -func (x *Fragment) ProtoReflect() protoreflect.Message { - mi := &file_protocol_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) +func (m *Fragment) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Fragment.Marshal(b, m, deterministic) } - -// Deprecated: Use Fragment.ProtoReflect.Descriptor instead. -func (*Fragment) Descriptor() ([]byte, []int) { - return file_protocol_proto_rawDescGZIP(), []int{1} +func (m *Fragment) XXX_Merge(src proto.Message) { + xxx_messageInfo_Fragment.Merge(m, src) } +func (m *Fragment) XXX_Size() int { + return xxx_messageInfo_Fragment.Size(m) +} +func (m *Fragment) XXX_DiscardUnknown() { + xxx_messageInfo_Fragment.DiscardUnknown(m) +} + +var xxx_messageInfo_Fragment proto.InternalMessageInfo -func (x *Fragment) GetData() []byte { - if x != nil { - return x.Data +func (m *Fragment) GetData() []byte { + if m != nil { + return m.Data } return nil } type VersionInfo struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Version int64 `protobuf:"varint,1,opt,name=version,proto3" json:"version,omitempty"` + Version int64 `protobuf:"varint,1,opt,name=version,proto3" json:"version,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` } -func (x *VersionInfo) Reset() { - *x = VersionInfo{} - if protoimpl.UnsafeEnabled { - mi := &file_protocol_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } +func (m *VersionInfo) Reset() { *m = VersionInfo{} } +func (m *VersionInfo) String() string { return proto.CompactTextString(m) } +func (*VersionInfo) ProtoMessage() {} +func (*VersionInfo) Descriptor() ([]byte, []int) { + return fileDescriptor_2bc2336598a3f7e0, []int{2} } -func (x *VersionInfo) String() string { - return protoimpl.X.MessageStringOf(x) +func (m *VersionInfo) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_VersionInfo.Unmarshal(m, b) } - -func (*VersionInfo) ProtoMessage() {} - -func (x *VersionInfo) ProtoReflect() protoreflect.Message { - mi := &file_protocol_proto_msgTypes[2] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) +func (m *VersionInfo) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_VersionInfo.Marshal(b, m, deterministic) } - -// Deprecated: Use VersionInfo.ProtoReflect.Descriptor instead. -func (*VersionInfo) Descriptor() ([]byte, []int) { - return file_protocol_proto_rawDescGZIP(), []int{2} +func (m *VersionInfo) XXX_Merge(src proto.Message) { + xxx_messageInfo_VersionInfo.Merge(m, src) +} +func (m *VersionInfo) XXX_Size() int { + return xxx_messageInfo_VersionInfo.Size(m) +} +func (m *VersionInfo) XXX_DiscardUnknown() { + xxx_messageInfo_VersionInfo.DiscardUnknown(m) } -func (x *VersionInfo) GetVersion() int64 { - if x != nil { - return x.Version +var xxx_messageInfo_VersionInfo proto.InternalMessageInfo + +func (m *VersionInfo) GetVersion() int64 { + if m != nil { + return m.Version } return 0 } type NodeInfo struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - NodeType NodeType `protobuf:"varint,1,opt,name=node_type,json=nodeType,proto3,enum=protocol.NodeType" json:"node_type,omitempty"` - NodeName string `protobuf:"bytes,2,opt,name=node_name,json=nodeName,proto3" json:"node_name,omitempty"` - HttpPort int64 `protobuf:"varint,3,opt,name=http_port,json=httpPort,proto3" json:"http_port,omitempty"` - BinaryVersion string `protobuf:"bytes,4,opt,name=binary_version,json=binaryVersion,proto3" json:"binary_version,omitempty"` -} - -func (x *NodeInfo) Reset() { - *x = NodeInfo{} - if protoimpl.UnsafeEnabled { - mi := &file_protocol_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + NodeType NodeType `protobuf:"varint,1,opt,name=node_type,json=nodeType,proto3,enum=protocol.NodeType" json:"node_type,omitempty"` + NodeName string `protobuf:"bytes,2,opt,name=node_name,json=nodeName,proto3" json:"node_name,omitempty"` + HttpPort int64 `protobuf:"varint,3,opt,name=http_port,json=httpPort,proto3" json:"http_port,omitempty"` + BinaryVersion string `protobuf:"bytes,4,opt,name=binary_version,json=binaryVersion,proto3" json:"binary_version,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *NodeInfo) Reset() { *m = NodeInfo{} } +func (m *NodeInfo) String() string { return proto.CompactTextString(m) } +func (*NodeInfo) ProtoMessage() {} +func (*NodeInfo) Descriptor() ([]byte, []int) { + return fileDescriptor_2bc2336598a3f7e0, []int{3} } -func (x *NodeInfo) String() string { - return protoimpl.X.MessageStringOf(x) +func (m *NodeInfo) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_NodeInfo.Unmarshal(m, b) } - -func (*NodeInfo) ProtoMessage() {} - -func (x *NodeInfo) ProtoReflect() protoreflect.Message { - mi := &file_protocol_proto_msgTypes[3] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) +func (m *NodeInfo) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_NodeInfo.Marshal(b, m, deterministic) } - -// Deprecated: Use NodeInfo.ProtoReflect.Descriptor instead. -func (*NodeInfo) Descriptor() ([]byte, []int) { - return file_protocol_proto_rawDescGZIP(), []int{3} +func (m *NodeInfo) XXX_Merge(src proto.Message) { + xxx_messageInfo_NodeInfo.Merge(m, src) +} +func (m *NodeInfo) XXX_Size() int { + return xxx_messageInfo_NodeInfo.Size(m) +} +func (m *NodeInfo) XXX_DiscardUnknown() { + xxx_messageInfo_NodeInfo.DiscardUnknown(m) } -func (x *NodeInfo) GetNodeType() NodeType { - if x != nil { - return x.NodeType +var xxx_messageInfo_NodeInfo proto.InternalMessageInfo + +func (m *NodeInfo) GetNodeType() NodeType { + if m != nil { + return m.NodeType } return NodeType_ServerNode } -func (x *NodeInfo) GetNodeName() string { - if x != nil { - return x.NodeName +func (m *NodeInfo) GetNodeName() string { + if m != nil { + return m.NodeName } return "" } -func (x *NodeInfo) GetHttpPort() int64 { - if x != nil { - return x.HttpPort +func (m *NodeInfo) GetHttpPort() int64 { + if m != nil { + return m.HttpPort } return 0 } -func (x *NodeInfo) GetBinaryVersion() string { - if x != nil { - return x.BinaryVersion +func (m *NodeInfo) GetBinaryVersion() string { + if m != nil { + return m.BinaryVersion } return "" } -type PushProgressRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Tracer string `protobuf:"bytes,1,opt,name=tracer,proto3" json:"tracer,omitempty"` - Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` - Status string `protobuf:"bytes,3,opt,name=status,proto3" json:"status,omitempty"` - Error string `protobuf:"bytes,4,opt,name=error,proto3" json:"error,omitempty"` - Count int64 `protobuf:"varint,5,opt,name=count,proto3" json:"count,omitempty"` - Total int64 `protobuf:"varint,6,opt,name=total,proto3" json:"total,omitempty"` - StartTime int64 `protobuf:"varint,7,opt,name=start_time,json=startTime,proto3" json:"start_time,omitempty"` - FinishTime int64 `protobuf:"varint,8,opt,name=finish_time,json=finishTime,proto3" json:"finish_time,omitempty"` -} - -func (x *PushProgressRequest) Reset() { - *x = PushProgressRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_protocol_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } +type Progress struct { + Tracer string `protobuf:"bytes,1,opt,name=tracer,proto3" json:"tracer,omitempty"` + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + Status string `protobuf:"bytes,3,opt,name=status,proto3" json:"status,omitempty"` + Error string `protobuf:"bytes,4,opt,name=error,proto3" json:"error,omitempty"` + Count int64 `protobuf:"varint,5,opt,name=count,proto3" json:"count,omitempty"` + Total int64 `protobuf:"varint,6,opt,name=total,proto3" json:"total,omitempty"` + StartTime int64 `protobuf:"varint,7,opt,name=start_time,json=startTime,proto3" json:"start_time,omitempty"` + FinishTime int64 `protobuf:"varint,8,opt,name=finish_time,json=finishTime,proto3" json:"finish_time,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` } -func (x *PushProgressRequest) String() string { - return protoimpl.X.MessageStringOf(x) +func (m *Progress) Reset() { *m = Progress{} } +func (m *Progress) String() string { return proto.CompactTextString(m) } +func (*Progress) ProtoMessage() {} +func (*Progress) Descriptor() ([]byte, []int) { + return fileDescriptor_2bc2336598a3f7e0, []int{4} } -func (*PushProgressRequest) ProtoMessage() {} - -func (x *PushProgressRequest) ProtoReflect() protoreflect.Message { - mi := &file_protocol_proto_msgTypes[4] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) +func (m *Progress) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Progress.Unmarshal(m, b) } - -// Deprecated: Use PushProgressRequest.ProtoReflect.Descriptor instead. -func (*PushProgressRequest) Descriptor() ([]byte, []int) { - return file_protocol_proto_rawDescGZIP(), []int{4} +func (m *Progress) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Progress.Marshal(b, m, deterministic) +} +func (m *Progress) XXX_Merge(src proto.Message) { + xxx_messageInfo_Progress.Merge(m, src) +} +func (m *Progress) XXX_Size() int { + return xxx_messageInfo_Progress.Size(m) +} +func (m *Progress) XXX_DiscardUnknown() { + xxx_messageInfo_Progress.DiscardUnknown(m) } -func (x *PushProgressRequest) GetTracer() string { - if x != nil { - return x.Tracer +var xxx_messageInfo_Progress proto.InternalMessageInfo + +func (m *Progress) GetTracer() string { + if m != nil { + return m.Tracer } return "" } -func (x *PushProgressRequest) GetName() string { - if x != nil { - return x.Name +func (m *Progress) GetName() string { + if m != nil { + return m.Name } return "" } -func (x *PushProgressRequest) GetStatus() string { - if x != nil { - return x.Status +func (m *Progress) GetStatus() string { + if m != nil { + return m.Status } return "" } -func (x *PushProgressRequest) GetError() string { - if x != nil { - return x.Error +func (m *Progress) GetError() string { + if m != nil { + return m.Error } return "" } -func (x *PushProgressRequest) GetCount() int64 { - if x != nil { - return x.Count +func (m *Progress) GetCount() int64 { + if m != nil { + return m.Count } return 0 } -func (x *PushProgressRequest) GetTotal() int64 { - if x != nil { - return x.Total +func (m *Progress) GetTotal() int64 { + if m != nil { + return m.Total } return 0 } -func (x *PushProgressRequest) GetStartTime() int64 { - if x != nil { - return x.StartTime +func (m *Progress) GetStartTime() int64 { + if m != nil { + return m.StartTime } return 0 } -func (x *PushProgressRequest) GetFinishTime() int64 { - if x != nil { - return x.FinishTime +func (m *Progress) GetFinishTime() int64 { + if m != nil { + return m.FinishTime } return 0 } -type PushProgressResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields +type PushProgressRequest struct { + Progress []*Progress `protobuf:"bytes,1,rep,name=progress,proto3" json:"progress,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` } -func (x *PushProgressResponse) Reset() { - *x = PushProgressResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_protocol_proto_msgTypes[5] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } +func (m *PushProgressRequest) Reset() { *m = PushProgressRequest{} } +func (m *PushProgressRequest) String() string { return proto.CompactTextString(m) } +func (*PushProgressRequest) ProtoMessage() {} +func (*PushProgressRequest) Descriptor() ([]byte, []int) { + return fileDescriptor_2bc2336598a3f7e0, []int{5} } -func (x *PushProgressResponse) String() string { - return protoimpl.X.MessageStringOf(x) +func (m *PushProgressRequest) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_PushProgressRequest.Unmarshal(m, b) +} +func (m *PushProgressRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_PushProgressRequest.Marshal(b, m, deterministic) +} +func (m *PushProgressRequest) XXX_Merge(src proto.Message) { + xxx_messageInfo_PushProgressRequest.Merge(m, src) +} +func (m *PushProgressRequest) XXX_Size() int { + return xxx_messageInfo_PushProgressRequest.Size(m) +} +func (m *PushProgressRequest) XXX_DiscardUnknown() { + xxx_messageInfo_PushProgressRequest.DiscardUnknown(m) } -func (*PushProgressResponse) ProtoMessage() {} +var xxx_messageInfo_PushProgressRequest proto.InternalMessageInfo -func (x *PushProgressResponse) ProtoReflect() protoreflect.Message { - mi := &file_protocol_proto_msgTypes[5] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms +func (m *PushProgressRequest) GetProgress() []*Progress { + if m != nil { + return m.Progress } - return mi.MessageOf(x) + return nil } -// Deprecated: Use PushProgressResponse.ProtoReflect.Descriptor instead. -func (*PushProgressResponse) Descriptor() ([]byte, []int) { - return file_protocol_proto_rawDescGZIP(), []int{5} -} - -var File_protocol_proto protoreflect.FileDescriptor - -var file_protocol_proto_rawDesc = []byte{ - 0x0a, 0x0e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x12, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0xc6, 0x01, 0x0a, 0x04, 0x4d, - 0x65, 0x74, 0x61, 0x12, 0x16, 0x0a, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x32, 0x0a, 0x15, 0x72, - 0x61, 0x6e, 0x6b, 0x69, 0x6e, 0x67, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x5f, 0x76, 0x65, 0x72, - 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x13, 0x72, 0x61, 0x6e, 0x6b, - 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, - 0x2e, 0x0a, 0x13, 0x63, 0x6c, 0x69, 0x63, 0x6b, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x5f, 0x76, - 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x11, 0x63, 0x6c, - 0x69, 0x63, 0x6b, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, - 0x0e, 0x0a, 0x02, 0x6d, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x6d, 0x65, 0x12, - 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x77, 0x6f, 0x72, - 0x6b, 0x65, 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x77, 0x6f, 0x72, 0x6b, - 0x65, 0x72, 0x73, 0x22, 0x1e, 0x0a, 0x08, 0x46, 0x72, 0x61, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x12, - 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, - 0x61, 0x74, 0x61, 0x22, 0x27, 0x0a, 0x0b, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x49, 0x6e, - 0x66, 0x6f, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x03, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0x9c, 0x01, 0x0a, - 0x08, 0x4e, 0x6f, 0x64, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x2f, 0x0a, 0x09, 0x6e, 0x6f, 0x64, - 0x65, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x12, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x54, 0x79, 0x70, 0x65, - 0x52, 0x08, 0x6e, 0x6f, 0x64, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x6e, 0x6f, - 0x64, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x6e, - 0x6f, 0x64, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x68, 0x74, 0x74, 0x70, 0x5f, - 0x70, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x68, 0x74, 0x74, 0x70, - 0x50, 0x6f, 0x72, 0x74, 0x12, 0x25, 0x0a, 0x0e, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, 0x5f, 0x76, - 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x62, 0x69, - 0x6e, 0x61, 0x72, 0x79, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x22, 0xdb, 0x01, 0x0a, 0x13, - 0x50, 0x75, 0x73, 0x68, 0x50, 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x74, 0x72, 0x61, 0x63, 0x65, 0x72, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x06, 0x74, 0x72, 0x61, 0x63, 0x65, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x6e, - 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, - 0x16, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x14, 0x0a, - 0x05, 0x63, 0x6f, 0x75, 0x6e, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x03, 0x52, 0x05, 0x63, 0x6f, - 0x75, 0x6e, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x18, 0x06, 0x20, 0x01, - 0x28, 0x03, 0x52, 0x05, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x61, - 0x72, 0x74, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x73, - 0x74, 0x61, 0x72, 0x74, 0x54, 0x69, 0x6d, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x66, 0x69, 0x6e, 0x69, - 0x73, 0x68, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x18, 0x08, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0a, 0x66, - 0x69, 0x6e, 0x69, 0x73, 0x68, 0x54, 0x69, 0x6d, 0x65, 0x22, 0x16, 0x0a, 0x14, 0x50, 0x75, 0x73, - 0x68, 0x50, 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x2a, 0x3a, 0x0a, 0x08, 0x4e, 0x6f, 0x64, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, - 0x0a, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4e, 0x6f, 0x64, 0x65, 0x10, 0x00, 0x12, 0x0e, 0x0a, - 0x0a, 0x57, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x4e, 0x6f, 0x64, 0x65, 0x10, 0x01, 0x12, 0x0e, 0x0a, - 0x0a, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x10, 0x02, 0x32, 0x8c, 0x02, - 0x0a, 0x06, 0x4d, 0x61, 0x73, 0x74, 0x65, 0x72, 0x12, 0x2f, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, - 0x65, 0x74, 0x61, 0x12, 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4e, - 0x6f, 0x64, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, - 0x6f, 0x6c, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x22, 0x00, 0x12, 0x40, 0x0a, 0x0f, 0x47, 0x65, 0x74, - 0x52, 0x61, 0x6e, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x15, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x49, - 0x6e, 0x66, 0x6f, 0x1a, 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x46, - 0x72, 0x61, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x22, 0x00, 0x30, 0x01, 0x12, 0x3e, 0x0a, 0x0d, 0x47, - 0x65, 0x74, 0x43, 0x6c, 0x69, 0x63, 0x6b, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x15, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x49, - 0x6e, 0x66, 0x6f, 0x1a, 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x46, - 0x72, 0x61, 0x67, 0x6d, 0x65, 0x6e, 0x74, 0x22, 0x00, 0x30, 0x01, 0x12, 0x4f, 0x0a, 0x0c, 0x50, - 0x75, 0x73, 0x68, 0x50, 0x72, 0x6f, 0x67, 0x72, 0x65, 0x73, 0x73, 0x12, 0x1d, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x50, 0x72, 0x6f, 0x67, 0x72, - 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x50, 0x75, 0x73, 0x68, 0x50, 0x72, 0x6f, 0x67, 0x72, 0x65, - 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x25, 0x5a, 0x23, - 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x7a, 0x68, 0x65, 0x6e, 0x67, - 0x68, 0x61, 0x6f, 0x7a, 0x2f, 0x67, 0x6f, 0x72, 0x73, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x63, 0x6f, 0x6c, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} - -var ( - file_protocol_proto_rawDescOnce sync.Once - file_protocol_proto_rawDescData = file_protocol_proto_rawDesc -) +type PushProgressResponse struct { + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} -func file_protocol_proto_rawDescGZIP() []byte { - file_protocol_proto_rawDescOnce.Do(func() { - file_protocol_proto_rawDescData = protoimpl.X.CompressGZIP(file_protocol_proto_rawDescData) - }) - return file_protocol_proto_rawDescData -} - -var file_protocol_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_protocol_proto_msgTypes = make([]protoimpl.MessageInfo, 6) -var file_protocol_proto_goTypes = []interface{}{ - (NodeType)(0), // 0: protocol.NodeType - (*Meta)(nil), // 1: protocol.Meta - (*Fragment)(nil), // 2: protocol.Fragment - (*VersionInfo)(nil), // 3: protocol.VersionInfo - (*NodeInfo)(nil), // 4: protocol.NodeInfo - (*PushProgressRequest)(nil), // 5: protocol.PushProgressRequest - (*PushProgressResponse)(nil), // 6: protocol.PushProgressResponse -} -var file_protocol_proto_depIdxs = []int32{ - 0, // 0: protocol.NodeInfo.node_type:type_name -> protocol.NodeType - 4, // 1: protocol.Master.GetMeta:input_type -> protocol.NodeInfo - 3, // 2: protocol.Master.GetRankingModel:input_type -> protocol.VersionInfo - 3, // 3: protocol.Master.GetClickModel:input_type -> protocol.VersionInfo - 5, // 4: protocol.Master.PushProgress:input_type -> protocol.PushProgressRequest - 1, // 5: protocol.Master.GetMeta:output_type -> protocol.Meta - 2, // 6: protocol.Master.GetRankingModel:output_type -> protocol.Fragment - 2, // 7: protocol.Master.GetClickModel:output_type -> protocol.Fragment - 6, // 8: protocol.Master.PushProgress:output_type -> protocol.PushProgressResponse - 5, // [5:9] is the sub-list for method output_type - 1, // [1:5] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name -} - -func init() { file_protocol_proto_init() } -func file_protocol_proto_init() { - if File_protocol_proto != nil { - return - } - if !protoimpl.UnsafeEnabled { - file_protocol_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Meta); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_protocol_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Fragment); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_protocol_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*VersionInfo); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_protocol_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*NodeInfo); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_protocol_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PushProgressRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_protocol_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PushProgressResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_protocol_proto_rawDesc, - NumEnums: 1, - NumMessages: 6, - NumExtensions: 0, - NumServices: 1, - }, - GoTypes: file_protocol_proto_goTypes, - DependencyIndexes: file_protocol_proto_depIdxs, - EnumInfos: file_protocol_proto_enumTypes, - MessageInfos: file_protocol_proto_msgTypes, - }.Build() - File_protocol_proto = out.File - file_protocol_proto_rawDesc = nil - file_protocol_proto_goTypes = nil - file_protocol_proto_depIdxs = nil +func (m *PushProgressResponse) Reset() { *m = PushProgressResponse{} } +func (m *PushProgressResponse) String() string { return proto.CompactTextString(m) } +func (*PushProgressResponse) ProtoMessage() {} +func (*PushProgressResponse) Descriptor() ([]byte, []int) { + return fileDescriptor_2bc2336598a3f7e0, []int{6} +} + +func (m *PushProgressResponse) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_PushProgressResponse.Unmarshal(m, b) +} +func (m *PushProgressResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_PushProgressResponse.Marshal(b, m, deterministic) +} +func (m *PushProgressResponse) XXX_Merge(src proto.Message) { + xxx_messageInfo_PushProgressResponse.Merge(m, src) +} +func (m *PushProgressResponse) XXX_Size() int { + return xxx_messageInfo_PushProgressResponse.Size(m) +} +func (m *PushProgressResponse) XXX_DiscardUnknown() { + xxx_messageInfo_PushProgressResponse.DiscardUnknown(m) +} + +var xxx_messageInfo_PushProgressResponse proto.InternalMessageInfo + +func init() { + proto.RegisterEnum("protocol.NodeType", NodeType_name, NodeType_value) + proto.RegisterType((*Meta)(nil), "protocol.Meta") + proto.RegisterType((*Fragment)(nil), "protocol.Fragment") + proto.RegisterType((*VersionInfo)(nil), "protocol.VersionInfo") + proto.RegisterType((*NodeInfo)(nil), "protocol.NodeInfo") + proto.RegisterType((*Progress)(nil), "protocol.Progress") + proto.RegisterType((*PushProgressRequest)(nil), "protocol.PushProgressRequest") + proto.RegisterType((*PushProgressResponse)(nil), "protocol.PushProgressResponse") +} + +func init() { + proto.RegisterFile("protocol.proto", fileDescriptor_2bc2336598a3f7e0) +} + +var fileDescriptor_2bc2336598a3f7e0 = []byte{ + // 592 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x53, 0xcd, 0x6e, 0xd3, 0x4e, + 0x10, 0x8f, 0x93, 0x36, 0x71, 0xa6, 0x6d, 0xfe, 0x7f, 0xb6, 0x1f, 0xb2, 0x8a, 0x5a, 0x22, 0xa3, + 0x8a, 0x88, 0x43, 0x82, 0xca, 0x8d, 0x03, 0x42, 0x54, 0x50, 0x71, 0x68, 0xa9, 0x4c, 0x05, 0x12, + 0x97, 0x68, 0xeb, 0x4c, 0x6d, 0xab, 0xf1, 0xae, 0xd9, 0x9d, 0x80, 0xda, 0x67, 0xe0, 0x11, 0x78, + 0x1e, 0xce, 0x3c, 0x12, 0xda, 0xb1, 0x9d, 0xba, 0x45, 0x1c, 0xb8, 0xed, 0xef, 0x63, 0xf6, 0xe3, + 0x37, 0xb3, 0x30, 0x28, 0x8c, 0x26, 0x1d, 0xeb, 0xf9, 0x98, 0x17, 0xc2, 0xaf, 0x71, 0xf8, 0xd3, + 0x83, 0x95, 0x13, 0x24, 0x29, 0x76, 0xa0, 0x1b, 0x6b, 0x75, 0x99, 0x25, 0x81, 0x37, 0xf4, 0x46, + 0xfd, 0xa8, 0x42, 0xe2, 0x10, 0xb6, 0x8d, 0x54, 0x57, 0x99, 0x4a, 0xa6, 0xb9, 0x9e, 0xe1, 0x7c, + 0xfa, 0x15, 0x8d, 0xcd, 0xb4, 0x0a, 0x3a, 0x43, 0x6f, 0xd4, 0x89, 0x36, 0x2b, 0xf1, 0xc4, 0x69, + 0x1f, 0x4b, 0x49, 0x8c, 0x61, 0x33, 0x9e, 0x67, 0xf1, 0xd5, 0xbd, 0x8a, 0x15, 0xae, 0x78, 0xc0, + 0xd2, 0x1d, 0xff, 0x00, 0xda, 0x39, 0x06, 0xab, 0x7c, 0x6e, 0x3b, 0x47, 0x11, 0x40, 0xcf, 0xa2, + 0x71, 0x65, 0x41, 0x77, 0xd8, 0x19, 0xf5, 0xa3, 0x1a, 0x3a, 0xe5, 0x9b, 0x36, 0x57, 0x4e, 0xe9, + 0x95, 0x4a, 0x05, 0xc3, 0x7d, 0xf0, 0xdf, 0x1a, 0x99, 0xe4, 0xa8, 0x48, 0x08, 0x58, 0x99, 0x49, + 0x92, 0xfc, 0x92, 0xf5, 0x88, 0xd7, 0xe1, 0x13, 0x58, 0xab, 0x8e, 0x7b, 0xa7, 0x2e, 0xb5, 0xdb, + 0xa8, 0xbe, 0x96, 0xc7, 0xd7, 0xaa, 0x61, 0xf8, 0xc3, 0x03, 0xff, 0x54, 0xcf, 0x90, 0x6d, 0x13, + 0xe8, 0x2b, 0x3d, 0xc3, 0x29, 0x5d, 0x17, 0xc8, 0xc6, 0xc1, 0xa1, 0x18, 0x2f, 0xc3, 0x74, 0xb6, + 0xf3, 0xeb, 0x02, 0x23, 0x5f, 0x55, 0x2b, 0xf1, 0xb0, 0x2a, 0x50, 0x32, 0xc7, 0xa0, 0xcd, 0x2f, + 0x62, 0xf1, 0x54, 0xe6, 0x2c, 0xa6, 0x44, 0xc5, 0xb4, 0xd0, 0x86, 0xaa, 0xfc, 0x7c, 0x47, 0x9c, + 0x69, 0x43, 0xe2, 0x00, 0x06, 0x17, 0x99, 0x92, 0xe6, 0xfa, 0x4e, 0x5e, 0xfd, 0x68, 0xa3, 0x64, + 0xab, 0xcb, 0x87, 0xbf, 0x3c, 0xf0, 0xcf, 0x8c, 0x4e, 0x0c, 0x5a, 0xeb, 0x9a, 0x46, 0x46, 0xc6, + 0x68, 0xea, 0xa6, 0x95, 0xc8, 0x05, 0xd0, 0xb8, 0x00, 0xaf, 0x9d, 0xd7, 0x92, 0xa4, 0x85, 0xe5, + 0x93, 0xfb, 0x51, 0x85, 0xc4, 0x16, 0xac, 0xa2, 0x31, 0xda, 0x54, 0xc7, 0x95, 0xc0, 0xb1, 0xb1, + 0x5e, 0x28, 0xe2, 0xae, 0x74, 0xa2, 0x12, 0x38, 0x96, 0x34, 0xc9, 0x79, 0xd0, 0x2d, 0x59, 0x06, + 0x62, 0x0f, 0xc0, 0x92, 0x34, 0x34, 0xa5, 0x2c, 0xc7, 0xa0, 0xc7, 0x52, 0x9f, 0x99, 0xf3, 0x2c, + 0x47, 0xf1, 0x08, 0xd6, 0x2e, 0x33, 0x95, 0xd9, 0xb4, 0xd4, 0x7d, 0xd6, 0xa1, 0xa4, 0x9c, 0x21, + 0x7c, 0x03, 0x9b, 0x67, 0x0b, 0x9b, 0xd6, 0xaf, 0x8a, 0xf0, 0xcb, 0x02, 0x2d, 0x89, 0x31, 0xb8, + 0x31, 0x65, 0x2a, 0xf0, 0x86, 0x9d, 0xd1, 0x5a, 0x33, 0xfa, 0xa5, 0x79, 0xe9, 0x09, 0x77, 0x60, + 0xeb, 0xee, 0x36, 0xb6, 0xd0, 0xca, 0xe2, 0xd3, 0x17, 0x65, 0x3f, 0xb9, 0x3d, 0x03, 0x80, 0x0f, + 0x3c, 0x4a, 0x8e, 0xf9, 0xbf, 0xe5, 0xf0, 0x27, 0x1e, 0x20, 0xc6, 0x9e, 0xc3, 0x47, 0xf3, 0x0c, + 0x15, 0x31, 0x6e, 0x1f, 0x7e, 0x6f, 0x43, 0xf7, 0x44, 0x5a, 0x42, 0x23, 0x26, 0xd0, 0x3b, 0x46, + 0xe2, 0xbf, 0x72, 0x6f, 0x04, 0xdc, 0xa4, 0xec, 0x0e, 0x6e, 0x39, 0xe7, 0x09, 0x5b, 0xe2, 0x15, + 0xfc, 0x77, 0x8c, 0x14, 0x35, 0xfe, 0x87, 0xd8, 0xbe, 0x35, 0x35, 0x86, 0x71, 0xb7, 0xb1, 0x5f, + 0x3d, 0xc3, 0x61, 0xeb, 0x99, 0x27, 0x5e, 0xc2, 0xc6, 0x31, 0xd2, 0xd1, 0xf2, 0xbf, 0xfc, 0x6b, + 0xfd, 0x7b, 0x58, 0x6f, 0x26, 0x22, 0xf6, 0x1a, 0xf9, 0xfd, 0x19, 0xf8, 0xee, 0xfe, 0xdf, 0xe4, + 0x32, 0xc8, 0xb0, 0xf5, 0xfa, 0xe0, 0xf3, 0xe3, 0x24, 0xa3, 0x74, 0x71, 0x31, 0x8e, 0x75, 0x3e, + 0xb9, 0x49, 0x51, 0x25, 0xa9, 0xd4, 0x37, 0x93, 0x44, 0x1b, 0x8b, 0x93, 0xba, 0xfa, 0xa2, 0xcb, + 0xab, 0xe7, 0xbf, 0x03, 0x00, 0x00, 0xff, 0xff, 0x0e, 0x9c, 0xdb, 0xcd, 0x77, 0x04, 0x00, 0x00, } diff --git a/protocol/protocol.proto b/protocol/protocol.proto index 1f1a03fb0..54c1c43ec 100644 --- a/protocol/protocol.proto +++ b/protocol/protocol.proto @@ -60,7 +60,7 @@ message NodeInfo { string binary_version = 4; } -message PushProgressRequest { +message Progress { string tracer = 1; string name = 2; string status = 3; @@ -71,4 +71,8 @@ message PushProgressRequest { int64 finish_time = 8; } +message PushProgressRequest { + repeated Progress progress = 1; +} + message PushProgressResponse {} diff --git a/protocol/protocol_grpc.pb.go b/protocol/protocol_grpc.pb.go index dbc48c9c1..c76123352 100644 --- a/protocol/protocol_grpc.pb.go +++ b/protocol/protocol_grpc.pb.go @@ -1,6 +1,20 @@ +// Copyright 2020 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.2.0 +// - protoc-gen-go-grpc v1.3.0 // - protoc v3.12.4 // source: protocol.proto @@ -18,6 +32,13 @@ import ( // Requires gRPC-Go v1.32.0 or later. const _ = grpc.SupportPackageIsVersion7 +const ( + Master_GetMeta_FullMethodName = "/protocol.Master/GetMeta" + Master_GetRankingModel_FullMethodName = "/protocol.Master/GetRankingModel" + Master_GetClickModel_FullMethodName = "/protocol.Master/GetClickModel" + Master_PushProgress_FullMethodName = "/protocol.Master/PushProgress" +) + // MasterClient is the client API for Master service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. @@ -40,7 +61,7 @@ func NewMasterClient(cc grpc.ClientConnInterface) MasterClient { func (c *masterClient) GetMeta(ctx context.Context, in *NodeInfo, opts ...grpc.CallOption) (*Meta, error) { out := new(Meta) - err := c.cc.Invoke(ctx, "/protocol.Master/GetMeta", in, out, opts...) + err := c.cc.Invoke(ctx, Master_GetMeta_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -48,7 +69,7 @@ func (c *masterClient) GetMeta(ctx context.Context, in *NodeInfo, opts ...grpc.C } func (c *masterClient) GetRankingModel(ctx context.Context, in *VersionInfo, opts ...grpc.CallOption) (Master_GetRankingModelClient, error) { - stream, err := c.cc.NewStream(ctx, &Master_ServiceDesc.Streams[0], "/protocol.Master/GetRankingModel", opts...) + stream, err := c.cc.NewStream(ctx, &Master_ServiceDesc.Streams[0], Master_GetRankingModel_FullMethodName, opts...) if err != nil { return nil, err } @@ -80,7 +101,7 @@ func (x *masterGetRankingModelClient) Recv() (*Fragment, error) { } func (c *masterClient) GetClickModel(ctx context.Context, in *VersionInfo, opts ...grpc.CallOption) (Master_GetClickModelClient, error) { - stream, err := c.cc.NewStream(ctx, &Master_ServiceDesc.Streams[1], "/protocol.Master/GetClickModel", opts...) + stream, err := c.cc.NewStream(ctx, &Master_ServiceDesc.Streams[1], Master_GetClickModel_FullMethodName, opts...) if err != nil { return nil, err } @@ -113,7 +134,7 @@ func (x *masterGetClickModelClient) Recv() (*Fragment, error) { func (c *masterClient) PushProgress(ctx context.Context, in *PushProgressRequest, opts ...grpc.CallOption) (*PushProgressResponse, error) { out := new(PushProgressResponse) - err := c.cc.Invoke(ctx, "/protocol.Master/PushProgress", in, out, opts...) + err := c.cc.Invoke(ctx, Master_PushProgress_FullMethodName, in, out, opts...) if err != nil { return nil, err } @@ -172,7 +193,7 @@ func _Master_GetMeta_Handler(srv interface{}, ctx context.Context, dec func(inte } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/protocol.Master/GetMeta", + FullMethod: Master_GetMeta_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(MasterServer).GetMeta(ctx, req.(*NodeInfo)) @@ -232,7 +253,7 @@ func _Master_PushProgress_Handler(srv interface{}, ctx context.Context, dec func } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/protocol.Master/PushProgress", + FullMethod: Master_PushProgress_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(MasterServer).PushProgress(ctx, req.(*PushProgressRequest)) diff --git a/protocol/task.go b/protocol/task.go index 903f09b0a..b35c563e7 100644 --- a/protocol/task.go +++ b/protocol/task.go @@ -22,26 +22,36 @@ import ( //go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative protocol.proto -func DecodeProgress(in *PushProgressRequest) *progress.Progress { - return &progress.Progress{ - Tracer: in.GetTracer(), - Name: in.GetName(), - Status: progress.Status(in.GetStatus()), - Count: int(in.GetCount()), - Total: int(in.GetTotal()), - StartTime: time.UnixMilli(in.GetStartTime()), - FinishTime: time.UnixMilli(in.GetFinishTime()), +func DecodeProgress(in *PushProgressRequest) []progress.Progress { + var progressList []progress.Progress + for _, p := range in.Progress { + progressList = append(progressList, progress.Progress{ + Tracer: p.GetTracer(), + Name: p.GetName(), + Status: progress.Status(p.GetStatus()), + Count: int(p.GetCount()), + Total: int(p.GetTotal()), + StartTime: time.UnixMilli(p.GetStartTime()), + FinishTime: time.UnixMilli(p.GetFinishTime()), + }) } + return progressList } -func EncodeProgress(t *progress.Progress) *PushProgressRequest { +func EncodeProgress(progressList []progress.Progress) *PushProgressRequest { + var pbList []*Progress + for _, p := range progressList { + pbList = append(pbList, &Progress{ + Tracer: p.Tracer, + Name: p.Name, + Status: string(p.Status), + Count: int64(p.Count), + Total: int64(p.Total), + StartTime: p.StartTime.UnixMilli(), + FinishTime: p.FinishTime.UnixMilli(), + }) + } return &PushProgressRequest{ - Tracer: t.Tracer, - Name: t.Name, - Status: string(t.Status), - Count: int64(t.Count), - Total: int64(t.Total), - StartTime: t.StartTime.UnixMilli(), - FinishTime: t.FinishTime.UnixMilli(), + Progress: pbList, } } diff --git a/protocol/task_test.go b/protocol/task_test.go index 9afb80448..d1db3246b 100644 --- a/protocol/task_test.go +++ b/protocol/task_test.go @@ -23,15 +23,26 @@ import ( ) func TestEncodeDecode(t *testing.T) { - tk := &progress.Progress{ - Tracer: "tracer", - Name: "a", - Total: 100, - Count: 50, - Status: progress.StatusRunning, - StartTime: time.Date(2018, time.January, 1, 0, 0, 0, 0, time.Local), - FinishTime: time.Date(2018, time.January, 2, 0, 0, 0, 0, time.Local), + progressList := []progress.Progress{ + { + Tracer: "tracer", + Name: "a", + Total: 100, + Count: 50, + Status: progress.StatusRunning, + StartTime: time.Date(2018, time.January, 1, 0, 0, 0, 0, time.Local), + FinishTime: time.Date(2018, time.January, 2, 0, 0, 0, 0, time.Local), + }, + { + Tracer: "tracer", + Name: "b", + Total: 100, + Count: 50, + Status: progress.StatusRunning, + StartTime: time.Date(2018, time.January, 1, 0, 0, 0, 0, time.Local), + FinishTime: time.Date(2018, time.January, 2, 0, 0, 0, 0, time.Local), + }, } - pb := EncodeProgress(tk) - assert.Equal(t, tk, DecodeProgress(pb)) + pb := EncodeProgress(progressList) + assert.Equal(t, progressList, DecodeProgress(pb)) } diff --git a/worker/worker_test.go b/worker/worker_test.go index 90f1bc97b..811d5e663 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -176,7 +176,7 @@ func (m *mockMatrixFactorizationForRecommend) Invalid() bool { return false } -func (m *mockMatrixFactorizationForRecommend) Fit(_, _ *ranking.DataSet, _ *ranking.FitConfig) ranking.Score { +func (m *mockMatrixFactorizationForRecommend) Fit(_ context.Context, _, _ *ranking.DataSet, _ *ranking.FitConfig) ranking.Score { panic("implement me") } @@ -659,7 +659,7 @@ func newMockMaster(t *testing.T) *mockMaster { // create click model train, test := newClickDataset() fm := click.NewFM(click.FMClassification, model.Params{model.NEpochs: 0}) - fm.Fit(train, test, nil) + fm.Fit(context.Background(), train, test, nil) clickModelBuffer := bytes.NewBuffer(nil) err := click.MarshalModel(clickModelBuffer, fm) assert.NoError(t, err) @@ -667,7 +667,7 @@ func newMockMaster(t *testing.T) *mockMaster { // create ranking model trainSet, testSet := newRankingDataset() bpr := ranking.NewBPR(model.Params{model.NEpochs: 0}) - bpr.Fit(trainSet, testSet, nil) + bpr.Fit(context.Background(), trainSet, testSet, nil) rankingModelBuffer := bytes.NewBuffer(nil) err = ranking.MarshalModel(rankingModelBuffer, bpr) assert.NoError(t, err) @@ -848,7 +848,7 @@ func (m mockFactorizationMachine) InternalPredict(_ []int32, _ []float32) float3 panic("implement me") } -func (m mockFactorizationMachine) Fit(_, _ *click.Dataset, _ *click.FitConfig) click.Score { +func (m mockFactorizationMachine) Fit(_ context.Context, _, _ *click.Dataset, _ *click.FitConfig) click.Score { panic("implement me") } From 6c84b47388da6a5a7e5b85d9cd90c44a7e23169f Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Fri, 28 Jul 2023 21:45:50 +0800 Subject: [PATCH 09/13] Fix worker --- worker/worker.go | 36 +++++++++++------------------------- worker/worker_test.go | 2 ++ 2 files changed, 13 insertions(+), 25 deletions(-) diff --git a/worker/worker.go b/worker/worker.go index 9f74bfa4c..04cabbe43 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -37,8 +37,8 @@ import ( "github.com/zhenghaoz/gorse/base/heap" "github.com/zhenghaoz/gorse/base/log" "github.com/zhenghaoz/gorse/base/parallel" + "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/base/search" - "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/cmd/version" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/model/click" @@ -65,6 +65,7 @@ type ScheduleState struct { // Worker manages states of a worker node. type Worker struct { + tracer *progress.Tracer oneMode bool testMode bool managedMode bool @@ -369,6 +370,9 @@ func (w *Worker) Serve() { zap.String("worker_name", w.workerName)) } + // create progress tracer + w.tracer = progress.NewTracer(w.workerName) + // connect to master conn, err := grpc.Dial(fmt.Sprintf("%v:%v", w.masterHost, w.masterPort), grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { @@ -424,14 +428,6 @@ func (w *Worker) Serve() { } } -func (w *Worker) estimateRecommendComplexity(numUsers, numItems int) int { - complexity := numUsers * numItems * recommendComplexityFactor - if w.Config.Recommend.Collaborative.EnableIndex { - complexity += search.EstimateHNSWBuilderComplexity(numItems, w.Config.Recommend.Collaborative.IndexFitEpoch) - } - return complexity -} - // Recommend items to users. The workflow of recommendation is: // 1. Skip inactive users. // 2. Load historical items. @@ -464,12 +460,8 @@ func (w *Worker) Recommend(users []data.User) { if !w.oneMode { recommendTaskName += fmt.Sprintf(" [%s]", w.workerName) } - recommendTask := task.NewTask(recommendTaskName, w.estimateRecommendComplexity(len(users), itemCache.Len())) - if w.masterClient != nil { - if _, err := w.masterClient.PushTaskInfo(context.Background(), protocol.EncodeTask(recommendTask)); err != nil { - log.Logger().Error("failed to report start task", zap.Error(err)) - } - } + _, span := w.tracer.Start(context.Background(), "Recommend", len(users)) + defer span.End() go func() { defer base.CheckPanic() @@ -487,14 +479,14 @@ func (w *Worker) Recommend(users []data.User) { previousCount = completedCount if throughput > 0 { if w.masterClient != nil { - recommendTask.Add(throughput * itemCache.Len() * recommendComplexityFactor) + span.Add(throughput) } log.Logger().Info("ranking recommendation", zap.Int("n_complete_users", completedCount), zap.Int("n_working_users", len(users)), zap.Int("throughput", throughput)) } - if _, err := w.masterClient.PushTaskInfo(context.Background(), protocol.EncodeTask(recommendTask)); err != nil { + if _, err := w.masterClient.PushProgress(context.Background(), protocol.EncodeProgress(w.tracer.List())); err != nil { log.Logger().Error("failed to report update task", zap.Error(err)) } } @@ -518,8 +510,8 @@ func (w *Worker) Recommend(users []data.User) { } builder := search.NewHNSWBuilder(vectors, w.Config.Recommend.CacheSize, w.jobs) var recall float32 - w.rankingIndex, recall = builder.Build(w.Config.Recommend.Collaborative.IndexRecall, - w.Config.Recommend.Collaborative.IndexFitEpoch, false, recommendTask) + w.rankingIndex, recall = builder.Build(ctx, w.Config.Recommend.Collaborative.IndexRecall, + w.Config.Recommend.Collaborative.IndexFitEpoch, false) CollaborativeFilteringIndexRecall.Set(float64(recall)) if err = w.CacheClient.Set(ctx, cache.String(cache.Key(cache.GlobalMeta, cache.MatchingIndexRecall), encoding.FormatFloat32(recall))); err != nil { log.Logger().Error("failed to write meta", zap.Error(err)) @@ -819,12 +811,6 @@ func (w *Worker) Recommend(users []data.User) { log.Logger().Error("failed to continue offline recommendation", zap.Error(err)) return } - if w.masterClient != nil { - recommendTask.Finish() - if _, err := w.masterClient.PushTaskInfo(context.Background(), protocol.EncodeTask(recommendTask)); err != nil { - log.Logger().Error("failed to report finish task", zap.Error(err)) - } - } log.Logger().Info("complete ranking recommendation", zap.String("used_time", time.Since(startTime).String())) UpdateUserRecommendTotal.Set(updateUserCount.Load()) diff --git a/worker/worker_test.go b/worker/worker_test.go index 811d5e663..0cac316be 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -36,6 +36,7 @@ import ( "github.com/thoas/go-funk" "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/base/parallel" + "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/model" "github.com/zhenghaoz/gorse/model/click" @@ -56,6 +57,7 @@ type WorkerTestSuite struct { func (suite *WorkerTestSuite) SetupSuite() { // open database var err error + suite.tracer = progress.NewTracer("test") suite.Settings = config.NewSettings() suite.DataClient, err = data.Open(fmt.Sprintf("sqlite://%s/data.db", suite.T().TempDir()), "") suite.NoError(err) From cc486fcfc6e62533caa8c177f505140c6385173d Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Fri, 28 Jul 2023 22:07:59 +0800 Subject: [PATCH 10/13] Fix tests --- base/progress/progress.go | 74 ++++++++++++++++++------- master/master.go | 9 +-- master/master_test.go | 2 - master/rest.go | 23 +++++--- master/rpc.go | 27 ++++++--- master/rpc_test.go | 47 +++++++--------- master/tasks.go | 113 ++++++-------------------------------- 7 files changed, 133 insertions(+), 162 deletions(-) diff --git a/base/progress/progress.go b/base/progress/progress.go index 97685ace2..1c7f346cc 100644 --- a/base/progress/progress.go +++ b/base/progress/progress.go @@ -17,32 +17,50 @@ package progress import ( "context" "sync" + "time" "github.com/google/uuid" ) -var spanKey = uuid.New().String() +type spanKeyType string + +var spanKeyName = spanKeyType(uuid.New().String()) + +type Status string + +const ( + StatusPending Status = "Pending" + StatusComplete Status = "Complete" + StatusRunning Status = "Running" + StatusSuspended Status = "Suspended" + StatusFailed Status = "Failed" +) type Tracer struct { + name string spans sync.Map } +func NewTracer(name string) *Tracer { + return &Tracer{name: name} +} + // Start creates a root span. -func (t *Tracer) Start(ctx context.Context, name string, total int64) (context.Context, *Span) { +func (t *Tracer) Start(ctx context.Context, name string, total int) (context.Context, *Span) { span := &Span{name: name, total: total} t.spans.Store(name, span) - return context.WithValue(ctx, spanKey, span), span + return context.WithValue(ctx, spanKeyName, span), span } func (t *Tracer) List() []Progress { var progress []Progress t.spans.Range(func(key, value interface{}) bool { - span := value.(*Span) - progress = append(progress, Progress{ - Name: span.name, - Total: span.total, - Count: span.count, - }) + // span := value.(*Span) + // progress = append(progress, Progress{ + // Name: span.name, + // Total: span.total, + // Count: span.count, + // }) return true }) return progress @@ -50,13 +68,16 @@ func (t *Tracer) List() []Progress { type Span struct { name string - total int64 - count int64 + status Status + total int + count int err error + start time.Time + finish time.Time children sync.Map } -func (s *Span) Add(n int64) { +func (s *Span) Add(n int) { s.count += n } @@ -68,21 +89,36 @@ func (s *Span) Error(err error) { s.err = err } -func Start(ctx context.Context, name string, total int64) (context.Context, *Span) { - childSpan := &Span{name: name, total: total} +func (s *Span) Count() int { + return s.count +} + +func Start(ctx context.Context, name string, total int) (context.Context, *Span) { + childSpan := &Span{ + name: name, + status: StatusRunning, + total: total, + count: 0, + start: time.Now(), + } if ctx == nil { return nil, childSpan } - span, ok := (ctx).Value(spanKey).(*Span) + span, ok := (ctx).Value(spanKeyName).(*Span) if !ok { return nil, childSpan } span.children.Store(name, childSpan) - return context.WithValue(ctx, spanKey, childSpan), childSpan + return context.WithValue(ctx, spanKeyName, childSpan), childSpan } type Progress struct { - Name string - Total int64 - Count int64 + Tracer string + Name string + Status Status + Error string + Count int + Total int + StartTime time.Time + FinishTime time.Time } diff --git a/master/master.go b/master/master.go index afbba8baf..dc36fad19 100644 --- a/master/master.go +++ b/master/master.go @@ -58,10 +58,11 @@ type Master struct { server.RestServer grpcServer *grpc.Server - tracer progress.Tracer - jobsScheduler *task.JobsScheduler - cacheFile string - managedMode bool + tracer *progress.Tracer + remoteProgress sync.Map + jobsScheduler *task.JobsScheduler + cacheFile string + managedMode bool // cluster meta cache ttlCache *ttlcache.Cache diff --git a/master/master_test.go b/master/master_test.go index a5b74fda1..8e2dfbdee 100644 --- a/master/master_test.go +++ b/master/master_test.go @@ -18,7 +18,6 @@ import ( "testing" "github.com/stretchr/testify/suite" - "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/storage/cache" "github.com/zhenghaoz/gorse/storage/data" @@ -30,7 +29,6 @@ type MasterTestSuite struct { } func (s *MasterTestSuite) SetupTest() { - s.taskMonitor = task.NewTaskMonitor() // open database var err error s.Settings = config.NewSettings() diff --git a/master/rest.go b/master/rest.go index 630fbd3f6..a1e70e765 100644 --- a/master/rest.go +++ b/master/rest.go @@ -40,7 +40,7 @@ import ( "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/base/encoding" "github.com/zhenghaoz/gorse/base/log" - "github.com/zhenghaoz/gorse/base/task" + "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/cmd/version" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/model/click" @@ -82,8 +82,8 @@ func (m *Master) CreateWebService() { Doc("Get tasks."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). Param(ws.HeaderParameter("X-API-Key", "secret key for RESTful API")). - Returns(http.StatusOK, "OK", []task.Task{}). - Writes([]task.Task{})) + Returns(http.StatusOK, "OK", []progress.Progress{}). + Writes([]progress.Progress{})) ws.Route(ws.GET("/dashboard/rates").To(m.getRates). Doc("Get positive feedback rates."). Metadata(restfulspec.KeyOpenAPITags, []string{"dashboard"}). @@ -639,17 +639,24 @@ func (m *Master) getStats(request *restful.Request, response *restful.Response) func (m *Master) getTasks(_ *restful.Request, response *restful.Response) { // List workers - workers := make([]string, 0) + workers := mapset.NewSet[string]() m.nodesInfoMutex.RLock() for _, info := range m.nodesInfo { if info.Type == WorkerNode { - workers = append(workers, info.Name) + workers.Add(info.Name) } } m.nodesInfoMutex.RUnlock() - // List tasks - tasks := m.taskMonitor.List(workers...) - server.Ok(response, tasks) + // List local progress + progressList := m.tracer.List() + // list remote progress + m.remoteProgress.Range(func(key, value interface{}) bool { + if workers.Contains(key.(string)) { + progressList = append(progressList, value.([]progress.Progress)...) + } + return true + }) + server.Ok(response, progressList) } func (m *Master) getRates(request *restful.Request, response *restful.Response) { diff --git a/master/rpc.go b/master/rpc.go index 07ffc07df..11c743829 100644 --- a/master/rpc.go +++ b/master/rpc.go @@ -17,6 +17,9 @@ package master import ( "context" "encoding/json" + "io" + "strings" + "github.com/juju/errors" "github.com/zhenghaoz/gorse/base/log" "github.com/zhenghaoz/gorse/model/click" @@ -24,8 +27,6 @@ import ( "github.com/zhenghaoz/gorse/protocol" "go.uber.org/zap" "google.golang.org/grpc/peer" - "io" - "strings" ) // Node could be worker node for server node. @@ -232,11 +233,21 @@ func (m *Master) nodeDown(key string, value interface{}) { delete(m.nodesInfo, key) } -func (m *Master) PushTaskInfo( +func (m *Master) PushProgress( _ context.Context, - in *protocol.PushTaskInfoRequest) (*protocol.PushTaskInfoResponse, error) { - m.taskMonitor.TaskLock.Lock() - defer m.taskMonitor.TaskLock.Unlock() - m.taskMonitor.Tasks[in.GetName()] = protocol.DecodeTask(in) - return &protocol.PushTaskInfoResponse{}, nil + in *protocol.PushProgressRequest) (*protocol.PushProgressResponse, error) { + // check empty progress + if len(in.Progress) == 0 { + return &protocol.PushProgressResponse{}, nil + } + // check tracers + tracer := in.Progress[0].Tracer + for _, p := range in.Progress { + if p.Tracer != tracer { + return nil, errors.Errorf("tracers must be the same, expect %v, got %v", tracer, p.Tracer) + } + } + // store progress + m.remoteProgress.Store(tracer, protocol.DecodeProgress(in)) + return &protocol.PushProgressResponse{}, nil } diff --git a/master/rpc_test.go b/master/rpc_test.go index f1281a437..9c537b6a2 100644 --- a/master/rpc_test.go +++ b/master/rpc_test.go @@ -17,9 +17,13 @@ package master import ( "context" "encoding/json" + "net" + "testing" + "time" + "github.com/ReneKroon/ttlcache/v2" "github.com/stretchr/testify/assert" - "github.com/zhenghaoz/gorse/base/task" + "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/model" "github.com/zhenghaoz/gorse/model/click" @@ -30,9 +34,6 @@ import ( "github.com/zhenghaoz/gorse/storage/data" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" - "net" - "testing" - "time" ) type mockMasterRPC struct { @@ -45,14 +46,13 @@ func newMockMasterRPC(_ *testing.T) *mockMasterRPC { // create click model train, test := newClickDataset() fm := click.NewFM(click.FMClassification, model.Params{model.NEpochs: 0}) - fm.Fit(train, test, nil) + fm.Fit(context.Background(), train, test, nil) // create ranking model trainSet, testSet := newRankingDataset() bpr := ranking.NewBPR(model.Params{model.NEpochs: 0}) - bpr.Fit(trainSet, testSet, nil) + bpr.Fit(context.Background(), trainSet, testSet, nil) return &mockMasterRPC{ Master: Master{ - taskMonitor: task.NewTaskMonitor(), nodesInfo: make(map[string]*Node), rankingModelName: "bpr", RestServer: server.RestServer{ @@ -101,26 +101,21 @@ func TestRPC(t *testing.T) { client := protocol.NewMasterClient(conn) ctx := context.Background() - testTask := task.NewTask("a", 12) - _, err = client.PushTaskInfo(ctx, protocol.EncodeTask(testTask)) - assert.NoError(t, err) - assert.Equal(t, 12, rpcServer.taskMonitor.Tasks["a"].Total) - assert.Equal(t, 0, rpcServer.taskMonitor.Tasks["a"].Done) - assert.Equal(t, task.StatusRunning, rpcServer.taskMonitor.Tasks["a"].Status) - - testTask.Update(10) - _, err = client.PushTaskInfo(ctx, protocol.EncodeTask(testTask)) - assert.NoError(t, err) - assert.Equal(t, 12, rpcServer.taskMonitor.Tasks["a"].Total) - assert.Equal(t, 10, rpcServer.taskMonitor.Tasks["a"].Done) - assert.Equal(t, task.StatusRunning, rpcServer.taskMonitor.Tasks["a"].Status) - - testTask.Finish() - _, err = client.PushTaskInfo(ctx, protocol.EncodeTask(testTask)) + progressList := []progress.Progress{{ + Tracer: "tracer", + Name: "a", + Status: progress.StatusRunning, + Total: 12, + Count: 6, + StartTime: time.Date(2018, time.January, 1, 0, 0, 0, 0, time.Local), + FinishTime: time.Date(2018, time.January, 2, 0, 0, 0, 0, time.Local), + }} + _, err = client.PushProgress(ctx, protocol.EncodeProgress(progressList)) assert.NoError(t, err) - assert.Equal(t, 12, rpcServer.taskMonitor.Tasks["a"].Total) - assert.Equal(t, 12, rpcServer.taskMonitor.Tasks["a"].Done) - assert.Equal(t, task.StatusComplete, rpcServer.taskMonitor.Tasks["a"].Status) + i, ok := rpcServer.remoteProgress.Load("tracer") + assert.True(t, ok) + remoteProgressList := i.([]progress.Progress) + assert.Equal(t, progressList, remoteProgressList) // test get click model clickModelReceiver, err := client.GetClickModel(ctx, &protocol.VersionInfo{Version: 456}) diff --git a/master/tasks.go b/master/tasks.go index 502073598..fff4d14cd 100644 --- a/master/tasks.go +++ b/master/tasks.go @@ -62,7 +62,7 @@ const ( type Task interface { name() string priority() int - run(j *task.JobsAllocator) error + run(ctx context.Context, j *task.JobsAllocator) error } // runLoadDatasetTask loads dataset. @@ -194,22 +194,6 @@ func (m *Master) runLoadDatasetTask() error { return nil } -func (m *Master) estimateFindItemNeighborsComplexity(dataset *ranking.DataSet) int { - complexity := dataset.ItemCount() * dataset.ItemCount() - if m.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeRelated || - m.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeAuto { - complexity += len(dataset.UserFeedback) + len(dataset.ItemFeedback) - } - if m.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeSimilar || - m.Config.Recommend.ItemNeighbors.NeighborType == config.NeighborTypeAuto { - complexity += len(dataset.ItemFeatures) + int(dataset.NumItemLabels) - } - if m.Config.Recommend.ItemNeighbors.EnableIndex { - complexity += search.EstimateIVFBuilderComplexity(dataset.ItemCount(), m.Config.Recommend.ItemNeighbors.IndexFitEpoch) - } - return complexity -} - // FindItemNeighborsTask updates neighbors of items. type FindItemNeighborsTask struct { *Master @@ -229,16 +213,14 @@ func (t *FindItemNeighborsTask) priority() int { return -t.rankingTrainSet.ItemCount() * t.rankingTrainSet.ItemCount() } -func (t *FindItemNeighborsTask) run(j *task.JobsAllocator) error { +func (t *FindItemNeighborsTask) run(ctx context.Context, j *task.JobsAllocator) error { t.rankingDataMutex.RLock() defer t.rankingDataMutex.RUnlock() dataset := t.rankingTrainSet numItems := dataset.ItemCount() numFeedback := dataset.Count() - ctx := context.Background() if numItems == 0 { - t.taskMonitor.Fail(TaskFindItemNeighbors, "No item found.") return nil } else if numItems == t.lastNumItems && numFeedback == t.lastNumFeedback { log.Logger().Info("No item neighbors need to be updated.") @@ -246,7 +228,6 @@ func (t *FindItemNeighborsTask) run(j *task.JobsAllocator) error { } startTaskTime := time.Now() - t.taskMonitor.Start(TaskFindItemNeighbors, t.estimateFindItemNeighborsComplexity(dataset)) log.Logger().Info("start searching neighbors of items", zap.Int("n_cache", t.Config.Recommend.CacheSize)) // create progress tracker @@ -265,7 +246,6 @@ func (t *FindItemNeighborsTask) run(j *task.JobsAllocator) error { throughput := completedCount - previousCount previousCount = completedCount if throughput > 0 { - t.taskMonitor.Add(TaskFindItemNeighbors, throughput*dataset.ItemCount()) log.Logger().Debug("searching neighbors of items", zap.Int("n_complete_items", completedCount), zap.Int("n_items", dataset.ItemCount()), @@ -281,7 +261,6 @@ func (t *FindItemNeighborsTask) run(j *task.JobsAllocator) error { for _, feedbacks := range dataset.ItemFeedback { sort.Sort(sortutil.Int32Slice(feedbacks)) } - t.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.ItemFeedback)) // inverse document frequency of users for i := range dataset.UserFeedback { if dataset.ItemCount() == len(dataset.UserFeedback[i]) { @@ -290,7 +269,6 @@ func (t *FindItemNeighborsTask) run(j *task.JobsAllocator) error { userIDF[i] = math32.Log(float32(dataset.ItemCount()) / float32(len(dataset.UserFeedback[i]))) } } - t.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.UserFeedback)) } labeledItems := make([][]int32, dataset.NumItemLabels) labelIDF := make([]float32, dataset.NumItemLabels) @@ -304,7 +282,6 @@ func (t *FindItemNeighborsTask) run(j *task.JobsAllocator) error { labeledItems[label.A] = append(labeledItems[label.A], int32(i)) } } - t.taskMonitor.Add(TaskFindItemNeighbors, len(dataset.ItemFeatures)) // inverse document frequency of labels for i := range labeledItems { labeledItems[i] = lo.Uniq(labeledItems[i]) @@ -314,7 +291,6 @@ func (t *FindItemNeighborsTask) run(j *task.JobsAllocator) error { labelIDF[i] = math32.Log(float32(dataset.ItemCount()) / float32(len(labeledItems[i]))) } } - t.taskMonitor.Add(TaskFindItemNeighbors, len(labeledItems)) } start := time.Now() @@ -329,7 +305,6 @@ func (t *FindItemNeighborsTask) run(j *task.JobsAllocator) error { close(completed) if err != nil { log.Logger().Error("failed to searching neighbors of items", zap.Error(err)) - t.taskMonitor.Fail(TaskFindItemNeighbors, err.Error()) FindItemNeighborsTotalSeconds.Set(0) } else { if err := t.CacheClient.Set(ctx, cache.Time(cache.Key(cache.GlobalMeta, cache.LastUpdateItemNeighborsTime), time.Now())); err != nil { @@ -337,7 +312,6 @@ func (t *FindItemNeighborsTask) run(j *task.JobsAllocator) error { } log.Logger().Info("complete searching neighbors of items", zap.String("search_time", searchTime.String())) - t.taskMonitor.Finish(TaskFindItemNeighbors) FindItemNeighborsTotalSeconds.Set(time.Since(startTaskTime).Seconds()) } @@ -477,8 +451,7 @@ func (m *Master) findItemNeighborsIVF(dataset *ranking.DataSet, labelIDF, userID var recall float32 index, recall = builder.Build(m.Config.Recommend.ItemNeighbors.IndexRecall, m.Config.Recommend.ItemNeighbors.IndexFitEpoch, - true, - m.taskMonitor.GetTask(TaskFindItemNeighbors)) + true) ItemNeighborIndexRecall.Set(float64(recall)) if err := m.CacheClient.Set(ctx, cache.String(cache.Key(cache.GlobalMeta, cache.ItemNeighborIndexRecall), encoding.FormatFloat32(recall))); err != nil { return errors.Trace(err) @@ -546,22 +519,6 @@ func (m *Master) findItemNeighborsIVF(dataset *ranking.DataSet, labelIDF, userID return nil } -func (m *Master) estimateFindUserNeighborsComplexity(dataset *ranking.DataSet) int { - complexity := dataset.UserCount() * dataset.UserCount() - if m.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeRelated || - m.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeAuto { - complexity += len(dataset.UserFeedback) + len(dataset.ItemFeedback) - } - if m.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeSimilar || - m.Config.Recommend.UserNeighbors.NeighborType == config.NeighborTypeAuto { - complexity += len(dataset.UserFeatures) + int(dataset.NumUserLabels) - } - if m.Config.Recommend.UserNeighbors.EnableIndex { - complexity += search.EstimateIVFBuilderComplexity(dataset.UserCount(), m.Config.Recommend.UserNeighbors.IndexFitEpoch) - } - return complexity -} - // FindUserNeighborsTask updates neighbors of users. type FindUserNeighborsTask struct { *Master @@ -581,16 +538,14 @@ func (t *FindUserNeighborsTask) priority() int { return -t.rankingTrainSet.UserCount() * t.rankingTrainSet.UserCount() } -func (t *FindUserNeighborsTask) run(j *task.JobsAllocator) error { +func (t *FindUserNeighborsTask) run(ctx context.Context, j *task.JobsAllocator) error { t.rankingDataMutex.RLock() defer t.rankingDataMutex.RUnlock() dataset := t.rankingTrainSet numUsers := dataset.UserCount() numFeedback := dataset.Count() - ctx := context.Background() if numUsers == 0 { - t.taskMonitor.Fail(TaskFindItemNeighbors, "No item found.") return nil } else if numUsers == t.lastNumUsers && numFeedback == t.lastNumFeedback { log.Logger().Info("No update of user neighbors needed.") @@ -598,7 +553,6 @@ func (t *FindUserNeighborsTask) run(j *task.JobsAllocator) error { } startTaskTime := time.Now() - t.taskMonitor.Start(TaskFindUserNeighbors, t.estimateFindUserNeighborsComplexity(dataset)) log.Logger().Info("start searching neighbors of users", zap.Int("n_cache", t.Config.Recommend.CacheSize)) // create progress tracker @@ -617,7 +571,6 @@ func (t *FindUserNeighborsTask) run(j *task.JobsAllocator) error { throughput := completedCount - previousCount previousCount = completedCount if throughput > 0 { - t.taskMonitor.Add(TaskFindUserNeighbors, throughput*dataset.UserCount()) log.Logger().Debug("searching neighbors of users", zap.Int("n_complete_users", completedCount), zap.Int("n_users", dataset.UserCount()), @@ -633,7 +586,6 @@ func (t *FindUserNeighborsTask) run(j *task.JobsAllocator) error { for _, feedbacks := range dataset.UserFeedback { sort.Sort(sortutil.Int32Slice(feedbacks)) } - t.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.UserFeedback)) // inverse document frequency of items for i := range dataset.ItemFeedback { if dataset.UserCount() == len(dataset.ItemFeedback[i]) { @@ -642,7 +594,6 @@ func (t *FindUserNeighborsTask) run(j *task.JobsAllocator) error { itemIDF[i] = math32.Log(float32(dataset.UserCount()) / float32(len(dataset.ItemFeedback[i]))) } } - t.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.ItemFeedback)) } labeledUsers := make([][]int32, dataset.NumUserLabels) labelIDF := make([]float32, dataset.NumUserLabels) @@ -656,7 +607,6 @@ func (t *FindUserNeighborsTask) run(j *task.JobsAllocator) error { labeledUsers[label.A] = append(labeledUsers[label.A], int32(i)) } } - t.taskMonitor.Add(TaskFindUserNeighbors, len(dataset.UserFeatures)) // inverse document frequency of labels for i := range labeledUsers { labeledUsers[i] = lo.Uniq(labeledUsers[i]) @@ -666,7 +616,6 @@ func (t *FindUserNeighborsTask) run(j *task.JobsAllocator) error { labelIDF[i] = math32.Log(float32(dataset.UserCount()) / float32(len(labeledUsers[i]))) } } - t.taskMonitor.Add(TaskFindUserNeighbors, len(labeledUsers)) } start := time.Now() @@ -681,7 +630,6 @@ func (t *FindUserNeighborsTask) run(j *task.JobsAllocator) error { close(completed) if err != nil { log.Logger().Error("failed to searching neighbors of users", zap.Error(err)) - t.taskMonitor.Fail(TaskFindUserNeighbors, err.Error()) FindUserNeighborsTotalSeconds.Set(0) } else { if err := t.CacheClient.Set(ctx, cache.Time(cache.Key(cache.GlobalMeta, cache.LastUpdateUserNeighborsTime), time.Now())); err != nil { @@ -689,7 +637,6 @@ func (t *FindUserNeighborsTask) run(j *task.JobsAllocator) error { } log.Logger().Info("complete searching neighbors of users", zap.String("search_time", searchTime.String())) - t.taskMonitor.Finish(TaskFindUserNeighbors) FindUserNeighborsTotalSeconds.Set(time.Since(startTaskTime).Seconds()) } @@ -820,8 +767,7 @@ func (m *Master) findUserNeighborsIVF(dataset *ranking.DataSet, labelIDF, itemID index, recall = builder.Build( m.Config.Recommend.UserNeighbors.IndexRecall, m.Config.Recommend.UserNeighbors.IndexFitEpoch, - true, - m.taskMonitor.GetTask(TaskFindUserNeighbors)) + true) UserNeighborIndexRecall.Set(float64(recall)) if err := m.CacheClient.Set(ctx, cache.String(cache.Key(cache.GlobalMeta, cache.UserNeighborIndexRecall), encoding.FormatFloat32(recall))); err != nil { return errors.Trace(err) @@ -1028,14 +974,13 @@ func (t *FitRankingModelTask) priority() int { return -t.rankingTrainSet.Count() } -func (t *FitRankingModelTask) run(j *task.JobsAllocator) error { +func (t *FitRankingModelTask) run(ctx context.Context, j *task.JobsAllocator) error { t.rankingDataMutex.RLock() defer t.rankingDataMutex.RUnlock() dataset := t.rankingTrainSet numFeedback := dataset.Count() var modelChanged bool - ctx := context.Background() bestRankingName, bestRankingModel, bestRankingScore := t.rankingModelSearcher.GetBestModel() t.rankingModelMutex.Lock() if bestRankingModel != nil && !bestRankingModel.Invalid() && @@ -1057,7 +1002,7 @@ func (t *FitRankingModelTask) run(j *task.JobsAllocator) error { t.rankingModelMutex.Unlock() if numFeedback == 0 { - t.taskMonitor.Fail(TaskFitRankingModel, "No feedback found.") + // t.taskMonitor.Fail(TaskFitRankingModel, "No feedback found.") return nil } else if numFeedback == t.lastNumFeedback && !modelChanged { log.Logger().Info("nothing changed") @@ -1065,9 +1010,7 @@ func (t *FitRankingModelTask) run(j *task.JobsAllocator) error { } startFitTime := time.Now() - score := rankingModel.Fit(t.rankingTrainSet, t.rankingTestSet, ranking.NewFitConfig(). - SetJobsAllocator(j). - SetTask(t.taskMonitor.Start(TaskFitRankingModel, rankingModel.Complexity()))) + score := rankingModel.Fit(ctx, t.rankingTrainSet, t.rankingTestSet, ranking.NewFitConfig().SetJobsAllocator(j)) CollaborativeFilteringFitSeconds.Set(time.Since(startFitTime).Seconds()) // update ranking model @@ -1105,7 +1048,7 @@ func (t *FitRankingModelTask) run(j *task.JobsAllocator) error { zap.Any("ranking_model_params", t.localCache.RankingModel.GetParams())) } - t.taskMonitor.Finish(TaskFitRankingModel) + // t.taskMonitor.Finish(TaskFitRankingModel) t.lastNumFeedback = numFeedback return nil } @@ -1133,7 +1076,7 @@ func (t *FitClickModelTask) priority() int { return -t.clickTrainSet.Count() } -func (t *FitClickModelTask) run(j *task.JobsAllocator) error { +func (t *FitClickModelTask) run(ctx context.Context, j *task.JobsAllocator) error { log.Logger().Info("prepare to fit click model", zap.Int("n_jobs", t.Config.Master.NumJobs)) t.clickDataMutex.RLock() defer t.clickDataMutex.RUnlock() @@ -1141,12 +1084,10 @@ func (t *FitClickModelTask) run(j *task.JobsAllocator) error { numItems := t.clickTrainSet.ItemCount() numFeedback := t.clickTrainSet.Count() var shouldFit bool - ctx := context.Background() if t.clickTrainSet == nil || numUsers == 0 || numItems == 0 || numFeedback == 0 { log.Logger().Warn("empty ranking dataset", zap.Strings("positive_feedback_type", t.Config.Recommend.DataSource.PositiveFeedbackTypes)) - t.taskMonitor.Fail(TaskFitClickModel, "No feedback found.") return nil } else if numUsers != t.lastNumUsers || numItems != t.lastNumItems || @@ -1179,9 +1120,8 @@ func (t *FitClickModelTask) run(j *task.JobsAllocator) error { return nil } startFitTime := time.Now() - score := clickModel.Fit(t.clickTrainSet, t.clickTestSet, click.NewFitConfig(). - SetJobsAllocator(j). - SetTask(t.taskMonitor.Start(TaskFitClickModel, clickModel.Complexity()))) + score := clickModel.Fit(context.Background(), t.clickTrainSet, t.clickTestSet, click.NewFitConfig(). + SetJobsAllocator(j)) RankingFitSeconds.Set(time.Since(startFitTime).Seconds()) // update match model @@ -1217,7 +1157,6 @@ func (t *FitClickModelTask) run(j *task.JobsAllocator) error { zap.Any("click_model_params", t.localCache.ClickModel.GetParams())) } - t.taskMonitor.Finish(TaskFitClickModel) t.lastNumItems = numItems t.lastNumUsers = numUsers t.lastNumFeedback = numFeedback @@ -1245,7 +1184,7 @@ func (t *SearchRankingModelTask) priority() int { return -t.rankingTrainSet.Count() } -func (t *SearchRankingModelTask) run(j *task.JobsAllocator) error { +func (t *SearchRankingModelTask) run(ctx context.Context, j *task.JobsAllocator) error { log.Logger().Info("start searching ranking model") t.rankingDataMutex.RLock() defer t.rankingDataMutex.RUnlock() @@ -1260,7 +1199,7 @@ func (t *SearchRankingModelTask) run(j *task.JobsAllocator) error { if numUsers == 0 || numItems == 0 || numFeedback == 0 { log.Logger().Warn("empty ranking dataset", zap.Strings("positive_feedback_type", t.Config.Recommend.DataSource.PositiveFeedbackTypes)) - t.taskMonitor.Fail(TaskSearchRankingModel, "No feedback found.") + // t.taskMonitor.Fail(TaskSearchRankingModel, "No feedback found.") return nil } else if numUsers == t.lastNumUsers && numItems == t.lastNumItems && @@ -1270,8 +1209,7 @@ func (t *SearchRankingModelTask) run(j *task.JobsAllocator) error { } startTime := time.Now() - err := t.rankingModelSearcher.Fit(t.rankingTrainSet, t.rankingTestSet, - t.taskMonitor.Start(TaskSearchRankingModel, t.rankingModelSearcher.Complexity()), j) + err := t.rankingModelSearcher.Fit(ctx, t.rankingTrainSet, t.rankingTestSet, nil) if err != nil { log.Logger().Error("failed to search collaborative filtering model", zap.Error(err)) return nil @@ -1280,7 +1218,6 @@ func (t *SearchRankingModelTask) run(j *task.JobsAllocator) error { _, _, bestScore := t.rankingModelSearcher.GetBestModel() CollaborativeFilteringSearchPrecision10.Set(float64(bestScore.Precision)) - t.taskMonitor.Finish(TaskSearchRankingModel) t.lastNumItems = numItems t.lastNumUsers = numUsers t.lastNumFeedback = numFeedback @@ -1308,7 +1245,7 @@ func (t *SearchClickModelTask) priority() int { return -t.clickTrainSet.Count() } -func (t *SearchClickModelTask) run(j *task.JobsAllocator) error { +func (t *SearchClickModelTask) run(ctx context.Context, j *task.JobsAllocator) error { log.Logger().Info("start searching click model") t.clickDataMutex.RLock() defer t.clickDataMutex.RUnlock() @@ -1323,7 +1260,6 @@ func (t *SearchClickModelTask) run(j *task.JobsAllocator) error { if numUsers == 0 || numItems == 0 || numFeedback == 0 { log.Logger().Warn("empty click dataset", zap.Strings("positive_feedback_type", t.Config.Recommend.DataSource.PositiveFeedbackTypes)) - t.taskMonitor.Fail(TaskSearchClickModel, "No feedback found.") return nil } else if numUsers == t.lastNumUsers && numItems == t.lastNumItems && @@ -1333,8 +1269,7 @@ func (t *SearchClickModelTask) run(j *task.JobsAllocator) error { } startTime := time.Now() - err := t.clickModelSearcher.Fit(t.clickTrainSet, t.clickTestSet, - t.taskMonitor.Start(TaskSearchClickModel, t.clickModelSearcher.Complexity()), j) + err := t.clickModelSearcher.Fit(context.Background(), t.clickTrainSet, t.clickTestSet, j) if err != nil { log.Logger().Error("failed to search ranking model", zap.Error(err)) return nil @@ -1343,7 +1278,6 @@ func (t *SearchClickModelTask) run(j *task.JobsAllocator) error { _, bestScore := t.clickModelSearcher.GetBestModel() RankingSearchPrecision.Set(float64(bestScore.Precision)) - t.taskMonitor.Finish(TaskSearchClickModel) t.lastNumItems = numItems t.lastNumUsers = numUsers t.lastNumFeedback = numFeedback @@ -1366,15 +1300,13 @@ func (t *CacheGarbageCollectionTask) priority() int { return -t.rankingTrainSet.UserCount() - t.rankingTrainSet.ItemCount() } -func (t *CacheGarbageCollectionTask) run(j *task.JobsAllocator) error { +func (t *CacheGarbageCollectionTask) run(ctx context.Context, j *task.JobsAllocator) error { if t.rankingTrainSet == nil { log.Logger().Debug("dataset has not been loaded") return nil } - ctx := context.Background() log.Logger().Info("start cache garbage collection") - t.taskMonitor.Start(TaskCacheGarbageCollection, t.rankingTrainSet.UserCount()*9+t.rankingTrainSet.ItemCount()*4) var scanCount, reclaimCount int start := time.Now() err := t.CacheClient.Scan(func(s string) error { @@ -1383,7 +1315,6 @@ func (t *CacheGarbageCollectionTask) run(j *task.JobsAllocator) error { return nil } scanCount++ - t.taskMonitor.Update(TaskCacheGarbageCollection, scanCount) switch splits[0] { case cache.UserNeighbors, cache.UserNeighborsDigest, cache.OfflineRecommend, cache.OfflineRecommendDigest, cache.CollaborativeRecommend, @@ -1437,7 +1368,6 @@ func (t *CacheGarbageCollectionTask) run(j *task.JobsAllocator) error { } return nil }) - t.taskMonitor.Finish(TaskCacheGarbageCollection) CacheScannedTotal.Set(float64(scanCount)) CacheReclaimedTotal.Set(float64(reclaimCount)) CacheScannedSeconds.Set(time.Since(start).Seconds()) @@ -1448,7 +1378,6 @@ func (t *CacheGarbageCollectionTask) run(j *task.JobsAllocator) error { func (m *Master) LoadDataFromDatabase(database data.Database, posFeedbackTypes, readTypes []string, itemTTL, positiveFeedbackTTL uint, evaluator *OnlineEvaluator) ( rankingDataset *ranking.DataSet, clickDataset *click.Dataset, latestItems *cache.DocumentAggregator, popularItems *cache.DocumentAggregator, err error) { startLoadTime := time.Now() - m.taskMonitor.Start(TaskLoadDataset, 5) ctx := context.Background() // setup time limit var itemTimeLimit, feedbackTimeLimit *time.Time @@ -1515,7 +1444,6 @@ func (m *Master) LoadDataFromDatabase(database data.Database, posFeedbackTypes, return nil, nil, nil, nil, errors.Trace(err) } rankingDataset.NumUserLabels = userLabelIndex.Len() - m.taskMonitor.Update(TaskLoadDataset, 1) log.Logger().Debug("pulled users from database", zap.Int("n_users", rankingDataset.UserCount()), zap.Int32("n_user_labels", userLabelIndex.Len()), @@ -1581,7 +1509,6 @@ func (m *Master) LoadDataFromDatabase(database data.Database, posFeedbackTypes, return nil, nil, nil, nil, errors.Trace(err) } rankingDataset.NumItemLabels = itemLabelIndex.Len() - m.taskMonitor.Update(TaskLoadDataset, 2) log.Logger().Debug("pulled items from database", zap.Int("n_items", rankingDataset.ItemCount()), zap.Int32("n_item_labels", itemLabelIndex.Len()), @@ -1623,7 +1550,6 @@ func (m *Master) LoadDataFromDatabase(database data.Database, posFeedbackTypes, if err = <-errChan; err != nil { return nil, nil, nil, nil, errors.Trace(err) } - m.taskMonitor.Update(TaskLoadDataset, 3) log.Logger().Debug("pulled positive feedback from database", zap.Int("n_positive_feedback", rankingDataset.Count()), zap.Duration("used_time", time.Since(start))) @@ -1658,7 +1584,6 @@ func (m *Master) LoadDataFromDatabase(database data.Database, posFeedbackTypes, if err = <-errChan; err != nil { return nil, nil, nil, nil, errors.Trace(err) } - m.taskMonitor.Update(TaskLoadDataset, 4) FeedbacksTotal.Set(feedbackCount) log.Logger().Debug("pulled negative feedback from database", zap.Duration("used_time", time.Since(start))) @@ -1705,7 +1630,6 @@ func (m *Master) LoadDataFromDatabase(database data.Database, posFeedbackTypes, zap.Int("n_valid_positive", clickDataset.PositiveCount), zap.Int("n_valid_negative", clickDataset.NegativeCount), zap.Duration("used_time", time.Since(start))) - m.taskMonitor.Update(TaskLoadDataset, 5) LoadDatasetStepSecondsVec.WithLabelValues("create_ranking_dataset").Set(time.Since(start).Seconds()) // collect latest items @@ -1734,6 +1658,5 @@ func (m *Master) LoadDataFromDatabase(database data.Database, posFeedbackTypes, popularItems.Add(category, items, scores) } - m.taskMonitor.Finish(TaskLoadDataset) return rankingDataset, clickDataset, latestItems, popularItems, nil } From 778f88f952e90819a7fdb042295f5a4c91c41bc2 Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Sat, 29 Jul 2023 13:20:37 +0800 Subject: [PATCH 11/13] Update dashboard --- base/progress/progress.go | 70 +++++++++++++++++++++++++++++++++------ go.mod | 4 +-- go.sum | 4 +-- master/master.go | 1 + master/tasks.go | 53 ++++++++++++++++++++--------- 5 files changed, 102 insertions(+), 30 deletions(-) diff --git a/base/progress/progress.go b/base/progress/progress.go index 1c7f346cc..ddab44463 100644 --- a/base/progress/progress.go +++ b/base/progress/progress.go @@ -16,10 +16,12 @@ package progress import ( "context" + "sort" "sync" "time" "github.com/google/uuid" + "modernc.org/mathutil" ) type spanKeyType string @@ -47,7 +49,12 @@ func NewTracer(name string) *Tracer { // Start creates a root span. func (t *Tracer) Start(ctx context.Context, name string, total int) (context.Context, *Span) { - span := &Span{name: name, total: total} + span := &Span{ + name: name, + status: StatusRunning, + total: total, + start: time.Now(), + } t.spans.Store(name, span) return context.WithValue(ctx, spanKeyName, span), span } @@ -55,14 +62,14 @@ func (t *Tracer) Start(ctx context.Context, name string, total int) (context.Con func (t *Tracer) List() []Progress { var progress []Progress t.spans.Range(func(key, value interface{}) bool { - // span := value.(*Span) - // progress = append(progress, Progress{ - // Name: span.name, - // Total: span.total, - // Count: span.count, - // }) + span := value.(*Span) + progress = append(progress, span.Progress()) return true }) + // sort by start time + sort.Slice(progress, func(i, j int) bool { + return progress[i].StartTime.Before(progress[j].StartTime) + }) return progress } @@ -71,28 +78,71 @@ type Span struct { status Status total int count int - err error + err string start time.Time finish time.Time children sync.Map } func (s *Span) Add(n int) { - s.count += n + s.count = mathutil.Min(s.count+n, s.total) } func (s *Span) End() { + s.status = StatusComplete s.count = s.total + s.finish = time.Now() } func (s *Span) Error(err error) { - s.err = err + s.err = err.Error() } func (s *Span) Count() int { return s.count } +func (s *Span) Progress() Progress { + // find running children + var children []Progress + s.children.Range(func(key, value interface{}) bool { + child := value.(*Span) + progress := child.Progress() + if progress.Status == StatusRunning { + children = append(children, progress) + } + return true + }) + // leaf node + if len(children) == 0 { + return Progress{ + Name: s.name, + Status: s.status, + Error: s.err, + Count: s.count, + Total: s.total, + StartTime: s.start, + FinishTime: s.finish, + } + } + // non-leaf node + childTotal := children[0].Total + parentTotal := s.total * childTotal + parentCount := s.count * childTotal + for _, child := range children { + parentCount += childTotal * child.Count / child.Total + } + return Progress{ + Name: s.name, + Status: s.status, + Error: s.err, + Count: parentCount, + Total: parentTotal, + StartTime: s.start, + FinishTime: s.finish, + } +} + func Start(ctx context.Context, name string, total int) (context.Context, *Span) { childSpan := &Span{ name: name, diff --git a/go.mod b/go.mod index c58ec072f..5a9cfdd04 100644 --- a/go.mod +++ b/go.mod @@ -20,9 +20,10 @@ require ( github.com/go-redis/redis/v9 v9.0.0-rc.1 github.com/go-resty/resty/v2 v2.7.0 github.com/go-sql-driver/mysql v1.6.0 + github.com/golang/protobuf v1.5.2 github.com/google/uuid v1.3.0 github.com/gorilla/securecookie v1.1.1 - github.com/gorse-io/dashboard v0.0.0-20230319140716-18e3dabe9366 + github.com/gorse-io/dashboard v0.0.0-20230729051855-6c53a42d2bd4 github.com/haxii/go-swagger-ui v0.0.0-20210203093335-a63a6bbde946 github.com/jaswdr/faker v1.16.0 github.com/json-iterator/go v1.1.12 @@ -96,7 +97,6 @@ require ( github.com/go-openapi/swag v0.22.3 // indirect github.com/go-redis/redis/extra/rediscmd/v9 v9.0.0-rc.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang/protobuf v1.5.2 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v2.0.6+incompatible // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0 // indirect diff --git a/go.sum b/go.sum index e63f8e0eb..df7d9e729 100644 --- a/go.sum +++ b/go.sum @@ -310,8 +310,8 @@ github.com/gorgonia/bindgen v0.0.0-20180812032444-09626750019e/go.mod h1:YzKk63P github.com/gorgonia/bindgen v0.0.0-20210223094355-432cd89e7765/go.mod h1:BLHSe436vhQKRfm6wxJgebeK4fDY+ER/8jV3vVH9yYU= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= -github.com/gorse-io/dashboard v0.0.0-20230319140716-18e3dabe9366 h1:s4CZgfU5HnOtJnGJ1EmUipY1IBpZa0nQmtTTz9gvXvM= -github.com/gorse-io/dashboard v0.0.0-20230319140716-18e3dabe9366/go.mod h1:w74IGf70uM5ZCeXmkBhLl3Ux6D+HpBryzcc75VfZA4s= +github.com/gorse-io/dashboard v0.0.0-20230729051855-6c53a42d2bd4 h1:x0bLXsLkjEZdztd0Tw+Hx38vIjzabyj2Fk0EDitKcLk= +github.com/gorse-io/dashboard v0.0.0-20230729051855-6c53a42d2bd4/go.mod h1:bv2Yg9Pn4Dca4xPJbvibpF6LH6BjoxcjsEdIuojNano= github.com/gorse-io/sqlite v1.3.3-0.20220713123255-c322aec4e59e h1:uPQtYQzG1QcC3Qbv+tuEe8Q2l++V4KEcqYSSwB9qobg= github.com/gorse-io/sqlite v1.3.3-0.20220713123255-c322aec4e59e/go.mod h1:PmIOwYnI+F1lRKd6F/PdLXGgI8GZ5H8x8z1yx0+0bmQ= github.com/gorse-io/tensor v0.0.0-20230617102451-4c006ddc5162 h1:W4aIbIvkE9/9PLuGJ7OcWuEtTeUaXgTd2enX440+e7Q= diff --git a/master/master.go b/master/master.go index dc36fad19..2ba55e46d 100644 --- a/master/master.go +++ b/master/master.go @@ -119,6 +119,7 @@ func NewMaster(cfg *config.Config, cacheFile string, managedMode bool) *Master { cacheFile: cacheFile, managedMode: managedMode, jobsScheduler: task.NewJobsScheduler(cfg.Master.NumJobs), + tracer: progress.NewTracer("master"), // default ranking model rankingModelName: "bpr", rankingModelSearcher: ranking.NewModelSearcher( diff --git a/master/tasks.go b/master/tasks.go index fff4d14cd..666c66086 100644 --- a/master/tasks.go +++ b/master/tasks.go @@ -30,6 +30,7 @@ import ( "github.com/zhenghaoz/gorse/base/heap" "github.com/zhenghaoz/gorse/base/log" "github.com/zhenghaoz/gorse/base/parallel" + "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/base/search" "github.com/zhenghaoz/gorse/base/task" "github.com/zhenghaoz/gorse/config" @@ -67,7 +68,9 @@ type Task interface { // runLoadDatasetTask loads dataset. func (m *Master) runLoadDatasetTask() error { - ctx := context.Background() + ctx, span := m.tracer.Start(context.Background(), "Load Dataset", 1) + defer span.End() + initialStartTime := time.Now() log.Logger().Info("load dataset", zap.Strings("positive_feedback_types", m.Config.Recommend.DataSource.PositiveFeedbackTypes), @@ -75,7 +78,7 @@ func (m *Master) runLoadDatasetTask() error { zap.Uint("item_ttl", m.Config.Recommend.DataSource.ItemTTL), zap.Uint("feedback_ttl", m.Config.Recommend.DataSource.PositiveFeedbackTTL)) evaluator := NewOnlineEvaluator() - rankingDataset, clickDataset, latestItems, popularItems, err := m.LoadDataFromDatabase(m.DataClient, + rankingDataset, clickDataset, latestItems, popularItems, err := m.LoadDataFromDatabase(ctx, m.DataClient, m.Config.Recommend.DataSource.PositiveFeedbackTypes, m.Config.Recommend.DataSource.ReadFeedbackTypes, m.Config.Recommend.DataSource.ItemTTL, @@ -220,6 +223,9 @@ func (t *FindItemNeighborsTask) run(ctx context.Context, j *task.JobsAllocator) numItems := dataset.ItemCount() numFeedback := dataset.Count() + _, span := t.tracer.Start(ctx, "Find Item Neighbors", dataset.ItemCount()) + defer span.End() + if numItems == 0 { return nil } else if numItems == t.lastNumItems && numFeedback == t.lastNumFeedback { @@ -250,6 +256,7 @@ func (t *FindItemNeighborsTask) run(ctx context.Context, j *task.JobsAllocator) zap.Int("n_complete_items", completedCount), zap.Int("n_items", dataset.ItemCount()), zap.Int("throughput", throughput/10)) + span.Add(throughput) } } } @@ -545,6 +552,9 @@ func (t *FindUserNeighborsTask) run(ctx context.Context, j *task.JobsAllocator) numUsers := dataset.UserCount() numFeedback := dataset.Count() + newCtx, span := t.tracer.Start(ctx, "Find User Neighbors", dataset.UserCount()) + defer span.End() + if numUsers == 0 { return nil } else if numUsers == t.lastNumUsers && numFeedback == t.lastNumFeedback { @@ -575,6 +585,7 @@ func (t *FindUserNeighborsTask) run(ctx context.Context, j *task.JobsAllocator) zap.Int("n_complete_users", completedCount), zap.Int("n_users", dataset.UserCount()), zap.Int("throughput", throughput)) + span.Add(throughput) } } } @@ -621,9 +632,9 @@ func (t *FindUserNeighborsTask) run(ctx context.Context, j *task.JobsAllocator) start := time.Now() var err error if t.Config.Recommend.UserNeighbors.EnableIndex { - err = t.findUserNeighborsIVF(dataset, labelIDF, itemIDF, completed, j) + err = t.findUserNeighborsIVF(newCtx, dataset, labelIDF, itemIDF, completed, j) } else { - err = t.findUserNeighborsBruteForce(dataset, labeledUsers, labelIDF, itemIDF, completed, j) + err = t.findUserNeighborsBruteForce(newCtx, dataset, labeledUsers, labelIDF, itemIDF, completed, j) } searchTime := time.Since(start) @@ -645,12 +656,11 @@ func (t *FindUserNeighborsTask) run(ctx context.Context, j *task.JobsAllocator) return nil } -func (m *Master) findUserNeighborsBruteForce(dataset *ranking.DataSet, labeledUsers [][]int32, labelIDF, itemIDF []float32, completed chan struct{}, j *task.JobsAllocator) error { +func (m *Master) findUserNeighborsBruteForce(ctx context.Context, dataset *ranking.DataSet, labeledUsers [][]int32, labelIDF, itemIDF []float32, completed chan struct{}, j *task.JobsAllocator) error { var ( updateUserCount atomic.Float64 findNeighborSeconds atomic.Float64 ) - ctx := context.Background() var vectors VectorsInterface switch m.Config.Recommend.UserNeighbors.NeighborType { @@ -730,13 +740,12 @@ func (m *Master) findUserNeighborsBruteForce(dataset *ranking.DataSet, labeledUs return nil } -func (m *Master) findUserNeighborsIVF(dataset *ranking.DataSet, labelIDF, itemIDF []float32, completed chan struct{}, j *task.JobsAllocator) error { +func (m *Master) findUserNeighborsIVF(ctx context.Context, dataset *ranking.DataSet, labelIDF, itemIDF []float32, completed chan struct{}, j *task.JobsAllocator) error { var ( updateUserCount atomic.Float64 buildIndexSeconds atomic.Float64 findNeighborSeconds atomic.Float64 ) - ctx := context.Background() // build index buildStart := time.Now() var index search.VectorIndex @@ -975,6 +984,9 @@ func (t *FitRankingModelTask) priority() int { } func (t *FitRankingModelTask) run(ctx context.Context, j *task.JobsAllocator) error { + newCtx, span := t.Master.tracer.Start(ctx, "Fit Embedding", 1) + defer span.End() + t.rankingDataMutex.RLock() defer t.rankingDataMutex.RUnlock() dataset := t.rankingTrainSet @@ -1010,7 +1022,7 @@ func (t *FitRankingModelTask) run(ctx context.Context, j *task.JobsAllocator) er } startFitTime := time.Now() - score := rankingModel.Fit(ctx, t.rankingTrainSet, t.rankingTestSet, ranking.NewFitConfig().SetJobsAllocator(j)) + score := rankingModel.Fit(newCtx, t.rankingTrainSet, t.rankingTestSet, ranking.NewFitConfig().SetJobsAllocator(j)) CollaborativeFilteringFitSeconds.Set(time.Since(startFitTime).Seconds()) // update ranking model @@ -1077,6 +1089,9 @@ func (t *FitClickModelTask) priority() int { } func (t *FitClickModelTask) run(ctx context.Context, j *task.JobsAllocator) error { + newCtx, span := t.tracer.Start(ctx, "Fit Ranker", 1) + defer span.End() + log.Logger().Info("prepare to fit click model", zap.Int("n_jobs", t.Config.Master.NumJobs)) t.clickDataMutex.RLock() defer t.clickDataMutex.RUnlock() @@ -1120,7 +1135,7 @@ func (t *FitClickModelTask) run(ctx context.Context, j *task.JobsAllocator) erro return nil } startFitTime := time.Now() - score := clickModel.Fit(context.Background(), t.clickTrainSet, t.clickTestSet, click.NewFitConfig(). + score := clickModel.Fit(newCtx, t.clickTrainSet, t.clickTestSet, click.NewFitConfig(). SetJobsAllocator(j)) RankingFitSeconds.Set(time.Since(startFitTime).Seconds()) @@ -1375,10 +1390,12 @@ func (t *CacheGarbageCollectionTask) run(ctx context.Context, j *task.JobsAlloca } // LoadDataFromDatabase loads dataset from data store. -func (m *Master) LoadDataFromDatabase(database data.Database, posFeedbackTypes, readTypes []string, itemTTL, positiveFeedbackTTL uint, evaluator *OnlineEvaluator) ( +func (m *Master) LoadDataFromDatabase(ctx context.Context, database data.Database, posFeedbackTypes, readTypes []string, itemTTL, positiveFeedbackTTL uint, evaluator *OnlineEvaluator) ( rankingDataset *ranking.DataSet, clickDataset *click.Dataset, latestItems *cache.DocumentAggregator, popularItems *cache.DocumentAggregator, err error) { + newCtx, span := progress.Start(ctx, "LoadDataFromDatabase", 4) + defer span.End() + startLoadTime := time.Now() - ctx := context.Background() // setup time limit var itemTimeLimit, feedbackTimeLimit *time.Time if itemTTL > 0 { @@ -1404,7 +1421,7 @@ func (m *Master) LoadDataFromDatabase(database data.Database, posFeedbackTypes, userLabelFirst := make(map[string]int32) userLabelIndex := base.NewMapIndex() start := time.Now() - userChan, errChan := database.GetUserStream(ctx, batchSize) + userChan, errChan := database.GetUserStream(newCtx, batchSize) for users := range userChan { for _, user := range users { rankingDataset.AddUser(user.UserId) @@ -1449,13 +1466,14 @@ func (m *Master) LoadDataFromDatabase(database data.Database, posFeedbackTypes, zap.Int32("n_user_labels", userLabelIndex.Len()), zap.Duration("used_time", time.Since(start))) LoadDatasetStepSecondsVec.WithLabelValues("load_users").Set(time.Since(start).Seconds()) + span.Add(1) // STEP 2: pull items itemLabelCount := make(map[string]int) itemLabelFirst := make(map[string]int32) itemLabelIndex := base.NewMapIndex() start = time.Now() - itemChan, errChan := database.GetItemStream(ctx, batchSize, itemTimeLimit) + itemChan, errChan := database.GetItemStream(newCtx, batchSize, itemTimeLimit) for items := range itemChan { for _, item := range items { rankingDataset.AddItem(item.ItemId) @@ -1514,6 +1532,7 @@ func (m *Master) LoadDataFromDatabase(database data.Database, posFeedbackTypes, zap.Int32("n_item_labels", itemLabelIndex.Len()), zap.Duration("used_time", time.Since(start))) LoadDatasetStepSecondsVec.WithLabelValues("load_items").Set(time.Since(start).Seconds()) + span.Add(1) // create positive set popularCount := make([]int32, rankingDataset.ItemCount()) @@ -1525,7 +1544,7 @@ func (m *Master) LoadDataFromDatabase(database data.Database, posFeedbackTypes, // STEP 3: pull positive feedback var feedbackCount float64 start = time.Now() - feedbackChan, errChan := database.GetFeedbackStream(ctx, batchSize, feedbackTimeLimit, m.Config.Now(), posFeedbackTypes...) + feedbackChan, errChan := database.GetFeedbackStream(newCtx, batchSize, feedbackTimeLimit, m.Config.Now(), posFeedbackTypes...) for feedback := range feedbackChan { for _, f := range feedback { feedbackCount++ @@ -1554,6 +1573,7 @@ func (m *Master) LoadDataFromDatabase(database data.Database, posFeedbackTypes, zap.Int("n_positive_feedback", rankingDataset.Count()), zap.Duration("used_time", time.Since(start))) LoadDatasetStepSecondsVec.WithLabelValues("load_positive_feedback").Set(time.Since(start).Seconds()) + span.Add(1) // create negative set negativeSet := make([]mapset.Set[int32], rankingDataset.UserCount()) @@ -1563,7 +1583,7 @@ func (m *Master) LoadDataFromDatabase(database data.Database, posFeedbackTypes, // STEP 4: pull negative feedback start = time.Now() - feedbackChan, errChan = database.GetFeedbackStream(ctx, batchSize, feedbackTimeLimit, m.Config.Now(), readTypes...) + feedbackChan, errChan = database.GetFeedbackStream(newCtx, batchSize, feedbackTimeLimit, m.Config.Now(), readTypes...) for feedback := range feedbackChan { for _, f := range feedback { feedbackCount++ @@ -1588,6 +1608,7 @@ func (m *Master) LoadDataFromDatabase(database data.Database, posFeedbackTypes, log.Logger().Debug("pulled negative feedback from database", zap.Duration("used_time", time.Since(start))) LoadDatasetStepSecondsVec.WithLabelValues("load_negative_feedback").Set(time.Since(start).Seconds()) + span.Add(1) // STEP 5: create click dataset start = time.Now() From 19a07678e48e51440b8cbbda63c6828756aaae75 Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Sat, 29 Jul 2023 13:25:47 +0800 Subject: [PATCH 12/13] Fix lint --- master/tasks_test.go | 12 ++++++------ worker/worker.go | 4 ---- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/master/tasks_test.go b/master/tasks_test.go index b572d125c..dc02c87d0 100644 --- a/master/tasks_test.go +++ b/master/tasks_test.go @@ -80,7 +80,7 @@ func (s *MasterTestSuite) TestFindItemNeighborsBruteForce() { } // load mock dataset - dataset, _, _, _, err := s.LoadDataFromDatabase(s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) + dataset, _, _, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) s.NoError(err) s.rankingTrainSet = dataset @@ -185,7 +185,7 @@ func (s *MasterTestSuite) TestFindItemNeighborsIVF() { } // load mock dataset - dataset, _, _, _, err := s.LoadDataFromDatabase(s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) + dataset, _, _, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) s.NoError(err) s.rankingTrainSet = dataset @@ -252,7 +252,7 @@ func (s *MasterTestSuite) TestFindItemNeighborsIVF_ZeroIDF() { {FeedbackKey: data.FeedbackKey{FeedbackType: "FeedbackType", UserId: "0", ItemId: "1"}}, }, true, true, true) s.NoError(err) - dataset, _, _, _, err := s.LoadDataFromDatabase(s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) + dataset, _, _, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) s.NoError(err) s.rankingTrainSet = dataset @@ -312,7 +312,7 @@ func (s *MasterTestSuite) TestFindUserNeighborsBruteForce() { s.NoError(err) err = s.DataClient.BatchInsertFeedback(ctx, feedbacks, true, true, true) s.NoError(err) - dataset, _, _, _, err := s.LoadDataFromDatabase(s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) + dataset, _, _, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) s.NoError(err) s.rankingTrainSet = dataset @@ -392,7 +392,7 @@ func (s *MasterTestSuite) TestFindUserNeighborsIVF() { s.NoError(err) err = s.DataClient.BatchInsertFeedback(ctx, feedbacks, true, true, true) s.NoError(err) - dataset, _, _, _, err := s.LoadDataFromDatabase(s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) + dataset, _, _, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) s.NoError(err) s.rankingTrainSet = dataset @@ -451,7 +451,7 @@ func (s *MasterTestSuite) TestFindUserNeighborsIVF_ZeroIDF() { {FeedbackKey: data.FeedbackKey{FeedbackType: "FeedbackType", UserId: "1", ItemId: "0"}}, }, true, true, true) s.NoError(err) - dataset, _, _, _, err := s.LoadDataFromDatabase(s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) + dataset, _, _, _, err := s.LoadDataFromDatabase(context.Background(), s.DataClient, []string{"FeedbackType"}, nil, 0, 0, NewOnlineEvaluator()) s.NoError(err) s.rankingTrainSet = dataset diff --git a/worker/worker.go b/worker/worker.go index 04cabbe43..914080c69 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -456,10 +456,6 @@ func (w *Worker) Recommend(users []data.User) { // progress tracker completed := make(chan struct{}, 1000) - recommendTaskName := "Generate offline recommendation" - if !w.oneMode { - recommendTaskName += fmt.Sprintf(" [%s]", w.workerName) - } _, span := w.tracer.Start(context.Background(), "Recommend", len(users)) defer span.End() From 6aafb0d966941e916aa7fcd9329263fee6d4415a Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Sat, 29 Jul 2023 13:36:58 +0800 Subject: [PATCH 13/13] Fix test --- master/master_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/master/master_test.go b/master/master_test.go index 8e2dfbdee..a58277645 100644 --- a/master/master_test.go +++ b/master/master_test.go @@ -18,6 +18,7 @@ import ( "testing" "github.com/stretchr/testify/suite" + "github.com/zhenghaoz/gorse/base/progress" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/storage/cache" "github.com/zhenghaoz/gorse/storage/data" @@ -31,6 +32,7 @@ type MasterTestSuite struct { func (s *MasterTestSuite) SetupTest() { // open database var err error + s.tracer = progress.NewTracer("test") s.Settings = config.NewSettings() s.DataClient, err = data.Open(fmt.Sprintf("sqlite://%s/data.db", s.T().TempDir()), "") s.NoError(err)