diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 08df3e3aa..e012fe017 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,9 +27,9 @@ jobs: name: Test strategy: matrix: - os: [ubuntu-latest] + os: [ubuntu-22.04] scala: [2.13, 3] - java: [temurin@11] + java: [temurin@17] project: [skunkJS, skunkJVM, skunkNative] runs-on: ${{ matrix.os }} timeout-minutes: 60 @@ -42,17 +42,17 @@ jobs: - name: Setup sbt uses: sbt/setup-sbt@v1 - - name: Setup Java (temurin@11) - id: setup-java-temurin-11 - if: matrix.java == 'temurin@11' + - name: Setup Java (temurin@17) + id: setup-java-temurin-17 + if: matrix.java == 'temurin@17' uses: actions/setup-java@v4 with: distribution: temurin - java-version: 11 + java-version: 17 cache: sbt - name: sbt update - if: matrix.java == 'temurin@11' && steps.setup-java-temurin-11.outputs.cache-hit == 'false' + if: matrix.java == 'temurin@17' && steps.setup-java-temurin-17.outputs.cache-hit == 'false' run: sbt +update - name: Start up Postgres @@ -69,7 +69,7 @@ jobs: run: sbt githubWorkflowCheck - name: Check headers - if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest' + if: matrix.java == 'temurin@17' && matrix.os == 'ubuntu-22.04' run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' headerCheckAll - name: scalaJSLink @@ -84,11 +84,11 @@ jobs: run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' test - name: Check binary compatibility - if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest' + if: matrix.java == 'temurin@17' && matrix.os == 'ubuntu-22.04' run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' mimaReportBinaryIssues - name: Generate API documentation - if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest' + if: matrix.java == 'temurin@17' && matrix.os == 'ubuntu-22.04' run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' doc - name: Make target directories @@ -112,8 +112,8 @@ jobs: if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v') || github.ref == 'refs/heads/main' || github.ref == 'refs/heads/series/0.6.x') strategy: matrix: - os: [ubuntu-latest] - java: [temurin@11] + os: [ubuntu-22.04] + java: [temurin@17] runs-on: ${{ matrix.os }} steps: - name: Checkout current branch (full) @@ -124,17 +124,17 @@ jobs: - name: Setup sbt uses: sbt/setup-sbt@v1 - - name: Setup Java (temurin@11) - id: setup-java-temurin-11 - if: matrix.java == 'temurin@11' + - name: Setup Java (temurin@17) + id: setup-java-temurin-17 + if: matrix.java == 'temurin@17' uses: actions/setup-java@v4 with: distribution: temurin - java-version: 11 + java-version: 17 cache: sbt - name: sbt update - if: matrix.java == 'temurin@11' && steps.setup-java-temurin-11.outputs.cache-hit == 'false' + if: matrix.java == 'temurin@17' && steps.setup-java-temurin-17.outputs.cache-hit == 'false' run: sbt +update - name: Start up Postgres @@ -233,7 +233,7 @@ jobs: strategy: matrix: os: [ubuntu-22.04] - java: [temurin@11] + java: [temurin@17] runs-on: ${{ matrix.os }} steps: - name: Checkout current branch (full) @@ -244,17 +244,17 @@ jobs: - name: Setup sbt uses: sbt/setup-sbt@v1 - - name: Setup Java (temurin@11) - id: setup-java-temurin-11 - if: matrix.java == 'temurin@11' + - name: Setup Java (temurin@17) + id: setup-java-temurin-17 + if: matrix.java == 'temurin@17' uses: actions/setup-java@v4 with: distribution: temurin - java-version: 11 + java-version: 17 cache: sbt - name: sbt update - if: matrix.java == 'temurin@11' && steps.setup-java-temurin-11.outputs.cache-hit == 'false' + if: matrix.java == 'temurin@17' && steps.setup-java-temurin-17.outputs.cache-hit == 'false' run: sbt +update - name: Start up Postgres @@ -274,7 +274,7 @@ jobs: strategy: matrix: os: [ubuntu-22.04] - java: [temurin@11] + java: [temurin@17] runs-on: ${{ matrix.os }} steps: - name: Checkout current branch (full) @@ -285,17 +285,17 @@ jobs: - name: Setup sbt uses: sbt/setup-sbt@v1 - - name: Setup Java (temurin@11) - id: setup-java-temurin-11 - if: matrix.java == 'temurin@11' + - name: Setup Java (temurin@17) + id: setup-java-temurin-17 + if: matrix.java == 'temurin@17' uses: actions/setup-java@v4 with: distribution: temurin - java-version: 11 + java-version: 17 cache: sbt - name: sbt update - if: matrix.java == 'temurin@11' && steps.setup-java-temurin-11.outputs.cache-hit == 'false' + if: matrix.java == 'temurin@17' && steps.setup-java-temurin-17.outputs.cache-hit == 'false' run: sbt +update - name: Start up Postgres @@ -320,7 +320,7 @@ jobs: strategy: matrix: os: [ubuntu-22.04] - java: [temurin@11] + java: [temurin@17] runs-on: ${{ matrix.os }} steps: - name: Checkout current branch (full) @@ -331,17 +331,17 @@ jobs: - name: Setup sbt uses: sbt/setup-sbt@v1 - - name: Setup Java (temurin@11) - id: setup-java-temurin-11 - if: matrix.java == 'temurin@11' + - name: Setup Java (temurin@17) + id: setup-java-temurin-17 + if: matrix.java == 'temurin@17' uses: actions/setup-java@v4 with: distribution: temurin - java-version: 11 + java-version: 17 cache: sbt - name: sbt update - if: matrix.java == 'temurin@11' && steps.setup-java-temurin-11.outputs.cache-hit == 'false' + if: matrix.java == 'temurin@17' && steps.setup-java-temurin-17.outputs.cache-hit == 'false' run: sbt +update - name: Start up Postgres diff --git a/build.sbt b/build.sbt index 797e9d36b..246e315a4 100644 --- a/build.sbt +++ b/build.sbt @@ -16,8 +16,7 @@ ThisBuild / developers := List( ThisBuild / tlCiReleaseBranches += "series/0.6.x" ThisBuild / tlCiScalafmtCheck := false ThisBuild / tlSitePublishBranch := Some("series/0.6.x") -ThisBuild / githubWorkflowOSes := Seq("ubuntu-latest") -ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.temurin("11")) +ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.temurin("17")) ThisBuild / tlJdkRelease := Some(8) ThisBuild / githubWorkflowBuildPreamble ++= nativeBrewInstallWorkflowSteps.value @@ -38,6 +37,7 @@ ThisBuild / githubWorkflowAddedJobs += id = "coverage", name = s"Generate coverage report (2.13 JVM only)", scalas = Nil, + javas = githubWorkflowJavaVersions.value.toList, sbtStepPreamble = Nil, steps = githubWorkflowJobSetup.value.toList ++ List( @@ -59,7 +59,7 @@ ThisBuild / mimaBinaryIssueFilters ++= List( ThisBuild / tlFatalWarnings := false // This is used in a couple places -lazy val fs2Version = "3.12.0" +lazy val fs2Version = "3.13.0-M5" lazy val openTelemetryVersion = "1.44.1" lazy val otel4sVersion = "0.11.1" lazy val refinedVersion = "0.11.0" @@ -193,10 +193,16 @@ lazy val tests = crossProject(JVMPlatform, JSPlatform, NativePlatform) ), testFrameworks += new TestFramework("munit.Framework"), testOptions += { - if(System.getProperty("os.arch").startsWith("aarch64")) { - Tests.Argument(TestFrameworks.MUnit, "--exclude-tags=X86ArchOnly") - } else Tests.Argument() - } + var excludedTags = List.empty[String] + if (System.getProperty("os.arch").startsWith("aarch64")) + excludedTags = "X86ArchOnly" :: excludedTags + if (!System.getProperty("os.name").contains("linux")) + excludedTags = "LinuxOnly" :: excludedTags + if (excludedTags.nonEmpty) + Tests.Argument(TestFrameworks.MUnit, "--exclude-tags=" + excludedTags.mkString(",")) + else Tests.Argument() + }, + Test / baseDirectory := (ThisBuild / Test / run / baseDirectory).value ) .jvmSettings( Test / fork := true, diff --git a/docker-compose.yml b/docker-compose.yml index 2d2317334..35a3e6945 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -46,6 +46,15 @@ services: POSTGRES_PASSWORD: banana POSTGRES_HOST_AUTH_METHOD: password POSTGRES_INITDB_ARGS: --auth-host=password + # for testing domain sockets + unixsockets: + image: postgres:11 + environment: + POSTGRES_DB: world + POSTGRES_USER: jimmy + POSTGRES_PASSWORD: banana + volumes: + - ./test-unix-socket:/var/run/postgresql # for testing redshift connections redshift: image: guildeducation/docker-amazon-redshift @@ -65,4 +74,4 @@ services: - 4317:4317 - 4318:4318 environment: - COLLECTOR_OTLP_ENABLED: "true" \ No newline at end of file + COLLECTOR_OTLP_ENABLED: "true" diff --git a/modules/core/shared/src/main/scala/Session.scala b/modules/core/shared/src/main/scala/Session.scala index 214d940f5..3c48b2efb 100644 --- a/modules/core/shared/src/main/scala/Session.scala +++ b/modules/core/shared/src/main/scala/Session.scala @@ -8,9 +8,9 @@ import cats._ import cats.effect._ import cats.effect.std.Console import cats.syntax.all._ -import com.comcast.ip4s.{Host, Port, SocketAddress} +import com.comcast.ip4s.* import fs2.concurrent.Signal -import fs2.io.net.{ Network, Socket, SocketGroup, SocketOption } +import fs2.io.net.{ Network, Socket, SocketOption } import fs2.Pipe import fs2.Stream import org.typelevel.otel4s.trace.Tracer @@ -297,8 +297,6 @@ sealed trait Session[F[_]] { */ def closeEvictedPreparedStatements: F[Unit] - - /** * Transform this `Session` by a given `FunctionK`. * @group Transformations @@ -443,6 +441,15 @@ object Session { Recycler(_.execute(Command("RESET ALL", Origin.unknown, Void.codec)).as(true)) } + /** Enumeration of protocols that can be used to connect to a Postgres server. */ + sealed trait ConnectionType + object ConnectionType { + /** Connect via TCP using a host and port. */ + case object TCP extends ConnectionType + /** Connect via a Unix domain socket. */ + case object Unix extends ConnectionType + } + /** * Supports creation of a `Session`. * @@ -451,10 +458,13 @@ object Session { * After overriding the various defaults, call `single` to create a single-use session or `pooled` * to create a pool of sessions. * + * @param connectionType type of connection to use to connect to server; defaults to TCP * @param host Postgres server host; defaults to localhost * @param port Postgres server port; defaults to 5432 * @param credentials user and optional password, evaluated for each session; defaults to user "postgres" with no password * @param database database to use; defaults to None and hence whatever user is used to authenticate (e.g. "postgres" when using default user) + * @param unixSocketAddress explicit path to the Postgres server unix domain socket; if not defined and connection type is Unix, defaults to ${unixSocketsDirectory}/.s.PGSQL.nnnn where nnnn is the port + * @param unixSocketDirectory directory Postgres server uses for unix domain sockets; defaults to /tmp * @param debug whether debug logs should be written to the console; defaults to false * @param typingStrategy typing strategy; defaults to [[TypingStrategy.BuiltinsOnly]] * @param redactionStrategy redaction strategy; defaults to [[RedactionStrategy.OptIn]] @@ -467,8 +477,11 @@ object Session { * @param parseCacheSize size of the pool-level cache for parsing statements; defaults to 2048 */ final class Builder[F[_]: Temporal: Network: Console] private ( - val host: String, - val port: Int, + val connectionType: ConnectionType, + val host: Host, + val port: Port, + val unixSocketAddress: Option[UnixSocketAddress], + val unixSocketDirectory: String, val credentials: F[Credentials], val database: Option[String], val debug: Boolean, @@ -484,8 +497,11 @@ object Session { ) { self => private def copy( - host: String = self.host, - port: Int = self.port, + connectionType: ConnectionType = self.connectionType, + host: Host = self.host, + port: Port = self.port, + unixSocketAddress: Option[UnixSocketAddress] = self.unixSocketAddress, + unixSocketDirectory: String = self.unixSocketDirectory, credentials: F[Credentials] = self.credentials, database: Option[String] = self.database, debug: Boolean = self.debug, @@ -499,14 +515,48 @@ object Session { queryCacheSize: Int = self.queryCacheSize, parseCacheSize: Int = self.parseCacheSize, ): Builder[F] = - new Builder(host, port, credentials, database, debug, typingStrategy, redactionStrategy, ssl, connectionParameters, socketOptions, readTimeout, commandCacheSize, queryCacheSize, parseCacheSize) + new Builder(connectionType, host, port, unixSocketAddress, unixSocketDirectory, credentials, database, debug, typingStrategy, redactionStrategy, ssl, connectionParameters, socketOptions, readTimeout, commandCacheSize, queryCacheSize, parseCacheSize) + + /** Configures the connection type. */ + def withConnectionType(newConnectionType: ConnectionType): Builder[F] = + copy(connectionType = newConnectionType) + /** Configures the host of the Postgres server. Throws `IllegalArgumentException` if the specified host is not syntactically valid. */ def withHost(newHost: String): Builder[F] = + withHost(Host.fromString(newHost).getOrElse(throw new SkunkException(sql = None, message = s"""Hostname: "$newHost" is not syntactically valid."""))) + + /** Configures the host of the Postgres server. */ + def withHost(newHost: Host): Builder[F] = copy(host = newHost) + /** Configures the post of the Postgres server. Throws `IllegalArgumentException` if the specified port is not a valid port number. */ def withPort(newPort: Int): Builder[F] = + withPort(Port.fromInt(newPort).getOrElse(throw new SkunkException(sql = None, message = s"Port: $newPort falls out of the allowed range."))) + + /** Configures the port of the Postgres server. */ + def withPort(newPort: Port): Builder[F] = copy(port = newPort) + /** Configures this session for connecting via unix domain sockets. */ + def withUnixSockets: Builder[F] = + copy(connectionType = ConnectionType.Unix) + + /** Configures the Postgres directory for unix domain sockets. */ + def withUnixSocketDirectory(newUnixSocketDirectory: String): Builder[F] = + withUnixSockets.copy(unixSocketDirectory = newUnixSocketDirectory) + + /** Configures this session for connecting via unix domain sockets using the specified path. */ + def withUnixSocketAddress(path: String): Builder[F] = + withUnixSockets.withUnixSocketAddress(UnixSocketAddress(path)) + + /** Configures this session for connecting via unix domain sockets using the specified address. */ + def withUnixSocketAddress(newUnixSocketAddress: UnixSocketAddress): Builder[F] = + withUnixSockets.copy(unixSocketAddress = Some(newUnixSocketAddress)) + + /** Clears the explicitly configured unix socket address. */ + def withoutUnixSocketAddress: Builder[F] = + copy(unixSocketAddress = None) + def withCredentials(newCredentials: F[Credentials]): Builder[F] = copy(credentials = newCredentials) @@ -595,27 +645,33 @@ object Session { for { dc <- Resource.eval(Describe.Cache.empty[F](commandCacheSize, queryCacheSize)) sslOp <- ssl.toSSLNegotiationOptions(if (debug) logger.some else none) - pool <- Pool.ofF({implicit T: Tracer[F] => fromSocketGroup(Network[F], sslOp, dc)}, max)(Recyclers.full) + pool <- Pool.ofF({implicit T: Tracer[F] => sessions(sslOp, dc)}, max)(Recyclers.full) } yield pool } - private def fromSocketGroup( - socketGroup: SocketGroup[F], - sslOptions: Option[SSLNegotiation.Options[F]], - describeCache: Describe.Cache[F] + private def sessions( + sslOptions: Option[SSLNegotiation.Options[F]], + describeCache: Describe.Cache[F] )(implicit T: Tracer[F]): Resource[F, Session[F]] = { - def fail[A](msg: String): Resource[F, A] = - Resource.eval(Temporal[F].raiseError(new SkunkException(message = msg, sql = None))) - - val sockets: Resource[F, Socket[F]] = { - (Host.fromString(host), Port.fromInt(port)) match { - case (Some(validHost), Some(validPort)) => socketGroup.client(SocketAddress(validHost, validPort), socketOptions) - case (None, _) => fail(s"""Hostname: "$host" is not syntactically valid.""") - case (_, None) => fail(s"Port: $port falls out of the allowed range.") - } + val sockets = connectionType match { + case ConnectionType.TCP => + val address = SocketAddress(host, port) + Network[F].connect(address, socketOptions) + + case ConnectionType.Unix => + val address = unixSocketAddress.getOrElse(UnixSocketAddress(s"${unixSocketDirectory}/.s.PGSQL.${port}")) + val filteredSocketOptions = socketOptions.filter(o => o.key != SocketOption.NoDelay) + Network[F].connect(address, filteredSocketOptions) } + fromSockets(sockets, sslOptions, describeCache) + } - for { + private def fromSockets( + sockets: Resource[F, Socket[F]], + sslOptions: Option[SSLNegotiation.Options[F]], + describeCache: Describe.Cache[F] + )(implicit T: Tracer[F]): Resource[F, Session[F]] = + for { namer <- Resource.eval(Namer[F]) pc <- Resource.eval(Parse.Cache.empty[F](parseCacheSize)) proto <- Protocol[F](debug, namer, sockets, sslOptions, describeCache, pc, readTimeout, redactionStrategy) @@ -623,7 +679,6 @@ object Session { _ <- Resource.eval(proto.startup(creds.user, database.getOrElse(creds.user), creds.password, connectionParameters)) sess <- Resource.make(fromProtocol(proto, namer, typingStrategy, redactionStrategy))(_ => proto.cleanup) } yield sess - } } /** @@ -648,8 +703,11 @@ object Session { object Builder { def apply[F[_]: Temporal: Network: Console]: Builder[F] = new Builder[F]( - host = "localhost", - port = 5432, + connectionType = ConnectionType.TCP, + host = host"localhost", + port = port"5432", + unixSocketAddress = None, + unixSocketDirectory = "/tmp", database = None, credentials = Credentials("postgres", None).pure[F], debug = false, diff --git a/modules/core/shared/src/main/scala/net/BufferedMessageSocket.scala b/modules/core/shared/src/main/scala/net/BufferedMessageSocket.scala index 6dde0ec4e..e05da2fb6 100644 --- a/modules/core/shared/src/main/scala/net/BufferedMessageSocket.scala +++ b/modules/core/shared/src/main/scala/net/BufferedMessageSocket.scala @@ -164,8 +164,8 @@ object BufferedMessageSocket { noTop.subscribeAwait(maxQueued) override protected def terminate: F[Unit] = - fib.cancel *> // stop processing incoming messages - send(Terminate) // server will close the socket when it sees this + fib.cancel *> // stop processing incoming messages + send(Terminate).attempt.void // server will close the socket when it sees this; ignore failure as socket may be closed mid-write override def history(max: Int): F[List[Either[Any, Any]]] = ms.history(max) diff --git a/modules/core/shared/src/main/scala/net/Protocol.scala b/modules/core/shared/src/main/scala/net/Protocol.scala index 00794b2ea..4b57c7bfb 100644 --- a/modules/core/shared/src/main/scala/net/Protocol.scala +++ b/modules/core/shared/src/main/scala/net/Protocol.scala @@ -208,19 +208,14 @@ object Protocol { def execute(maxRows: Int): F[List[B] ~ Boolean] } - /** - * Resource yielding a new `Protocol` with the given `host` and `port`. - * @param host Postgres server host - * @param port Postgres port, default 5432 - */ def apply[F[_]: Temporal: Tracer: Console]( - debug: Boolean, - nam: Namer[F], - sockets: Resource[F, Socket[F]], - sslOptions: Option[SSLNegotiation.Options[F]], - describeCache: Describe.Cache[F], - parseCache: Parse.Cache[F], - readTimeout: Duration, + debug: Boolean, + nam: Namer[F], + sockets: Resource[F, Socket[F]], + sslOptions: Option[SSLNegotiation.Options[F]], + describeCache: Describe.Cache[F], + parseCache: Parse.Cache[F], + readTimeout: Duration, redactionStrategy: RedactionStrategy ): Resource[F, Protocol[F]] = for { diff --git a/modules/tests/shared/src/test/scala/SessionTest.scala b/modules/tests/shared/src/test/scala/SessionTest.scala index b0e87e2f4..27e693dec 100644 --- a/modules/tests/shared/src/test/scala/SessionTest.scala +++ b/modules/tests/shared/src/test/scala/SessionTest.scala @@ -12,12 +12,16 @@ import skunk.exception.SkunkException class SessionTest extends ffstest.FTest { test("Invalid host") { - Session.Builder[IO].withHost("").single.use(_ => IO.unit).assertFailsWith[SkunkException] - .flatMap(e => assertEqual("message", e.message, """Hostname: "" is not syntactically valid.""")) + val e = intercept[SkunkException] { + Session.Builder[IO].withHost("") + } + assertEquals(e.message, """Hostname: "" is not syntactically valid.""") } + test("Invalid port") { - Session.Builder[IO].withPort(-1).single.use(_ => IO.unit).assertFailsWith[SkunkException] - .flatMap(e => assertEqual("message", e.message, "Port: -1 falls out of the allowed range.")) + val e = intercept[SkunkException] { + Session.Builder[IO].withPort(-1).single.use(_ => IO.unit) + } + assertEquals(e.message, "Port: -1 falls out of the allowed range.") } - } diff --git a/modules/tests/shared/src/test/scala/StartupTest.scala b/modules/tests/shared/src/test/scala/StartupTest.scala index 3e0a369ca..62a6ac5ed 100644 --- a/modules/tests/shared/src/test/scala/StartupTest.scala +++ b/modules/tests/shared/src/test/scala/StartupTest.scala @@ -239,4 +239,15 @@ class StartupTest extends ffstest.FTest { .single .use(_ => IO.unit).assertFailsWith[UnknownHostException] } + + object LinuxOnly extends munit.Tag("LinuxOnly") + + tracedTest("unix domain sockets - successful login".tag(LinuxOnly)) { implicit tracer: Tracer[IO] => + Session.Builder[IO] + .withUnixSocketDirectory("test-unix-socket") + .withUser("jimmy") + .withDatabase("world") + .single + .use(_ => IO.unit) + } }