wangqifan 3 лет назад
Родитель
Сommit
8a3905c967

+ 4 - 2
src/chat-pdf/chat-pdf.module.ts

@@ -5,9 +5,11 @@ import { TypeOrmModule } from '@nestjs/typeorm'
 import { ConfigModule } from '@nestjs/config'
 import { SysConfigModule } from '../sys-config/sys-config.module'
 import { ApiUserModule } from '../api-users/api-user.module'
+import {ChatService} from "../chat/chat.service";
 @Module({
-    imports: [ConfigModule, SysConfigModule, ApiUserModule],
+    imports: [ConfigModule, ApiUserModule],
     providers: [ChatPdfService],
-    controllers: [ChatPdfController]
+    controllers: [ChatPdfController],
+    exports: [ChatPdfService]
 })
 export class ChatPdfModule {}

+ 38 - 17
src/chat-pdf/chat-pdf.service.ts

@@ -1,17 +1,16 @@
-import { BadRequestException, Injectable, InternalServerErrorException, Logger } from '@nestjs/common'
+import {BadRequestException, Injectable, InternalServerErrorException, Logger} from '@nestjs/common'
 import * as PdfParse from '@cyber2024/pdf-parse-fixed'
-import { createHash } from 'crypto'
-import { Tiktoken, get_encoding } from '@dqbd/tiktoken'
-import { Configuration, OpenAIApi } from 'azure-openai'
+import {createHash} from 'crypto'
+import {get_encoding, Tiktoken} from '@dqbd/tiktoken'
+import {Configuration, OpenAIApi} from 'azure-openai'
 import * as queue from 'fastq'
-import { setTimeout } from 'timers/promises'
+import {setTimeout} from 'timers/promises'
 import * as dedent from 'dedent'
