From 3b2c4bd70cbf218c2ac70df859d078cb7a39de37 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Thu, 2 Jan 2025 17:05:11 -1000 Subject: [PATCH 1/7] Add take/takeRight/drop/dropRight to Chain --- core/src/main/scala/cats/data/Chain.scala | 164 ++++++++++++++++++ .../test/scala/cats/tests/ChainSuite.scala | 23 +++ 2 files changed, 187 insertions(+) diff --git a/core/src/main/scala/cats/data/Chain.scala b/core/src/main/scala/cats/data/Chain.scala index c93a1719e9..bb803c9a92 100644 --- a/core/src/main/scala/cats/data/Chain.scala +++ b/core/src/main/scala/cats/data/Chain.scala @@ -256,6 +256,90 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { result } + /** + * take a certain amount of items from the front of the Chain + */ + final def take(count: Int): Chain[A] = { + // invariant count >= 1 + @tailrec + def go(lhs: Chain[A], count: Int, arg: Chain[A], rhs: Chain[A]): Chain[A] = + arg match { + case Wrap(seq) => + if (count == 1) { + lhs.append(seq(0)) + } else { + // count > 1 + val taken = seq.take(count) + // we may have not takeped all of count + val newCount = count - taken.length + val newLhs = lhs.concat(Wrap(taken)) + if (newCount > 0) { + // we have to keep takeping on the rhs + go(newLhs, newCount, rhs, Chain.nil) + } else { + // newCount == 0, we have taken enough + newLhs + } + } + case Append(l, r) => + go(lhs, count, l, r.concat(rhs)) + case s @ Singleton(_) => + // due to the invariant count >= 1 + val newLhs = if (lhs.isEmpty) s else Append(lhs, s) + if (count > 1) { + go(newLhs, count - 1, rhs, Chain.nil) + } else newLhs + case Empty => + if (rhs.isEmpty) lhs + else go(lhs, count, rhs, Chain.nil) + } + + if (count <= 0) Empty + else go(Empty, count, this, Empty) + } + + /** + * take a certain amount of items from the back of the Chain + */ + final def takeRight(count: Int): Chain[A] = { + // invariant count >= 1 + @tailrec + def go(lhs: Chain[A], count: Int, arg: Chain[A], rhs: Chain[A]): Chain[A] = + arg match { + case Wrap(seq) => + if (count == 1) { + lhs.append(seq.last) + } else { + // count > 1 + val taken = seq.takeRight(count) + // we may have not takeped all of count + val newCount = count - taken.length + val newRhs = Wrap(taken).concat(rhs) + if (newCount > 0) { + // we have to keep takeping on the rhs + go(Chain.nil, newCount, lhs, newRhs) + } else { + // newCount == 0, we have taken enough + newRhs + } + } + case Append(l, r) => + go(lhs.concat(l), count, r, rhs) + case s @ Singleton(_) => + // due to the invariant count >= 1 + val newRhs = if (rhs.isEmpty) s else Append(s, rhs) + if (count > 1) { + go(Empty, count - 1, lhs, newRhs) + } else newRhs + case Empty => + if (lhs.isEmpty) rhs + else go(Chain.nil, count, lhs, rhs) + } + + if (count <= 0) Empty + else go(Empty, count, this, Empty) + } + /** * Drops longest prefix of elements that satisfy a predicate. * @@ -275,6 +359,86 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { go(this) } + /** + * Drop a certain amount of items from the front of the Chain + */ + final def drop(count: Int): Chain[A] = { + // invariant count >= 1 + @tailrec + def go(count: Int, arg: Chain[A], rhs: Chain[A]): Chain[A] = + arg match { + case Wrap(seq) => + val dropped = seq.drop(count) + if (dropped.isEmpty) { + // we may have not dropped all of count + val newCount = count - seq.length + if (newCount > 0) { + // we have to keep dropping on the rhs + go(newCount, rhs, Chain.nil) + } else { + // we know that count >= seq.length else we wouldn't be empty + // so in this case, it is exactly count == seq.length + rhs + } + } else { + // we must be done + Chain.fromSeq(dropped).concat(rhs) + } + case Append(l, r) => + go(count, l, r.concat(rhs)) + case Singleton(_) => + // due to the invariant count >= 1 + if (count > 1) go(count - 1, rhs, Chain.nil) + else rhs + case Empty => + if (rhs.isEmpty) Empty + else go(count, rhs, Chain.nil) + } + + if (count <= 0) this + else go(count, this, Empty) + } + + /** + * Drop a certain amount of items from the back of the Chain + */ + final def dropRight(count: Int): Chain[A] = { + // invariant count >= 1 + @tailrec + def go(lhs: Chain[A], count: Int, arg: Chain[A]): Chain[A] = + arg match { + case Wrap(seq) => + val dropped = seq.dropRight(count) + if (dropped.isEmpty) { + // we may have not dropped all of count + val newCount = count - seq.length + if (newCount > 0) { + // we have to keep dropping on the rhs + go(Chain.nil, newCount, lhs) + } else { + // we know that count >= seq.length else we wouldn't be empty + // so in this case, it is exactly count == seq.length + lhs + } + } else { + // we must be done + lhs.concat(Chain.fromSeq(dropped)) + } + case Append(l, r) => + go(lhs.concat(l), count, r) + case Singleton(_) => + // due to the invariant count >= 1 + if (count > 1) go(Chain.nil, count - 1, lhs) + else lhs + case Empty => + if (lhs.isEmpty) Empty + else go(Chain.nil, count, lhs) + } + + if (count <= 0) this + else go(Empty, count, this) + } + /** * Folds over the elements from right to left using the supplied initial value and function. */ diff --git a/tests/shared/src/test/scala/cats/tests/ChainSuite.scala b/tests/shared/src/test/scala/cats/tests/ChainSuite.scala index 4a920ad4c4..7794823bfd 100644 --- a/tests/shared/src/test/scala/cats/tests/ChainSuite.scala +++ b/tests/shared/src/test/scala/cats/tests/ChainSuite.scala @@ -448,4 +448,27 @@ class ChainSuite extends CatsSuite { assert(chain.foldRight(init)(fn) == chain.toList.foldRight(init)(fn)) } } + + test("drop(cnt).toList == toList.drop(cnt)") { + forAll { (chain: Chain[Int], count: Int) => + assert(chain.drop(count).toList == chain.toList.drop(count)) + } + } + + test("dropRight(cnt).toList == toList.dropRight(cnt)") { + forAll { (chain: Chain[Int], count: Int) => + assert(chain.dropRight(count).toList == chain.toList.dropRight(count)) + } + } + test("take(cnt).toList == toList.take(cnt)") { + forAll { (chain: Chain[Int], count: Int) => + assert(chain.take(count).toList == chain.toList.take(count)) + } + } + + test("takeRight(cnt).toList == toList.takeRight(cnt)") { + forAll { (chain: Chain[Int], count: Int) => + assert(chain.takeRight(count).toList == chain.toList.takeRight(count)) + } + } } From 83fe74c3567c2c99d4e9e3802a5f59281a082df9 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Thu, 2 Jan 2025 18:06:34 -1000 Subject: [PATCH 2/7] fix bug, improve scalachecks --- core/src/main/scala/cats/data/Chain.scala | 27 ++++++++++++------- .../test/scala/cats/tests/ChainSuite.scala | 19 ++++++++++--- 2 files changed, 32 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/cats/data/Chain.scala b/core/src/main/scala/cats/data/Chain.scala index bb803c9a92..262beaaf87 100644 --- a/core/src/main/scala/cats/data/Chain.scala +++ b/core/src/main/scala/cats/data/Chain.scala @@ -266,13 +266,15 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { arg match { case Wrap(seq) => if (count == 1) { - lhs.append(seq(0)) + lhs.append(seq.head) } else { // count > 1 val taken = seq.take(count) // we may have not takeped all of count val newCount = count - taken.length - val newLhs = lhs.concat(Wrap(taken)) + val wrapped = Wrap(taken) + // this is more efficient than using concat + val newLhs = if (lhs.isEmpty) wrapped else Append(lhs, wrapped) if (newCount > 0) { // we have to keep takeping on the rhs go(newLhs, newCount, rhs, Chain.nil) @@ -282,7 +284,7 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { } } case Append(l, r) => - go(lhs, count, l, r.concat(rhs)) + go(lhs, count, l, if (rhs.isEmpty) r else Append(r, rhs)) case s @ Singleton(_) => // due to the invariant count >= 1 val newLhs = if (lhs.isEmpty) s else Append(lhs, s) @@ -308,13 +310,14 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { arg match { case Wrap(seq) => if (count == 1) { - lhs.append(seq.last) + seq.last +: rhs } else { // count > 1 val taken = seq.takeRight(count) // we may have not takeped all of count val newCount = count - taken.length - val newRhs = Wrap(taken).concat(rhs) + val wrapped = Wrap(taken) + val newRhs = if (rhs.isEmpty) wrapped else Append(wrapped, rhs) if (newCount > 0) { // we have to keep takeping on the rhs go(Chain.nil, newCount, lhs, newRhs) @@ -324,7 +327,7 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { } } case Append(l, r) => - go(lhs.concat(l), count, r, rhs) + go(if (lhs.isEmpty) l else Append(lhs, l), count, r, rhs) case s @ Singleton(_) => // due to the invariant count >= 1 val newRhs = if (rhs.isEmpty) s else Append(s, rhs) @@ -381,11 +384,13 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { rhs } } else { + // dropped is not empty + val wrapped = Wrap(dropped) // we must be done - Chain.fromSeq(dropped).concat(rhs) + if (rhs.isEmpty) wrapped else Append(wrapped, rhs) } case Append(l, r) => - go(count, l, r.concat(rhs)) + go(count, l, if (rhs.isEmpty) r else Append(r, rhs)) case Singleton(_) => // due to the invariant count >= 1 if (count > 1) go(count - 1, rhs, Chain.nil) @@ -422,10 +427,12 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { } } else { // we must be done - lhs.concat(Chain.fromSeq(dropped)) + // note: dropped.nonEmpty + val wrapped = Wrap(dropped) + if (lhs.isEmpty) wrapped else Append(lhs, wrapped) } case Append(l, r) => - go(lhs.concat(l), count, r) + go(if (lhs.isEmpty) l else Append(lhs, l), count, r) case Singleton(_) => // due to the invariant count >= 1 if (count > 1) go(Chain.nil, count - 1, lhs) diff --git a/tests/shared/src/test/scala/cats/tests/ChainSuite.scala b/tests/shared/src/test/scala/cats/tests/ChainSuite.scala index 7794823bfd..154ec56304 100644 --- a/tests/shared/src/test/scala/cats/tests/ChainSuite.scala +++ b/tests/shared/src/test/scala/cats/tests/ChainSuite.scala @@ -449,25 +449,36 @@ class ChainSuite extends CatsSuite { } } + private val genChainDropTakeArgs = + Arbitrary.arbitrary[Chain[Int]].flatMap { chain => + // Bias to values close to the length + Gen + .oneOf( + Gen.choose(Int.MinValue, Int.MaxValue), + Gen.choose(-1, chain.length.toInt + 1) + ) + .map((chain, _)) + } + test("drop(cnt).toList == toList.drop(cnt)") { - forAll { (chain: Chain[Int], count: Int) => + forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) => assert(chain.drop(count).toList == chain.toList.drop(count)) } } test("dropRight(cnt).toList == toList.dropRight(cnt)") { - forAll { (chain: Chain[Int], count: Int) => + forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) => assert(chain.dropRight(count).toList == chain.toList.dropRight(count)) } } test("take(cnt).toList == toList.take(cnt)") { - forAll { (chain: Chain[Int], count: Int) => + forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) => assert(chain.take(count).toList == chain.toList.take(count)) } } test("takeRight(cnt).toList == toList.takeRight(cnt)") { - forAll { (chain: Chain[Int], count: Int) => + forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) => assert(chain.takeRight(count).toList == chain.toList.takeRight(count)) } } From 49625fef4932854bae0fbfe5efdb4f12d98c4968 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Fri, 3 Jan 2025 08:11:44 -1000 Subject: [PATCH 3/7] respond to review comments --- core/src/main/scala/cats/data/Chain.scala | 68 +++++++++++++---------- 1 file changed, 40 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/cats/data/Chain.scala b/core/src/main/scala/cats/data/Chain.scala index 262beaaf87..0b96e46857 100644 --- a/core/src/main/scala/cats/data/Chain.scala +++ b/core/src/main/scala/cats/data/Chain.scala @@ -259,24 +259,26 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { /** * take a certain amount of items from the front of the Chain */ - final def take(count: Int): Chain[A] = { + final def take(count: Long): Chain[A] = { // invariant count >= 1 @tailrec - def go(lhs: Chain[A], count: Int, arg: Chain[A], rhs: Chain[A]): Chain[A] = + def go(lhs: Chain[A], count: Long, arg: Chain[A], rhs: Chain[A]): Chain[A] = arg match { case Wrap(seq) => if (count == 1) { lhs.append(seq.head) } else { // count > 1 - val taken = seq.take(count) - // we may have not takeped all of count + val taken = + if (count < Int.MaxValue) seq.take(count.toInt) + else seq.take(Int.MaxValue) + // we may have not taken all of count val newCount = count - taken.length val wrapped = Wrap(taken) // this is more efficient than using concat val newLhs = if (lhs.isEmpty) wrapped else Append(lhs, wrapped) if (newCount > 0) { - // we have to keep takeping on the rhs + // we have to keep taking on the rhs go(newLhs, newCount, rhs, Chain.nil) } else { // newCount == 0, we have taken enough @@ -288,38 +290,42 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { case s @ Singleton(_) => // due to the invariant count >= 1 val newLhs = if (lhs.isEmpty) s else Append(lhs, s) - if (count > 1) { - go(newLhs, count - 1, rhs, Chain.nil) + if (count > 1L) { + go(newLhs, count - 1L, rhs, Chain.nil) } else newLhs case Empty => + // this empty check isn't an optimization but to ensure + // the recursion terminates. if (rhs.isEmpty) lhs - else go(lhs, count, rhs, Chain.nil) + else go(lhs, count, rhs, Empty) } - if (count <= 0) Empty + if (count <= 0L) Empty else go(Empty, count, this, Empty) } /** * take a certain amount of items from the back of the Chain */ - final def takeRight(count: Int): Chain[A] = { + final def takeRight(count: Long): Chain[A] = { // invariant count >= 1 @tailrec - def go(lhs: Chain[A], count: Int, arg: Chain[A], rhs: Chain[A]): Chain[A] = + def go(lhs: Chain[A], count: Long, arg: Chain[A], rhs: Chain[A]): Chain[A] = arg match { case Wrap(seq) => - if (count == 1) { + if (count == 1L) { seq.last +: rhs } else { // count > 1 - val taken = seq.takeRight(count) - // we may have not takeped all of count + val taken = + if (count < Int.MaxValue) seq.takeRight(count.toInt) + else seq.takeRight(Int.MaxValue) + // we may have not taken all of count val newCount = count - taken.length val wrapped = Wrap(taken) val newRhs = if (rhs.isEmpty) wrapped else Append(wrapped, rhs) if (newCount > 0) { - // we have to keep takeping on the rhs + // we have to keep taking on the rhs go(Chain.nil, newCount, lhs, newRhs) } else { // newCount == 0, we have taken enough @@ -335,8 +341,10 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { go(Empty, count - 1, lhs, newRhs) } else newRhs case Empty => + // this empty check isn't an optimization but to ensure + // the recursion terminates. if (lhs.isEmpty) rhs - else go(Chain.nil, count, lhs, rhs) + else go(Empty, count, lhs, rhs) } if (count <= 0) Empty @@ -365,13 +373,13 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { /** * Drop a certain amount of items from the front of the Chain */ - final def drop(count: Int): Chain[A] = { + final def drop(count: Long): Chain[A] = { // invariant count >= 1 @tailrec - def go(count: Int, arg: Chain[A], rhs: Chain[A]): Chain[A] = + def go(count: Long, arg: Chain[A], rhs: Chain[A]): Chain[A] = arg match { case Wrap(seq) => - val dropped = seq.drop(count) + val dropped = if (count < Int.MaxValue) seq.drop(count.toInt) else seq.drop(Int.MaxValue) if (dropped.isEmpty) { // we may have not dropped all of count val newCount = count - seq.length @@ -393,31 +401,33 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { go(count, l, if (rhs.isEmpty) r else Append(r, rhs)) case Singleton(_) => // due to the invariant count >= 1 - if (count > 1) go(count - 1, rhs, Chain.nil) + if (count > 1L) go(count - 1L, rhs, Chain.nil) else rhs case Empty => + // this empty check isn't an optimization but to ensure + // the recursion terminates. if (rhs.isEmpty) Empty - else go(count, rhs, Chain.nil) + else go(count, rhs, Empty) } - if (count <= 0) this + if (count <= 0L) this else go(count, this, Empty) } /** * Drop a certain amount of items from the back of the Chain */ - final def dropRight(count: Int): Chain[A] = { + final def dropRight(count: Long): Chain[A] = { // invariant count >= 1 @tailrec - def go(lhs: Chain[A], count: Int, arg: Chain[A]): Chain[A] = + def go(lhs: Chain[A], count: Long, arg: Chain[A]): Chain[A] = arg match { case Wrap(seq) => - val dropped = seq.dropRight(count) + val dropped = if (count < Int.MaxValue) seq.dropRight(count.toInt) else seq.dropRight(Int.MaxValue) if (dropped.isEmpty) { // we may have not dropped all of count val newCount = count - seq.length - if (newCount > 0) { + if (newCount > 0L) { // we have to keep dropping on the rhs go(Chain.nil, newCount, lhs) } else { @@ -435,11 +445,13 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { go(if (lhs.isEmpty) l else Append(lhs, l), count, r) case Singleton(_) => // due to the invariant count >= 1 - if (count > 1) go(Chain.nil, count - 1, lhs) + if (count > 1L) go(Chain.nil, count - 1L, lhs) else lhs case Empty => + // this empty check isn't an optimization but to ensure + // the recursion terminates. if (lhs.isEmpty) Empty - else go(Chain.nil, count, lhs) + else go(Empty, count, lhs) } if (count <= 0) this From f56cad6883b6661abb5e566a73f420594fd11fd6 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Fri, 3 Jan 2025 08:55:31 -1000 Subject: [PATCH 4/7] recurse on NonEmpty --- core/src/main/scala/cats/data/Chain.scala | 141 +++++++++--------- .../test/scala/cats/tests/ChainSuite.scala | 8 +- 2 files changed, 77 insertions(+), 72 deletions(-) diff --git a/core/src/main/scala/cats/data/Chain.scala b/core/src/main/scala/cats/data/Chain.scala index 0b96e46857..ce4e1eb9b1 100644 --- a/core/src/main/scala/cats/data/Chain.scala +++ b/core/src/main/scala/cats/data/Chain.scala @@ -262,7 +262,7 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { final def take(count: Long): Chain[A] = { // invariant count >= 1 @tailrec - def go(lhs: Chain[A], count: Long, arg: Chain[A], rhs: Chain[A]): Chain[A] = + def go(lhs: Chain[A], count: Long, arg: NonEmpty[A], rhs: Chain[A]): Chain[A] = arg match { case Wrap(seq) => if (count == 1) { @@ -277,12 +277,12 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { val wrapped = Wrap(taken) // this is more efficient than using concat val newLhs = if (lhs.isEmpty) wrapped else Append(lhs, wrapped) - if (newCount > 0) { - // we have to keep taking on the rhs - go(newLhs, newCount, rhs, Chain.nil) - } else { - // newCount == 0, we have taken enough - newLhs + rhs match { + case rhsNE: NonEmpty[A] if newCount > 0L => + // we have to keep taking on the rhs + go(newLhs, newCount, rhsNE, Empty) + case _ => + newLhs } } case Append(l, r) => @@ -290,18 +290,18 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { case s @ Singleton(_) => // due to the invariant count >= 1 val newLhs = if (lhs.isEmpty) s else Append(lhs, s) - if (count > 1L) { - go(newLhs, count - 1L, rhs, Chain.nil) - } else newLhs - case Empty => - // this empty check isn't an optimization but to ensure - // the recursion terminates. - if (rhs.isEmpty) lhs - else go(lhs, count, rhs, Empty) + rhs match { + case rhsNE: NonEmpty[A] if count > 1L => + go(newLhs, count - 1L, rhsNE, Empty) + case _ => newLhs + } } - if (count <= 0L) Empty - else go(Empty, count, this, Empty) + this match { + case ne: NonEmpty[A] if count > 0L => + go(Empty, count, ne, Empty) + case _ => Empty + } } /** @@ -310,7 +310,7 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { final def takeRight(count: Long): Chain[A] = { // invariant count >= 1 @tailrec - def go(lhs: Chain[A], count: Long, arg: Chain[A], rhs: Chain[A]): Chain[A] = + def go(lhs: Chain[A], count: Long, arg: NonEmpty[A], rhs: Chain[A]): Chain[A] = arg match { case Wrap(seq) => if (count == 1L) { @@ -324,12 +324,10 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { val newCount = count - taken.length val wrapped = Wrap(taken) val newRhs = if (rhs.isEmpty) wrapped else Append(wrapped, rhs) - if (newCount > 0) { - // we have to keep taking on the rhs - go(Chain.nil, newCount, lhs, newRhs) - } else { - // newCount == 0, we have taken enough - newRhs + lhs match { + case lhsNE: NonEmpty[A] if newCount > 0 => + go(Empty, newCount, lhsNE, newRhs) + case _ => newRhs } } case Append(l, r) => @@ -337,18 +335,18 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { case s @ Singleton(_) => // due to the invariant count >= 1 val newRhs = if (rhs.isEmpty) s else Append(s, rhs) - if (count > 1) { - go(Empty, count - 1, lhs, newRhs) - } else newRhs - case Empty => - // this empty check isn't an optimization but to ensure - // the recursion terminates. - if (lhs.isEmpty) rhs - else go(Empty, count, lhs, rhs) + lhs match { + case lhsNE: NonEmpty[A] if count > 1 => + go(Empty, count - 1, lhsNE, newRhs) + case _ => newRhs + } } - if (count <= 0) Empty - else go(Empty, count, this, Empty) + this match { + case ne: NonEmpty[A] if count > 0L => + go(Empty, count, ne, Empty) + case _ => Empty + } } /** @@ -376,20 +374,21 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { final def drop(count: Long): Chain[A] = { // invariant count >= 1 @tailrec - def go(count: Long, arg: Chain[A], rhs: Chain[A]): Chain[A] = + def go(count: Long, arg: NonEmpty[A], rhs: Chain[A]): Chain[A] = arg match { case Wrap(seq) => val dropped = if (count < Int.MaxValue) seq.drop(count.toInt) else seq.drop(Int.MaxValue) if (dropped.isEmpty) { // we may have not dropped all of count val newCount = count - seq.length - if (newCount > 0) { - // we have to keep dropping on the rhs - go(newCount, rhs, Chain.nil) - } else { - // we know that count >= seq.length else we wouldn't be empty - // so in this case, it is exactly count == seq.length - rhs + rhs match { + case rhsNE: NonEmpty[A] if newCount > 0 => + // we have to keep dropping on the rhs + go(newCount, rhsNE, Empty) + case _ => + // we know that count >= seq.length else we wouldn't be empty + // so in this case, it is exactly count == seq.length + rhs } } else { // dropped is not empty @@ -401,17 +400,19 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { go(count, l, if (rhs.isEmpty) r else Append(r, rhs)) case Singleton(_) => // due to the invariant count >= 1 - if (count > 1L) go(count - 1L, rhs, Chain.nil) - else rhs - case Empty => - // this empty check isn't an optimization but to ensure - // the recursion terminates. - if (rhs.isEmpty) Empty - else go(count, rhs, Empty) + rhs match { + case rhsNE: NonEmpty[A] if count > 1L => + go(count - 1L, rhsNE, Empty) + case _ => + rhs + } } - if (count <= 0L) this - else go(count, this, Empty) + this match { + case ne: NonEmpty[A] if count > 0L => + go(count, ne, Empty) + case _ => this + } } /** @@ -420,20 +421,21 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { final def dropRight(count: Long): Chain[A] = { // invariant count >= 1 @tailrec - def go(lhs: Chain[A], count: Long, arg: Chain[A]): Chain[A] = + def go(lhs: Chain[A], count: Long, arg: NonEmpty[A]): Chain[A] = arg match { case Wrap(seq) => val dropped = if (count < Int.MaxValue) seq.dropRight(count.toInt) else seq.dropRight(Int.MaxValue) if (dropped.isEmpty) { // we may have not dropped all of count val newCount = count - seq.length - if (newCount > 0L) { - // we have to keep dropping on the rhs - go(Chain.nil, newCount, lhs) - } else { - // we know that count >= seq.length else we wouldn't be empty - // so in this case, it is exactly count == seq.length - lhs + lhs match { + case lhsNE: NonEmpty[A] if newCount > 0L => + // we have to keep dropping on the lhs + go(Empty, newCount, lhsNE) + case _ => + // we know that count >= seq.length else we wouldn't be empty + // so in this case, it is exactly count == seq.length + lhs } } else { // we must be done @@ -445,17 +447,20 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { go(if (lhs.isEmpty) l else Append(lhs, l), count, r) case Singleton(_) => // due to the invariant count >= 1 - if (count > 1L) go(Chain.nil, count - 1L, lhs) - else lhs - case Empty => - // this empty check isn't an optimization but to ensure - // the recursion terminates. - if (lhs.isEmpty) Empty - else go(Empty, count, lhs) + lhs match { + case lhsNE: NonEmpty[A] if count > 1L => + go(Empty, count - 1L, lhsNE) + case _ => + lhs + } } - if (count <= 0) this - else go(Empty, count, this) + this match { + case ne: NonEmpty[A] if count > 0L => + go(Empty, count, ne) + case _ => + this + } } /** diff --git a/tests/shared/src/test/scala/cats/tests/ChainSuite.scala b/tests/shared/src/test/scala/cats/tests/ChainSuite.scala index 154ec56304..b25e2022e3 100644 --- a/tests/shared/src/test/scala/cats/tests/ChainSuite.scala +++ b/tests/shared/src/test/scala/cats/tests/ChainSuite.scala @@ -462,24 +462,24 @@ class ChainSuite extends CatsSuite { test("drop(cnt).toList == toList.drop(cnt)") { forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) => - assert(chain.drop(count).toList == chain.toList.drop(count)) + assertEquals(chain.drop(count).toList, chain.toList.drop(count)) } } test("dropRight(cnt).toList == toList.dropRight(cnt)") { forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) => - assert(chain.dropRight(count).toList == chain.toList.dropRight(count)) + assertEquals(chain.dropRight(count).toList, chain.toList.dropRight(count)) } } test("take(cnt).toList == toList.take(cnt)") { forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) => - assert(chain.take(count).toList == chain.toList.take(count)) + assertEquals(chain.take(count).toList, chain.toList.take(count)) } } test("takeRight(cnt).toList == toList.takeRight(cnt)") { forAll(genChainDropTakeArgs) { case (chain: Chain[Int], count: Int) => - assert(chain.takeRight(count).toList == chain.toList.takeRight(count)) + assertEquals(chain.takeRight(count).toList, chain.toList.takeRight(count)) } } } From 4dddf8c2f87c874d45251c5557fa2cea80f6a9ee Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Thu, 9 Jan 2025 10:22:58 -1000 Subject: [PATCH 5/7] maintain Wrap invariant --- core/src/main/scala/cats/data/Chain.scala | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/cats/data/Chain.scala b/core/src/main/scala/cats/data/Chain.scala index ce4e1eb9b1..11f8a51ca8 100644 --- a/core/src/main/scala/cats/data/Chain.scala +++ b/core/src/main/scala/cats/data/Chain.scala @@ -378,7 +378,9 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { arg match { case Wrap(seq) => val dropped = if (count < Int.MaxValue) seq.drop(count.toInt) else seq.drop(Int.MaxValue) - if (dropped.isEmpty) { + val lc = dropped.lengthCompare(1) + if (lc < 0) { + // if dropped.length < 1, then it is zero // we may have not dropped all of count val newCount = count - seq.length rhs match { @@ -392,7 +394,7 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { } } else { // dropped is not empty - val wrapped = Wrap(dropped) + val wrapped = if (lc > 0) Wrap(dropped) else Singleton(dropped.head) // we must be done if (rhs.isEmpty) wrapped else Append(wrapped, rhs) } @@ -425,7 +427,9 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { arg match { case Wrap(seq) => val dropped = if (count < Int.MaxValue) seq.dropRight(count.toInt) else seq.dropRight(Int.MaxValue) - if (dropped.isEmpty) { + val lc = dropped.lengthCompare(1) + if (lc < 0) { + // if dropped.length < 1, then it is zero // we may have not dropped all of count val newCount = count - seq.length lhs match { @@ -440,7 +444,7 @@ sealed abstract class Chain[+A] extends ChainCompat[A] { } else { // we must be done // note: dropped.nonEmpty - val wrapped = Wrap(dropped) + val wrapped = if (lc > 0) Wrap(dropped) else Singleton(dropped.head) if (lhs.isEmpty) wrapped else Append(lhs, wrapped) } case Append(l, r) => @@ -1128,7 +1132,8 @@ object Chain extends ChainInstances with ChainCompanionCompat { * if the length is one, fromSeq returns Singleton * * The only places we create Wrap is in fromSeq and in methods that preserve - * length: zipWithIndex, map, sort + * length: zipWithIndex, map, sort. Additionally, in drop/dropRight we carefully + * preserve this invariant. */ final private[data] case class Wrap[A](seq: immutable.Seq[A]) extends NonEmpty[A] From 71e076b2eb140d10fa2b77691fdf48a717a58922 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Thu, 9 Jan 2025 11:13:45 -1000 Subject: [PATCH 6/7] slightly optimize Chain.fromSeq --- .../cats/data/ChainCompanionCompat.scala | 14 ++++++++------ .../cats/data/ChainCompanionCompat.scala | 10 ++++++---- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala-2.12/cats/data/ChainCompanionCompat.scala b/core/src/main/scala-2.12/cats/data/ChainCompanionCompat.scala index d113e733f3..c35c704ccd 100644 --- a/core/src/main/scala-2.12/cats/data/ChainCompanionCompat.scala +++ b/core/src/main/scala-2.12/cats/data/ChainCompanionCompat.scala @@ -38,15 +38,17 @@ private[data] trait ChainCompanionCompat { } private def fromImmutableSeq[A](s: immutable.Seq[A]): Chain[A] = { - if (s.isEmpty) nil - else if (s.lengthCompare(1) == 0) one(s.head) - else Wrap(s) + val lc = s.lengthCompare(1) + if (lc < 0) nil + else if (lc > 0) Wrap(s.toVector) + else one(s.head) } private def fromMutableSeq[A](s: Seq[A]): Chain[A] = { - if (s.isEmpty) nil - else if (s.lengthCompare(1) == 0) one(s.head) - else Wrap(s.toVector) + val lc = s.lengthCompare(1) + if (lc < 0) nil + else if (lc > 0) Wrap(s.toVector) + else one(s.head) } /** diff --git a/core/src/main/scala-2.13+/cats/data/ChainCompanionCompat.scala b/core/src/main/scala-2.13+/cats/data/ChainCompanionCompat.scala index 1a8be19c79..3821a1ce8c 100644 --- a/core/src/main/scala-2.13+/cats/data/ChainCompanionCompat.scala +++ b/core/src/main/scala-2.13+/cats/data/ChainCompanionCompat.scala @@ -28,10 +28,12 @@ private[data] trait ChainCompanionCompat { /** * Creates a Chain from the specified sequence. */ - def fromSeq[A](s: Seq[A]): Chain[A] = - if (s.isEmpty) nil - else if (s.lengthCompare(1) == 0) one(s.head) - else Wrap(s) + def fromSeq[A](s: Seq[A]): Chain[A] = { + val lc = s.lengthCompare(1) + if (lc < 0) nil + else if (lc > 0) Wrap(s) + else one(s.head) + } /** * Creates a Chain from the specified IterableOnce. From fdcff72d177ad169ca44a18b5f1965b4b645dac4 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Fri, 10 Jan 2025 14:46:50 -1000 Subject: [PATCH 7/7] don't convert immutable to Vector --- core/src/main/scala-2.12/cats/data/ChainCompanionCompat.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala-2.12/cats/data/ChainCompanionCompat.scala b/core/src/main/scala-2.12/cats/data/ChainCompanionCompat.scala index c35c704ccd..de7cd35d22 100644 --- a/core/src/main/scala-2.12/cats/data/ChainCompanionCompat.scala +++ b/core/src/main/scala-2.12/cats/data/ChainCompanionCompat.scala @@ -40,7 +40,7 @@ private[data] trait ChainCompanionCompat { private def fromImmutableSeq[A](s: immutable.Seq[A]): Chain[A] = { val lc = s.lengthCompare(1) if (lc < 0) nil - else if (lc > 0) Wrap(s.toVector) + else if (lc > 0) Wrap(s) else one(s.head) }