From 36dbd4f35cf7e4bcc7d6548f109074272e6d8768 Mon Sep 17 00:00:00 2001 From: Anton <14254374+0xmad@users.noreply.github.com> Date: Fri, 24 May 2024 07:56:42 -0500 Subject: [PATCH] feat: coordinator public key method - [x] Add api method for getting rsa public key - [x] Use dependency injection for crypto service - [x] Move file related logic to a separate service - [x] Add explicit public by-pass for authorization --- coordinator/tests/app.test.ts | 54 ++++++------- coordinator/ts/app.controller.test.ts | 34 +++++++- coordinator/ts/app.controller.ts | 31 ++++++- coordinator/ts/app.module.ts | 4 + .../ts/auth/AccountSignatureGuard.service.ts | 42 ++++++++-- .../__tests__/AccountSignatureGuard.test.ts | 51 ++++++++---- coordinator/ts/common/errors.ts | 1 + .../crypto/__tests__/crypto.service.test.ts | 6 +- coordinator/ts/crypto/crypto.module.ts | 9 +++ coordinator/ts/crypto/crypto.service.ts | 25 +----- .../ts/file/__tests__/file.service.test.ts | 74 +++++++++++++++++ coordinator/ts/file/file.module.ts | 9 +++ coordinator/ts/file/file.service.ts | 81 +++++++++++++++++++ coordinator/ts/file/types.ts | 39 +++++++++ .../ts/proof/__tests__/proof.service.test.ts | 38 +++++---- coordinator/ts/proof/proof.service.ts | 42 +++------- 16 files changed, 411 insertions(+), 129 deletions(-) create mode 100644 coordinator/ts/crypto/crypto.module.ts create mode 100644 coordinator/ts/file/__tests__/file.service.test.ts create mode 100644 coordinator/ts/file/file.module.ts create mode 100644 coordinator/ts/file/file.service.ts create mode 100644 coordinator/ts/file/types.ts diff --git a/coordinator/tests/app.test.ts b/coordinator/tests/app.test.ts index 390f068f..63ac3391 100644 --- a/coordinator/tests/app.test.ts +++ b/coordinator/tests/app.test.ts @@ -26,6 +26,7 @@ import type { App } from "supertest/types"; import { AppModule } from "../ts/app.module"; import { ErrorCodes } from "../ts/common"; import { CryptoService } from "../ts/crypto/crypto.service"; +import { FileModule } from "../ts/file/file.module"; const STATE_TREE_DEPTH = 10; const INT_STATE_TREE_DEPTH = 1; @@ -40,11 +41,13 @@ describe("AppController (e2e)", () => { let maciAddresses: DeployedContracts; let pollContracts: PollContracts; + const cryptoService = new CryptoService(); + const getAuthorizationHeader = async () => { const publicKey = await fs.promises.readFile(process.env.COORDINATOR_PUBLIC_KEY_PATH!); const signature = await signer.signMessage("message"); const digest = Buffer.from(getBytes(hashMessage("message"))).toString("hex"); - return `Bearer ${CryptoService.getInstance().encrypt(publicKey, `${signature}:${digest}`)}`; + return `Bearer ${cryptoService.encrypt(publicKey, `${signature}:${digest}`)}`; }; beforeAll(async () => { @@ -88,7 +91,7 @@ describe("AppController (e2e)", () => { beforeEach(async () => { const moduleFixture = await Test.createTestingModule({ - imports: [AppModule], + imports: [AppModule, FileModule], }).compile(); app = moduleFixture.createNestApplication(); @@ -117,10 +120,7 @@ describe("AppController (e2e)", () => { test("should throw an error if poll id is invalid", async () => { const publicKey = await fs.promises.readFile(process.env.COORDINATOR_PUBLIC_KEY_PATH!); - const encryptedCoordinatorPrivateKey = CryptoService.getInstance().encrypt( - publicKey, - coordinatorKeypair.privKey.serialize(), - ); + const encryptedCoordinatorPrivateKey = cryptoService.encrypt(publicKey, coordinatorKeypair.privKey.serialize()); const encryptedHeader = await getAuthorizationHeader(); const result = await request(app.getHttpServer() as App) @@ -166,10 +166,7 @@ describe("AppController (e2e)", () => { test("should throw an error if maci address is invalid", async () => { const publicKey = await fs.promises.readFile(process.env.COORDINATOR_PUBLIC_KEY_PATH!); - const encryptedCoordinatorPrivateKey = CryptoService.getInstance().encrypt( - publicKey, - coordinatorKeypair.privKey.serialize(), - ); + const encryptedCoordinatorPrivateKey = cryptoService.encrypt(publicKey, coordinatorKeypair.privKey.serialize()); const encryptedHeader = await getAuthorizationHeader(); const result = await request(app.getHttpServer() as App) @@ -193,10 +190,7 @@ describe("AppController (e2e)", () => { test("should throw an error if tally address is invalid", async () => { const publicKey = await fs.promises.readFile(process.env.COORDINATOR_PUBLIC_KEY_PATH!); - const encryptedCoordinatorPrivateKey = CryptoService.getInstance().encrypt( - publicKey, - coordinatorKeypair.privKey.serialize(), - ); + const encryptedCoordinatorPrivateKey = cryptoService.encrypt(publicKey, coordinatorKeypair.privKey.serialize()); const encryptedHeader = await getAuthorizationHeader(); const result = await request(app.getHttpServer() as App) @@ -219,6 +213,18 @@ describe("AppController (e2e)", () => { }); }); + describe("/v1/proof/publicKey GET", () => { + test("should get public key properly", async () => { + const publicKey = await fs.promises.readFile(process.env.COORDINATOR_PUBLIC_KEY_PATH!); + + const result = await request(app.getHttpServer() as App) + .get("/v1/proof/publicKey") + .expect(200); + + expect(result.body).toStrictEqual({ publicKey: publicKey.toString() }); + }); + }); + describe("/v1/proof/generate POST", () => { beforeAll(async () => { const user = new Keypair(); @@ -240,10 +246,7 @@ describe("AppController (e2e)", () => { test("should throw an error if poll is not over", async () => { const publicKey = await fs.promises.readFile(process.env.COORDINATOR_PUBLIC_KEY_PATH!); - const encryptedCoordinatorPrivateKey = CryptoService.getInstance().encrypt( - publicKey, - coordinatorKeypair.privKey.serialize(), - ); + const encryptedCoordinatorPrivateKey = cryptoService.encrypt(publicKey, coordinatorKeypair.privKey.serialize()); const encryptedHeader = await getAuthorizationHeader(); const result = await request(app.getHttpServer() as App) @@ -266,10 +269,7 @@ describe("AppController (e2e)", () => { test("should throw an error if signups are not merged", async () => { const publicKey = await fs.promises.readFile(process.env.COORDINATOR_PUBLIC_KEY_PATH!); - const encryptedCoordinatorPrivateKey = CryptoService.getInstance().encrypt( - publicKey, - coordinatorKeypair.privKey.serialize(), - ); + const encryptedCoordinatorPrivateKey = cryptoService.encrypt(publicKey, coordinatorKeypair.privKey.serialize()); const encryptedHeader = await getAuthorizationHeader(); const result = await request(app.getHttpServer() as App) @@ -295,10 +295,7 @@ describe("AppController (e2e)", () => { await mergeSignups({ pollId: 0n, signer }); const publicKey = await fs.promises.readFile(process.env.COORDINATOR_PUBLIC_KEY_PATH!); - const encryptedCoordinatorPrivateKey = CryptoService.getInstance().encrypt( - publicKey, - coordinatorKeypair.privKey.serialize(), - ); + const encryptedCoordinatorPrivateKey = cryptoService.encrypt(publicKey, coordinatorKeypair.privKey.serialize()); const encryptedHeader = await getAuthorizationHeader(); const result = await request(app.getHttpServer() as App) @@ -405,10 +402,7 @@ describe("AppController (e2e)", () => { test("should generate proofs properly", async () => { const publicKey = await fs.promises.readFile(process.env.COORDINATOR_PUBLIC_KEY_PATH!); - const encryptedCoordinatorPrivateKey = CryptoService.getInstance().encrypt( - publicKey, - coordinatorKeypair.privKey.serialize(), - ); + const encryptedCoordinatorPrivateKey = cryptoService.encrypt(publicKey, coordinatorKeypair.privKey.serialize()); const encryptedHeader = await getAuthorizationHeader(); await request(app.getHttpServer() as App) diff --git a/coordinator/ts/app.controller.test.ts b/coordinator/ts/app.controller.test.ts index 89caa450..d059f3ac 100644 --- a/coordinator/ts/app.controller.test.ts +++ b/coordinator/ts/app.controller.test.ts @@ -1,10 +1,12 @@ import { HttpException, HttpStatus } from "@nestjs/common"; import { Test } from "@nestjs/testing"; +import type { IGetPublicKeyData } from "./file/types"; import type { IGenerateArgs, IGenerateData } from "./proof/types"; import type { TallyData } from "maci-cli"; import { AppController } from "./app.controller"; +import { FileService } from "./file/file.service"; import { ProofGeneratorService } from "./proof/proof.service"; describe("AppController", () => { @@ -25,10 +27,18 @@ describe("AppController", () => { tallyData: {} as TallyData, }; + const defaultPublicKeyData: IGetPublicKeyData = { + publicKey: "key", + }; + const mockGeneratorService = { generate: jest.fn(), }; + const mockFileService = { + getPublicKey: jest.fn(), + }; + beforeEach(async () => { const app = await Test.createTestingModule({ controllers: [AppController], @@ -40,6 +50,12 @@ describe("AppController", () => { return mockGeneratorService; } + if (token === FileService) { + mockFileService.getPublicKey.mockResolvedValue(defaultPublicKeyData); + + return mockFileService; + } + return jest.fn(); }) .compile(); @@ -51,7 +67,7 @@ describe("AppController", () => { jest.clearAllMocks(); }); - describe("v1/proof", () => { + describe("v1/proof/generate", () => { test("should return generated proof data", async () => { const data = await appController.generate(defaultProofGeneratorArgs); expect(data).toStrictEqual(defaultProofGeneratorData); @@ -66,4 +82,20 @@ describe("AppController", () => { ); }); }); + + describe("v1/proof/publicKey", () => { + test("should return public key properly", async () => { + const data = await appController.getPublicKey(); + expect(data).toStrictEqual(defaultPublicKeyData); + }); + + test("should throw an error if file service throws an error", async () => { + const error = new Error("error"); + mockFileService.getPublicKey.mockRejectedValue(error); + + await expect(appController.getPublicKey()).rejects.toThrow( + new HttpException(error.message, HttpStatus.BAD_REQUEST), + ); + }); + }); }); diff --git a/coordinator/ts/app.controller.ts b/coordinator/ts/app.controller.ts index 71f58477..39e01c13 100644 --- a/coordinator/ts/app.controller.ts +++ b/coordinator/ts/app.controller.ts @@ -1,10 +1,13 @@ -import { Body, Controller, HttpException, HttpStatus, Logger, Post, UseGuards } from "@nestjs/common"; +import { Body, Controller, Get, HttpException, HttpStatus, Logger, Post, UseGuards } from "@nestjs/common"; import { ApiBearerAuth, ApiBody, ApiResponse, ApiTags } from "@nestjs/swagger"; -import { AccountSignatureGuard } from "./auth/AccountSignatureGuard.service"; +import type { IGetPublicKeyData } from "./file/types"; +import type { IGenerateData } from "./proof/types"; + +import { AccountSignatureGuard, Public } from "./auth/AccountSignatureGuard.service"; +import { FileService } from "./file/file.service"; import { GenerateProofDto } from "./proof/dto"; import { ProofGeneratorService } from "./proof/proof.service"; -import { IGenerateData } from "./proof/types"; @ApiTags("v1/proof") @ApiBearerAuth() @@ -20,8 +23,12 @@ export class AppController { * Initialize AppController * * @param proofGeneratorService - proof generator service + * @param fileService - file service */ - constructor(private readonly proofGeneratorService: ProofGeneratorService) {} + constructor( + private readonly proofGeneratorService: ProofGeneratorService, + private readonly fileService: FileService, + ) {} /** * Generate proofs api method @@ -40,4 +47,20 @@ export class AppController { throw new HttpException(error.message, HttpStatus.BAD_REQUEST); }); } + + /** + * Get RSA public key for authorization setup + * + * @returns RSA public key + */ + @ApiResponse({ status: HttpStatus.OK, description: "Public key was successfully returned" }) + @ApiResponse({ status: HttpStatus.BAD_REQUEST, description: "BadRequest" }) + @Public() + @Get("publicKey") + async getPublicKey(): Promise { + return this.fileService.getPublicKey().catch((error: Error) => { + this.logger.error(`Error:`, error); + throw new HttpException(error.message, HttpStatus.BAD_REQUEST); + }); + } } diff --git a/coordinator/ts/app.module.ts b/coordinator/ts/app.module.ts index 14c78690..9bcb94e0 100644 --- a/coordinator/ts/app.module.ts +++ b/coordinator/ts/app.module.ts @@ -2,6 +2,8 @@ import { Module } from "@nestjs/common"; import { ThrottlerModule } from "@nestjs/throttler"; import { AppController } from "./app.controller"; +import { CryptoModule } from "./crypto/crypto.module"; +import { FileModule } from "./file/file.module"; import { ProofGeneratorService } from "./proof/proof.service"; @Module({ @@ -12,6 +14,8 @@ import { ProofGeneratorService } from "./proof/proof.service"; limit: Number(process.env.LIMIT), }, ]), + FileModule, + CryptoModule, ], controllers: [AppController], providers: [ProofGeneratorService], diff --git a/coordinator/ts/auth/AccountSignatureGuard.service.ts b/coordinator/ts/auth/AccountSignatureGuard.service.ts index 4d3d9333..bfbbfd42 100644 --- a/coordinator/ts/auth/AccountSignatureGuard.service.ts +++ b/coordinator/ts/auth/AccountSignatureGuard.service.ts @@ -1,4 +1,12 @@ -import { Logger, CanActivate, type ExecutionContext, Injectable } from "@nestjs/common"; +import { + Logger, + CanActivate, + Injectable, + SetMetadata, + type ExecutionContext, + type CustomDecorator, +} from "@nestjs/common"; +import { Reflector } from "@nestjs/core"; import { ethers } from "ethers"; import fs from "fs"; @@ -8,6 +16,18 @@ import type { Request as Req } from "express"; import { CryptoService } from "../crypto/crypto.service"; +/** + * Public metadata key + */ +export const PUBLIC_METADATA_KEY = "isPublic"; + +/** + * Public decorator to by-pass auth checks + * + * @returns public decorator + */ +export const Public = (): CustomDecorator => SetMetadata(PUBLIC_METADATA_KEY, true); + /** * AccountSignatureGuard is responsible for protecting calling controller functions. * If account address is not added to .env file, you will not be allowed to call any API methods. @@ -24,15 +44,17 @@ import { CryptoService } from "../crypto/crypto.service"; */ @Injectable() export class AccountSignatureGuard implements CanActivate { - /** - * Crypto service - */ - private readonly cryptoService = CryptoService.getInstance(); - /** * Logger */ - private readonly logger = new Logger(AccountSignatureGuard.name); + private readonly logger: Logger; + + constructor( + private readonly cryptoService: CryptoService, + private readonly reflector: Reflector, + ) { + this.logger = new Logger(AccountSignatureGuard.name); + } /** * This function should return a boolean, indicating whether the request is allowed or not based on message signature and digest. @@ -42,6 +64,12 @@ export class AccountSignatureGuard implements CanActivate { */ async canActivate(ctx: ExecutionContext): Promise { try { + const isPublic = this.reflector.get(PUBLIC_METADATA_KEY, ctx.getHandler()); + + if (isPublic) { + return true; + } + const request = ctx.switchToHttp().getRequest(); const encryptedHeader = request.headers.authorization; diff --git a/coordinator/ts/auth/__tests__/AccountSignatureGuard.test.ts b/coordinator/ts/auth/__tests__/AccountSignatureGuard.test.ts index 0ab00a99..a7e3f2b7 100644 --- a/coordinator/ts/auth/__tests__/AccountSignatureGuard.test.ts +++ b/coordinator/ts/auth/__tests__/AccountSignatureGuard.test.ts @@ -1,3 +1,4 @@ +import { Reflector } from "@nestjs/core"; import dotenv from "dotenv"; import { getBytes, hashMessage } from "ethers"; import hardhat from "hardhat"; @@ -5,7 +6,7 @@ import hardhat from "hardhat"; import type { ExecutionContext } from "@nestjs/common"; import { CryptoService } from "../../crypto/crypto.service"; -import { AccountSignatureGuard } from "../AccountSignatureGuard.service"; +import { AccountSignatureGuard, PUBLIC_METADATA_KEY, Public } from "../AccountSignatureGuard.service"; dotenv.config(); @@ -21,6 +22,7 @@ describe("AccountSignatureGuard", () => { }; const mockContext = { + getHandler: jest.fn(), switchToHttp: jest.fn().mockReturnValue({ getRequest: jest.fn(() => mockRequest), }), @@ -32,20 +34,30 @@ describe("AccountSignatureGuard", () => { const mockCryptoService = { decrypt: jest.fn(), - }; + } as unknown as CryptoService; + + const reflector = { + get: jest.fn(), + } as Reflector & { get: jest.Mock }; beforeEach(() => { mockCryptoService.decrypt = jest.fn(() => `${mockSignature}:${mockDigest}`); - - (CryptoService.getInstance as jest.Mock).mockReturnValue(mockCryptoService); + reflector.get.mockReturnValue(false); }); afterEach(() => { jest.clearAllMocks(); }); + test("should create public decorator properly", () => { + const decorator = Public(); + + expect(decorator.KEY).toBe(PUBLIC_METADATA_KEY); + }); + test("should return false if there is no Authorization header", async () => { const ctx = { + getHandler: jest.fn(), switchToHttp: jest.fn().mockReturnValue({ getRequest: jest.fn(() => ({ headers: { authorization: "" }, @@ -53,7 +65,7 @@ describe("AccountSignatureGuard", () => { }), } as unknown as ExecutionContext; - const guard = new AccountSignatureGuard(); + const guard = new AccountSignatureGuard(mockCryptoService, reflector); const result = await guard.canActivate(ctx); @@ -61,9 +73,9 @@ describe("AccountSignatureGuard", () => { }); test("should return false if there is no signature", async () => { - mockCryptoService.decrypt.mockReturnValue(`:${mockDigest}`); + (mockCryptoService.decrypt as jest.Mock).mockReturnValue(`:${mockDigest}`); - const guard = new AccountSignatureGuard(); + const guard = new AccountSignatureGuard(mockCryptoService, reflector); const result = await guard.canActivate(mockContext); @@ -71,9 +83,9 @@ describe("AccountSignatureGuard", () => { }); test("should return false if there is no digest", async () => { - mockCryptoService.decrypt.mockReturnValue(mockSignature); + (mockCryptoService.decrypt as jest.Mock).mockReturnValue(mockSignature); - const guard = new AccountSignatureGuard(); + const guard = new AccountSignatureGuard(mockCryptoService, reflector); const result = await guard.canActivate(mockContext); @@ -81,9 +93,9 @@ describe("AccountSignatureGuard", () => { }); test("should return false if signature or digest are invalid", async () => { - mockCryptoService.decrypt.mockReturnValue(`signature:digest`); + (mockCryptoService.decrypt as jest.Mock).mockReturnValue(`signature:digest`); - const guard = new AccountSignatureGuard(); + const guard = new AccountSignatureGuard(mockCryptoService, reflector); const result = await guard.canActivate(mockContext); @@ -95,9 +107,9 @@ describe("AccountSignatureGuard", () => { const signature = await signer.signMessage("message"); const digest = Buffer.from(getBytes(hashMessage("message"))).toString("hex"); - mockCryptoService.decrypt.mockReturnValue(`${signature}:${digest}`); + (mockCryptoService.decrypt as jest.Mock).mockReturnValue(`${signature}:${digest}`); - const guard = new AccountSignatureGuard(); + const guard = new AccountSignatureGuard(mockCryptoService, reflector); const result = await guard.canActivate(mockContext); @@ -110,9 +122,18 @@ describe("AccountSignatureGuard", () => { const signature = await signer.signMessage("message"); const digest = Buffer.from(getBytes(hashMessage("message"))).toString("hex"); - mockCryptoService.decrypt.mockReturnValue(`${signature}:${digest}`); + (mockCryptoService.decrypt as jest.Mock).mockReturnValue(`${signature}:${digest}`); + + const guard = new AccountSignatureGuard(mockCryptoService, reflector); + + const result = await guard.canActivate(mockContext); + + expect(result).toBe(true); + }); - const guard = new AccountSignatureGuard(); + test("should return true if can skip authorization", async () => { + reflector.get.mockReturnValue(true); + const guard = new AccountSignatureGuard(mockCryptoService, reflector); const result = await guard.canActivate(mockContext); diff --git a/coordinator/ts/common/errors.ts b/coordinator/ts/common/errors.ts index 621970a6..9d38bf37 100644 --- a/coordinator/ts/common/errors.ts +++ b/coordinator/ts/common/errors.ts @@ -8,4 +8,5 @@ export enum ErrorCodes { POLL_NOT_FOUND = "3", DECRYPTION = "4", ENCRYPTION = "5", + FILE_NOT_FOUND = "6", } diff --git a/coordinator/ts/crypto/__tests__/crypto.service.test.ts b/coordinator/ts/crypto/__tests__/crypto.service.test.ts index 493e098e..8643cd89 100644 --- a/coordinator/ts/crypto/__tests__/crypto.service.test.ts +++ b/coordinator/ts/crypto/__tests__/crypto.service.test.ts @@ -7,13 +7,13 @@ import { CryptoService } from "../crypto.service"; describe("CryptoService", () => { test("should throw encryption error if key is invalid", () => { - const service = CryptoService.getInstance(); + const service = new CryptoService(); expect(() => service.encrypt("", "")).toThrow(ErrorCodes.ENCRYPTION); }); test("should throw decryption error if key is invalid", () => { - const service = CryptoService.getInstance(); + const service = new CryptoService(); expect(() => service.decrypt("", "")).toThrow(ErrorCodes.DECRYPTION); }); @@ -21,7 +21,7 @@ describe("CryptoService", () => { test("should encrypt and decrypt properly", () => { fc.assert( fc.property(fc.string(), (text: string) => { - const service = CryptoService.getInstance(); + const service = new CryptoService(); const keypair = generateKeyPairSync("rsa", { modulusLength: 2048, diff --git a/coordinator/ts/crypto/crypto.module.ts b/coordinator/ts/crypto/crypto.module.ts new file mode 100644 index 00000000..fcb331b7 --- /dev/null +++ b/coordinator/ts/crypto/crypto.module.ts @@ -0,0 +1,9 @@ +import { Module } from "@nestjs/common"; + +import { CryptoService } from "./crypto.service"; + +@Module({ + exports: [CryptoService], + providers: [CryptoService], +}) +export class CryptoModule {} diff --git a/coordinator/ts/crypto/crypto.service.ts b/coordinator/ts/crypto/crypto.service.ts index c6f7748f..7875c12c 100644 --- a/coordinator/ts/crypto/crypto.service.ts +++ b/coordinator/ts/crypto/crypto.service.ts @@ -1,4 +1,4 @@ -import { Logger } from "@nestjs/common"; +import { Injectable, Logger } from "@nestjs/common"; import { publicEncrypt, privateDecrypt, type KeyLike } from "crypto"; @@ -7,37 +7,20 @@ import { ErrorCodes } from "../common"; /** * CryptoService is responsible for encrypting and decrypting user sensitive data */ +@Injectable() export class CryptoService { - /** - * Singleton instance - */ - private static INSTANCE?: CryptoService; - /** * Logger */ private readonly logger: Logger; /** - * Empty constructor + * Initialize service */ - private constructor() { + constructor() { this.logger = new Logger(CryptoService.name); } - /** - * Get singleton crypto service instance - * - * @returns crypto service instance - */ - static getInstance(): CryptoService { - if (!CryptoService.INSTANCE) { - CryptoService.INSTANCE = new CryptoService(); - } - - return CryptoService.INSTANCE; - } - /** * Encrypt plaintext with public key * diff --git a/coordinator/ts/file/__tests__/file.service.test.ts b/coordinator/ts/file/__tests__/file.service.test.ts new file mode 100644 index 00000000..c87c47e4 --- /dev/null +++ b/coordinator/ts/file/__tests__/file.service.test.ts @@ -0,0 +1,74 @@ +import dotenv from "dotenv"; + +import { ErrorCodes } from "../../common"; +import { FileService } from "../file.service"; + +dotenv.config(); + +describe("FileService", () => { + afterEach(() => { + jest.clearAllMocks(); + }); + + test("should return public key properly", async () => { + const service = new FileService(); + + const { publicKey } = await service.getPublicKey(); + + expect(publicKey).toBeDefined(); + }); + + test("should return private key properly", async () => { + const service = new FileService(); + + const { privateKey } = await service.getPrivateKey(); + + expect(privateKey).toBeDefined(); + }); + + test("should return zkey filepaths for tally qv properly", () => { + const service = new FileService(); + + const { zkey, wasm, witgen } = service.getZkeyFilePaths(process.env.COORDINATOR_TALLY_ZKEY_NAME!, true); + + expect(zkey).toBeDefined(); + expect(wasm).toBeDefined(); + expect(witgen).toBeDefined(); + }); + + test("should return zkey filepaths for tally non-qv properly", () => { + const service = new FileService(); + + const { zkey, wasm, witgen } = service.getZkeyFilePaths(process.env.COORDINATOR_TALLY_ZKEY_NAME!, false); + + expect(zkey).toBeDefined(); + expect(wasm).toBeDefined(); + expect(witgen).toBeDefined(); + }); + + test("should return zkey filepaths for message process qv properly", () => { + const service = new FileService(); + + const { zkey, wasm, witgen } = service.getZkeyFilePaths(process.env.COORDINATOR_MESSAGE_PROCESS_ZKEY_NAME!, true); + + expect(zkey).toBeDefined(); + expect(wasm).toBeDefined(); + expect(witgen).toBeDefined(); + }); + + test("should return zkey filepaths for message process non-qv properly", () => { + const service = new FileService(); + + const { zkey, wasm, witgen } = service.getZkeyFilePaths(process.env.COORDINATOR_MESSAGE_PROCESS_ZKEY_NAME!, false); + + expect(zkey).toBeDefined(); + expect(wasm).toBeDefined(); + expect(witgen).toBeDefined(); + }); + + test("should throw an error if there are not zkey filepaths", () => { + const service = new FileService(); + + expect(() => service.getZkeyFilePaths("unknown", false)).toThrow(ErrorCodes.FILE_NOT_FOUND); + }); +}); diff --git a/coordinator/ts/file/file.module.ts b/coordinator/ts/file/file.module.ts new file mode 100644 index 00000000..09a46ce1 --- /dev/null +++ b/coordinator/ts/file/file.module.ts @@ -0,0 +1,9 @@ +import { Module } from "@nestjs/common"; + +import { FileService } from "./file.service"; + +@Module({ + exports: [FileService], + providers: [FileService], +}) +export class FileModule {} diff --git a/coordinator/ts/file/file.service.ts b/coordinator/ts/file/file.service.ts new file mode 100644 index 00000000..d0bdff26 --- /dev/null +++ b/coordinator/ts/file/file.service.ts @@ -0,0 +1,81 @@ +import { Injectable, Logger } from "@nestjs/common"; + +import fs from "fs"; +import path from "path"; + +import type { IGetPrivateKeyData, IGetPublicKeyData, IGetZkeyFilePathsData } from "./types"; + +import { ErrorCodes } from "../common"; + +/** + * FileService is responsible for working with local files like: + * 1. RSA public/private keys + * 2. Zkey files + */ +@Injectable() +export class FileService { + /** + * Logger + */ + private readonly logger: Logger; + + /** + * Initialize service + */ + constructor() { + this.logger = new Logger(FileService.name); + } + + /** + * Get RSA private key for coordinator service + * + * @returns serialized RSA public key + */ + async getPublicKey(): Promise { + const publicKey = await fs.promises.readFile(path.resolve(process.env.COORDINATOR_PUBLIC_KEY_PATH!)); + + return { publicKey: publicKey.toString() }; + } + + /** + * Get RSA private key for coordinator service + * + * @returns serialized RSA private key + */ + async getPrivateKey(): Promise { + const privateKey = await fs.promises.readFile(path.resolve(process.env.COORDINATOR_PRIVATE_KEY_PATH!)); + + return { privateKey: privateKey.toString() }; + } + + /** + * Get zkey, wasm and witgen filepaths for zkey set + * + * @param name - zkey set name + * @param useQuadraticVoting - whether to use Qv or NonQv + * @returns zkey and wasm filepaths + */ + getZkeyFilePaths(name: string, useQuadraticVoting: boolean): IGetZkeyFilePathsData { + const root = path.resolve(process.env.COORDINATOR_ZKEY_PATH!); + const index = name.indexOf("_"); + const type = name.slice(0, index); + const params = name.slice(index + 1); + const mode = useQuadraticVoting ? "" : "NonQv"; + const filename = `${type}${mode}_${params}`; + + const zkey = path.resolve(root, `${filename}/${filename}.0.zkey`); + const wasm = path.resolve(root, `${filename}/${filename}_js/${filename}.wasm`); + const witgen = path.resolve(root, `${filename}/${filename}_cpp/${filename}`); + + if (!fs.existsSync(zkey) || (!fs.existsSync(wasm) && !fs.existsSync(witgen))) { + this.logger.error(`Error: ${ErrorCodes.FILE_NOT_FOUND}, zkey: ${zkey}, wasm: ${wasm}, witgen: ${witgen}`); + throw new Error(ErrorCodes.FILE_NOT_FOUND); + } + + return { + zkey, + wasm, + witgen, + }; + } +} diff --git a/coordinator/ts/file/types.ts b/coordinator/ts/file/types.ts new file mode 100644 index 00000000..b609a25a --- /dev/null +++ b/coordinator/ts/file/types.ts @@ -0,0 +1,39 @@ +/** + * Interface that represents public key return data + */ +export interface IGetPublicKeyData { + /** + * RSA public key + */ + publicKey: string; +} + +/** + * Interface that represents private key return data + */ +export interface IGetPrivateKeyData { + /** + * RSA private key + */ + privateKey: string; +} + +/** + * Interface that represents zkey file paths return data + */ +export interface IGetZkeyFilePathsData { + /** + * Zkey filepath + */ + zkey: string; + + /** + * Wasm filepath + */ + wasm: string; + + /** + * Witgen filepath + */ + witgen: string; +} diff --git a/coordinator/ts/proof/__tests__/proof.service.test.ts b/coordinator/ts/proof/__tests__/proof.service.test.ts index 0ae9f2f8..2a0bd95e 100644 --- a/coordinator/ts/proof/__tests__/proof.service.test.ts +++ b/coordinator/ts/proof/__tests__/proof.service.test.ts @@ -7,6 +7,7 @@ import type { IGenerateArgs } from "../types"; import { ErrorCodes } from "../../common"; import { CryptoService } from "../../crypto/crypto.service"; +import { FileService } from "../../file/file.service"; import { ProofGeneratorService } from "../proof.service"; dotenv.config(); @@ -58,9 +59,9 @@ describe("ProofGeneratorService", () => { generateTallyProofs: jest.fn(), }; - let defaultCryptoService = { + const defaultCryptoService = { decrypt: jest.fn(), - }; + } as unknown as CryptoService; const defaultDeploymentService = { setHre: jest.fn(), @@ -68,6 +69,8 @@ describe("ProofGeneratorService", () => { getContract: jest.fn(() => Promise.resolve(mockContract)), }; + const fileService = new FileService(); + beforeEach(() => { mockContract = { polls: jest.fn(() => Promise.resolve(ZeroAddress.replace("0x0", "0x1"))), @@ -88,9 +91,9 @@ describe("ProofGeneratorService", () => { generateTallyProofs: jest.fn(() => Promise.resolve({ proofs: [1], tallyData: {} })), }; - defaultCryptoService = { - decrypt: jest.fn(() => "macisk.6d5efa8ebc6f7a6ee3e9bf573346af2df29b007b29ef420c030aa4a7f3410182"), - }; + (defaultCryptoService.decrypt as jest.Mock) = jest.fn( + () => "macisk.6d5efa8ebc6f7a6ee3e9bf573346af2df29b007b29ef420c030aa4a7f3410182", + ); (Deployment.getInstance as jest.Mock).mockReturnValue(defaultDeploymentService); @@ -103,8 +106,6 @@ describe("ProofGeneratorService", () => { polls: new Map([[1n, {}]]), }), ); - - (CryptoService.getInstance as jest.Mock).mockReturnValue(defaultCryptoService); }); afterEach(() => { @@ -114,7 +115,7 @@ describe("ProofGeneratorService", () => { test("should throw error if state is not merged yet", async () => { mockContract.stateMerged.mockResolvedValue(false); - const service = new ProofGeneratorService(); + const service = new ProofGeneratorService(defaultCryptoService, fileService); await expect(service.generate(defaultArgs)).rejects.toThrow(ErrorCodes.NOT_MERGED_STATE_TREE); }); @@ -123,7 +124,7 @@ describe("ProofGeneratorService", () => { const keypair = new Keypair(new PrivKey(0n)); mockContract.coordinatorPubKey.mockResolvedValue(keypair.pubKey.asContractParam()); - const service = new ProofGeneratorService(); + const service = new ProofGeneratorService(defaultCryptoService, fileService); await expect(service.generate(defaultArgs)).rejects.toThrow(ErrorCodes.PRIVATE_KEY_MISMATCH); }); @@ -131,21 +132,28 @@ describe("ProofGeneratorService", () => { test("should throw error if there is no any poll", async () => { mockContract.getMainRoot.mockResolvedValue(0n); - const service = new ProofGeneratorService(); + const service = new ProofGeneratorService(defaultCryptoService, fileService); await expect(service.generate(defaultArgs)).rejects.toThrow(ErrorCodes.NOT_MERGED_MESSAGE_TREE); }); test("should throw error if poll is not found", async () => { - const service = new ProofGeneratorService(); + const service = new ProofGeneratorService(defaultCryptoService, fileService); + + await expect(service.generate({ ...defaultArgs, poll: 2 })).rejects.toThrow(ErrorCodes.POLL_NOT_FOUND); + }); + + test("should throw error if poll is not found in maci contract", async () => { + mockContract.polls.mockResolvedValue(ZeroAddress); + const service = new ProofGeneratorService(defaultCryptoService, fileService); await expect(service.generate({ ...defaultArgs, poll: 2 })).rejects.toThrow(ErrorCodes.POLL_NOT_FOUND); }); test("should throw error if coordinator key cannot be decrypted", async () => { - defaultCryptoService.decrypt.mockReturnValue("unknown"); + (defaultCryptoService.decrypt as jest.Mock).mockReturnValue("unknown"); - const service = new ProofGeneratorService(); + const service = new ProofGeneratorService(defaultCryptoService, fileService); await expect(service.generate({ ...defaultArgs, encryptedCoordinatorPrivateKey: "unknown" })).rejects.toThrow( "Cannot convert 0x to a BigInt", @@ -153,7 +161,7 @@ describe("ProofGeneratorService", () => { }); test("should generate proofs properly for NonQv", async () => { - const service = new ProofGeneratorService(); + const service = new ProofGeneratorService(defaultCryptoService, fileService); const data = await service.generate(defaultArgs); @@ -162,7 +170,7 @@ describe("ProofGeneratorService", () => { }); test("should generate proofs properly for Qv", async () => { - const service = new ProofGeneratorService(); + const service = new ProofGeneratorService(defaultCryptoService, fileService); const data = await service.generate({ ...defaultArgs, useQuadraticVoting: true }); diff --git a/coordinator/ts/proof/proof.service.ts b/coordinator/ts/proof/proof.service.ts index b10cf1c7..a8d249ec 100644 --- a/coordinator/ts/proof/proof.service.ts +++ b/coordinator/ts/proof/proof.service.ts @@ -4,13 +4,13 @@ import hre from "hardhat"; import { Deployment, EContracts, ProofGenerator, type Poll, type MACI, type AccQueue } from "maci-contracts"; import { Keypair, PrivKey, PubKey } from "maci-domainobjs"; -import fs from "fs"; import path from "path"; import type { IGenerateArgs, IGenerateData } from "./types"; import { ErrorCodes } from "../common"; import { CryptoService } from "../crypto/crypto.service"; +import { FileService } from "../file/file.service"; /** * ProofGeneratorService is responsible for generating message processing and tally proofs. @@ -22,11 +22,6 @@ export class ProofGeneratorService { */ private readonly deployment: Deployment; - /** - * CryptoService for user sensitive data decryption - */ - private readonly cryptoService: CryptoService; - /** * Logger */ @@ -35,10 +30,13 @@ export class ProofGeneratorService { /** * Proof generator initialization */ - constructor() { + constructor( + private readonly cryptoService: CryptoService, + private readonly fileService: FileService, + ) { this.deployment = Deployment.getInstance(hre); this.deployment.setHre(hre); - this.cryptoService = CryptoService.getInstance(); + this.fileService = fileService; this.logger = new Logger(ProofGeneratorService.name); } @@ -97,7 +95,7 @@ export class ProofGeneratorService { throw new Error(ErrorCodes.NOT_MERGED_MESSAGE_TREE); } - const privateKey = await fs.promises.readFile(path.resolve(process.env.COORDINATOR_PRIVATE_KEY_PATH!)); + const { privateKey } = await this.fileService.getPrivateKey(); const maciPrivateKey = PrivKey.deserialize(this.cryptoService.decrypt(privateKey, encryptedCoordinatorPrivateKey)); const coordinatorKeypair = new Keypair(maciPrivateKey); const publicKey = new PubKey([ @@ -136,8 +134,8 @@ export class ProofGeneratorService { poll: foundPoll, maciContractAddress, tallyContractAddress, - tally: this.getZkeyFiles(process.env.COORDINATOR_TALLY_ZKEY_NAME!, useQuadraticVoting), - mp: this.getZkeyFiles(process.env.COORDINATOR_MESSAGE_PROCESS_ZKEY_NAME!, useQuadraticVoting), + tally: this.fileService.getZkeyFilePaths(process.env.COORDINATOR_TALLY_ZKEY_NAME!, useQuadraticVoting), + mp: this.fileService.getZkeyFilePaths(process.env.COORDINATOR_MESSAGE_PROCESS_ZKEY_NAME!, useQuadraticVoting), rapidsnark: process.env.COORDINATOR_RAPIDSNARK_EXE, outputDir: path.resolve("./proofs"), tallyOutputFile: path.resolve("./tally.json"), @@ -153,26 +151,4 @@ export class ProofGeneratorService { tallyData, }; } - - /** - * Get zkey, wasm and witgen filepaths for zkey set - * - * @param name - zkey set name - * @param useQuadraticVoting - whether to use Qv or NonQv - * @returns zkey and wasm filepaths - */ - private getZkeyFiles(name: string, useQuadraticVoting: boolean): { zkey: string; wasm: string; witgen: string } { - const root = path.resolve(process.env.COORDINATOR_ZKEY_PATH!); - const index = name.indexOf("_"); - const type = name.slice(0, index); - const params = name.slice(index + 1); - const mode = useQuadraticVoting ? "" : "NonQv"; - const filename = `${type}${mode}_${params}`; - - return { - zkey: path.resolve(root, `${filename}/${filename}.0.zkey`), - wasm: path.resolve(root, `${filename}/${filename}_js/${filename}.wasm`), - witgen: path.resolve(root, `${filename}/${filename}_cpp/${filename}`), - }; - } }