diff --git a/docker-compose.yml b/docker-compose.yml index f0fe62c..536c916 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,8 +3,7 @@ services: redis: - # image: redis/redis-stack:7.2.0-v12 - image: redis/redis-stack:7.2.0-v12 + image: redis/redis-stack:7.2.0-v13 ports: - "127.0.0.1:6379:6379" environment: diff --git a/package.json b/package.json index d6b7820..1b203f7 100644 --- a/package.json +++ b/package.json @@ -21,7 +21,7 @@ }, "scripts": { "build": "tsc -p .", - "compile": "yarn build --noEmit", + "compile": "tsc --noEmit -p ./test", "test:all": "yarn test ./test/*.test.ts", "test:coverage": "c8 --reporter=lcovonly --reporter=text yarn test:all", "test": "node --test --test-concurrency=1 --require=ts-node/register" diff --git a/src/constants.ts b/src/constants.ts index 151e71d..30d005c 100644 --- a/src/constants.ts +++ b/src/constants.ts @@ -10,4 +10,6 @@ export const SID = 'sid' export const ID = 'id' export const CHNL = 'chnl' -export const MAX_INT_ID = 2 ** 30 \ No newline at end of file +export const MAX_INT_ID = 2 ** 30 + +export const __MIGRATIONS = '__migrations' \ No newline at end of file diff --git a/src/index.ts b/src/index.ts index 71d6c12..1caafd7 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,5 +1,6 @@ -import type { Redis } from 'ioredis' -import { CHNL, CLNT, ID, IP, MAX_INT_ID, SID, SRVR, TTL_DEFAULT } from './constants' +import type { Callback, Redis, Result } from 'ioredis' +import { __MIGRATIONS, CHNL, CLNT, ID, IP, MAX_INT_ID, SID, SRVR, TTL_DEFAULT } from './constants' +import { sleep } from './utils' export type WSDiscoveryOptions = { redis: Redis @@ -11,6 +12,16 @@ export type WSDiscoveryOptions = { } +export type ClientWithServer = { + [CLNT]: number + [SRVR]: number +} + +export type ClientFields = + | typeof SID + | typeof CHNL + | typeof SRVR + const assertTTL = (ttl?: number): void | never => { if (ttl != null && ttl <= 0) { @@ -24,15 +35,28 @@ const assertChannel = (channel: string): void | never => { } } -export class WSDiscovery { +// Add declarations +declare module "ioredis" { + interface RedisCommander { + myecho( + key: string, + argv: string, + callback?: Callback + ): Result; + } +} +export class WSDiscovery { + readonly prefix: string + readonly prefixServer: string + readonly prefixClient: string + readonly ttlServer: number + readonly ttlClient: number + + readonly indexClntChnl: string + protected readonly redis: Redis - protected readonly ttlServer: number - protected readonly ttlClient: number - - protected readonly prefixServer: string - protected readonly prefixClient: string constructor({ redis, @@ -40,21 +64,22 @@ export class WSDiscovery { prefix = 'wsd', }: WSDiscoveryOptions) { this.redis = redis - + assertTTL(ttl.server) this.ttlServer = ttl.server || TTL_DEFAULT.server assertTTL(ttl.client) this.ttlClient = ttl.client || TTL_DEFAULT.client - this.prefixServer = `${prefix}:${SRVR}:` - this.prefixClient = `${prefix}:${CLNT}:` + this.prefix = `${prefix}:` + this.prefixServer = `${this.prefix}${SRVR}:` + this.prefixClient = `${this.prefix}${CLNT}:` + + this.indexClntChnl = `${this.prefix}__idx_${CLNT}_${CHNL}` } async connect() { await this.redis.ping() - - const list = await this.listIndexes() - // TODO load lua to redis + await this.migrate() } async registerServer(serverIp: string, ttl?: number) { @@ -112,6 +137,13 @@ export class WSDiscovery { return clientId } + async getClient(clientId: number, fields: ClientFields[] = [CHNL, SID, SRVR]) { + return this.redis.hmget( + this.getClientKey(clientId), + ...fields, + ) + } + /** * * @returns @@ -174,7 +206,7 @@ export class WSDiscovery { script, 1, [this.getClientKey(clientId), CHNL, channel], - ) + ) as Promise } async removeChannel(clientId: number, channel: string) { @@ -207,12 +239,113 @@ export class WSDiscovery { ) } + async getClientsByChannel(channel: string, batch = 100): Promise { + type KeyAndKey = ['__key', string] + type AggregateAndCursorResponse = [[1, ...KeyAndKey[]], number] + + let [aggregateResult, cursor] = await this.redis.call( + 'FT.AGGREGATE', + this.indexClntChnl, + `@${CHNL}:{${channel}}`, + 'LOAD', 1, '@__key', + 'WITHCURSOR', + 'COUNT', batch, + ) as AggregateAndCursorResponse + + const keys: string[] = [] + + while (true) { + keys.push(...((aggregateResult.slice(1) as KeyAndKey[]).map(([_key, key]) => key))) + if (!cursor) { + break + } + + [aggregateResult, cursor] = await this.redis.call( + 'FT.CURSOR', 'READ', this.indexClntChnl, cursor + ) as AggregateAndCursorResponse + } + + const hgetResults = await this.redis.pipeline( + keys.map((k) => ['hget', k, SRVR]), + ).exec() + + if (!hgetResults) { + throw new Error('multiple hget error') + } + + const results: ClientWithServer[] = [] + + for (const [index, key] of keys.entries()) { + const [err, serverId] = hgetResults[index] + + if (err) { + continue + } + + results.push({ + [CLNT]: Number(key.substring(this.prefixClient.length)), + [SRVR]: Number(serverId), + }) + } + + return results + } + protected getClientKey(clientId: number) { return this.prefixClient + clientId } - protected listIndexes() { - return this.redis.call('FT._LIST') + protected async lock(key: string, token: string) { + + for (let i = 0; i < 1000; ++i) { + const result = await this.redis.set(key, token, 'EX', 30, 'NX') + if (result) { + return + } + + await sleep(500) + } + throw new Error(`can not take redis lock on key=${key}`) + } + + protected async unlock(key: string, token: string) { + await this.redis.eval(` + if redis.call("get",KEYS[1]) == ARGV[1] then + return redis.call("del",KEYS[1]) + else + return 0 + end + `, 1, key, token) + } + + protected async migrate() { + const migrations: Array<[string, string[]]> = [ + ['FT.CREATE', `${this.indexClntChnl} PREFIX 1 ${this.prefixClient} SCHEMA ${CHNL} TAG`.split(' ')], + ] + + const token = `${Date.now()}_${Math.floor(Math.random() * 99999999)}` + const migrationsKey = this.prefix + __MIGRATIONS + const lockKey = migrationsKey + '_lock' + + await this.lock(lockKey, token) + + const applyedMigrations = new Set(await this.redis.smembers(migrationsKey)) + + for (let i = 0; i < migrations.length; ++i) { + const migrationId = 'm' + i + if (applyedMigrations.has(migrationId)) { + continue + } + + const migration = migrations[i] + // TODO there can be logical errors inside transaction!! + await this.redis + .multi() + .call(migration[0], migration[1]) + .sadd(migrationsKey, migrationId) + .exec() + } + await this.unlock(lockKey, token) } } \ No newline at end of file diff --git a/src/utils.ts b/src/utils.ts new file mode 100644 index 0000000..4123ecc --- /dev/null +++ b/src/utils.ts @@ -0,0 +1,3 @@ +export const sleep = (ms: number) => new Promise((resolve) => { + setTimeout(resolve, ms) +}) \ No newline at end of file diff --git a/test/channel.test.ts b/test/channel.test.ts index 9c24299..c39b913 100644 --- a/test/channel.test.ts +++ b/test/channel.test.ts @@ -1,8 +1,9 @@ -import { deepEqual } from 'assert/strict' -import { describe, it, before, after } from 'node:test' +import { deepEqual, rejects } from 'assert/strict' +import { describe, it, before, after, beforeEach, afterEach } from 'node:test' + +import { CLNT, SRVR } from '../src/constants' import { clearRedis, createRedis, WSDiscoveryForTests } from './utils' -import { rejects } from 'assert' describe('Channels', () => { const redis = createRedis() @@ -22,16 +23,24 @@ describe('Channels', () => { before(async () => { await wsd.connect() - serverId1 = await wsd.registerServer(ip1) - serverId2 = await wsd.registerServer(ip2) + serverId1 = await wsd.registerServer(ip1, 300) + serverId2 = await wsd.registerServer(ip2, 300) + }) + beforeEach(async () => { clientId1 = await wsd.registerClient(serverId1, 1) clientId2 = await wsd.registerClient(serverId1, 2) clientId3 = await wsd.registerClient(serverId2, 1) }) + afterEach(async () => { + await wsd.deleteClient(clientId1) + await wsd.deleteClient(clientId2) + await wsd.deleteClient(clientId3) + }) + after(async () => { - await clearRedis(redis, '') + await clearRedis(redis, wsd.prefix) await redis.quit() }) @@ -80,5 +89,95 @@ describe('Channels', () => { return err instanceof Error && err.message === 'Empty channel is not allowed' }) }) - + + it('getClientsByChannel() empty', async () => { + deepEqual( + await wsd.getClientsByChannel('xyz'), + [], + ) + }) + + it('getClientsByChannel() return one', async () => { + await wsd.addChannel(clientId1, 'abc') + + deepEqual( + await wsd.getClientsByChannel('xyz'), + [], + ) + }) + + it('getClientsByChannel() one', async () => { + await wsd.addChannel(clientId1, 'abc') + await wsd.addChannel(clientId2, 'xyz') + + deepEqual( + await wsd.getClientsByChannel('xyz'), + [{ + [CLNT]: clientId2, + [SRVR]: serverId1, + }], + ) + }) + + it('getClientsByChannel() return two', async () => { + await wsd.addChannel(clientId1, 'xyz') + await wsd.addChannel(clientId2, 'abc') + await wsd.addChannel(clientId3, 'xyz') + + deepEqual( + await wsd.getClientsByChannel('xyz'), + [ + { + [CLNT]: clientId1, + [SRVR]: serverId1, + }, + { + [CLNT]: clientId3, + [SRVR]: serverId2, + }, + ], + ) + }) + + it('getClientsByChannel() with batch=1', async () => { + await wsd.addChannel(clientId1, 'xyz') + await wsd.addChannel(clientId3, 'xyz') + + deepEqual( + await wsd.getClientsByChannel('xyz'), + [ + { + [CLNT]: clientId1, + [SRVR]: serverId1, + }, + { + [CLNT]: clientId3, + [SRVR]: serverId2, + }, + ], + ) + }) + + it('getClientsByChannel() multiple channels', async () => { + await wsd.addChannel(clientId1, 'xyz') + await wsd.addChannel(clientId1, 'abc') + await wsd.addChannel(clientId1, '123') + + await wsd.addChannel(clientId3, 'qwerty') + await wsd.addChannel(clientId3, 'xyz') + + deepEqual( + await wsd.getClientsByChannel('xyz'), + [ + { + [CLNT]: clientId1, + [SRVR]: serverId1, + }, + { + [CLNT]: clientId3, + [SRVR]: serverId2, + }, + ], + ) + }) }) diff --git a/test/client.test.ts b/test/client.test.ts index 2bd25fd..bcac92c 100644 --- a/test/client.test.ts +++ b/test/client.test.ts @@ -1,9 +1,10 @@ -import { strictEqual, rejects } from 'assert' +import { equal, rejects } from 'assert/strict' import { describe, it, before, after } from 'node:test' -import { clearRedis, createRedis, sleep, WSDiscoveryForTests } from './utils' import { MAX_INT_ID } from '../src/constants' +import { clearRedis, createRedis, sleep, WSDiscoveryForTests } from './utils' + describe('Client', () => { const redis = createRedis() const wsd = new WSDiscoveryForTests({ @@ -23,24 +24,24 @@ describe('Client', () => { }) after(async () => { - await clearRedis(redis, '') + await clearRedis(redis, wsd.prefix) await redis.quit() }) it('registerClient() OK', async () => { const cid = await wsd.registerClient(serverId1, 1) - strictEqual(typeof cid, 'number') + equal(typeof cid, 'number') const serverId = await wsd.getServerIdByClientId(cid) - strictEqual(serverId, serverId1) + equal(serverId, serverId1) }) it('registerClient() twice', async () => { const cid1 = await wsd.registerClient(serverId1, 1) const cid2 = await wsd.registerClient(serverId2, 2) - strictEqual(cid1 + 1, cid2) + equal(cid1 + 1, cid2) }) @@ -55,19 +56,19 @@ describe('Client', () => { await wsd.registerClient(serverId1, 1) const newId = await wsd.registerClient(serverId2, 2) - strictEqual(newId, 1) + equal(newId, 1) }) it('updateClientTTL()', async () => { const cid = await wsd.registerClient(serverId1, 11) const result = await wsd.updateClientTTL(cid, 1000) - strictEqual(result, true) + equal(result, true) const ttl = await wsd.getClientTTL(cid) - strictEqual(ttl > 1000 * 0.99, true) - strictEqual(ttl <= 1000, true) + equal(ttl > 1000 * 0.99, true) + equal(ttl <= 1000, true) }) it('updateClientTTL() expired', async () => { @@ -75,7 +76,7 @@ describe('Client', () => { const result = await wsd.updateClientTTL(cid, 10) - strictEqual(result, false) + equal(result, false) }) it('updateClientTTL() bad ttl', async () => { @@ -90,17 +91,17 @@ describe('Client', () => { it('client ttl expires', async () => { const cid = await wsd.registerClient(serverId1, 1, 1) - strictEqual(await wsd.getServerIdByClientId(cid), serverId1) + equal(await wsd.getServerIdByClientId(cid), serverId1) await sleep(1000) - strictEqual(await wsd.getServerIdByClientId(cid), 0) + equal(await wsd.getServerIdByClientId(cid), 0) }) it('delete client', async () => { const cid = await wsd.registerClient(serverId2, 2, 2) - strictEqual(await wsd.getServerIdByClientId(cid), serverId2) - strictEqual(await wsd.deleteClient(cid), true) - strictEqual(await wsd.getServerIdByClientId(cid), 0) + equal(await wsd.getServerIdByClientId(cid), serverId2) + equal(await wsd.deleteClient(cid), true) + equal(await wsd.getServerIdByClientId(cid), 0) }) }) diff --git a/test/server.test.ts b/test/server.test.ts index 49298b4..85a7bca 100644 --- a/test/server.test.ts +++ b/test/server.test.ts @@ -1,9 +1,10 @@ -import { strictEqual, rejects } from 'assert' +import { equal, rejects } from 'assert/strict' import { describe, it, before, after } from 'node:test' -import { clearRedis, createRedis, sleep, WSDiscoveryForTests } from './utils' import { MAX_INT_ID } from '../src/constants' +import { clearRedis, createRedis, sleep, WSDiscoveryForTests } from './utils' + describe('Server', () => { const redis = createRedis() const wsd = new WSDiscoveryForTests({ @@ -15,7 +16,7 @@ describe('Server', () => { }) after(async () => { - await clearRedis(redis, '') + await clearRedis(redis, wsd.prefix) await redis.quit() }) @@ -25,17 +26,17 @@ describe('Server', () => { it('registerServer() OK', async () => { const sid = await wsd.registerServer(ip1) - strictEqual(typeof sid, 'number') + equal(typeof sid, 'number') const ip = await wsd.getServerIp(sid) - strictEqual(ip, ip1) + equal(ip, ip1) }) it('registerServer() twice', async () => { const sid1 = await wsd.registerServer(ip1) const sid2 = await wsd.registerServer(ip2) - strictEqual(sid1 + 1, sid2) + equal(sid1 + 1, sid2) }) @@ -53,7 +54,7 @@ describe('Server', () => { await wsd.registerServer('abc') const newId = await wsd.registerServer('bcd') - strictEqual(newId, 1) + equal(newId, 1) }) it('updateServerTTL()', async () => { @@ -63,8 +64,8 @@ describe('Server', () => { const ttl = await wsd.getServerTTL(sid) - strictEqual(ttl > 1000 * 0.99, true) - strictEqual(ttl <= 1000, true) + equal(ttl > 1000 * 0.99, true) + equal(ttl <= 1000, true) }) it('updateServerTTL() expired', async () => { @@ -72,7 +73,7 @@ describe('Server', () => { const result = await wsd.updateServerTTL(sid, 10) - strictEqual(result, false) + equal(result, false) }) it('updateServerTTL() bad ttl', async () => { @@ -86,9 +87,9 @@ describe('Server', () => { it('server ttl expires', async () => { const sid = await wsd.registerServer(ip1, 1) - strictEqual(await wsd.getServerIp(sid), ip1) + equal(await wsd.getServerIp(sid), ip1) await sleep(1000) - strictEqual(await wsd.getServerIp(sid), null) + equal(await wsd.getServerIp(sid), null) }) }) diff --git a/test/utils.ts b/test/utils.ts index 64c42f2..a3f363d 100644 --- a/test/utils.ts +++ b/test/utils.ts @@ -21,7 +21,7 @@ export const clearRedis = async (redis: Redis, prefix: string) => { const [nextCursor, keys] = await redis.scan(cursor, 'MATCH', prefix + '*') if (keys.length) { - await redis.del(...keys) + await redis.del(...(keys.filter((k) => !k.startsWith(prefix + '__')))) } cursor = nextCursor } while (cursor !== '0')