Sfoglia il codice sorgente

refactor(controllers): 重构 OCR 相关控制器的权限验证和数据查询逻辑

- 优化了 OcrChannelController、OcrDevicesController 和 OcrRecordController 的权限验证逻辑
- 重构了这些控制器中的数据查询逻辑,提高了代码复用性和可维护性
- 新增 UserService.findReferredUsers 方法,用于获取指定用户的下级用户
- 改进了 API 用户的数据脱敏处理,确保敏感信息不被泄露
- 优化了分页查询的处理,提高了响应速度和用户体验
wui 6 mesi fa
parent
commit
f14ec59384

+ 99 - 10
app/Controllers/Http/OcrChannelController.ts

@@ -3,16 +3,90 @@ import PaginationService from 'App/Services/PaginationService'
 import OcrChannel from 'App/Models/OcrChannel'
 import { schema } from '@ioc:Adonis/Core/Validator'
 import { DateTime } from 'luxon'
-import OcrDevice from 'App/Models/OcrDevice'
-import OcrRecord from 'App/Models/OcrRecord'
-import * as console from 'node:console'
 import Database from '@ioc:Adonis/Lucid/Database'
+import User, { UserRoles } from 'App/Models/User'
+import UserService from 'App/Services/UserService'
 
 export default class OcrChannelController {
     private paginationService = new PaginationService(OcrChannel)
 
-    public async index({ request, bouncer }: HttpContextContract) {
-        return await this.paginationService.paginate(request.all())
+    public async index({ request, response, auth }: HttpContextContract) {
+        try {
+            const { page = 1, size = 20, id, deviceId, channel } = request.qs()
+            const user = auth.user
+            const role = user?.$attributes?.role
+
+            if (!role) {
+                return response.forbidden({
+                    error: 'Forbidden',
+                    message: 'Unauthorized access'
+                })
+            }
+
+            let userChannel: string | undefined
+
+            if (role === UserRoles.Api) {
+                userChannel = user!.username
+            } else if (['admin', 'operator'].includes(role)) {
+                const apiUsers = await UserService.findReferredUsers(user!.id)
+                const allowedChannels = apiUsers.map((user) => user.username as string)
+
+                // 如果请求中指定了channel,检查是否在允许的channel列表中
+                if (channel) {
+                    if (!allowedChannels.includes(channel)) {
+                        return response.ok({
+                            data: [],
+                            meta: {
+                                total: 0,
+                                per_page: Number(size),
+                                current_page: Number(page),
+                                last_page: 0,
+                                first_page: 1,
+                                first_page_url: '/?page=1',
+                                last_page_url: '/?page=0',
+                                next_page_url: null,
+                                previous_page_url: null
+                            }
+                        })
+                    }
+                    userChannel = channel
+                } else {
+                    userChannel = allowedChannels.join(',')
+                }
+            } else {
+                return response.forbidden({
+                    error: 'Forbidden',
+                    message: 'You are not authorized to access this resource'
+                })
+            }
+
+            // 构建查询条件
+            const query = OcrChannel.query()
+
+            if (id) {
+                query.where('id', id)
+            }
+
+            if (deviceId) {
+                query.whereIn('name', function (builder) {
+                    builder.select('channel').from('ocr_devices').where('id', deviceId)
+                })
+            }
+
+            if (userChannel) {
+                query.whereIn('name', userChannel.split(','))
+            }
+
+            // 执行分页查询
+            const result = await query.orderBy('created_at', 'desc').paginate(page, size)
+
+            return response.ok(result)
+        } catch (error) {
+            return response.internalServerError({
+                message: '获取OCR渠道列表时发生错误',
+                error: error.message
+            })
+        }
     }
 
     public async store({ request, bouncer }: HttpContextContract) {
@@ -157,15 +231,30 @@ export default class OcrChannelController {
             const user = auth.user
             const role = user?.$attributes?.role
 
-            // 验证管理员或操作员权限
-            if (role !== 'admin' && role !== 'operator') {
+            if (!role) {
+                return response.forbidden({
+                    error: 'Forbidden',
+                    message: 'Unauthorized access'
+                })
+            }
+
+            let channel: string = ''
+            if (['admin', 'operator'].includes(role)) {
+                const apiUsers = await UserService.findReferredUsers(user!.id)
+                const allowedChannels = apiUsers.map((user) => user.username as string)
+                channel = allowedChannels.join(',')
+            } else {
                 return response.forbidden({
-                    message: 'unauthorized'
+                    error: 'Forbidden',
+                    message: 'You are not authorized to access this resource'
                 })
             }
 
-            // 获取所有渠道名称
-            const channels = await OcrChannel.query().select('name').orderBy('name', 'asc')
+            // 获取指定渠道名称
+            const channels = await OcrChannel.query()
+                .select('name')
+                .whereIn('name', channel.split(','))
+                .orderBy('name', 'asc')
 
             return response.ok({
                 data: channels.map((channel) => channel.name)

+ 156 - 50
app/Controllers/Http/OcrDevicesController.ts

@@ -7,19 +7,87 @@ import OcrChannel from 'App/Models/OcrChannel'
 import { DateTime } from 'luxon'
 import * as console from 'node:console'
 import Database from '@ioc:Adonis/Lucid/Database'
+import { UserRoles } from 'App/Models/User'
+import UserService from 'App/Services/UserService'
 
 export default class OcrDevicesController {
     private paginationService = new PaginationService(OcrDevice)
 
-    public async index({ request, auth }: HttpContextContract) {
-        const user = auth.user
-        const isApiUser = user?.$attributes?.role === 'api'
+    public async index({ request, response, auth }: HttpContextContract) {
+        try {
+            const { page = 1, size = 20, id, deviceId, channel } = request.qs()
+            const user = auth.user
+            const role = user?.$attributes?.role
+
+            if (!role) {
+                return response.forbidden({
+                    error: 'Forbidden',
+                    message: 'Unauthorized access'
+                })
+            }
+
+            let userChannel: string | undefined
+
+            if (role === UserRoles.Api) {
+                userChannel = user!.username
+            } else if (['admin', 'operator'].includes(role)) {
+                const apiUsers = await UserService.findReferredUsers(user!.id)
+                const allowedChannels = apiUsers.map((user) => user.username as string)
+
+                // 如果请求中指定了channel,检查是否在允许的channel列表中
+                if (channel) {
+                    if (!allowedChannels.includes(channel)) {
+                        return response.ok({
+                            data: [],
+                            meta: {
+                                total: 0,
+                                per_page: Number(size),
+                                current_page: Number(page),
+                                last_page: 0,
+                                first_page: 1,
+                                first_page_url: '/?page=1',
+                                last_page_url: '/?page=0',
+                                next_page_url: null,
+                                previous_page_url: null
+                            }
+                        })
+                    }
+                    userChannel = channel
+                } else {
+                    userChannel = allowedChannels.join(',')
+                }
+            } else {
+                return response.forbidden({
+                    error: 'Forbidden',
+                    message: 'You are not authorized to access this resource'
+                })
+            }
 
-        const requestData = request.all()
-        if (isApiUser) {
-            requestData.channel = user.username
+            // 构建查询条件
+            const query = OcrDevice.query()
+
+            if (id) {
+                query.where('id', id)
+            }
+
+            if (deviceId) {
+                query.where('id', deviceId)
+            }
+
+            if (userChannel) {
+                query.whereIn('channel', userChannel.split(','))
+            }
+
+            // 执行分页查询
+            const result = await query.orderBy('created_at', 'desc').paginate(page, size)
+
+            return response.ok(result)
+        } catch (error) {
+            return response.internalServerError({
+                message: '获取OCR设备列表时发生错误',
+                error: error.message
+            })
         }
-        return await this.paginationService.paginate(request.all())
     }
 
     public async store({ request, bouncer }: HttpContextContract) {
@@ -126,7 +194,44 @@ export default class OcrDevicesController {
     public async getStatistics({ request, response, auth }: HttpContextContract) {
         try {
             const user = auth.user
-            const isApiUser = user?.$attributes?.role === 'api'
+            const role = user?.$attributes?.role
+
+            if (!role) {
+                return response.forbidden({
+                    error: 'Forbidden',
+                    message: 'Unauthorized access'
+                })
+            }
+
+            let userChannel: string | undefined
+
+            if (role === UserRoles.Api) {
+                userChannel = user!.username
+            } else if (['admin', 'operator'].includes(role)) {
+                const apiUsers = await UserService.findReferredUsers(user!.id)
+                const allowedChannels = apiUsers.map((user) => user.username as string)
+
+                // 如果请求中指定了channel,检查是否在允许的channel列表中
+                const requestChannel = request.input('channel')
+                if (requestChannel) {
+                    if (!allowedChannels.includes(requestChannel)) {
+                        return response.ok({
+                            dates: [],
+                            total: [],
+                            scanned: [],
+                            deviceCount: []
+                        })
+                    }
+                    userChannel = requestChannel
+                } else {
+                    userChannel = allowedChannels.join(',')
+                }
+            } else {
+                return response.forbidden({
+                    error: 'Forbidden',
+                    message: 'You are not authorized to access this resource'
+                })
+            }
 
             // 获取开始日期和结束日期,默认为不包括今天的最近七天
             let startDate = request.input(
@@ -200,14 +305,6 @@ export default class OcrDevicesController {
             }
 
             // 准备查询条件
-            let channelCondition = {}
-            if (isApiUser) {
-                channelCondition = { channel: user.username }
-            } else if (request.input('channel')) {
-                channelCondition = { channel: request.input('channel') }
-            }
-
-            // 使用SQL直接获取每日设备统计数据
             const deviceStatsQuery = Database.from('ocr_devices')
                 .select(
                     Database.raw("DATE_FORMAT(created_at, '%Y-%m-%d') as date"),
@@ -215,14 +312,10 @@ export default class OcrDevicesController {
                     Database.raw('SUM(scanned) as scanned_count')
                 )
                 .whereBetween('created_at', [startDate, endDate])
+                .whereIn('channel', userChannel!.split(','))
                 .groupBy('date')
                 .orderBy('date', 'asc')
 
-            // 添加渠道条件
-            if (Object.keys(channelCondition).length > 0) {
-                deviceStatsQuery.where(channelCondition)
-            }
-
             const deviceStats = await deviceStatsQuery
 
             // 使用SQL直接获取每日记录统计数据
@@ -232,14 +325,10 @@ export default class OcrDevicesController {
                     Database.raw('COUNT(id) as record_count')
                 )
                 .whereBetween('created_at', [startDate, endDate])
+                .whereIn('channel', userChannel!.split(','))
                 .groupBy('date')
                 .orderBy('date', 'asc')
 
-            // 添加渠道条件
-            if (Object.keys(channelCondition).length > 0) {
-                recordStatsQuery.where(channelCondition)
-            }
-
             const recordStats = await recordStatsQuery
 
             // 合并结果
@@ -291,18 +380,43 @@ export default class OcrDevicesController {
     public async getTodayStatistics({ request, response, auth }: HttpContextContract) {
         try {
             const user = auth.user
-            const isApiUser = user?.$attributes?.role === 'api'
-            const deviceQuery = OcrDevice.query()
+            const role = user?.$attributes?.role
 
-            // 如果是API用户,强制使用其username作为channel
-            if (isApiUser) {
-                deviceQuery.where('channel', user.username)
-            } else {
-                // 如果不是API用户,则使用请求中的channel参数
-                const channel = request.input('channel')
-                if (channel) {
-                    deviceQuery.where('channel', channel)
+            if (!role) {
+                return response.forbidden({
+                    error: 'Forbidden',
+                    message: 'Unauthorized access'
+                })
+            }
+
+            let userChannel: string | undefined
+
+            if (role === UserRoles.Api) {
+                userChannel = user!.username
+            } else if (['admin', 'operator'].includes(role)) {
+                const apiUsers = await UserService.findReferredUsers(user!.id)
+                const allowedChannels = apiUsers.map((user) => user.username as string)
+
+                // 如果请求中指定了channel,检查是否在允许的channel列表中
+                const requestChannel = request.input('channel')
+                if (requestChannel) {
+                    if (!allowedChannels.includes(requestChannel)) {
+                        return response.ok({
+                            date: request.input('date', DateTime.now().toFormat('yyyy-MM-dd')),
+                            total: 0,
+                            scanned: 0,
+                            deviceCount: 0
+                        })
+                    }
+                    userChannel = requestChannel
+                } else {
+                    userChannel = allowedChannels.join(',')
                 }
+            } else {
+                return response.forbidden({
+                    error: 'Forbidden',
+                    message: 'You are not authorized to access this resource'
+                })
             }
 
             // 获取指定日期的数据,默认为今天
@@ -311,29 +425,21 @@ export default class OcrDevicesController {
             const dayEnd = DateTime.fromFormat(targetDate, 'yyyy-MM-dd').endOf('day').toSQL()
 
             // 获取设备数据
-            const deviceData = await deviceQuery
-                .where('createdAt', '>=', dayStart)
-                .where('createdAt', '<=', dayEnd)
+            const deviceData = await Database.from('ocr_devices')
+                .where('created_at', '>=', dayStart)
+                .where('created_at', '<=', dayEnd)
+                .whereIn('channel', userChannel!.split(','))
                 .select('scanned')
 
             // 获取OcrRecord数据
             const recordCount = await Database.from('ocr_records')
                 .where('created_at', '>=', dayStart)
                 .where('created_at', '<=', dayEnd)
-                .where(function (query) {
-                    if (isApiUser) {
-                        query.where('channel', user.username)
-                    } else {
-                        const channel = request.input('channel')
-                        if (channel) {
-                            query.where('channel', channel)
-                        }
-                    }
-                })
+                .whereIn('channel', userChannel!.split(','))
                 .count('* as total')
 
             // 计算统计数据
-            const scanned = deviceData.reduce((acc, item) => acc + item.scanned, 0)
+            const scanned = deviceData.reduce((acc, item) => acc + Number(item.scanned || 0), 0)
             const total = Number(recordCount[0].total) || 0
             const deviceCount = deviceData.length
 

+ 193 - 54
app/Controllers/Http/OcrRecordController.ts

@@ -9,44 +9,113 @@ import { HttpStatusCode } from 'axios'
 import { HttpException } from '@adonisjs/http-server/build/src/Exceptions/HttpException'
 import FilesService from 'App/Services/FilesService'
 import * as console from 'node:console'
+import { UserRoles } from 'App/Models/User'
+import UserService from 'App/Services/UserService'
+import { SimplePaginatorContract } from '@ioc:Adonis/Lucid/Database'
 
 export default class OcrRecordController {
     private paginationService = new PaginationService(OcrRecord)
 
-    public async index({ request, auth }: HttpContextContract) {
-        const user = auth.user
-        const isApiUser = user?.$attributes?.role === 'api'
+    public async index({ request, response, auth }: HttpContextContract) {
+        try {
+            const { page = 1, size = 20, id, deviceId, channel } = request.qs()
+            const user = auth.user
+            const role = user?.$attributes?.role
 
-        const requestData = request.all()
-        if (isApiUser) {
-            requestData.channel = user.username
-        }
+            if (!role) {
+                return response.forbidden({
+                    error: 'Forbidden',
+                    message: 'Unauthorized access'
+                })
+            }
 
-        const res = await this.paginationService.paginate(requestData)
-        if (isApiUser) {
-            res.forEach((record) => {
-                record.content = ''
-                record.record = ''
-                record.img = ''
-                record.thumbnail = ''
-            })
-        } else {
-            await Promise.all(
-                res.map(async (record) => {
-                    if (record.img && record.img !== '-') {
-                        const url = new URL(record.img)
-                        const filePath = url.pathname.replace(/^\//, '')
-                        record.img = await Drive.getSignedUrl(filePath)
-                        record.thumbnail = await FilesService.generateThumbnailUrl(filePath)
-                    } else {
-                        record.img = ''
-                        record.thumbnail = ''
+            let userChannel: string | undefined
+
+            if (role === UserRoles.Api) {
+                userChannel = user!.username
+            } else if (['admin', 'operator'].includes(role)) {
+                const apiUsers = await UserService.findReferredUsers(user!.id)
+                const allowedChannels = apiUsers.map((user) => user.username as string)
+
+                // 如果请求中指定了channel,检查是否在允许的channel列表中
+                if (channel) {
+                    if (!allowedChannels.includes(channel)) {
+                        return response.ok({
+                            data: [],
+                            meta: {
+                                total: 0,
+                                per_page: Number(size),
+                                current_page: Number(page),
+                                last_page: 0,
+                                first_page: 1,
+                                first_page_url: '/?page=1',
+                                last_page_url: '/?page=0',
+                                next_page_url: null,
+                                previous_page_url: null
+                            }
+                        })
                     }
+                    userChannel = channel
+                } else {
+                    userChannel = allowedChannels.join(',')
+                }
+            } else {
+                return response.forbidden({
+                    error: 'Forbidden',
+                    message: 'You are not authorized to access this resource'
                 })
-            )
-        }
+            }
+
+            // 构建查询条件
+            const query = OcrRecord.query()
+
+            if (id) {
+                query.where('id', id)
+            }
 
-        return res
+            if (deviceId) {
+                query.where('deviceId', deviceId)
+            }
+
+            if (userChannel) {
+                query.whereIn('channel', userChannel.split(','))
+            }
+
+            // 执行分页查询
+            const result = await query.orderBy('created_at', 'desc').paginate(page, size)
+
+            // 处理API用户的数据脱敏
+            if (role === UserRoles.Api) {
+                result.forEach((record) => {
+                    record.content = ''
+                    record.record = ''
+                    record.img = ''
+                    record.thumbnail = ''
+                })
+            } else {
+                // 处理图片URL
+                await Promise.all(
+                    result.map(async (record) => {
+                        if (record.img && record.img !== '-') {
+                            const url = new URL(record.img)
+                            const filePath = url.pathname.replace(/^\//, '')
+                            record.img = await Drive.getSignedUrl(filePath)
+                            record.thumbnail = await FilesService.generateThumbnailUrl(filePath)
+                        } else {
+                            record.img = ''
+                            record.thumbnail = ''
+                        }
+                    })
+                )
+            }
+
+            return response.ok(result)
+        } catch (error) {
+            return response.internalServerError({
+                message: '获取OCR记录列表时发生错误',
+                error: error.message
+            })
+        }
     }
 
     public async store({ request, bouncer }: HttpContextContract) {
@@ -104,36 +173,106 @@ export default class OcrRecordController {
         }
     }
 
-    public async favorite({ request, auth }: HttpContextContract) {
-        const user = auth.user
-        const isApiUser = user?.$attributes?.role === 'api'
+    public async favorite({ request, response, auth }: HttpContextContract) {
+        try {
+            const { page = 1, size = 20, id, deviceId, channel } = request.qs()
+            const user = auth.user
+            const role = user?.$attributes?.role
 
-        const requestData = request.all()
-        requestData.favorite = 1
-        if (isApiUser) {
-            requestData.channel = user.username
-        }
+            if (!role) {
+                return response.forbidden({
+                    error: 'Forbidden',
+                    message: 'Unauthorized access'
+                })
+            }
 
-        const res = await this.paginationService.paginate(requestData)
-        if (isApiUser) {
-            res.forEach((record) => {
-                record.content = ''
-                record.record = ''
-                record.img = ''
-            })
-        } else {
-            await Promise.all(
-                res.map(async (record) => {
-                    if (record.img && record.img !== '-') {
-                        record.img = await Drive.getSignedUrl(
-                            new URL(record.img).pathname.replace(/^\//, '')
-                        )
+            let userChannel: string | undefined
+
+            if (role === UserRoles.Api) {
+                userChannel = user!.username
+            } else if (['admin', 'operator'].includes(role)) {
+                const apiUsers = await UserService.findReferredUsers(user!.id)
+                const allowedChannels = apiUsers.map((user) => user.username as string)
+
+                // 如果请求中指定了channel,检查是否在允许的channel列表中
+                if (channel) {
+                    if (!allowedChannels.includes(channel)) {
+                        return response.ok({
+                            data: [],
+                            meta: {
+                                total: 0,
+                                per_page: Number(size),
+                                current_page: Number(page),
+                                last_page: 0,
+                                first_page: 1,
+                                first_page_url: '/?page=1',
+                                last_page_url: '/?page=0',
+                                next_page_url: null,
+                                previous_page_url: null
+                            }
+                        })
                     }
+                    userChannel = channel
+                } else {
+                    userChannel = allowedChannels.join(',')
+                }
+            } else {
+                return response.forbidden({
+                    error: 'Forbidden',
+                    message: 'You are not authorized to access this resource'
                 })
-            )
-        }
+            }
+
+            // 构建查询条件
+            const query = OcrRecord.query().where('favorite', true)
+
+            if (id) {
+                query.where('id', id)
+            }
 
-        return res
+            if (deviceId) {
+                query.where('deviceId', deviceId)
+            }
+
+            if (userChannel) {
+                query.whereIn('channel', userChannel.split(','))
+            }
+
+            // 执行分页查询
+            const result = await query.orderBy('created_at', 'desc').paginate(page, size)
+
+            // 处理API用户的数据脱敏
+            if (role === UserRoles.Api) {
+                result.forEach((record) => {
+                    record.content = ''
+                    record.record = ''
+                    record.img = ''
+                    record.thumbnail = ''
+                })
+            } else {
+                // 处理图片URL
+                await Promise.all(
+                    result.map(async (record) => {
+                        if (record.img && record.img !== '-') {
+                            const url = new URL(record.img)
+                            const filePath = url.pathname.replace(/^\//, '')
+                            record.img = await Drive.getSignedUrl(filePath)
+                            record.thumbnail = await FilesService.generateThumbnailUrl(filePath)
+                        } else {
+                            record.img = ''
+                            record.thumbnail = ''
+                        }
+                    })
+                )
+            }
+
+            return response.ok(result)
+        } catch (error) {
+            return response.internalServerError({
+                message: '获取收藏记录列表时发生错误',
+                error: error.message
+            })
+        }
     }
 
     public async getAllAddresses({ request }: HttpContextContract) {

+ 31 - 1
app/Services/UserService.ts

@@ -1,12 +1,13 @@
 import Database from '@ioc:Adonis/Lucid/Database'
 import Membership from 'App/Models/Membership'
 import Referrer from 'App/Models/Referrer'
-import User from 'App/Models/User'
+import User, { UserRoles } from 'App/Models/User'
 import UserBalance from 'App/Models/UserBalance'
 import Decimal from 'decimal.js'
 import { DateTime } from 'luxon'
 import randomstring from 'randomstring'
 import { addDays } from 'date-fns'
+
 class UserService {
     public async findById(id: number) {
         return await User.findByOrFail('id', id)
@@ -57,5 +58,34 @@ class UserService {
 
         return user
     }
+
+    public async findReferredUsers(userId: number) {
+        const result: Partial<User>[] = []
+
+        const findChildren = async (referrerId: number) => {
+            const users = await User.query()
+                .select(['id', 'username', 'role', 'referrer'])
+                .where('visitor', false)
+                .where('referrer', referrerId)
+
+            result.push(...users)
+
+            for (const user of users) {
+                await findChildren(user.id)
+            }
+        }
+
+        await findChildren(userId)
+        if (result.length === 0) {
+            return [
+                {
+                    username: 'nouser',
+                    role: UserRoles.Api
+                }
+            ]
+        }
+        return result
+    }
 }
+
 export default new UserService()