Skip to content
Merged
176 changes: 176 additions & 0 deletions cmd/gorse-benchmark/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Copyright 2026 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"context"
"fmt"
"log"
"os"
"runtime"
"sort"

"github.com/gorse-io/gorse/config"
"github.com/gorse-io/gorse/dataset"
"github.com/gorse-io/gorse/master"
"github.com/gorse-io/gorse/model/ctr"
"github.com/gorse-io/gorse/storage"
"github.com/gorse-io/gorse/storage/data"
"github.com/samber/lo"
"github.com/spf13/cobra"
"modernc.org/sortutil"
)

var rootCmd = &cobra.Command{
Use: "gorse-benchmark",
Short: "Gorse Benchmarking Tool",
}

var llmCmd = &cobra.Command{
Use: "llm",
Short: "Benchmark LLM models",
Run: func(cmd *cobra.Command, args []string) {
// Load configuration
configPath, _ := cmd.Flags().GetString("config")
cfg, err := config.LoadConfig(configPath)
if err != nil {
log.Fatalf("failed to load config: %v", err)
}
// Load dataset
m := master.NewMaster(cfg, os.TempDir(), false)
m.DataClient, err = data.Open(m.Config.Database.DataStore, m.Config.Database.DataTablePrefix,
storage.WithIsolationLevel(m.Config.Database.MySQL.IsolationLevel))
if err != nil {
log.Fatalf("failed to open data client: %v", err)
}
evaluator := master.NewOnlineEvaluator(
m.Config.Recommend.DataSource.PositiveFeedbackTypes,
m.Config.Recommend.DataSource.ReadFeedbackTypes)
dataset, _, err := m.LoadDataFromDatabase(context.Background(), m.DataClient,
m.Config.Recommend.DataSource.PositiveFeedbackTypes,
m.Config.Recommend.DataSource.ReadFeedbackTypes,
m.Config.Recommend.DataSource.ItemTTL,
m.Config.Recommend.DataSource.PositiveFeedbackTTL,
evaluator,
nil)
if err != nil {
log.Fatalf("failed to load dataset: %v", err)
}
fmt.Println("Dataset loaded:")
fmt.Printf(" Users: %d\n", dataset.CountUsers())
fmt.Printf(" Items: %d\n", dataset.CountItems())
fmt.Printf(" Positive Feedbacks: %d\n", dataset.CountPositive())
fmt.Printf(" Negative Feedbacks: %d\n", dataset.CountNegative())
// Split dataset
train, test := dataset.Split(0.2, 42)
EvaluateFM(train, test)
// EvaluateLLM(cfg, train, test, aux.GetItems())
},
}

func EvaluateFM(train, test dataset.CTRSplit) float32 {
fmt.Println("Training FM...")
ml := ctr.NewAFM(nil)
ml.Fit(context.Background(), train, test,
ctr.NewFitConfig().
SetVerbose(10).
SetJobs(runtime.NumCPU()).
SetPatience(10))

userTrain := make(map[int32]int, train.CountUsers())
for i := 0; i < train.Count(); i++ {
indices, _, _, target := train.Get(i)
userId := indices[0]
if target > 0 {
userTrain[userId]++
}
}

var posFeatures, negFeatures []lo.Tuple2[[]int32, []float32]
var posEmbeddings, negEmbeddings [][][]float32
var posUsers, negUsers []int32
for i := 0; i < test.Count(); i++ {
indices, values, embeddings, target := test.Get(i)
userId := indices[0]
if target > 0 {
posFeatures = append(posFeatures, lo.Tuple2[[]int32, []float32]{A: indices, B: values})
posEmbeddings = append(posEmbeddings, embeddings)
posUsers = append(posUsers, userId)
} else {
negFeatures = append(negFeatures, lo.Tuple2[[]int32, []float32]{A: indices, B: values})
negEmbeddings = append(negEmbeddings, embeddings)
negUsers = append(negUsers, userId)
}
}
posPrediction := ml.BatchInternalPredict(posFeatures, posEmbeddings, runtime.NumCPU())
negPrediction := ml.BatchInternalPredict(negFeatures, negEmbeddings, runtime.NumCPU())

userPosPrediction := make(map[int32][]float32)
userNegPrediction := make(map[int32][]float32)
for i, p := range posPrediction {
userPosPrediction[posUsers[i]] = append(userPosPrediction[posUsers[i]], p)
}
for i, p := range negPrediction {
userNegPrediction[negUsers[i]] = append(userNegPrediction[negUsers[i]], p)
}
var sumAUC float32
var validUsers float32
for user, pos := range userPosPrediction {
if userTrain[user] > 100 || userTrain[user] == 0 {
continue
}
if neg, ok := userNegPrediction[user]; ok {
sumAUC += AUC(pos, neg) * float32(len(pos))
validUsers += float32(len(pos))
}
}
if validUsers == 0 {
return 0
}
score := sumAUC / validUsers

fmt.Println("FM GAUC:", score)
return score
}

