[RFC] Return 401 for an authentication error on WebSockets (#3411)
* Return 401 for an authentication error on WebSocket * Use upgradeReq instead of a custom object
This commit is contained in:
		| @@ -95,7 +95,6 @@ const startWorker = (workerId) => { | ||||
|   const app    = express(); | ||||
|   const pgPool = new pg.Pool(Object.assign(pgConfigs[env], dbUrlToConfig(process.env.DATABASE_URL))); | ||||
|   const server = http.createServer(app); | ||||
|   const wss    = new WebSocket.Server({ server }); | ||||
|   const redisNamespace = process.env.REDIS_NAMESPACE || null; | ||||
|  | ||||
|   const redisParams = { | ||||
| @@ -186,14 +185,10 @@ const startWorker = (workerId) => { | ||||
|     }); | ||||
|   }; | ||||
|  | ||||
|   const authenticationMiddleware = (req, res, next) => { | ||||
|     if (req.method === 'OPTIONS') { | ||||
|       next(); | ||||
|       return; | ||||
|     } | ||||
|  | ||||
|     const authorization = req.get('Authorization'); | ||||
|     const accessToken = req.query.access_token; | ||||
|   const accountFromRequest = (req, next) => { | ||||
|     const authorization = req.headers.authorization; | ||||
|     const location = url.parse(req.url, true); | ||||
|     const accessToken = location.query.access_token; | ||||
|  | ||||
|     if (!authorization && !accessToken) { | ||||
|       const err = new Error('Missing access token'); | ||||
| @@ -208,6 +203,26 @@ const startWorker = (workerId) => { | ||||
|     accountFromToken(token, req, next); | ||||
|   }; | ||||
|  | ||||
|   const wsVerifyClient = (info, cb) => { | ||||
|     accountFromRequest(info.req, err => { | ||||
|       if (!err) { | ||||
|         cb(true, undefined, undefined); | ||||
|       } else { | ||||
|         log.error(info.req.requestId, err.toString()); | ||||
|         cb(false, 401, 'Unauthorized'); | ||||
|       } | ||||
|     }); | ||||
|   }; | ||||
|  | ||||
|   const authenticationMiddleware = (req, res, next) => { | ||||
|     if (req.method === 'OPTIONS') { | ||||
|       next(); | ||||
|       return; | ||||
|     } | ||||
|  | ||||
|     accountFromRequest(req, next); | ||||
|   }; | ||||
|  | ||||
|   const errorMiddleware = (err, req, res, next) => { | ||||
|     log.error(req.requestId, err.toString()); | ||||
|     res.writeHead(err.statusCode || 500, { 'Content-Type': 'application/json' }); | ||||
| @@ -352,10 +367,12 @@ const startWorker = (workerId) => { | ||||
|     streamFrom(`timeline:hashtag:${req.query.tag}:local`, req, streamToHttp(req, res), streamHttpEnd(req), true); | ||||
|   }); | ||||
|  | ||||
|   const wss    = new WebSocket.Server({ server, verifyClient: wsVerifyClient }); | ||||
|  | ||||
|   wss.on('connection', ws => { | ||||
|     const location = url.parse(ws.upgradeReq.url, true); | ||||
|     const token    = location.query.access_token; | ||||
|     const req      = { requestId: uuid.v4() }; | ||||
|     const req      = ws.upgradeReq; | ||||
|     const location = url.parse(req.url, true); | ||||
|     req.requestId  = uuid.v4(); | ||||
|  | ||||
|     ws.isAlive = true; | ||||
|  | ||||
| @@ -363,33 +380,25 @@ const startWorker = (workerId) => { | ||||
|       ws.isAlive = true; | ||||
|     }); | ||||
|  | ||||
|     accountFromToken(token, req, err => { | ||||
|       if (err) { | ||||
|         log.error(req.requestId, err); | ||||
|         ws.close(); | ||||
|         return; | ||||
|       } | ||||
|  | ||||
|       switch(location.query.stream) { | ||||
|       case 'user': | ||||
|         streamFrom(`timeline:${req.accountId}`, req, streamToWs(req, ws), streamWsEnd(req, ws)); | ||||
|         break; | ||||
|       case 'public': | ||||
|         streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true); | ||||
|         break; | ||||
|       case 'public:local': | ||||
|         streamFrom('timeline:public:local', req, streamToWs(req, ws), streamWsEnd(req, ws), true); | ||||
|         break; | ||||
|       case 'hashtag': | ||||
|         streamFrom(`timeline:hashtag:${location.query.tag}`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); | ||||
|         break; | ||||
|       case 'hashtag:local': | ||||
|         streamFrom(`timeline:hashtag:${location.query.tag}:local`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); | ||||
|         break; | ||||
|       default: | ||||
|         ws.close(); | ||||
|       } | ||||
|     }); | ||||
|     switch(location.query.stream) { | ||||
|     case 'user': | ||||
|       streamFrom(`timeline:${req.accountId}`, req, streamToWs(req, ws), streamWsEnd(req, ws)); | ||||
|       break; | ||||
|     case 'public': | ||||
|       streamFrom('timeline:public', req, streamToWs(req, ws), streamWsEnd(req, ws), true); | ||||
|       break; | ||||
|     case 'public:local': | ||||
|       streamFrom('timeline:public:local', req, streamToWs(req, ws), streamWsEnd(req, ws), true); | ||||
|       break; | ||||
|     case 'hashtag': | ||||
|       streamFrom(`timeline:hashtag:${location.query.tag}`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); | ||||
|       break; | ||||
|     case 'hashtag:local': | ||||
|       streamFrom(`timeline:hashtag:${location.query.tag}:local`, req, streamToWs(req, ws), streamWsEnd(req, ws), true); | ||||
|       break; | ||||
|     default: | ||||
|       ws.close(); | ||||
|     } | ||||
|   }); | ||||
|  | ||||
|   const wsInterval = setInterval(() => { | ||||
|   | ||||
		Reference in New Issue
	
	Block a user