diff --git a/__tests__/integration/core-p2p/socket-server/peer.test.ts b/__tests__/integration/core-p2p/socket-server/peer.test.ts index 9af9b22488..9c4334d0ee 100644 --- a/__tests__/integration/core-p2p/socket-server/peer.test.ts +++ b/__tests__/integration/core-p2p/socket-server/peer.test.ts @@ -6,6 +6,7 @@ import { defaults } from "../mocks/p2p-options"; import { Blocks, Managers } from "@arkecosystem/crypto/src"; import unitnetMilestones from "@arkecosystem/crypto/src/networks/unitnet/milestones.json"; import delay from "delay"; +import net from "net"; import SocketCluster from "socketcluster"; import socketCluster from "socketcluster-client"; import { startSocketServer } from "../../../../packages/core-p2p/src/socket-server"; @@ -533,6 +534,36 @@ describe("Peer socket endpoint", () => { send(stringifiedPayload); await delay(500); expect(socket.state).not.toBe("open"); + + // kill workers to reset ipLastError (or we won't pass handshake for 1 minute) + server.killWorkers({ immediate: true }); + await delay(2000); // give time to workers to respawn + }); + + it("should close the connection when the HTTP url is not valid", async () => { + const socket = new net.Socket(); + socket.connect(4007, "127.0.0.1", function() { + socket.write("GET /invalid/ HTTP/1.0\r\n\r\n"); + }); + await delay(500); + expect(socket.destroyed).toBe(true); + + socket.connect(4007, "127.0.0.1"); + await delay(500); + expect(socket.destroyed).toBe(true); + + // kill workers to reset ipLastError (or we won't pass handshake for 1 minute) + server.killWorkers({ immediate: true }); + await delay(2000); // give time to workers to respawn + }); + + it("should close the connection if the initial HTTP request is not processed within 2 seconds", async () => { + const socket = new net.Socket(); + socket.connect(4007, "127.0.0.1"); + await delay(500); + expect(socket.destroyed).toBe(false); + await delay(2000); + expect(socket.destroyed).toBe(true); }); }); }); diff --git a/__tests__/unit/core-p2p/socket-server/worker.test.ts b/__tests__/unit/core-p2p/socket-server/worker.test.ts index 20f1d9db10..bd7a0538cc 100644 --- a/__tests__/unit/core-p2p/socket-server/worker.test.ts +++ b/__tests__/unit/core-p2p/socket-server/worker.test.ts @@ -7,7 +7,15 @@ import { Worker } from "../../../../packages/core-p2p/src/socket-server/worker"; const worker = new Worker(); // @ts-ignore -worker.scServer.wsServer = { on: () => undefined }; +worker.scServer.wsServer = { + on: () => undefined, + _server: { + on: () => undefined, + }, +}; +worker.httpServer = { + on: () => undefined, +} as any; worker.scServer.setCodecEngine = codec => undefined; describe("Worker", () => { diff --git a/packages/core-p2p/src/socket-server/worker.ts b/packages/core-p2p/src/socket-server/worker.ts index f5b393cd89..2414c39887 100644 --- a/packages/core-p2p/src/socket-server/worker.ts +++ b/packages/core-p2p/src/socket-server/worker.ts @@ -40,6 +40,9 @@ export class Worker extends SCWorker { await this.loadHandlers(); + // @ts-ignore + this.scServer.wsServer._server.timeout = 2000; + // @ts-ignore this.scServer.wsServer.on("connection", (ws, req) => { const clients = [...Object.values(this.scServer.clients), ...Object.values(this.scServer.pendingClients)]; @@ -52,10 +55,17 @@ export class Worker extends SCWorker { } this.handlePayload(ws, req); }); + // @ts-ignore + this.httpServer.on("request", req => { + // @ts-ignore + if (req.method !== "GET" || req.url !== this.scServer.wsServer.options.path) { + this.setErrorForIpAndTerminate(req); + req.destroy(); + } + }); + // @ts-ignore + this.scServer.wsServer._server.on("connection", socket => this.handleSocket(socket)); this.scServer.on("connection", socket => this.handleConnection(socket)); - this.scServer.addMiddleware(this.scServer.MIDDLEWARE_HANDSHAKE_WS, (req, next) => - this.handleHandshake(req, next), - ); this.scServer.addMiddleware(this.scServer.MIDDLEWARE_EMIT, (req, next) => this.handleEmit(req, next)); } @@ -89,15 +99,15 @@ export class Worker extends SCWorker { ws.removeAllListeners("ping"); ws.removeAllListeners("pong"); ws.prependListener("ping", () => { - this.setErrorForIpAndTerminate(ws, req); + this.setErrorForIpAndTerminate(req, ws); }); ws.prependListener("pong", () => { - this.setErrorForIpAndTerminate(ws, req); + this.setErrorForIpAndTerminate(req, ws); }); ws.prependListener("error", error => { if (error instanceof RangeError) { - this.setErrorForIpAndTerminate(ws, req); + this.setErrorForIpAndTerminate(req, ws); } }); @@ -105,17 +115,17 @@ export class Worker extends SCWorker { ws.removeAllListeners("message"); ws.prependListener("message", message => { if (ws._disconnected) { - return this.setErrorForIpAndTerminate(ws, req); + return this.setErrorForIpAndTerminate(req, ws); } else if (message === "#2") { const timeNow: number = new Date().getTime() / 1000; if (ws._lastPingTime && timeNow - ws._lastPingTime < 1) { - return this.setErrorForIpAndTerminate(ws, req); + return this.setErrorForIpAndTerminate(req, ws); } ws._lastPingTime = timeNow; } else if (message.length < 10) { // except for #2 message, we should have JSON with some required properties // (see below) which implies that message length should be longer than 10 chars - return this.setErrorForIpAndTerminate(ws, req); + return this.setErrorForIpAndTerminate(req, ws); } else { try { const parsed = JSON.parse(message); @@ -123,7 +133,7 @@ export class Worker extends SCWorker { ws._disconnected = true; } else if (parsed.event === "#handshake") { if (ws._handshake) { - return this.setErrorForIpAndTerminate(ws, req); + return this.setErrorForIpAndTerminate(req, ws); } ws._handshake = true; } else if ( @@ -134,10 +144,10 @@ export class Worker extends SCWorker { (parsed.event === "#disconnect" && typeof parsed.cid !== "undefined")) || !this.handlers.includes(parsed.event) ) { - return this.setErrorForIpAndTerminate(ws, req); + return this.setErrorForIpAndTerminate(req, ws); } } catch (error) { - return this.setErrorForIpAndTerminate(ws, req); + return this.setErrorForIpAndTerminate(req, ws); } } @@ -209,9 +219,11 @@ export class Worker extends SCWorker { return false; } - private setErrorForIpAndTerminate(ws, req): void { + private setErrorForIpAndTerminate(req, ws?): void { this.ipLastError[req.socket.remoteAddress] = Date.now(); - ws.terminate(); + if (ws) { + ws.terminate(); + } } private async handleConnection(socket): Promise { @@ -227,10 +239,10 @@ export class Worker extends SCWorker { } } - private async handleHandshake(req, next): Promise { - const ip = req.socket.remoteAddress; - if (this.ipLastError[ip] && this.ipLastError[ip] > Date.now() - MINUTE_IN_MILLISECONDS) { - req.socket.destroy(); + private async handleSocket(socket): Promise { + const ip = socket.remoteAddress; + if (!ip || (this.ipLastError[ip] && this.ipLastError[ip] > Date.now() - MINUTE_IN_MILLISECONDS)) { + socket.destroy(); return; } @@ -243,7 +255,7 @@ export class Worker extends SCWorker { const isBlacklisted: boolean = (this.config.blacklist || []).includes(ip); if (data.blocked || isBlacklisted) { - req.socket.destroy(); + socket.destroy(); return; } @@ -252,11 +264,9 @@ export class Worker extends SCWorker { client => cidr(`${client.remoteAddress}/24`) === cidrRemoteAddress, ); if (sameSubnetSockets.length > this.config.maxSameSubnetPeers) { - req.socket.destroy(); + socket.destroy(); return; } - - next(); } private async handleEmit(req, next): Promise {