Skip to content

Commit a19d05f

Browse files
committed
Support for quoted identifiers
1 parent b9b19ec commit a19d05f

9 files changed

Lines changed: 195 additions & 9 deletions

File tree

modules/core/shared/src/main/scala-2/syntax/StringContextOps.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ class StringContextOps private[skunk](sc: StringContext) {
1919
def id(): Identifier =
2020
macro StringContextOps.StringOpsMacros.identifier_impl
2121

22+
def qid(): Identifier =
23+
macro StringContextOps.StringOpsMacros.quotedIdentifier_impl
24+
2225
/** Construct a constant `Fragment` with no interpolated values. */
2326
def const()(implicit or: Origin): Fragment[Void] =
2427
Fragment(sc.parts.toList.map(Left(_)), Void.codec, or)
@@ -156,6 +159,14 @@ object StringContextOps {
156159
}
157160
}
158161

162+
def quotedIdentifier_impl(): Tree = {
163+
val Apply(_, List(Apply(_, List(Literal(Constant(part: String)))))) = c.prefix.tree: @unchecked
164+
Identifier.fromStringQuoted(part) match {
165+
case Left(s) => c.abort(c.enclosingPosition, s)
166+
case Right(Identifier(s)) => q"_root_.skunk.data.Identifier.fromStringQuoted($s).fold(sys.error, identity)"
167+
}
168+
}
169+
159170
}
160171

161172
}

modules/core/shared/src/main/scala-3/syntax/StringContextOps.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,21 @@ object StringContextOps {
154154
return '{???}
155155
}
156156

157+
def qidImpl(sc: Expr[StringContext])(using qc: Quotes): Expr[Identifier] =
158+
import qc.reflect.report
159+
sc match {
160+
case '{ StringContext(${Varargs(Exprs(Seq(part)))}: _*) } =>
161+
Identifier.fromStringQuoted(part) match {
162+
case Right(Identifier(s)) => '{ Identifier.fromStringQuoted(${Expr(s)}).fold(sys.error, identity) }
163+
case Left(s) =>
164+
report.error(s)
165+
return '{???}
166+
}
167+
case _ =>
168+
report.error(s"Identifiers cannot have interpolated arguments")
169+
return '{???}
170+
}
171+
157172
}
158173

