Skip to content

Commit

Permalink
feat: use dedicated .lua files (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
arusakov authored Oct 13, 2024
1 parent d4eb678 commit 5e3c0a5
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 72 deletions.
20 changes: 20 additions & 0 deletions lua/channelAdd.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
local key = KEYS[1]
local chnl_key = ARGV[1]
local chnl_arg = ARGV[2]

local chnl_str = redis.call('HGET', key, chnl_key)
for match in chnl_str:gmatch('([^,]+)') do
if chnl_arg == match then
return 0
end
end

if chnl_str == '' then
chnl_str = chnl_arg
else
chnl_str = chnl_str .. ',' .. chnl_arg
end

redis.call('HSET', key, chnl_key, chnl_str)

return 1
22 changes: 22 additions & 0 deletions lua/channelRemove.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
local key = KEYS[1]
local chnl_key = ARGV[1]
local chnl_arg = ARGV[2]

local chnl_str = redis.call('HGET', key, chnl_key)

local chnl = {}
local removed = false
for match in chnl_str:gmatch('([^,]+)') do
if match ~= chnl_arg then
table.insert(chnl, match)
else
removed = true
end
end

if removed then
redis.call('HSET', key, chnl_key, table.concat(chnl, ','))
return 1
end

return 0
1 change: 1 addition & 0 deletions src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export const IP = 'ip'
export const SID = 'sid'
export const ID = 'id'
export const CHNL = 'chnl'
export const LUA = 'lua'

export const MAX_INT_ID = 2 ** 30

Expand Down
121 changes: 51 additions & 70 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import { readFile } from 'fs/promises'
import { resolve } from 'path'

import type { Callback, Redis, Result } from 'ioredis'
import { __MIGRATIONS, CHNL, CLNT, ID, IP, MAX_INT_ID, SID, SRVR, TTL_DEFAULT } from './constants'

import { __MIGRATIONS, CHNL, CLNT, ID, IP, LUA, MAX_INT_ID, SID, SRVR, TTL_DEFAULT } from './constants'
import { sleep } from './utils'

export type WSDiscoveryOptions = {
Expand Down Expand Up @@ -38,14 +42,27 @@ const assertChannel = (channel: string): void | never => {
// Add declarations
declare module "ioredis" {
interface RedisCommander<Context> {
myecho(
channelAdd(
key: string,
channelProp: string,
channel: string,
callback?: Callback<0 | 1>
): Result<0 | 1, Context>;

channelRemove(
key: string,
argv: string,
callback?: Callback<string>
): Result<string, Context>;
channelProp: string,
channel: string,
callback?: Callback<0 | 1>
): Result<0 | 1, Context>;
}
}

enum CustomScripts {
CHANNEL_ADD = 'channelAdd',
CHANNEL_REMOVE = 'channelRemove'
}

export class WSDiscovery {
readonly prefix: string
readonly prefixServer: string
Expand All @@ -57,19 +74,18 @@ export class WSDiscovery {

protected readonly redis: Redis


constructor({
redis,
ttl = TTL_DEFAULT,
prefix = 'wsd',
}: WSDiscoveryOptions) {
this.redis = redis

assertTTL(ttl.server)
this.ttlServer = ttl.server || TTL_DEFAULT.server
assertTTL(ttl.client)
this.ttlClient = ttl.client || TTL_DEFAULT.client

this.redis = redis

this.prefix = `${prefix}:`
this.prefixServer = `${this.prefix}${SRVR}:`
this.prefixClient = `${this.prefix}${CLNT}:`
Expand All @@ -78,6 +94,29 @@ export class WSDiscovery {
}

async connect() {
type ScriptData = {
keys: number
readOnly: boolean
}
const scripts: Record<CustomScripts, ScriptData> = {
[CustomScripts.CHANNEL_ADD]: {
keys: 1,
readOnly: false,
},
[CustomScripts.CHANNEL_REMOVE]: {
keys: 1,
readOnly: false,
}
}

for (const [scriptName, scriptData] of Object.entries(scripts)) {
this.redis.defineCommand(scriptName, {
lua: await readFile(resolve(__dirname, '..', LUA, `${scriptName}.${LUA}`), 'utf8'),
numberOfKeys: scriptData.keys,
readOnly: scriptData.readOnly,
})
}

await this.redis.ping()
await this.migrate()
}
Expand Down Expand Up @@ -165,78 +204,20 @@ export class WSDiscovery {
async addChannel(clientId: number, channel: string) {
assertChannel(channel)

const script =
`
local key = KEYS[1]
local chnl_key = ARGV[1]
local chnl_arg = ARGV[2]
local chnl_str = redis.call('HGET', key, chnl_key)
for match in chnl_str:gmatch('([^,]+)') do
if chnl_arg == match then
return 0
end
end
if chnl_str == '' then
chnl_str = chnl_arg
else
chnl_str = chnl_str .. ',' .. chnl_arg
end
redis.call('HSET', key, chnl_key, chnl_str)
return 1
`.trim()

const result = await this.redis.eval(
script,
1,
[this.getClientKey(clientId), CHNL, channel],
) as 0 | 1

const result = await this.redis.channelAdd(this.getClientKey(clientId), CHNL, channel)
return result === 1
}

async removeChannel(clientId: number, channel: string) {
assertChannel(channel)

const script =
`
local key = KEYS[1]
local chnl_key = ARGV[1]
local chnl_arg = ARGV[2]
local chnl_str = redis.call('HGET', key, chnl_key)
local chnl = {}
local removed = false
for match in chnl_str:gmatch('([^,]+)') do
if match ~= chnl_arg then
table.insert(chnl, match)
else
removed = true
end
end
if removed then
redis.call('HSET', key, chnl_key, table.concat(chnl, ','))
return 1
end
return 0
`
.trim()

const result = await this.redis.eval(
script,
1,
[this.getClientKey(clientId), CHNL, channel],
) as 0 | 1

const result = await this.redis.channelRemove(this.getClientKey(clientId), CHNL, channel)
return result === 1
}

async getClientsByChannel(channel: string, batch = 100): Promise<ClientWithServer[]> {
assertChannel(channel)

type KeyAndKey = ['__key', string]
type AggregateAndCursorResponse = [[1, ...KeyAndKey[]], number]

Expand Down
10 changes: 8 additions & 2 deletions test/channel.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,20 @@ describe('Channels', () => {
})
})

it('getClientsByChannel() empty', async () => {
it('getClientsByChannel() validation', async () => {
await rejects(() => wsd.getClientsByChannel(''), (err) => {
return err instanceof Error && err.message === 'Empty channel is not allowed'
})
})

it('getClientsByChannel() no clients', async () => {
deepEqual(
await wsd.getClientsByChannel('xyz'),
[],
)
})

it('getClientsByChannel() return one', async () => {
it('getClientsByChannel() return empty array', async () => {
await wsd.addChannel(clientId1, 'abc')

deepEqual(
Expand Down

0 comments on commit 5e3c0a5

Please sign in to comment.