diff --git a/src/main/scala/com/twitter/finagle/postgres/PostgresClientImpl.scala b/src/main/scala/com/twitter/finagle/postgres/PostgresClientImpl.scala index 96f22078..16f83c93 100644 --- a/src/main/scala/com/twitter/finagle/postgres/PostgresClientImpl.scala +++ b/src/main/scala/com/twitter/finagle/postgres/PostgresClientImpl.scala @@ -158,7 +158,7 @@ class PostgresClientImpl( typeMap().flatMap { _ => for { service <- factory() - statement = new PreparedStatementImpl("", sql, service) + statement = new PreparedStatementImpl(sql, service) result <- statement.selectToStream(params: _*)(f) } yield result } @@ -171,7 +171,7 @@ class PostgresClientImpl( typeMap().flatMap { _ => for { service <- factory() - statement = new PreparedStatementImpl("", sql, service) + statement = new PreparedStatementImpl(sql, service) OK(count) <- statement.exec(params: _*) } yield count } @@ -215,11 +215,12 @@ class PostgresClientImpl( private[this] class PreparedStatementImpl( - name: String, sql: String, service: Service[PgRequest, PgResponse] ) extends PreparedStatement { + private[this] val name = s"fin-pg-$id-" + counter.incrementAndGet + def closeService = service.close() private[this] def parse(params: Param[_]*): Future[Unit] = { @@ -306,8 +307,6 @@ class PostgresClientImpl( }.ensure(service.close()) } } - - private[this] def genName() = s"fin-pg-$id-" + counter.incrementAndGet } diff --git a/src/test/scala/com/twitter/finagle/postgres/integration/IntegrationSpec.scala b/src/test/scala/com/twitter/finagle/postgres/integration/IntegrationSpec.scala index d49e1383..1451ecfb 100644 --- a/src/test/scala/com/twitter/finagle/postgres/integration/IntegrationSpec.scala +++ b/src/test/scala/com/twitter/finagle/postgres/integration/IntegrationSpec.scala @@ -39,15 +39,14 @@ class IntegrationSpec extends Spec { useSsl = sys.env.getOrElse("USE_PG_SSL", "0") == "1" sslHost = sys.env.get("PG_SSL_HOST") } yield { - - val queryTimeout = Duration.fromSeconds(2) - def getClient: PostgresClientImpl = { + def getClient: PostgresClientImpl = getConcurrentClient(maxConcurrency = 1) + def getConcurrentClient(maxConcurrency: Int): PostgresClientImpl = { val client = Postgres.Client() .withCredentials(user, password) .database(dbname) - .withSessionPool.maxSize(1) + .withSessionPool.maxSize(maxConcurrency) .conditionally(useSsl, c => sslHost.fold(c.withTransport.tls)(c.withTransport.tls(_))) .newRichClient(hostPort) @@ -440,6 +439,29 @@ class IntegrationSpec extends Spec { client.status must equal(Status.Closed) } } + + "generate unique names per prepared statement" in { + val client = getConcurrentClient(maxConcurrency = 10) + cleanDb(client) + + val concurrentQueries = (1 to 1000).par.map { + case number if number % 2 == 0 => + client.prepareAndExecute( + "INSERT INTO %s (str_field, int_field, double_field, bool_field) VALUES ($1, $2, $3, $4)".format(IntegrationSpec.pgTestTable), + number.toString, number, number, true + ) + case number => + client.prepareAndExecute( + "INSERT INTO %s (str_field) VALUES ($1)".format(IntegrationSpec.pgTestTable), + number.toString + ) + } + + val result = Await.result( + Future.collect(concurrentQueries.seq) + ) + result.size must equal(1000) + } } } diff --git a/version.sbt b/version.sbt index ca330acc..7e7c28a0 100644 --- a/version.sbt +++ b/version.sbt @@ -1 +1 @@ -version in ThisBuild := "0.13.0-SNAPSHOT" +version in ThisBuild := "0.13.1-SNAPSHOT"