From 0daaf1747cfa4e4850376ad50a7834fb78b0cc0e Mon Sep 17 00:00:00 2001
From: abhijeet45 <abhi.bh.31296@gmail.com>
Date: Thu, 22 Aug 2024 16:33:42 +0530
Subject: [PATCH] fix: AfterQuery using safer right trim while clearing from
 clause's join added as part of https://github.com/go-gorm/gorm/pull/7027
 (#7153)

Co-authored-by: Abhijeet Bhowmik <abhijeet.bhowmik@cambiumnetworks.com>
---
 callbacks/query.go  |  2 +-
 utils/utils.go      | 11 ++++++++
 utils/utils_test.go | 61 +++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 73 insertions(+), 1 deletion(-)

diff --git a/callbacks/query.go b/callbacks/query.go
index 9b2b17ea94..bbf238a9fd 100644
--- a/callbacks/query.go
+++ b/callbacks/query.go
@@ -288,7 +288,7 @@ func AfterQuery(db *gorm.DB) {
 	// clear the joins after query because preload need it
 	if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
 		fromClause := db.Statement.Clauses["FROM"]
-		fromClause.Expression = clause.From{Tables: v.Tables, Joins: v.Joins[:len(v.Joins)-len(db.Statement.Joins)]} // keep the original From Joins
+		fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins
 		db.Statement.Clauses["FROM"] = fromClause
 	}
 	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
diff --git a/utils/utils.go b/utils/utils.go
index b8d30b35b9..fc615d73b6 100644
--- a/utils/utils.go
+++ b/utils/utils.go
@@ -166,3 +166,14 @@ func SplitNestedRelationName(name string) []string {
 func JoinNestedRelationNames(relationNames []string) string {
 	return strings.Join(relationNames, nestedRelationSplit)
 }
+
+// RTrimSlice Right trims the given slice by given length
+func RTrimSlice[T any](v []T, trimLen int) []T {
+	if trimLen >= len(v) { // trimLen greater than slice len means fully sliced
+		return v[:0]
+	}
+	if trimLen < 0 { // negative trimLen is ignored
+		return v[:]
+	}
+	return v[:len(v)-trimLen]
+}
diff --git a/utils/utils_test.go b/utils/utils_test.go
index 8ff42af8d1..089cc4c8e9 100644
--- a/utils/utils_test.go
+++ b/utils/utils_test.go
@@ -138,3 +138,64 @@ func TestToString(t *testing.T) {
 		})
 	}
 }
+
+func TestRTrimSlice(t *testing.T) {
+	tests := []struct {
+		name     string
+		input    []int
+		trimLen  int
+		expected []int
+	}{
+		{
+			name:     "Trim two elements from end",
+			input:    []int{1, 2, 3, 4, 5},
+			trimLen:  2,
+			expected: []int{1, 2, 3},
+		},
+		{
+			name:     "Trim entire slice",
+			input:    []int{1, 2, 3},
+			trimLen:  3,
+			expected: []int{},
+		},
+		{
+			name:     "Trim length greater than slice length",
+			input:    []int{1, 2, 3},
+			trimLen:  5,
+			expected: []int{},
+		},
+		{
+			name:     "Zero trim length",
+			input:    []int{1, 2, 3},
+			trimLen:  0,
+			expected: []int{1, 2, 3},
+		},
+		{
+			name:     "Trim one element from end",
+			input:    []int{1, 2, 3},
+			trimLen:  1,
+			expected: []int{1, 2},
+		},
+		{
+			name:     "Empty slice",
+			input:    []int{},
+			trimLen:  2,
+			expected: []int{},
+		},
+		{
+			name:     "Negative trim length (should be treated as zero)",
+			input:    []int{1, 2, 3},
+			trimLen:  -1,
+			expected: []int{1, 2, 3},
+		},
+	}
+
+	for _, testcase := range tests {
+		t.Run(testcase.name, func(t *testing.T) {
+			result := RTrimSlice(testcase.input, testcase.trimLen)
+			if !AssertEqual(result, testcase.expected) {
+				t.Errorf("RTrimSlice(%v, %d) = %v; want %v", testcase.input, testcase.trimLen, result, testcase.expected)
+			}
+		})
+	}
+}