From 5e3c0a5fd350c4eacc7edeebfac244f3c72994a6 Mon Sep 17 00:00:00 2001 From: Alexander Rusakov Date: Sun, 13 Oct 2024 20:56:28 +0300 Subject: [PATCH] feat: use dedicated .lua files (#4) --- lua/channelAdd.lua | 20 +++++++ lua/channelRemove.lua | 22 ++++++++ src/constants.ts | 1 + src/index.ts | 121 ++++++++++++++++++------------------------ test/channel.test.ts | 10 +++- 5 files changed, 102 insertions(+), 72 deletions(-) create mode 100644 lua/channelAdd.lua create mode 100644 lua/channelRemove.lua diff --git a/lua/channelAdd.lua b/lua/channelAdd.lua new file mode 100644 index 0000000..b1037df --- /dev/null +++ b/lua/channelAdd.lua @@ -0,0 +1,20 @@ +local key = KEYS[1] +local chnl_key = ARGV[1] +local chnl_arg = ARGV[2] + +local chnl_str = redis.call('HGET', key, chnl_key) +for match in chnl_str:gmatch('([^,]+)') do + if chnl_arg == match then + return 0 + end +end + +if chnl_str == '' then + chnl_str = chnl_arg +else + chnl_str = chnl_str .. ',' .. chnl_arg +end + +redis.call('HSET', key, chnl_key, chnl_str) + +return 1 \ No newline at end of file diff --git a/lua/channelRemove.lua b/lua/channelRemove.lua new file mode 100644 index 0000000..46b1663 --- /dev/null +++ b/lua/channelRemove.lua @@ -0,0 +1,22 @@ +local key = KEYS[1] +local chnl_key = ARGV[1] +local chnl_arg = ARGV[2] + +local chnl_str = redis.call('HGET', key, chnl_key) + +local chnl = {} +local removed = false +for match in chnl_str:gmatch('([^,]+)') do + if match ~= chnl_arg then + table.insert(chnl, match) + else + removed = true + end +end + +if removed then + redis.call('HSET', key, chnl_key, table.concat(chnl, ',')) + return 1 +end + +return 0 diff --git a/src/constants.ts b/src/constants.ts index 30d005c..b4ec85e 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -9,6 +9,7 @@ export const IP = 'ip' export const SID = 'sid' export const ID = 'id' export const CHNL = 'chnl' +export const LUA = 'lua' export const MAX_INT_ID = 2 ** 30 diff --git a/src/index.ts b/src/index.ts index 4717c0f..23cdf1b 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,5 +1,9 @@ +import { readFile } from 'fs/promises' +import { resolve } from 'path' + import type { Callback, Redis, Result } from 'ioredis' -import { __MIGRATIONS, CHNL, CLNT, ID, IP, MAX_INT_ID, SID, SRVR, TTL_DEFAULT } from './constants' + +import { __MIGRATIONS, CHNL, CLNT, ID, IP, LUA, MAX_INT_ID, SID, SRVR, TTL_DEFAULT } from './constants' import { sleep } from './utils' export type WSDiscoveryOptions = { @@ -38,14 +42,27 @@ const assertChannel = (channel: string): void | never => { // Add declarations declare module "ioredis" { interface RedisCommander { - myecho( + channelAdd( + key: string, + channelProp: string, + channel: string, + callback?: Callback<0 | 1> + ): Result<0 | 1, Context>; + + channelRemove( key: string, - argv: string, - callback?: Callback - ): Result; + channelProp: string, + channel: string, + callback?: Callback<0 | 1> + ): Result<0 | 1, Context>; } } +enum CustomScripts { + CHANNEL_ADD = 'channelAdd', + CHANNEL_REMOVE = 'channelRemove' +} + export class WSDiscovery { readonly prefix: string readonly prefixServer: string @@ -57,19 +74,18 @@ export class WSDiscovery { protected readonly redis: Redis - constructor({ redis, ttl = TTL_DEFAULT, prefix = 'wsd', }: WSDiscoveryOptions) { - this.redis = redis - assertTTL(ttl.server) this.ttlServer = ttl.server || TTL_DEFAULT.server assertTTL(ttl.client) this.ttlClient = ttl.client || TTL_DEFAULT.client + this.redis = redis + this.prefix = `${prefix}:` this.prefixServer = `${this.prefix}${SRVR}:` this.prefixClient = `${this.prefix}${CLNT}:` @@ -78,6 +94,29 @@ export class WSDiscovery { } async connect() { + type ScriptData = { + keys: number + readOnly: boolean + } + const scripts: Record = { + [CustomScripts.CHANNEL_ADD]: { + keys: 1, + readOnly: false, + }, + [CustomScripts.CHANNEL_REMOVE]: { + keys: 1, + readOnly: false, + } + } + + for (const [scriptName, scriptData] of Object.entries(scripts)) { + this.redis.defineCommand(scriptName, { + lua: await readFile(resolve(__dirname, '..', LUA, `${scriptName}.${LUA}`), 'utf8'), + numberOfKeys: scriptData.keys, + readOnly: scriptData.readOnly, + }) + } + await this.redis.ping() await this.migrate() } @@ -165,78 +204,20 @@ export class WSDiscovery { async addChannel(clientId: number, channel: string) { assertChannel(channel) - const script = - ` - local key = KEYS[1] - local chnl_key = ARGV[1] - local chnl_arg = ARGV[2] - - local chnl_str = redis.call('HGET', key, chnl_key) - for match in chnl_str:gmatch('([^,]+)') do - if chnl_arg == match then - return 0 - end - end - - if chnl_str == '' then - chnl_str = chnl_arg - else - chnl_str = chnl_str .. ',' .. chnl_arg - end - - redis.call('HSET', key, chnl_key, chnl_str) - - return 1 - `.trim() - - const result = await this.redis.eval( - script, - 1, - [this.getClientKey(clientId), CHNL, channel], - ) as 0 | 1 - + const result = await this.redis.channelAdd(this.getClientKey(clientId), CHNL, channel) return result === 1 } async removeChannel(clientId: number, channel: string) { assertChannel(channel) - const script = - ` - local key = KEYS[1] - local chnl_key = ARGV[1] - local chnl_arg = ARGV[2] - local chnl_str = redis.call('HGET', key, chnl_key) - - local chnl = {} - local removed = false - for match in chnl_str:gmatch('([^,]+)') do - if match ~= chnl_arg then - table.insert(chnl, match) - else - removed = true - end - end - - if removed then - redis.call('HSET', key, chnl_key, table.concat(chnl, ',')) - return 1 - end - - return 0 - ` - .trim() - - const result = await this.redis.eval( - script, - 1, - [this.getClientKey(clientId), CHNL, channel], - ) as 0 | 1 - + const result = await this.redis.channelRemove(this.getClientKey(clientId), CHNL, channel) return result === 1 } async getClientsByChannel(channel: string, batch = 100): Promise { + assertChannel(channel) + type KeyAndKey = ['__key', string] type AggregateAndCursorResponse = [[1, ...KeyAndKey[]], number] diff --git a/test/channel.test.ts b/test/channel.test.ts index e27ed9b..bdbc240 100644 --- a/test/channel.test.ts +++ b/test/channel.test.ts @@ -97,14 +97,20 @@ describe('Channels', () => { }) }) - it('getClientsByChannel() empty', async () => { + it('getClientsByChannel() validation', async () => { + await rejects(() => wsd.getClientsByChannel(''), (err) => { + return err instanceof Error && err.message === 'Empty channel is not allowed' + }) + }) + + it('getClientsByChannel() no clients', async () => { deepEqual( await wsd.getClientsByChannel('xyz'), [], ) }) - it('getClientsByChannel() return one', async () => { + it('getClientsByChannel() return empty array', async () => { await wsd.addChannel(clientId1, 'abc') deepEqual(