From 4f12ed332c01fd941900b1615d1f864c98563f3c Mon Sep 17 00:00:00 2001 From: xion Date: Fri, 21 Mar 2025 13:41:34 +0800 Subject: [PATCH] feat: add oauth --- package.json | 4 + rollup.config.mjs | 112 +++++++++---------- src/core-models.ts | 6 +- src/lib.ts | 2 +- src/middleware/auth.ts | 53 +++++++++ src/models/user.ts | 71 ++++++------ src/oauth/auth.ts | 6 ++ src/oauth/index.ts | 2 + src/oauth/oauth.ts | 238 +++++++++++++++++++++++++++++++++++++++++ src/oauth/salt.ts | 32 ++++++ 10 files changed, 431 insertions(+), 95 deletions(-) create mode 100644 src/middleware/auth.ts create mode 100644 src/oauth/auth.ts create mode 100644 src/oauth/index.ts create mode 100644 src/oauth/oauth.ts create mode 100644 src/oauth/salt.ts diff --git a/package.json b/package.json index 6990463..cf935a9 100644 --- a/package.json +++ b/package.json @@ -83,6 +83,10 @@ "./models": { "import": "./dist/models.mjs", "types": "./dist/models.d.ts" + }, + "./oauth": { + "import": "./dist/oauth.mjs", + "types": "./dist/oauth.d.ts" } } } \ No newline at end of file diff --git a/rollup.config.mjs b/rollup.config.mjs index 379594f..8c87eaa 100644 --- a/rollup.config.mjs +++ b/rollup.config.mjs @@ -13,20 +13,16 @@ const version = pkgs.version|| '1.0.0'; const external = [ /@kevisual\/router(\/.*)?/, //, // 路由 /@kevisual\/use-config(\/.*)?/, // - /@kevisual\/auth(\/.*)?/, // - 'sequelize', // 数据库 orm 'ioredis', // redis - 'socket.io', // socket.io 'minio', // minio - - 'pm2', - 'pg', // pg - 'pino', // pino - 'pino-pretty', // pino-pretty - '@msgpack/msgpack', // msgpack ] +const replaceConfig = { + preventAssignment: true, // 防止意外赋值 + DEV_SERVER: JSON.stringify(isDev), // 替换 process.env.NODE_ENV + VERSION: JSON.stringify(version), // 替换版本号 +} /** * @type {import('rollup').RollupOptions} */ @@ -35,50 +31,11 @@ const config = { output: { dir: './dist', entryFileNames: 'lib.mjs', - chunkFileNames: '[name]-[hash].mjs', format: 'esm', }, plugins: [ - replace({ - preventAssignment: true, // 防止意外赋值 - DEV_SERVER: JSON.stringify(isDev), // 替换 process.env.NODE_ENV - VERSION: JSON.stringify(version), // 替换版本号 - }), + replace(replaceConfig), alias({ - // only esbuild needs to be configured - entries: [ - { find: '@', replacement: path.resolve('src') }, // 配置 @ 为 src 目录 - ], - }), - resolve({ - preferBuiltins: true, // 强制优先使用内置模块 - }), - commonjs(), - esbuild({ - target: 'node22', // 目标为 Node.js 14 - minify: false, // 启用代码压缩 - tsconfig: 'tsconfig.json', - }), - json(), - ], - external: external, -}; -const configCjs = { - input: './src/lib.ts', - output: { - dir: './dist', - entryFileNames: 'lib.cjs', - chunkFileNames: '[name]-[hash].cjs', - format: 'cjs', - }, - plugins: [ - replace({ - preventAssignment: true, // 防止意外赋值 - DEV_SERVER: JSON.stringify(isDev), // 替换 process.env.NODE_ENV - VERSION: JSON.stringify(version), // 替换版本号 - }), - alias({ - // only esbuild needs to be configured entries: [ { find: '@', replacement: path.resolve('src') }, // 配置 @ 为 src 目录 ], @@ -96,6 +53,7 @@ const configCjs = { ], external: external, }; + const dtsConfig = { input: './src/lib.ts', output: { @@ -113,15 +71,10 @@ const systemConfig = [ output: { dir: './dist', entryFileNames: 'system.mjs', - chunkFileNames: '[name]-[hash].mjs', format: 'esm', }, plugins: [ - replace({ - preventAssignment: true, // 防止意外赋值 - DEV_SERVER: JSON.stringify(isDev), // 替换 process.env.NODE_ENV - VERSION: JSON.stringify(version), // 替换版本号 - }), + replace(replaceConfig), alias({ entries: [ { find: '@', replacement: path.resolve('src') }, // 配置 @ 为 src 目录 @@ -158,15 +111,10 @@ export const modelConfig = [ output: { dir: './dist', entryFileNames: 'models.mjs', - chunkFileNames: '[name]-[hash].mjs', format: 'esm', }, plugins: [ - replace({ - preventAssignment: true, // 防止意外赋值 - DEV_SERVER: JSON.stringify(isDev), // 替换 process.env.NODE_ENV - VERSION: JSON.stringify(version), // 替换版本号 - }), + replace(replaceConfig), alias({ entries: [ { find: '@', replacement: path.resolve('src') }, // 配置 @ 为 src 目录 @@ -199,4 +147,44 @@ export const modelConfig = [ ], }, ] -export default [config, dtsConfig, ...systemConfig, ...modelConfig]; +const oauthConfig = [ + { + input: './src/oauth/index.ts', + output: { + dir: './dist', + entryFileNames: 'oauth.mjs', + format: 'esm', + }, + plugins: [ + replace(replaceConfig), + alias({ + entries: [ + { find: '@', replacement: path.resolve('src') }, // 配置 @ 为 src 目录 + ], + }), + resolve({ + preferBuiltins: true, // 强制优先使用内置模块 + }), + commonjs(), + esbuild({ + target: 'node22', // 目标为 Node.js 14 + minify: false, // 启用代码压缩 + tsconfig: 'tsconfig.json', + }), + json(), + ], + }, + { + input: './src/oauth/index.ts', + output: { + dir: './dist', + entryFileNames: 'oauth.d.ts', + format: 'esm', + }, + plugins: [ + dts(), + ], + }, +] + +export default [config, dtsConfig, ...systemConfig, ...modelConfig, ...oauthConfig]; diff --git a/src/core-models.ts b/src/core-models.ts index becaea6..eb43c58 100644 --- a/src/core-models.ts +++ b/src/core-models.ts @@ -1,4 +1,8 @@ +/** + * Sequlize也不要了,只要核心的模块,sequelize自己默认已经有了 + */ import { UserServices, User, UserInit, UserModel } from './models/user.ts'; import { Org, OrgInit, OrgModel } from './models/org.ts'; - +import { addAuth } from './middleware/auth.ts'; export { User, Org, UserServices, UserInit, OrgInit, UserModel, OrgModel }; +export { addAuth }; diff --git a/src/lib.ts b/src/lib.ts index db46b4e..0641570 100644 --- a/src/lib.ts +++ b/src/lib.ts @@ -1,5 +1,5 @@ /** - * 自己初始化redis和sequelize,的模块,放到useContextKey当中 + * @description 自己初始化redis和sequelize,的模块,放到useContextKey当中 */ import { app } from './app.ts'; import { UserServices, UserInit, UserModel, User } from './models/user.ts'; diff --git a/src/middleware/auth.ts b/src/middleware/auth.ts new file mode 100644 index 0000000..adb57b3 --- /dev/null +++ b/src/middleware/auth.ts @@ -0,0 +1,53 @@ +import { User } from '../models/user.ts'; +import type { App } from '@kevisual/router'; + +/** + * 添加auth中间件, 用于验证token + * @param app + */ +export const addAuth = (app: App) => { + app + .route({ + path: 'auth', + id: 'auth', + }) + .define(async (ctx) => { + const token = ctx.query.token; + if (!token) { + app.throw(401, 'Token is required'); + } + const user = await User.getOauthUser(token); + if (!user) { + app.throw(401, 'Token is invalid'); + } + if (ctx.state) { + ctx.state.tokenUser = user; + } else { + ctx.state = { + tokenUser: user, + }; + } + }) + .addTo(app); + + app + .route({ + path: 'auth', + key: 'can', + id: 'auth-can', + }) + .define(async (ctx) => { + if (ctx.query?.token) { + const token = ctx.query.token; + const user = await User.getOauthUser(token); + if (ctx.state) { + ctx.state.tokenUser = user; + } else { + ctx.state = { + tokenUser: user, + }; + } + } + }) + .addTo(app); +}; diff --git a/src/models/user.ts b/src/models/user.ts index 44762b7..dcdae68 100644 --- a/src/models/user.ts +++ b/src/models/user.ts @@ -1,13 +1,14 @@ import { useConfig } from '@kevisual/use-config'; import { DataTypes, Model, Op, Sequelize } from 'sequelize'; -import { createToken, checkToken } from '@kevisual/auth'; -import { cryptPwd } from '@kevisual/auth'; -import { customRandom, nanoid, customAlphabet } from 'nanoid'; +import { nanoid, customAlphabet } from 'nanoid'; import { CustomError } from '@kevisual/router'; import { Org } from './org.ts'; import { useContextKey } from '@kevisual/use-config/context'; import { Redis } from 'ioredis'; +import { oauth } from '../oauth/auth.ts'; +import { cryptPwd } from '../oauth/salt.ts'; +import { OauthUser } from '../oauth/oauth.ts'; export const redis = useContextKey('redis'); const config = useConfig<{ tokenSecret: string }>(); @@ -25,6 +26,7 @@ export enum UserTypes { * 用户模型,在sequelize和Org之后初始化 */ export class User extends Model { + static oauth = oauth; declare id: string; declare username: string; declare nickname: string; // 昵称 @@ -43,44 +45,51 @@ export class User extends Model { this.tokenUser = tokenUser; } /** - * uid 是用于 orgId 的用户id 真实用户的id + * uid 是用于 orgId 的用户id, 如果uid存在,则表示是用户是组织,其中uid为真实用户 * @param uid * @returns */ - async createToken(uid?: string, loginType?: 'default' | 'plugin' | 'month' | 'season' | 'year') { + async createToken(uid?: string, loginType?: 'default' | 'plugin' | 'month' | 'season' | 'year', expand: any = {}) { const { id, username, type } = this; - let expireTime = 60 * 60 * 24 * 7; // 7 days - switch (loginType) { - case 'plugin': - expireTime = 60 * 60 * 24 * 30 * 12; // 365 days - break; - case 'month': - expireTime = 60 * 60 * 24 * 30; // 30 days - break; - case 'season': - expireTime = 60 * 60 * 24 * 30 * 3; // 90 days - break; - case 'year': - expireTime = 60 * 60 * 24 * 30 * 12; // 365 days - break; + const oauthUser: OauthUser = { + id, + username, + uid, + userId: uid || id, // 必存在,真实用户id + type: type as 'user' | 'org', + }; + if (uid) { + oauthUser.orgId = id; } - const now = new Date().getTime(); - const token = await createToken({ id, username, uid, type }, config.tokenSecret); - return { token, expireTime: now + expireTime }; + const token = await oauth.generateToken(oauthUser, { type: loginType, hasRefreshToken: true, ...expand }); + return { accessToken: token.accessToken, refreshToken: token.refreshToken, token: token.accessToken }; } + /** + * 验证token + * @param token + * @returns + */ static async verifyToken(token: string) { - const ct = await checkToken(token, config.tokenSecret); - const tokenUser = ct.payload; - return tokenUser; + return await oauth.verifyToken(token); + } + /** + * 刷新token + * @param refreshToken + * @returns + */ + static async refreshToken(refreshToken: string) { + const token = await oauth.refreshToken(refreshToken); + return { accessToken: token.accessToken, refreshToken: token.refreshToken, token: token.accessToken }; + } + static async getOauthUser(token: string) { + return await oauth.verifyToken(token); } static async getUserByToken(token: string) { - const ct = await checkToken(token, config.tokenSecret); - const tokenUser = ct.payload; - let userId = tokenUser.id; - if (tokenUser.uid) { - // 如果tokenUser.uid 存在,则表示是token是o用户的user,需要获取o的真实用户 - userId = tokenUser.uid; + const oauthUser = await oauth.verifyToken(token); + if (!oauthUser) { + throw new CustomError('Token is invalid'); } + const userId = oauthUser?.uid || oauthUser.id; const user = await User.findByPk(userId); return user; } diff --git a/src/oauth/auth.ts b/src/oauth/auth.ts new file mode 100644 index 0000000..2351910 --- /dev/null +++ b/src/oauth/auth.ts @@ -0,0 +1,6 @@ +import { OAuth, RedisTokenStore } from './oauth.ts'; +import { useContextKey } from '@kevisual/use-config/context'; +import { Redis } from 'ioredis'; + +export const redis = useContextKey('redis'); +export const oauth = useContextKey('oauth', () => new OAuth(new RedisTokenStore(redis))); diff --git a/src/oauth/index.ts b/src/oauth/index.ts new file mode 100644 index 0000000..3ebdf92 --- /dev/null +++ b/src/oauth/index.ts @@ -0,0 +1,2 @@ +export * from './oauth.ts'; +export * from './salt.ts'; \ No newline at end of file diff --git a/src/oauth/oauth.ts b/src/oauth/oauth.ts new file mode 100644 index 0000000..e6b2335 --- /dev/null +++ b/src/oauth/oauth.ts @@ -0,0 +1,238 @@ +/** + * 一个生成和验证token的模块,不使用jwt,使用redis缓存, + * token 分为两种,一种是access_token,一种是refresh_token + * + * access_token 用于验证用户是否登录,过期时间为1小时 + * refresh_token 用于刷新access_token,过期时间为7天 + * + * 生成token时,会根据用户信息生成一个access_token和refresh_token,并缓存到redis中 + * 验证token时,会根据token从redis中获取用户信息 + * 刷新token时,会根据refresh_token生成一个新的access_token和refresh_token,并缓存到redis中 + * + * 并删除旧的access_token和refresh_token + * + * 生成token的方法,使用nanoid生成一个随机字符串 + * 验证token的方法,使用redis的get方法验证token是否存在 + * + * 刷新token的方法,使用redis的set方法刷新token + * + * 缓存和获取都可以不使用redis,只是用可拓展的接口。store.get和store.set去实现。 + */ + +import { Redis } from 'ioredis'; +import { customAlphabet } from 'nanoid'; + +export const alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'; +export const randomId16 = customAlphabet(alphabet, 16); +export const randomId24 = customAlphabet(alphabet, 24); +export const randomId32 = customAlphabet(alphabet, 32); +export const randomId64 = customAlphabet(alphabet, 64); + +export type OauthUser = { + /** + * 真实用户,非org + */ + id: string; + /** + * 组织id,非必须存在 + */ + orgId?: string; + /** + * 必存在,真实用户id + */ + userId: string; + /** + * 当前用户的id,如果是org,则uid为org的id + */ + uid?: string; + username: string; + type?: 'user' | 'org'; // 用户类型,默认是user,token类型是用于token的扩展 + oauthType?: 'user' | 'token'; // 用户类型,默认是user,token类型是用于token的扩展 + oauthExpand?: UserExpand; +}; +export type UserExpand = { + createTime?: number; + refreshToken?: string; + [key: string]: any; +} & StoreSetOpts; + +type StoreSetOpts = { + loginType?: 'default' | 'plugin' | 'month' | 'season' | 'year'; // 登陆类型 'default' | 'plugin' | 'month' | 'season' | 'year' + expire?: number; // 过期时间,单位为秒 + hasRefreshToken?: boolean; + [key: string]: any; +}; +interface Store { + getObject: (key: string) => Promise; + setObject: (key: string, value: T, opts?: StoreSetOpts) => Promise; + expire: (key: string, ttl?: number) => Promise; + delObject: (key: string, value?: T) => Promise; + setToken: (value: { accessToken: string; refreshToken: string; value?: T }, opts?: StoreSetOpts) => Promise; +} +export class RedisTokenStore implements Store { + private redis: Redis; + private prefix: string = 'oauth:'; + constructor(redis: Redis, prefix?: string) { + this.redis = redis; + this.prefix = prefix || this.prefix; + } + async set(key: string, value: string, ttl?: number) { + await this.redis.set(this.prefix + key, value, 'EX', ttl); + } + async get(key: string) { + return await this.redis.get(this.prefix + key); + } + async getObject(key: string) { + try { + const value = await this.get(key); + if (!value) { + return null; + } + return JSON.parse(value); + } catch (error) { + return null; + } + } + async setObject(key: string, value: OauthUser, opts?: StoreSetOpts) { + await this.set(key, JSON.stringify(value), opts?.expire); + } + async expire(key: string, ttl?: number) { + await this.redis.expire(this.prefix + key, ttl); + } + async delObject(key: string, value?: OauthUser) { + await this.redis.del(this.prefix + key); + if (value) { + // await this.redis.del(this.prefix + value.refreshToken); + } + } + async setToken(data: { accessToken: string; refreshToken: string; value?: OauthUser }, opts?: StoreSetOpts) { + const { accessToken, refreshToken, value } = data; + let userPrefix = 'user:' + value?.id; + if (value?.orgId) { + userPrefix = 'org:' + value?.orgId + ':user:' + value?.id; + } + // 计算过期时间,根据opts.expire 和 opts.loginType + // 如果expire存在,则使用expire,否则使用opts.loginType 进行计算; + let expire = opts?.expire; + if (!expire) { + switch (opts.loginType) { + case 'month': + expire = 30 * 24 * 60 * 60; + break; + case 'season': + expire = 90 * 24 * 60 * 60; + break; + default: + expire = 25 * 60 * 60; // 默认过期时间为25小时 + } + } else { + expire = Math.min(expire, 60 * 60 * 24 * 30, 60 * 60 * 24 * 90); // 默认的过期时间最大为90天 + } + + await this.set(accessToken, JSON.stringify(value), expire); + await this.set(userPrefix + ':token:' + accessToken, accessToken, expire); + if (refreshToken) { + let refreshTokenExpire = Math.min(expire * 7, 60 * 60 * 24 * 30, 60 * 60 * 24 * 365); // 最大为一年 + // 小于7天, 则设置为7天 + if (refreshTokenExpire < 60 * 60 * 24 * 7) { + refreshTokenExpire = 60 * 60 * 24 * 7; + } + await this.set(refreshToken, JSON.stringify(value), refreshTokenExpire); + await this.set(userPrefix + ':refreshToken:' + refreshToken, refreshToken, refreshTokenExpire); + } + } +} + +export class OAuth { + private store: Store; + + constructor(store: Store) { + this.store = store; + } + /** + * 生成token + * @param user + * @returns + */ + async generateToken( + user: T, + expandOpts?: StoreSetOpts, + ): Promise<{ + accessToken: string; + refreshToken?: string; + }> { + // 拥有refreshToken 为 true 时,accessToken 为 st_ 开头,refreshToken 为 rk_开头 + // 意思是secretToken 和 secretKey的缩写 + const accessToken = expandOpts?.hasRefreshToken ? 'st_' + randomId32() : 'sk_' + randomId64(); + const refreshToken = expandOpts?.hasRefreshToken ? 'rk_' + randomId64() : null; + // 初始化 appExpand + user.oauthExpand = user.oauthExpand || {}; + if (expandOpts) { + user.oauthExpand = { + ...user.oauthExpand, + ...expandOpts, + createTime: new Date().getTime(), // + }; + if (expandOpts?.hasRefreshToken) { + user.oauthExpand.refreshToken = refreshToken; + } + } + await this.store.setToken({ accessToken, refreshToken, value: user }, expandOpts); + + return { accessToken, refreshToken }; + } + /** + * 验证token,如果token不存在,返回null + * @param token + * @returns + */ + async verifyToken(token: string) { + return await this.store.getObject(token); + } + /** + * 刷新token + * @param refreshToken + * @returns + */ + async refreshToken(refreshToken: string) { + const user = await this.store.getObject(refreshToken); + if (!user) { + // 过期 + throw new Error('Refresh token not found'); + } + const token = await this.generateToken(user, { + ...user.oauthExpand, + hasRefreshToken: true, + }); + // 删除旧的token + await this.store.delObject(refreshToken, user); + return token; + } + /** + * 刷新token的过期时间 + * expand 为扩展参数,可以扩展到user.oauthExpand中 + * @param token + * @returns + */ + async resetToken(accessToken: string, expand?: Record) { + const user = await this.store.getObject(accessToken); + if (!user) { + // 过期 + throw new Error('token not found'); + } + user.oauthExpand = user.oauthExpand || {}; + const refreshToken = user.oauthExpand.refreshToken; + if (refreshToken) { + await this.store.delObject(refreshToken, user); + } + user.oauthExpand = { + ...user.oauthExpand, + ...expand, + }; + const token = await this.generateToken(user, { + ...user.oauthExpand, + hasRefreshToken: true, + }); + return token; + } +} diff --git a/src/oauth/salt.ts b/src/oauth/salt.ts new file mode 100644 index 0000000..6cb4857 --- /dev/null +++ b/src/oauth/salt.ts @@ -0,0 +1,32 @@ +import MD5 from 'crypto-js/md5.js'; + +/** + * 生成随机盐 + * @returns + */ +export const getRandomSalt = () => { + return Math.random().toString().slice(2, 7); +}; + +/** + * 加密密码 + * @param password + * @param salt + * @returns + */ +export const cryptPwd = (password: string, salt = '') => { + const saltPassword = password + ':' + salt; + const md5 = MD5(saltPassword); + return md5.toString(); +}; + +/** + * Check password + * @param password + * @param salt + * @param md5 + * @returns + */ +export const checkPwd = (password: string, salt: string, md5: string) => { + return cryptPwd(password, salt) === md5; +};