159174
trait ToStringContextOps {
@@ -164,6 +179,9 @@ trait ToStringContextOps {
164179
extension (inline sc: StringContext) inline def id(): Identifier =
165180
${ StringContextOps.idImpl('sc) }
166181

182+
extension (inline sc: StringContext) inline def qid(): Identifier =
183+
${ StringContextOps.qidImpl('sc) }
184+
167185
implicit def toStringOps(sc: StringContext): StringContextOps =
168186
new StringContextOps(sc)
169187
}

modules/core/shared/src/main/scala/Channel.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,29 +120,29 @@ object Channel {
120120
new Channel[F, String, String] {
121121

122122
val listen: F[Unit] =
123-
proto.execute(Command(s"LISTEN ${name.value}", Origin.unknown, Void.codec)).void
123+
proto.execute(Command(s"LISTEN ${name.asSql}", Origin.unknown, Void.codec)).void
124124

125125
val unlisten: F[Unit] =
126-
proto.execute(Command(s"UNLISTEN ${name.value}", Origin.unknown, Void.codec)).void
126+
proto.execute(Command(s"UNLISTEN ${name.asSql}", Origin.unknown, Void.codec)).void
127127

128128
def listen(maxQueued: Int): Stream[F, Notification[String]] =
129129
for {
130130
_ <- Stream.resource(Resource.make(listen)(_ => unlisten))
131131
s <- Stream.resource(proto.notifications(maxQueued))
132-
n <- s.filter(_.channel === name)
132+
n <- s.filter(_.channel.value === name.value)
133133
} yield n
134134

135135

136136
def listenR(maxQueued: Int): Resource[F, Stream[F, Notification[String]]] =
137137
for {
138138
_ <- Resource.make(listen)(_ => unlisten)
139139
stream <- proto.notifications(maxQueued)
140-
} yield stream.filter(_.channel === name)
140+
} yield stream.filter(_.channel.value === name.value)
141141

142142

143143
def notify(message: String): F[Unit] =
144144
// TODO: escape the message
145-
proto.execute(Command(s"NOTIFY ${name.value}, '$message'", Origin.unknown, Void.codec)).void
145+
proto.execute(Command(s"NOTIFY ${name.asSql}, '$message'", Origin.unknown, Void.codec)).void
146146

147147
}
148148

modules/core/shared/src/main/scala/data/Identifier.scala

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,15 @@ import cats.Eq
99
import scala.util.matching.Regex
1010

1111
sealed abstract case class Identifier(value: String) {
12-
override def toString: String = value // ok?
12+
def quoted: Boolean = false
13+
def asSql: String = if (quoted) s"\"${value.replace("\"", "\"\"")}\"" else value
14+
override def toString: String = asSql
15+
16+
override def equals(other: Any): Boolean = other match {
17+
case that: Identifier => this.value == that.value && this.quoted == that.quoted
18+
case _ => false
19+
}
20+
override def hashCode: Int = (value, quoted).##
1321
}
1422

1523
object Identifier {
@@ -20,7 +28,7 @@ object Identifier {
2028
val pat: Regex = "([A-Za-z_][A-Za-z_0-9$]*)".r
2129

2230
implicit val EqIdentifier: Eq[Identifier] =
23-
Eq.by(_.value)
31+
Eq.instance((a, b) => a.value == b.value && a.quoted == b.quoted)
2432

2533
def fromString(s: String): Either[String, Identifier] =
2634
s match {
@@ -34,6 +42,20 @@ object Identifier {
3442
case _ => Left(s"Malformed identifier: does not match ${pat.regex}")
3543
}
3644

45+
def fromStringQuoted(s: String): Either[String, Identifier] = {
46+
if (s.isEmpty)
47+
Left("Illegal identifier: zero-length delimited identifier.")
48+
else if (s.contains('\u0000'))
49+
Left("Illegal identifier: cannot contain the null byte (\\u0000).")
50+
else {
51+
val byteLen = s.getBytes(java.nio.charset.StandardCharsets.UTF_8).length
52+
if (byteLen > maxLen)
53+
Left(s"Identifier too long: $byteLen bytes (max allowed is $maxLen)")
54+
else
55+
Right(new Identifier(s) { override val quoted: Boolean = true })
56+
}
57+
}
58+
3759
val keywords: Set[String] =
3860
Set(
3961
"A", "ABORT", "ABS", "ABSENT", "ABSOLUTE",
@@ -190,4 +212,4 @@ object Identifier {
190212
"YES", "ZONE"
191213
)
192214

193-
}
215+
}

modules/core/shared/src/main/scala/net/message/package.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ package object message { module =>
5858

5959
val identifier: SCodec[Identifier] =
6060
utf8z.exmap(
61-
s => Attempt.fromEither(Identifier.fromString(s).leftMap(Err(_))),
61+
s => Attempt.fromEither(Identifier.fromString(s).orElse(Identifier.fromStringQuoted(s)).leftMap(Err(_))),
6262
id => Attempt.successful(id.value)
6363
)
6464

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,52 @@
1+
```scala mdoc:invisible
2+
import skunk.data.Identifier
3+
import skunk.implicits._
4+
```
15
# Identifiers
26

7+
`skunk.data.Identifier` represents a Postgres SQL identifier — the name of a table, column, schema, channel, etc. Skunk validates identifiers up front so they can be safely spliced into SQL without risking injection.
8+
9+
Postgres recognises two flavours of identifier, and `Identifier` supports both.
10+
11+
## Unquoted identifiers
12+
13+
An *unquoted* identifier matches `[A-Za-z_][A-Za-z_0-9$]*`, is at most 63 characters, and is not a reserved keyword. Postgres folds unquoted identifiers to lower case, so `FOO`, `Foo`, and `foo` all refer to the same object.
14+
15+
Construct one with `Identifier.fromString` or the `id"…"` interpolator:
16+
17+
```scala mdoc:compile-only
18+
val a: Either[String, Identifier] = Identifier.fromString("my_table")
19+
val b: Identifier = id"my_table"
20+
```
21+
22+
The `id"…"` form validates at compile time and fails the build for malformed input.
23+
24+
## Quoted (delimited) identifiers
25+
26+
A *quoted* (delimited) identifier is any non-empty character sequence that does not contain the NUL byte. Quoting preserves case and lets you use characters or reserved words that an unquoted identifier cannot.
27+
28+
Construct one with `Identifier.fromStringQuoted` or the `qid"…"` interpolator:
29+
30+
```scala mdoc:compile-only
31+
val a: Either[String, Identifier] = Identifier.fromStringQuoted("MyTable") // case preserved
32+
val b: Identifier = qid"q_my_queue.INSERT" // keywords allowed
33+
```
34+
35+
Like `id"…"`, the `qid"…"` form validates at compile time and fails the build for malformed input (empty string, embedded space, or > 63 bytes).
36+
37+
Length is checked in **bytes** (Postgres' `NAMEDATALEN-1` is byte-counted), so multibyte characters are accounted for correctly.
38+
39+
## Rendering as SQL
40+
41+
`Identifier#asSql` returns the SQL-ready form: the bare value for unquoted identifiers, or the value wrapped in double quotes (with any embedded `"` doubled) for quoted ones. `toString` returns `asSql`, so logged identifiers show their SQL-correct form. `value` always returns the bare, unescaped name.
42+
43+
```scala mdoc:compile-only
44+
val unq = id"my_table"
45+
val unqRendered = unq.asSql // "my_table"
46+
47+
val q = qid"My.Channel"
48+
val qBare = q.value // "My.Channel"
49+
val qRendered = q.asSql // "\"My.Channel\""
50+
```
51+
52+
`Channel` uses `asSql` internally when issuing `LISTEN`/`UNLISTEN`/`NOTIFY`, so quoted channel names round-trip correctly.

modules/docs/src/main/laika/tutorial/Channels.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ Observe the following:
2727
- `ch` is a `Channel` which consumes `String`s and emits `Notification[String]`s. A notification is a structure that includes the process ID and channel identifier as well as the payload.
2828
- `Channel` is a profunctor and thus can be contramapped to change the input type, and mapped to change the output type.
2929

30+
If the channel name contains characters that are not valid in an unquoted identifier, use the `qid"…"` interpolator (or `Identifier.fromStringQuoted`) to build a quoted identifier:
31+
32+
```scala mdoc:compile-only
33+
// assume s: Session[IO]
34+
val ch = s.channel(qid"q_my_queue.INSERT")
35+
```
36+
37+
The resulting `LISTEN`/`UNLISTEN`/`NOTIFY` statements wrap the name in double quotes so Postgres parses it correctly.
38+
3039
## Listening to a Channel
3140

3241
To listen on a channel, construct a stream via `.listen`.

modules/tests/shared/src/test/scala/ChannelTest.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,16 @@ class ChannelTest extends SkunkTest {
5656
}
5757
}
5858

59+
sessionTest("channel with quoted identifier round-trips through LISTEN/NOTIFY/UNLISTEN") { s =>
60+
val data = List("foo", "bar", "baz")
61+
val ch = s.channel(qid"pgmq.q_my_queue.INSERT")
62+
ch.listenR(42).use { r =>
63+
for {
64+
_ <- data.traverse_(ch.notify)
65+
d <- r.map(_.value).takeThrough(_ != data.last).compile.toList
66+
_ <- assert(s"channel data $d $data", data == d)
67+
} yield "ok"
68+
}
69+
}
5970

6071
}

modules/tests/shared/src/test/scala/data/IdentifierTest.scala

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,5 +51,70 @@ class IdentifierTest extends ffstest.FTest {
5151
assertEqual("value", id"foo".toString, "foo")
5252
}
5353

54+
test("quoted - valid with dots") {
55+
Identifier.fromStringQuoted("pgmq.q_my_queue.INSERT") match {
56+
case Left(err) => fail(err)
57+
case Right(id) =>
58+
for {
59+
_ <- assertEqual("value", id.value, "pgmq.q_my_queue.INSERT")
60+
_ <- assertEqual("asSql", id.asSql, "\"pgmq.q_my_queue.INSERT\"")
61+
_ <- assertEqual("quoted", id.quoted, true)
62+
} yield ()
63+
}
64+
}
65+
66+
test("quoted - escapes embedded double quotes") {
67+
Identifier.fromStringQuoted("with\"quote") match {
68+
case Left(err) => fail(err)
69+
case Right(id) => assertEqual("asSql", id.asSql, "\"with\"\"quote\"")
70+
}
71+
}
72+
73+
test("quoted - empty rejected") {
74+
Identifier.fromStringQuoted("") match {
75+
case Left(err) => err.pure[IO]
76+
case Right(value) => fail(s"expected error, got $value")
77+
}
78+
}
79+
80+
test("quoted - null byte rejected") {
81+
val nullByte = '\u0000'
82+
Identifier.fromStringQuoted(s"a${nullByte}b") match {
83+
case Left(err) => err.pure[IO]
84+
case Right(value) => fail(s"expected error, got $value")
85+
}
86+
}
87+
88+
test("quoted - too long in bytes") {
89+
Identifier.fromStringQuoted("é" * 32) match {
90+
case Left(err) => err.pure[IO]
91+
case Right(value) => fail(s"expected error, got $value")
92+
}
93+
}
94+
95+
test("quoted - keywords allowed") {
96+
Identifier.fromStringQuoted("SELECT") match {
97+
case Left(err) => fail(err)
98+
case Right(id) =>
99+
for {
100+
_ <- assertEqual("value", id.value, "SELECT")
101+
_ <- assertEqual("asSql", id.asSql, "\"SELECT\"")
102+
} yield ()
103+
}
104+
}
105+
106+
test("unquoted - asSql equals value") {
107+
assertEqual("asSql", id"foo".asSql, "foo")
108+
}
109+
110+
test("Eq distinguishes quoted from unquoted with same value") {
111+
val unquoted = Identifier.fromString("foo").toOption.get
112+
val quoted = Identifier.fromStringQuoted("foo").toOption.get
113+
for {
114+
_ <- IO(assert(unquoted =!= quoted))
115+
_ <- IO(assert(unquoted != quoted))
116+
} yield ()
117+
}
118+
54119
}
55120

0 commit comments

Comments
 (0)