func AUC(posPrediction, negPrediction []float32) float32 {
sort.Sort(sortutil.Float32Slice(posPrediction))
sort.Sort(sortutil.Float32Slice(negPrediction))
var sum float32
var nPos int
for pPos := range posPrediction {
// find the negative sample with the greatest prediction less than current positive sample
for nPos < len(negPrediction) && negPrediction[nPos] < posPrediction[pPos] {
nPos++
}
// add the number of negative samples have less prediction than current positive sample
sum += float32(nPos)
}
if len(posPrediction)*len(negPrediction) == 0 {
return 0
}
return sum / float32(len(posPrediction)*len(negPrediction))
}

func init() {
rootCmd.PersistentFlags().StringP("config", "c", "", "Path to configuration file")
rootCmd.AddCommand(llmCmd)
}

func main() {
if err := rootCmd.Execute(); err != nil {
log.Fatal(err)
}
}
1 change: 1 addition & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ coverage:

ignore:
- "protocol/*.pb.go"
- "cmd/**"
32 changes: 32 additions & 0 deletions common/nn/layers.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,38 @@ func (s *Sequential) SetJobs(jobs int) {
}
}

type Attention struct {
W Layer
H *Tensor
jobs int
}

func NewAttention(dimensions, k int) *Attention {
return &Attention{
W: NewLinear(dimensions, k),
H: Normal(0, 0.01, k, dimensions),
}
}

func (a *Attention) Parameters() []*Tensor {
var params []*Tensor
params = append(params, a.H)
params = append(params, a.W.Parameters()...)
return params
}

func (a *Attention) Forward(x *Tensor) *Tensor {
return Mul(
Softmax(MatMul(ReLu(a.W.Forward(x)), a.H, false, false, a.jobs), 1),
x,
)
}

func (a *Attention) SetJobs(jobs int) {
a.W.SetJobs(jobs)
a.jobs = max(1, jobs)
}

