diff --git a/go/vt/vttablet/tabletserver/state_manager.go b/go/vt/vttablet/tabletserver/state_manager.go
index 11da06d6d4e..e231a628748 100644
--- a/go/vt/vttablet/tabletserver/state_manager.go
+++ b/go/vt/vttablet/tabletserver/state_manager.go
@@ -485,7 +485,7 @@ func (sm *stateManager) unservePrimary() error {
 	sm.se.MakePrimary(false)
 	sm.hs.MakePrimary(false)
 	sm.rt.MakePrimary()
-	sm.txThrottler.MakeNonPrimary()
+	sm.txThrottler.MakePrimary()
 	sm.setState(topodatapb.TabletType_PRIMARY, StateNotServing)
 	return nil
 }
diff --git a/go/vt/vttablet/tabletserver/state_manager_test.go b/go/vt/vttablet/tabletserver/state_manager_test.go
index 02896eeefe0..45257c995ad 100644
--- a/go/vt/vttablet/tabletserver/state_manager_test.go
+++ b/go/vt/vttablet/tabletserver/state_manager_test.go
@@ -77,18 +77,17 @@ func TestStateManagerServePrimary(t *testing.T) {
 	assert.Equal(t, testNow, sm.ptsTimestamp)
 
 	verifySubcomponent(t, 1, sm.watcher, testStateClosed)
-
 	verifySubcomponent(t, 2, sm.se, testStateOpen)
 	verifySubcomponent(t, 3, sm.vstreamer, testStateOpen)
 	verifySubcomponent(t, 4, sm.qe, testStateOpen)
-	verifySubcomponent(t, 5, sm.txThrottler, testStateOpen)
 	verifySubcomponent(t, 6, sm.rt, testStatePrimary)
-	verifySubcomponent(t, 7, sm.tracker, testStateOpen)
-	verifySubcomponent(t, 8, sm.te, testStatePrimary)
-	verifySubcomponent(t, 9, sm.messager, testStateOpen)
-	verifySubcomponent(t, 10, sm.throttler, testStateOpen)
-	verifySubcomponent(t, 11, sm.tableGC, testStateOpen)
-	verifySubcomponent(t, 12, sm.ddle, testStateOpen)
+	verifySubcomponent(t, 7, sm.txThrottler, testStatePrimary)
+	verifySubcomponent(t, 8, sm.tracker, testStateOpen)
+	verifySubcomponent(t, 9, sm.te, testStatePrimary)
+	verifySubcomponent(t, 10, sm.messager, testStateOpen)
+	verifySubcomponent(t, 11, sm.throttler, testStateOpen)
+	verifySubcomponent(t, 12, sm.tableGC, testStateOpen)
+	verifySubcomponent(t, 13, sm.ddle, testStateOpen)
 
 	assert.False(t, sm.se.(*testSchemaEngine).nonPrimary)
 	assert.True(t, sm.se.(*testSchemaEngine).ensureCalled)
@@ -109,14 +108,14 @@ func TestStateManagerServeNonPrimary(t *testing.T) {
 	verifySubcomponent(t, 4, sm.tracker, testStateClosed)
 	assert.True(t, sm.se.(*testSchemaEngine).nonPrimary)
 
-	verifySubcomponent(t, 5, sm.se, testStateOpen)
-	verifySubcomponent(t, 6, sm.vstreamer, testStateOpen)
-	verifySubcomponent(t, 7, sm.qe, testStateOpen)
-	verifySubcomponent(t, 8, sm.txThrottler, testStateOpen)
-	verifySubcomponent(t, 9, sm.te, testStateNonPrimary)
-	verifySubcomponent(t, 10, sm.rt, testStateNonPrimary)
-	verifySubcomponent(t, 11, sm.watcher, testStateOpen)
-	verifySubcomponent(t, 12, sm.throttler, testStateOpen)
+	verifySubcomponent(t, 6, sm.se, testStateOpen)
+	verifySubcomponent(t, 7, sm.vstreamer, testStateOpen)
+	verifySubcomponent(t, 8, sm.qe, testStateOpen)
+	verifySubcomponent(t, 9, sm.txThrottler, testStateOpen)
+	verifySubcomponent(t, 10, sm.te, testStateNonPrimary)
+	verifySubcomponent(t, 11, sm.rt, testStateNonPrimary)
+	verifySubcomponent(t, 12, sm.watcher, testStateOpen)
+	verifySubcomponent(t, 13, sm.throttler, testStateOpen)
 
 	assert.Equal(t, topodatapb.TabletType_REPLICA, sm.target.TabletType)
 	assert.Equal(t, StateServing, sm.state)
@@ -139,9 +138,9 @@ func TestStateManagerUnservePrimary(t *testing.T) {
 	verifySubcomponent(t, 8, sm.se, testStateOpen)
 	verifySubcomponent(t, 9, sm.vstreamer, testStateOpen)
 	verifySubcomponent(t, 10, sm.qe, testStateOpen)
-	verifySubcomponent(t, 11, sm.txThrottler, testStateOpen)
 
 	verifySubcomponent(t, 12, sm.rt, testStatePrimary)
+	verifySubcomponent(t, 13, sm.txThrottler, testStatePrimary)
 
 	assert.Equal(t, topodatapb.TabletType_PRIMARY, sm.target.TabletType)
 	assert.Equal(t, StateNotServing, sm.state)
@@ -165,7 +164,7 @@ func TestStateManagerUnserveNonPrimary(t *testing.T) {
 	verifySubcomponent(t, 7, sm.se, testStateOpen)
 	verifySubcomponent(t, 8, sm.vstreamer, testStateOpen)
 	verifySubcomponent(t, 9, sm.qe, testStateOpen)
-	verifySubcomponent(t, 10, sm.txThrottler, testStateOpen)
+	verifySubcomponent(t, 10, sm.txThrottler, testStateNonPrimary)
 
 	verifySubcomponent(t, 11, sm.rt, testStateNonPrimary)
 	verifySubcomponent(t, 12, sm.watcher, testStateOpen)
@@ -932,6 +931,16 @@ func (te *testTxThrottler) Close() {
 	te.state = testStateClosed
 }
 
+func (te *testTxThrottler) MakePrimary() {
+	te.order = order.Add(1)
+	te.state = testStatePrimary
+}
+
+func (te *testTxThrottler) MakeNonPrimary() {
+	te.order = order.Add(1)
+	te.state = testStateNonPrimary
+}
+
 type testOnlineDDLExecutor struct {
 	testOrderState
 }