diff --git a/packages/pglite/src/pglite.ts b/packages/pglite/src/pglite.ts index 726e490f..10d36728 100644 --- a/packages/pglite/src/pglite.ts +++ b/packages/pglite/src/pglite.ts @@ -726,7 +726,15 @@ export class PGlite this.#notifyListeners.set(pgChannel, new Set()) } this.#notifyListeners.get(pgChannel)!.add(callback) - await this.exec(`LISTEN ${channel}`) + try { + await this.exec(`LISTEN ${channel}`) + } catch (e) { + this.#notifyListeners.get(pgChannel)!.delete(callback) + if (this.#notifyListeners.get(pgChannel)?.size === 0) { + this.#notifyListeners.delete(pgChannel) + } + throw e + } return async () => { await this.unlisten(pgChannel, callback) } diff --git a/packages/pglite/tests/notify.test.js b/packages/pglite/tests/notify.test.js index a2496e77..96980acb 100644 --- a/packages/pglite/tests/notify.test.js +++ b/packages/pglite/tests/notify.test.js @@ -1,5 +1,6 @@ import { describe, it, expect, vi } from 'vitest' import { PGlite } from '../dist/index.js' +import { expectToThrowAsync } from './test-utils.js' describe('notify API', () => { it('notify', async () => { @@ -42,7 +43,7 @@ describe('notify API', () => { await new Promise((resolve) => setTimeout(resolve, 1000)) }) - it('check notify case sensitivity as Postgresql', async () => { + it('check notify case sensitivity + special chars as Postgresql', async () => { const pg = new PGlite() const allLower1 = vi.fn() @@ -73,6 +74,30 @@ describe('notify API', () => { await pg.listen('testNotCalled2', caseSensitive3) await pg.query(`NOTIFY "testNotCalled2", 'paYloAd2'`) + const quotedWithSpaces = vi.fn() + await pg.listen('"Quoted Channel With Spaces"', quotedWithSpaces) + await pg.query(`NOTIFY "Quoted Channel With Spaces", 'payload1'`) + + const unquotedWithSpaces = vi.fn() + await expectToThrowAsync( + pg.listen('Unquoted Channel With Spaces', unquotedWithSpaces), + ) + await expectToThrowAsync( + pg.query(`NOTIFY Unquoted Channel With Spaces, 'payload1'`), + ) + + const otherCharsWithQuotes = vi.fn() + await pg.listen('"test&me"', otherCharsWithQuotes) + await pg.query(`NOTIFY "test&me", 'paYloAd2'`) + + const otherChars = vi.fn() + await expectToThrowAsync( + pg.listen('test&me', otherChars), + ) + await expectToThrowAsync( + pg.query(`NOTIFY test&me, 'payload1'`), + ) + expect(allLower1).toHaveBeenCalledTimes(4) expect(autoLowerTest1).toHaveBeenCalledTimes(3) expect(autoLowerTest2).toHaveBeenCalledTimes(2) @@ -80,9 +105,12 @@ describe('notify API', () => { expect(caseSensitive1).toHaveBeenCalledOnce() expect(caseSensitive2).not.toHaveBeenCalled() expect(caseSensitive3).not.toHaveBeenCalled() + expect(otherCharsWithQuotes).toHaveBeenCalledOnce() + expect(quotedWithSpaces).toHaveBeenCalledOnce() + expect(unquotedWithSpaces).not.toHaveBeenCalled() }) - it('check unlisten case sensitivity as Postgresql', async () => { + it('check unlisten case sensitivity + special chars as Postgresql', async () => { const pg = new PGlite() const allLower1 = vi.fn() @@ -121,10 +149,25 @@ describe('notify API', () => { await pg.query(`NOTIFY "CaSESEnsiTIvE", 'payload1'`) } + const quotedWithSpaces = vi.fn() + { + await pg.listen('"Quoted Channel With Spaces"', quotedWithSpaces) + await pg.query(`NOTIFY "Quoted Channel With Spaces", 'payload1'`) + await pg.unlisten('"Quoted Channel With Spaces"') + } + + const otherCharsWithQuotes = vi.fn() + { + await pg.listen('"test&me"', otherCharsWithQuotes) + await pg.query(`NOTIFY "test&me", 'payload'`) + await pg.unlisten('"test&me"') + } + expect(allLower1).toHaveBeenCalledOnce() expect(autoLowerTest1).toHaveBeenCalledOnce() expect(autoLowerTest2).toHaveBeenCalledOnce() expect(autoLowerTest3).toHaveBeenCalledOnce() expect(caseSensitive1).toHaveBeenCalledOnce() + expect(otherCharsWithQuotes).toHaveBeenCalledOnce() }) })