func Save(o any, w io.Writer) error {
var save func(o any, key []string) error
save = func(o any, key []string) error {
Expand Down
4 changes: 3 additions & 1 deletion dataset/dataset.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ type CTRSplit interface {
CountNegative() int
GetIndex() UnifiedIndex
GetTarget(i int) float32
Get(i int) ([]int32, []float32, float32)
Get(i int) ([]int32, []float32, [][]float32, float32)
GetItemEmbeddingDim() []int
GetItemEmbeddingIndex() *Index
}

type Dataset struct {
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ require (
golang.org/x/crypto v0.45.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/sync v0.18.0 // indirect
golang.org/x/term v0.37.0 // indirect
golang.org/x/term v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
golang.org/x/time v0.14.0 // indirect
gonum.org/v1/gonum v0.16.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1054,8 +1054,8 @@ golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand Down
4 changes: 3 additions & 1 deletion logics/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,11 @@ func (r *ChatRanker) Rank(ctx context.Context, user *data.User, feedback []*Feed
s.Add(item.ItemId)
}
var result []string
m := mapset.NewSet[string]()
for _, itemId := range parsed {
if s.Contains(itemId) {
if s.Contains(itemId) && !m.Contains(itemId) {
result = append(result, itemId)
m.Add(itemId)
}
}
return result, nil
Expand Down
2 changes: 1 addition & 1 deletion master/rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func newMockMasterRPC(t *testing.T) *mockMasterRPC {
assert.NoError(t, err)
// create click model
train, test := newClickDataset()
fm := ctr.NewFMV2(model.Params{model.NEpochs: 0})
fm := ctr.NewAFM(model.Params{model.NEpochs: 0})
fm.Fit(context.Background(), train, test, &ctr.FitConfig{})
// create ranking model
trainSet, testSet := newRankingDataset()
Expand Down
4 changes: 2 additions & 2 deletions master/tasks.go
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ func (m *Master) trainClickThroughRatePrediction(parent context.Context, trainSe
zap.Float32("Recall", m.clickThroughRateTarget.Score.Recall),
zap.Any("params", clickThroughRateParams))
}
clickModel := ctr.NewFMV2(clickThroughRateParams)
clickModel := ctr.NewAFM(clickThroughRateParams)
m.clickThroughRateModelMutex.Unlock()

startFitTime := time.Now()
Expand Down Expand Up @@ -1244,7 +1244,7 @@ func (m *Master) optimizeClickThroughRatePrediction(parent context.Context, trai

search := ctr.NewModelSearch(map[string]ctr.ModelCreator{
"FM": func() ctr.FactorizationMachines {
return ctr.NewFMV2(nil)
return ctr.NewAFM(nil)
},
}, trainSet, testSet,
ctr.NewFitConfig().
Expand Down
42 changes: 30 additions & 12 deletions model/ctr/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ type Dataset struct {
Users []int32
Items []int32
Target []float32
ItemEmbeddings [][][]float32
ItemEmbeddings [][][]float32 // Index by row id, embedding id, embedding dimension
ItemEmbeddingDimension []int
ItemEmbeddingIndex *dataset.Index
PositiveCount int
Expand Down Expand Up @@ -207,11 +207,12 @@ func (dataset *Dataset) GetTarget(i int) float32 {
}

// Get returns the i-th sample.
func (dataset *Dataset) Get(i int) ([]int32, []float32, float32) {
func (dataset *Dataset) Get(i int) ([]int32, []float32, [][]float32, float32) {
var (
indices []int32
values []float32
position int32
indices []int32
values []float32
embedding [][]float32
position int32
)
// append user id
if len(dataset.Users) > 0 {
Expand All @@ -224,6 +225,9 @@ func (dataset *Dataset) Get(i int) ([]int32, []float32, float32) {
indices = append(indices, position+dataset.Items[i])
values = append(values, 1)
position += int32(dataset.CountItems())
if len(dataset.ItemEmbeddings) > 0 {
embedding = dataset.ItemEmbeddings[dataset.Items[i]]
}
}
// append user indices
if len(dataset.Users) > 0 {
Expand All @@ -248,7 +252,7 @@ func (dataset *Dataset) Get(i int) ([]int32, []float32, float32) {
indices = append(indices, contextIndices...)
values = append(values, contextValues...)
}
return indices, values, dataset.Target[i]
return indices, values, embedding, dataset.Target[i]
}

// LoadLibFMFile loads libFM format file.
Expand Down Expand Up @@ -325,14 +329,20 @@ func LoadDataFromBuiltIn(name string) (train, test *Dataset, err error) {
func (dataset *Dataset) Split(ratio float32, seed int64) (*Dataset, *Dataset) {
// create train/test dataset
trainSet := &Dataset{
Index: dataset.Index,
UserLabels: dataset.UserLabels,
ItemLabels: dataset.ItemLabels,
Index: dataset.Index,
UserLabels: dataset.UserLabels,
ItemLabels: dataset.ItemLabels,
ItemEmbeddings: dataset.ItemEmbeddings,
ItemEmbeddingIndex: dataset.ItemEmbeddingIndex,
ItemEmbeddingDimension: dataset.ItemEmbeddingDimension,
}
testSet := &Dataset{
Index: dataset.Index,
UserLabels: dataset.UserLabels,
ItemLabels: dataset.ItemLabels,
Index: dataset.Index,
UserLabels: dataset.UserLabels,
ItemLabels: dataset.ItemLabels,
ItemEmbeddings: dataset.ItemEmbeddings,
ItemEmbeddingIndex: dataset.ItemEmbeddingIndex,
ItemEmbeddingDimension: dataset.ItemEmbeddingDimension,
}
// split by random
numTestSize := int(float32(dataset.Count()) * ratio)
Expand Down Expand Up @@ -369,3 +379,11 @@ func (dataset *Dataset) Split(ratio float32, seed int64) (*Dataset, *Dataset) {
}
return trainSet, testSet
}

func (dataset *Dataset) GetItemEmbeddingDim() []int {
return dataset.ItemEmbeddingDimension
}

func (dataset *Dataset) GetItemEmbeddingIndex() *dataset.Index {
return dataset.ItemEmbeddingIndex
}
Loading
Loading