-import { Sequelize, Model, DataTypes } from 'sequelize'
-import { ConfigService } from '@nestjs/config'
-import { ChatEmbedding } from './entities/chat-embedding.entity'
-import { VECTOR } from './pgvector'
-import { SysConfigService } from '../sys-config/sys-config.service'
-import { ApiUserService } from '../api-users/api-user.service'
+import {DataTypes, Sequelize} from 'sequelize'
+import {ConfigService} from '@nestjs/config'
+import {ChatEmbedding} from './entities/chat-embedding.entity'
+import {VECTOR} from './pgvector'
+import {ApiUserService} from '../api-users/api-user.service'
 
 function formatEmbedding(embedding: number[]) {
     return `[${embedding.join(', ')}]`
@@ -23,7 +22,7 @@ export class ChatPdfService {
     private readonly openai: OpenAIApi
     private readonly sequelize: Sequelize
     constructor(
-        private readonly sysConfigService: SysConfigService,
+        // private readonly sysConfigService: SysConfigService,
         private readonly configService: ConfigService,
         private readonly apiUserService: ApiUserService
     ) {
@@ -218,8 +217,30 @@ export class ChatPdfService {
         return context
     }
 
+    async getSystemConfig(q: string, name: string, code: string) {
+        const apiUser = await this.apiUserService.findByCode(code)
+        const keywords = await this.getKeywords(q)
+        const { embedding: keywordEmbedding } = await this.getEmbedding(keywords)
+        const context = this.cutContext((await this.searchEmbedding(name, keywordEmbedding)).map((item) => item.text))
+        if (!context || !context.length) {
+            return {
+                answer: '未找到相关内容'
+            }
+        }
+        return dedent`${apiUser.desc}'
+        根据相关性从高到低排序。 '
+        这是用户提出的问题:
+        ${q}
+        你只能根据用户的问题,以下面的内容为准进行回答:
+        \`\`\`
+        ${context.join('\n')}
+        \`\`\`
+        你要确保你的回答全部基于上面的问题和内容,如果无法找到答案,你需要根据给你设定的角色表述你不能回答这个问题。
+        你只能用中文回答.`
+    }
     async ask(q: string, name: string) {
-        const defSysMsg = (await this.sysConfigService.findByName('customer_system_message'))?.value
+        // const defSysMsg = (await this.sysConfigService.findByName('customer_system_message'))?.value
+        const apiUser = await this.apiUserService.findByCode(name)
         const keywords = await this.getKeywords(q)
         const { embedding: keywordEmbedding } = await this.getEmbedding(keywords)
         const context = this.cutContext((await this.searchEmbedding(name, keywordEmbedding)).map((item) => item.text))
@@ -228,7 +249,7 @@ export class ChatPdfService {
                 answer: '未找到相关内容'
             }
         }
-        const content = dedent`${defSysMsg}'
+        const content = dedent`${apiUser.desc}'
         根据相关性从高到低排序。 '
         这是用户提出的问题:
         ${q}
@@ -257,11 +278,11 @@ export class ChatPdfService {
     }
 
     async customerAsk(q: string, name: string) {
-        let apiUser = this.apiUserService.findByCode(name)
+        let apiUser = await this.apiUserService.findByCode(name)
         if(!apiUser) {
             throw new BadRequestException("not a enabled api user")
         }
-        const defSysMsg = (await this.sysConfigService.findByName('customer_system_message'))?.value
+        // const defSysMsg = (await this.sysConfigService.findByName('customer_system_message'))?.value
         const keywords = await this.getKeywords(q)
         const { embedding: keywordEmbedding } = await this.getEmbedding(keywords)
         const context = this.cutContext((await this.searchEmbedding(name, keywordEmbedding)).map((item) => item.text))
@@ -270,7 +291,7 @@ export class ChatPdfService {
                 answer: '客服无法回答这个问题,换个试试吧'
             }
         }
-        const content = dedent`${defSysMsg}'
+        const content = dedent`${apiUser.desc}'
         根据相关性从高到低排序。 '
         这是用户提出的问题:
         ${q}

+ 2 - 1
src/chat/chat.module.ts

@@ -7,9 +7,10 @@ import { TokenUsage } from './entities/token-usage.entity'
 import { MembershipModule } from '../membership/membership.module'
 import { HttpModule } from '@nestjs/axios'
 import { SysConfigModule } from '../sys-config/sys-config.module'
+import {ChatPdfModule} from "../chat-pdf/chat-pdf.module";
 
 @Module({
-    imports: [TypeOrmModule.forFeature([ChatHistory, TokenUsage]), MembershipModule, HttpModule, SysConfigModule],
+    imports: [TypeOrmModule.forFeature([ChatHistory, TokenUsage]), MembershipModule, HttpModule, SysConfigModule, ChatPdfModule],
     providers: [ChatService],
     controllers: [ChatController],
     exports: [ChatService]

+ 21 - 11
src/chat/chat.service.ts

@@ -15,6 +15,7 @@ import { fetchSSE } from '../chatapi/fetch-sse'
 import { HttpService } from '@nestjs/axios'
 import * as types from '../chatapi/types'
 import { SysConfigService } from '../sys-config/sys-config.service'
+import { ChatPdfService } from "../chat-pdf/chat-pdf.service";
 
 @Injectable()
 export class ChatService {
@@ -26,7 +27,8 @@ export class ChatService {
         private readonly tokenUsageRepository: Repository<TokenUsage>,
         private readonly membershipService: MembershipService,
         private readonly httpService: HttpService,
-        private readonly sysConfigService: SysConfigService
+        private readonly sysConfigService: SysConfigService,
+        private readonly chatPdfService: ChatPdfService
     ) { }
 
     public chat(req, res): Observable<any> {
@@ -64,18 +66,26 @@ export class ChatService {
         res.setHeader('Content-type', 'application/octet-stream')
         const defSysMsg = (await this.sysConfigService.findByName('system_message'))?.value
         const membership = await this.membershipService.getMembership(req.user.id)
-        if (!membership) {
-            throw new ForbiddenException('请先成为会员')
-        }
-        if (membership.memberType == MemberType.Trial && membership.tokenLeft <= 0) {
-            throw new ForbiddenException('您的试用额度已用完,请升级会员')
-        }
-        if (membership.isExpired) {
-            throw new ForbiddenException('您的会员已过期,请即时续费')
+        const { prompt, options = {}, systemMessage, code, temperature, top_p } = req.body as RequestProps
+        if (!code) {
+            if (!membership) {
+                throw new ForbiddenException('请先成为会员')
+            }
+            if (membership.memberType == MemberType.Trial && membership.tokenLeft <= 0) {
+                throw new ForbiddenException('您的试用额度已用完,请升级会员')
+            }
+            if (membership.isExpired) {
+                throw new ForbiddenException('您的会员已过期,请即时续费')
+            }
         }
         try {
+            let content = ''
             const promptTime = new Date()
-            const { prompt, options = {}, systemMessage, temperature, top_p } = req.body as RequestProps
+            if (!!code) {
+                content = await this.chatPdfService.getSystemConfig(prompt, code, code)
+            }else {
+                content = systemMessage
+            }
             let firstChunk = true
             const result = await chatReplyProcess({
                 message: prompt,
@@ -84,7 +94,7 @@ export class ChatService {
                     res.write(firstChunk ? JSON.stringify(chat) : `\n${JSON.stringify(chat)}`)
                     firstChunk = false
                 },
-                systemMessage: systemMessage || defSysMsg,
+                systemMessage: content || defSysMsg,
                 temperature,
                 top_p
             })

+ 1 - 0
src/chat/types.ts

@@ -4,6 +4,7 @@ export interface RequestProps {
   prompt: string
   options?: ChatContext
   systemMessage: string
+  code: string
   temperature?: number
   top_p?: number
 }