You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
tailchat/server/mixins/socketio.mixin.ts

574 lines
16 KiB
TypeScript

import { Server as SocketServer } from 'socket.io';
import { createAdapter } from '@socket.io/redis-adapter';
import { instrument } from '@socket.io/admin-ui';
import RedisClient from 'ioredis';
import {
TcService,
TcContext,
UserJWTPayload,
parseLanguageFromHead,
config,
PureContext,
PureService,
PureServiceSchema,
Utils,
Errors,
} from 'tailchat-server-sdk';
import _ from 'lodash';
import { ServiceUnavailableError } from 'tailchat-server-sdk';
import { isValidStr } from '../lib/utils';
import bcrypt from 'bcryptjs';
import msgpackParser from 'socket.io-msgpack-parser';
const blacklist: (string | RegExp)[] = ['gateway.*'];
function checkBlacklist(eventName: string): boolean {
return blacklist.some((item) => {
if (_.isString(item)) {
return Utils.match(eventName, item);
} else if (_.isRegExp(item)) {
return item.test(eventName);
}
});
}
/**
* socket
*/
function buildUserRoomId(userId: string) {
return `u-${userId}`;
}
/**
* socket online
*/
function buildUserOnlineKey(userId: string) {
return `tailchat-socketio.online:${userId}`;
}
const expiredTime = 1 * 24 * 60 * 60; // 1天
interface SocketIOService extends PureService {
io: SocketServer;
redis: RedisClient.Redis;
socketCloseCallbacks: (() => Promise<unknown>)[];
}
interface TcSocketIOServiceOptions {
/**
* token
*/
userAuth: (token: string) => Promise<UserJWTPayload>;
/**
* msgpack
*/
disableMsgpack?: boolean;
}
/**
* Socket IO mixin
*/
export const TcSocketIOService = (
options: TcSocketIOServiceOptions
): Partial<PureServiceSchema> => {
const { userAuth } = options;
const schema: Partial<PureServiceSchema> = {
created(this: SocketIOService) {
this.broker.metrics.register({
type: 'gauge',
name: 'tailchat.socketio.online.count',
labelNames: ['nodeId'],
description: 'Number of online user',
});
},
async started(this: SocketIOService) {
if (!this.io) {
this.initSocketIO();
}
this.logger.info('SocketIO service started');
const io: SocketServer = this.io;
if (!config.redisUrl) {
throw new Errors.MoleculerClientError(
'SocketIO service failed to start, environment variables are required: `REDIS_URL`'
);
}
this.socketCloseCallbacks = []; // socketio服务关闭时需要执行的回调
const pubClient = new RedisClient(config.redisUrl, {
retryStrategy(times) {
const delay = Math.min(times * 50, 2000);
return delay;
},
});
const subClient = pubClient.duplicate();
io.adapter(
createAdapter(pubClient, subClient, {
key: 'tailchat-socket',
})
);
this.socketCloseCallbacks.push(async () => {
pubClient.disconnect(false);
subClient.disconnect(false);
});
this.logger.info('SocketIO is using Redis Adapter');
this.redis = pubClient;
io.use(async (socket, next) => {
// 授权
try {
if (
config.enableSocketAdmin &&
socket.handshake.headers['origin'] === 'https://admin.socket.io'
) {
// 如果是通过 admin-ui 访问的socket.io 直接链接
next();
return;
}
const token = socket.handshake.auth['token'];
if (typeof token !== 'string') {
throw new Errors.MoleculerError('Token cannot be empty');
}
const user: UserJWTPayload = await userAuth(token);
if (!(user && user._id)) {
throw new Error('Token invalid');
}
this.logger.info('[Socket] Authenticated via JWT: ', user.nickname);
socket.data.user = user;
socket.data.token = token;
socket.data.userId = user._id;
next();
} catch (e) {
return next(e);
}
});
this.io.on('connection', (socket) => {
if (typeof socket.data.userId !== 'string') {
// 不应该进入的逻辑
return;
}
this.broker.metrics.increment(
'tailchat.socketio.online.count',
{
nodeId: this.broker.nodeID,
},
1
);
const userId = socket.data.userId;
pubClient
.hset(buildUserOnlineKey(userId), socket.id, this.broker.nodeID)
.then(() => {
pubClient.expire(buildUserOnlineKey(userId), expiredTime);
});
// 加入自己userId所生产的id
socket.join(buildUserRoomId(userId));
/**
* 线线
*/
const removeOnlineMapping = () => {
return pubClient.hdel(buildUserOnlineKey(userId), socket.id);
};
this.socketCloseCallbacks.push(removeOnlineMapping);
// 用户断线
socket.on('disconnecting', (reason) => {
this.logger.info(
'Socket Disconnect:',
reason,
'| Rooms:',
socket.rooms
);
this.broker.metrics.decrement(
'tailchat.socketio.online.count',
{
nodeId: this.broker.nodeID,
},
1
);
removeOnlineMapping();
_.pull(this.socketCloseCallbacks, removeOnlineMapping);
});
// 连接时
socket.onAny(
async (
eventName: string,
eventData: unknown,
cb: (data: unknown) => void
) => {
this.logger.info(
'[SocketIO]',
eventName,
'<=',
JSON.stringify(eventData)
);
// 检测是否允许调用
if (checkBlacklist(eventName)) {
const message = 'Not allowed request';
this.logger.warn('[SocketIO]', '=>', message);
cb({
result: false,
message,
});
return;
}
// 接受任意消息, 并调用action
try {
const endpoint = this.broker.findNextActionEndpoint(eventName);
if (endpoint instanceof Error) {
if (endpoint instanceof Errors.ServiceNotFoundError) {
throw new ServiceUnavailableError();
}
throw endpoint;
}
if (
typeof endpoint.action.visibility === 'string' &&
endpoint.action.visibility !== 'published'
) {
throw new Errors.ServiceNotFoundError({
visibility: endpoint.action.visibility,
action: eventName,
});
}
if (endpoint.action.disableSocket === true) {
throw new Errors.ServiceNotFoundError({
disableSocket: true,
action: eventName,
});
}
/**
* TODO:
* molecular
*
*/
const language = parseLanguageFromHead(
socket.handshake.headers['accept-language']
);
const data = await this.broker.call(eventName, eventData, {
meta: {
...socket.data,
socketId: socket.id,
language,
},
});
if (typeof cb === 'function') {
this.logger.debug(
'[SocketIO]',
eventName,
'=>',
JSON.stringify(data)
);
cb({ result: true, data });
}
} catch (err: unknown) {
const message = _.get(err, 'message', 'Service Error');
this.logger.debug('[SocketIO]', eventName, '=>', message);
this.logger.error('[SocketIO]', err);
cb({
result: false,
message,
});
}
}
);
});
},
async stopped(this: SocketIOService) {
if (this.io) {
this.io.close();
await Promise.all(this.socketCloseCallbacks.map((fn) => fn()));
}
this.logger.info('断开所有连接');
},
actions: {
joinRoom: {
visibility: 'public',
params: {
roomIds: 'array',
userId: [{ type: 'string', optional: true }], // 可选, 如果不填则为当前socket的id
},
async handler(
this: TcService,
ctx: TcContext<{ roomIds: string[]; userId?: string }>
) {
const roomIds = ctx.params.roomIds;
const userId = ctx.params.userId;
const searchId = isValidStr(userId)
? buildUserRoomId(userId)
: ctx.meta.socketId;
if (typeof searchId !== 'string') {
throw new Error(
'Unable to join the room, the query condition is invalid, please contact the administrator'
);
}
if (!Array.isArray(roomIds)) {
throw new Error(
'Unable to join the room, the parameter must be an array'
);
}
// 获取远程socket链接并加入
const io: SocketServer = this.io;
const remoteSockets = await io.in(searchId).fetchSockets();
if (remoteSockets.length === 0) {
this.logger.warn(
'Unable to join the room, unable to find the current socket link:',
searchId
);
return;
}
remoteSockets.forEach((rs) =>
rs.join(
roomIds.map(String) // 强制确保roomId为字符串防止出现传个objectId类型的数据过来
)
);
},
},
leaveRoom: {
visibility: 'public',
params: {
roomIds: 'array',
userId: [{ type: 'string', optional: true }],
},
async handler(
this: TcService,
ctx: TcContext<{ roomIds: string[]; userId?: string }>
) {
const roomIds = ctx.params.roomIds;
const userId = ctx.params.userId;
const searchId = isValidStr(userId)
? buildUserRoomId(userId)
: ctx.meta.socketId;
if (typeof searchId !== 'string') {
this.logger.error(
'Unable to leave the room, the current socket connection does not exist'
);
return;
}
// 获取远程socket链接并离开
const io: SocketServer = this.io;
const remoteSockets = await io.in(searchId).fetchSockets();
if (remoteSockets.length === 0) {
this.logger.error(
`Can't leave room, can't find current socket link`
);
return;
}
remoteSockets.forEach((rs) => {
roomIds.forEach((roomId) => {
rs.leave(roomId);
});
});
},
},
/**
* userId
*/
fetchUserSocketIds: {
visibility: 'public',
params: {
userId: 'string',
},
async handler(
this: TcService,
ctx: TcContext<{ userId: string }>
): Promise<string[]> {
const userId = ctx.params.userId;
const io: SocketServer = this.io;
const remoteSockets = await io
.in(buildUserRoomId(userId))
.fetchSockets();
return remoteSockets.map((remoteSocket) => remoteSocket.id);
},
},
/**
* userIdtoken
*/
getUserSocketToken: {
visibility: 'public',
params: {
userId: 'string',
},
async handler(
this: TcService,
ctx: TcContext<{ userId: string }>
): Promise<string[]> {
const userId = ctx.params.userId;
const io: SocketServer = this.io;
const remoteSockets = await io
.in(buildUserRoomId(userId))
.fetchSockets();
return remoteSockets.map((remoteSocket) => remoteSocket.data.token);
},
},
/**
*
*/
tickUser: {
visibility: 'public',
params: {
userId: 'string',
},
async handler(this: TcService, ctx: TcContext<{ userId: string }>) {
const userId = ctx.params.userId;
const io: SocketServer = this.io;
const remoteSockets = await io
.in(buildUserRoomId(userId))
.fetchSockets();
remoteSockets.forEach((remoteSocket) => {
remoteSocket.disconnect(true);
});
},
},
/**
*
*/
notify: {
visibility: 'public',
params: {
type: 'string',
target: [
{ type: 'string', optional: true },
{ type: 'array', optional: true },
],
eventName: 'string',
eventData: 'any',
},
handler(
this: TcService,
ctx: PureContext<{
type: string;
target: string | string[];
eventName: string;
eventData: any;
}>
) {
const { type, target, eventName, eventData } = ctx.params;
const io: SocketServer = this.io;
if (type === 'unicast' && typeof target === 'string') {
// 单播
io.to(buildUserRoomId(target)).emit(eventName, eventData);
} else if (type === 'listcast' && Array.isArray(target)) {
// 列播
io.to(target.map((t) => buildUserRoomId(t))).emit(
eventName,
eventData
);
} else if (type === 'roomcast') {
// 组播
io.to(target).emit(eventName, eventData);
} else if (type === 'broadcast') {
// 广播
io.emit(eventName, eventData);
} else {
this.logger.warn(
'[SocketIO]',
'Unknown notify type or target',
type,
target
);
}
},
},
/**
* 线
*/
checkUserOnline: {
params: {
userIds: 'array',
},
async handler(
this: TcService,
ctx: PureContext<{ userIds: string[] }>
) {
const userIds = ctx.params.userIds;
const status = await Promise.all(
userIds.map((userId) =>
(this.redis as RedisClient.Redis).exists(
buildUserOnlineKey(userId)
)
)
);
return status.map((d) => Boolean(d));
},
},
},
methods: {
initSocketIO() {
if (!this.server) {
throw new Errors.ServiceNotAvailableError(
'Need to use with [ApiGatewayMixin]'
);
}
this.io = new SocketServer(this.server, {
serveClient: false,
transports: ['websocket'],
cors: {
origin: '*',
methods: ['GET', 'POST'],
},
parser: options.disableMsgpack ? undefined : msgpackParser,
});
if (
isValidStr(process.env.ADMIN_USER) &&
isValidStr(process.env.ADMIN_PASS)
) {
this.logger.info('****************************************');
this.logger.info(`Detected that Admin management is enabled`);
this.logger.info('****************************************');
instrument(this.io, {
auth: {
type: 'basic',
username: process.env.ADMIN_USER,
password: bcrypt.hashSync(process.env.ADMIN_PASS, 10),
},
});
}
},
},
};
return schema;
};