Forráskód Böngészése

Merge branch 'master' of http://git.izouma.com/xiongzhu/chat-api

wuyi 3 éve
szülő
commit
978a529eb8

+ 9 - 8
src/api-users/api-user.controller.ts

@@ -18,18 +18,16 @@ import { PageRequest } from '../common/dto/page-request'
 import { ApiUser } from './entities/api-user.entity'
 import { ApiUserDto } from './dto/api-user.dto'
 import { ApiUserService } from './api-user.service'
+import { Public } from 'src/auth/public.decorator'
 
 @ApiTags('apiUser')
 @Controller('/apiUser')
 @ApiBearerAuth()
 export class ApiUserController {
-
-    constructor(private readonly apiUserService: ApiUserService) { }
-
+    constructor(private readonly apiUserService: ApiUserService) {}
 
     @Post()
-    public async list(@Body() page: PageRequest<ApiUser>) {
-
+    public async list(@Body() page: PageRequest<ApiUser>, @Req() req) {
         return await this.apiUserService.findAll(page)
     }
 
@@ -39,14 +37,18 @@ export class ApiUserController {
         return await this.apiUserService.create(userId)
     }
 
+    @Public()
     @Get('/get/:id')
-    public async get(@Param('id') id: string) {
+    public async get(@Param('id') id: string, @Req() req) {
         const chatRole = await this.apiUserService.findById(Number(id))
+        // if (!req.user || req.user.apiUserId != chatRole.id) {
+        //     chatRole.code = ''
+        // }
         return chatRole
     }
 
     @Put('/:id')
-    @HasRoles(Role.Admin)
+    @HasRoles(Role.Api)
     public async update(@Param('id') id: string, @Body() apiUser: ApiUser) {
         try {
             await this.apiUserService.update(Number(id), apiUser)
@@ -73,5 +75,4 @@ export class ApiUserController {
             throw new BadRequestException(err, 'Error: apiuser not deleted!')
         }
     }
-
 }

+ 9 - 0
src/api-users/entities/api-user.entity.ts

@@ -16,9 +16,18 @@ export class ApiUser {
     @Column({ nullable: true })
     name: string
 
+    @Column()
+    avatar: string
+
+    @Column()
+    desc: string
+
     @Column({ nullable: true })
     code: string
 
     @Column({ type: 'enum', enum: ApiType })
     type: ApiType
+
+    @CreateDateColumn()
+    createdAt: Date
 }

+ 9 - 1
src/auth/auth.controller.ts

@@ -1,5 +1,5 @@
 import { PhoneLoginDto } from './dto/login.dto'
-import { Body, Controller, Get, Param, Post } from '@nestjs/common'
+import { Body, Controller, Get, Param, Post, Req } from '@nestjs/common'
 import { AuthService } from './auth.service'
 import { ApiTags } from '@nestjs/swagger'
 import { Public } from './public.decorator'
@@ -23,6 +23,14 @@ export class AuthController {
         return await this.authService.loginAdmin(username, password)
     }
 
+    @Get('/admin/getRole')
+    async getRole(@Req() req) {
+        if (req.user.roles.includes(Role.Api)) {
+            return 'api'
+        }
+        return 'admin'
+    }
+
     @Get('/admin/user/:userId/token')
     @HasRoles(Role.Admin)
     async getToken(@Param('userId') userId: string) {

+ 4 - 2
src/chat-pdf/chat-pdf.module.ts

@@ -5,9 +5,11 @@ import { TypeOrmModule } from '@nestjs/typeorm'
 import { ConfigModule } from '@nestjs/config'
 import { SysConfigModule } from '../sys-config/sys-config.module'
 import { ApiUserModule } from '../api-users/api-user.module'
+import {ChatService} from "../chat/chat.service";
 @Module({
-    imports: [ConfigModule, SysConfigModule, ApiUserModule],
+    imports: [ConfigModule, ApiUserModule],
     providers: [ChatPdfService],
-    controllers: [ChatPdfController]
+    controllers: [ChatPdfController],
+    exports: [ChatPdfService]
 })
 export class ChatPdfModule {}

+ 47 - 27
src/chat-pdf/chat-pdf.service.ts

@@ -1,17 +1,16 @@
-import { BadRequestException, Injectable, InternalServerErrorException, Logger } from '@nestjs/common'
+import {BadRequestException, Injectable, InternalServerErrorException, Logger} from '@nestjs/common'
 import * as PdfParse from '@cyber2024/pdf-parse-fixed'
-import { createHash } from 'crypto'
-import { Tiktoken, get_encoding } from '@dqbd/tiktoken'
-import { Configuration, OpenAIApi } from 'azure-openai'
+import {createHash} from 'crypto'
+import {get_encoding, Tiktoken} from '@dqbd/tiktoken'
+import {Configuration, OpenAIApi} from 'azure-openai'
 import * as queue from 'fastq'
-import { setTimeout } from 'timers/promises'
+import {setTimeout} from 'timers/promises'
 import * as dedent from 'dedent'
-import { Sequelize, Model, DataTypes } from 'sequelize'
-import { ConfigService } from '@nestjs/config'
-import { ChatEmbedding } from './entities/chat-embedding.entity'
-import { VECTOR } from './pgvector'
-import { SysConfigService } from '../sys-config/sys-config.service'
-import { ApiUserService } from '../api-users/api-user.service'
+import {DataTypes, Sequelize} from 'sequelize'
+import {ConfigService} from '@nestjs/config'
+import {ChatEmbedding} from './entities/chat-embedding.entity'
+import {VECTOR} from './pgvector'
+import {ApiUserService} from '../api-users/api-user.service'
 
 function formatEmbedding(embedding: number[]) {
     return `[${embedding.join(', ')}]`
@@ -23,7 +22,7 @@ export class ChatPdfService {
     private readonly openai: OpenAIApi
     private readonly sequelize: Sequelize
     constructor(
-        private readonly sysConfigService: SysConfigService,
+        // private readonly sysConfigService: SysConfigService,
         private readonly configService: ConfigService,
         private readonly apiUserService: ApiUserService
     ) {
@@ -218,7 +217,30 @@ export class ChatPdfService {
         return context
     }
 
+    async getSystemConfig(q: string, name: string, code: string) {
+        const apiUser = await this.apiUserService.findByCode(code)
+        const keywords = await this.getKeywords(q)
+        const { embedding: keywordEmbedding } = await this.getEmbedding(keywords)
+        const context = this.cutContext((await this.searchEmbedding(name, keywordEmbedding)).map((item) => item.text))
+        if (!context || !context.length) {
+            return {
+                answer: '未找到相关内容'
+            }
+        }
+        return dedent`${apiUser.desc}'
+        根据相关性从高到低排序。 '
+        这是用户提出的问题:
+        ${q}
+        你只能根据用户的问题,以下面的内容为准进行回答:
+        \`\`\`
+        ${context.join('\n')}
+        \`\`\`
+        你要确保你的回答全部基于上面的问题和内容,如果无法找到答案,你需要根据给你设定的角色表述你不能回答这个问题。
+        你只能用中文回答.`
+    }
     async ask(q: string, name: string) {
+        // const defSysMsg = (await this.sysConfigService.findByName('customer_system_message'))?.value
+        const apiUser = await this.apiUserService.findByCode(name)
         const keywords = await this.getKeywords(q)
         const { embedding: keywordEmbedding } = await this.getEmbedding(keywords)
         const context = this.cutContext((await this.searchEmbedding(name, keywordEmbedding)).map((item) => item.text))
@@ -227,16 +249,15 @@ export class ChatPdfService {
                 answer: '未找到相关内容'
             }
         }
-        const content = dedent`You are a helpful AI article assistant. '
-        The following are the relevant article content fragments found from the article. 
-        The relevance is sorted from high to low. '
-        You can only answer according to the following content:
+        const content = dedent`${apiUser.desc}'
+        根据相关性从高到低排序。 '
+        这是用户提出的问题:
+        ${q}
+        你只能根据用户的问题,以下面的内容为准进行回答:
         \`\`\`
         ${context.join('\n')}
         \`\`\`
-        You need to carefully consider your answer to ensure that it is based on the context. 
-        If the context does not mention the content or it is uncertain whether it is correct, 
-        please answer "Current context cannot provide effective information."
+        你要确保你的回答全部基于上面的问题和内容,如果无法找到答案,你需要根据给你设定的角色表述你不能回答这个问题。
         You must use Chinese to respond.`
         try {
             const response = await this.openai.createChatCompletion({
@@ -257,11 +278,11 @@ export class ChatPdfService {
     }
 
     async customerAsk(q: string, name: string) {
-        let apiUser = this.apiUserService.findByCode(name)
+        let apiUser = await this.apiUserService.findByCode(name)
         if(!apiUser) {
             throw new BadRequestException("not a enabled api user")
         }
-        const defSysMsg = (await this.sysConfigService.findByName('customer_system_message'))?.value
+        // const defSysMsg = (await this.sysConfigService.findByName('customer_system_message'))?.value
         const keywords = await this.getKeywords(q)
         const { embedding: keywordEmbedding } = await this.getEmbedding(keywords)
         const context = this.cutContext((await this.searchEmbedding(name, keywordEmbedding)).map((item) => item.text))
@@ -270,16 +291,15 @@ export class ChatPdfService {
                 answer: '客服无法回答这个问题,换个试试吧'
             }
         }
-        const content = dedent`${defSysMsg}'
-        下面的内容是你从客服指南中学习到的。
+        const content = dedent`${apiUser.desc}'
         根据相关性从高到低排序。 '
-        You can only answer according to the following content:
+        这是用户提出的问题:
+        ${q}
+        你只能根据用户的问题,以下面的内容为准进行回答:
         \`\`\`
         ${context.join('\n')}
         \`\`\`
-        You need to carefully consider your answer to ensure that it is based on the context. 
-        If the context does not mention the content or it is uncertain whether it is correct, 
-        please answer "这个问题无法从客服指南中找到答案,请询问更准确的问题。"
+        你要确保你的回答全部基于上面的问题和内容,如果无法找到答案,你需要根据给你设定的角色表述你不能回答这个问题。
         You must use Chinese to respond.`
         try {
             const response = await this.openai.createChatCompletion({

+ 2 - 1
src/chat/chat.module.ts

@@ -7,9 +7,10 @@ import { TokenUsage } from './entities/token-usage.entity'
 import { MembershipModule } from '../membership/membership.module'
 import { HttpModule } from '@nestjs/axios'
 import { SysConfigModule } from '../sys-config/sys-config.module'
+import {ChatPdfModule} from "../chat-pdf/chat-pdf.module";
 
 @Module({
-    imports: [TypeOrmModule.forFeature([ChatHistory, TokenUsage]), MembershipModule, HttpModule, SysConfigModule],
+    imports: [TypeOrmModule.forFeature([ChatHistory, TokenUsage]), MembershipModule, HttpModule, SysConfigModule, ChatPdfModule],
     providers: [ChatService],
     controllers: [ChatController],
     exports: [ChatService]

+ 21 - 11
src/chat/chat.service.ts

@@ -15,6 +15,7 @@ import { fetchSSE } from '../chatapi/fetch-sse'
 import { HttpService } from '@nestjs/axios'
 import * as types from '../chatapi/types'
 import { SysConfigService } from '../sys-config/sys-config.service'
+import { ChatPdfService } from "../chat-pdf/chat-pdf.service";
 
 @Injectable()
 export class ChatService {
@@ -26,7 +27,8 @@ export class ChatService {
         private readonly tokenUsageRepository: Repository<TokenUsage>,
         private readonly membershipService: MembershipService,
         private readonly httpService: HttpService,
-        private readonly sysConfigService: SysConfigService
+        private readonly sysConfigService: SysConfigService,
+        private readonly chatPdfService: ChatPdfService
     ) { }
 
     public chat(req, res): Observable<any> {
@@ -64,18 +66,26 @@ export class ChatService {
         res.setHeader('Content-type', 'application/octet-stream')
         const defSysMsg = (await this.sysConfigService.findByName('system_message'))?.value
         const membership = await this.membershipService.getMembership(req.user.id)
-        if (!membership) {
-            throw new ForbiddenException('请先成为会员')
-        }
-        if (membership.memberType == MemberType.Trial && membership.tokenLeft <= 0) {
-            throw new ForbiddenException('您的试用额度已用完,请升级会员')
-        }
-        if (membership.isExpired) {
-            throw new ForbiddenException('您的会员已过期,请即时续费')
+        const { prompt, options = {}, systemMessage, code, temperature, top_p } = req.body as RequestProps
+        if (!code) {
+            if (!membership) {
+                throw new ForbiddenException('请先成为会员')
+            }
+            if (membership.memberType == MemberType.Trial && membership.tokenLeft <= 0) {
+                throw new ForbiddenException('您的试用额度已用完,请升级会员')
+            }
+            if (membership.isExpired) {
+                throw new ForbiddenException('您的会员已过期,请即时续费')
+            }
         }
         try {
+            let content = ''
             const promptTime = new Date()
-            const { prompt, options = {}, systemMessage, temperature, top_p } = req.body as RequestProps
+            if (!!code) {
+                content = await this.chatPdfService.getSystemConfig(prompt, code, code)
+            }else {
+                content = systemMessage
+            }
             let firstChunk = true
             const result = await chatReplyProcess({
                 message: prompt,
@@ -84,7 +94,7 @@ export class ChatService {
                     res.write(firstChunk ? JSON.stringify(chat) : `\n${JSON.stringify(chat)}`)
                     firstChunk = false
                 },
-                systemMessage: systemMessage || defSysMsg,
+                systemMessage: content || defSysMsg,
                 temperature,
                 top_p
             })

+ 1 - 0
src/chat/types.ts

@@ -4,6 +4,7 @@ export interface RequestProps {
   prompt: string
   options?: ChatContext
   systemMessage: string
+  code: string
   temperature?: number
   top_p?: number
 }

+ 2 - 0
src/users/dto/user-create.dto.ts

@@ -28,6 +28,8 @@ export class UserCreateDto {
     @MaxLength(60)
     password: string
 
+    apiUserId: number
+
     @IsArray()
     readonly roles: Role[]
 }

+ 3 - 0
src/users/entities/users.entity.ts

@@ -37,4 +37,7 @@ export class Users {
 
     @Column({ nullable: true })
     iat: number
+
+    @Column({ nullable: true })
+    apiUserId: number
 }

+ 7 - 3
src/users/users.admin.controller.ts

@@ -17,6 +17,7 @@ import { UserUpdateDto } from './dto/user-update.dto'
 import { IUsers } from './interfaces/users.interface'
 import { ApiBearerAuth, ApiTags } from '@nestjs/swagger'
 import { HasRoles } from '../auth/roles.decorator'
+import { HasAnyRoles } from '../auth/roles.decorator'
 import { Role } from '../model/role.enum'
 import { IPaginationOptions } from 'nestjs-typeorm-paginate'
 import { PageRequest } from '../common/dto/page-request'
@@ -26,17 +27,20 @@ import { UserCreateDto } from './dto/user-create.dto'
 @ApiTags('users.admin')
 @Controller('/admin/users')
 @ApiBearerAuth()
-@HasRoles(Role.Admin)
+@HasAnyRoles(Role.Admin, Role.Api)
 export class UsersAdminController {
     constructor(private readonly usersService: UsersService) {}
 
     @Post()
-    public async list(@Body() page: PageRequest<Users>) {
+    public async list(@Body() page: PageRequest<Users>, @Req() req) {
         return await this.usersService.findAll(page)
     }
 
     @Put()
-    public async create(@Body() user: UserCreateDto) {
+    public async create(@Body() user: UserCreateDto, @Req() req) {
+        if (req.user.roles.includes(Role.Api)) {
+            return await this.usersService.createSubUser(user, req.user.id)
+        }
         return await this.usersService.create(user)
     }
 

+ 27 - 2
src/users/users.service.ts

@@ -22,6 +22,7 @@ import { ApiUserService } from '../api-users/api-user.service'
 import { paginate, Pagination } from 'nestjs-typeorm-paginate'
 import { Role } from '../model/role.enum'
 import { PageRequest } from '../common/dto/page-request'
+import { th } from 'date-fns/locale'
 
 @Injectable()
 export class UsersService {
@@ -105,9 +106,15 @@ export class UsersService {
         if (!isMatch) {
             throw new UnauthorizedException('用户名或密码错误')
         }
-        if (!user.roles.includes(Role.Admin)) {
+        if (!user.roles.includes(Role.Admin) && !user.roles.includes(Role.Api)) {
             throw new UnauthorizedException('用户名或密码错误')
         }
+        if (user.roles.includes(Role.Api)) {
+            let apiUser = await this.apiUserService.findById(user.apiUserId)
+            if (apiUser.userId != user.id) {
+                throw new UnauthorizedException('用户名或密码错误')
+            }
+        }
         return user
     }
 
@@ -118,8 +125,26 @@ export class UsersService {
             }
             let user = await this.userRepository.save(userDto)
             if (userDto.roles.includes(Role.Api)) {
-                this.apiUserService.create(user.id)
+                let apiUser = await this.apiUserService.create(user.id)
+                user.apiUserId = apiUser.id
+                user = await this.userRepository.save(user)
             }
+
+            return user
+        } catch (err) {
+            throw new InternalServerErrorException(err.message)
+        }
+    }
+
+    public async createSubUser(userDto: UserCreateDto, apiUserId: number): Promise<IUsers> {
+        try {
+            const apiUser = await this.findById(apiUserId)
+            if (userDto.password) {
+                userDto.password = await this.hashingService.hash(userDto.password)
+            }
+            userDto.apiUserId = apiUser.apiUserId
+            let user = await this.userRepository.save(userDto)
+
             return user
         } catch (err) {
             throw new InternalServerErrorException(err.message)