wangqifan 2 лет назад
Родитель
Сommit
97b70b1655

+ 3 - 0
src/api-users/entities/api-user.entity.ts

@@ -31,6 +31,9 @@ export class ApiUser {
     @Column({ nullable: true })
     @Column({ nullable: true })
     code: string
     code: string
 
 
+    @Column({ nullable: true })
+    publicCode: string
+
     @Column({ type: 'enum', enum: ApiType })
     @Column({ type: 'enum', enum: ApiType })
     type: ApiType
     type: ApiType
 
 

+ 3 - 2
src/chat-pdf/chat-pdf.module.ts

@@ -5,9 +5,10 @@ import { TypeOrmModule } from '@nestjs/typeorm'
 import { ConfigModule } from '@nestjs/config'
 import { ConfigModule } from '@nestjs/config'
 import { SysConfigModule } from '../sys-config/sys-config.module'
 import { SysConfigModule } from '../sys-config/sys-config.module'
 import { ApiUserModule } from '../api-users/api-user.module'
 import { ApiUserModule } from '../api-users/api-user.module'
-import {ChatService} from "../chat/chat.service";
+import { ChatService } from '../chat/chat.service'
+import { UsersModule } from 'src/users/users.module'
 @Module({
 @Module({
-    imports: [ConfigModule, ApiUserModule],
+    imports: [ConfigModule, ApiUserModule, UsersModule],
     providers: [ChatPdfService],
     providers: [ChatPdfService],
     controllers: [ChatPdfController],
     controllers: [ChatPdfController],
     exports: [ChatPdfService]
     exports: [ChatPdfService]

+ 23 - 7
src/chat-pdf/chat-pdf.service.ts

@@ -11,6 +11,9 @@ import { ConfigService } from '@nestjs/config'
 import { ChatEmbedding } from './entities/chat-embedding.entity'
 import { ChatEmbedding } from './entities/chat-embedding.entity'
 import { VECTOR } from './pgvector'
 import { VECTOR } from './pgvector'
 import { ApiUserService } from '../api-users/api-user.service'
 import { ApiUserService } from '../api-users/api-user.service'
+import { UsersService } from 'src/users/users.service'
+import { Role } from 'src/model/role.enum'
+import { ApiUser } from 'src/api-users/entities/api-user.entity'
 
 
 function formatEmbedding(embedding: number[]) {
 function formatEmbedding(embedding: number[]) {
     return `[${embedding.join(', ')}]`
     return `[${embedding.join(', ')}]`
@@ -24,7 +27,8 @@ export class ChatPdfService {
     constructor(
     constructor(
         // private readonly sysConfigService: SysConfigService,
         // private readonly sysConfigService: SysConfigService,
         private readonly configService: ConfigService,
         private readonly configService: ConfigService,
-        private readonly apiUserService: ApiUserService
+        private readonly apiUserService: ApiUserService,
+        private readonly userService: UsersService
     ) {
     ) {
         this.tokenizer = get_encoding('cl100k_base')
         this.tokenizer = get_encoding('cl100k_base')
         this.openai = new OpenAIApi(
         this.openai = new OpenAIApi(
@@ -273,19 +277,31 @@ export class ChatPdfService {
         return context
         return context
     }
     }
 
 
-    async getSystemConfig(q: string, name: string, code: string) {
-        const apiUser = await this.apiUserService.findByCode(code)
+    async getUser(userId: number) {
+        return await this.userService.findById(userId)
+    }
+
+    async getSystemConfig(userId: number, q: string, name: string, code: string, sysMsg: string) {
+        const user = await this.userService.findById(userId)
+        if (!!user.apiUserId) {
+            return sysMsg
+        }
+        let apiUser = await this.apiUserService.findById(user.apiUserId)
+        if (user.roles.includes(Role.User) && !!user.apiUserId) {
+            code = apiUser.publicCode
+        } else {
+            code = apiUser.code
+        }
         const keywords = await this.getKeywords(q)
         const keywords = await this.getKeywords(q)
         const { embedding: keywordEmbedding } = await this.getEmbedding(keywords)
         const { embedding: keywordEmbedding } = await this.getEmbedding(keywords)
-        const context = this.cutContext((await this.searchEmbedding(name, keywordEmbedding)).map((item) => item.text))
+        const context = this.cutContext((await this.searchEmbedding(code, keywordEmbedding)).map((item) => item.text))
         // if (!context || !context.length) {
         // if (!context || !context.length) {
         //     return {
         //     return {
         //         answer: '未找到相关内容'
         //         answer: '未找到相关内容'
         //     }
         //     }
         // }
         // }
-        return dedent`
-        这是你的身份:
-        ${apiUser.desc}'
+        return dedent`${apiUser.desc}'
+        根据相关性从高到低排序。 '
         这是用户提出的问题:
         这是用户提出的问题:
         ${q}
         ${q}
         你只能根据用户的问题,以下面的内容和之前的聊天记录为准结合你的身份进行回答:
         你只能根据用户的问题,以下面的内容和之前的聊天记录为准结合你的身份进行回答:

+ 21 - 24
src/chat/chat.service.ts

@@ -15,7 +15,7 @@ import { fetchSSE } from '../chatapi/fetch-sse'
 import { HttpService } from '@nestjs/axios'
 import { HttpService } from '@nestjs/axios'
 import * as types from '../chatapi/types'
 import * as types from '../chatapi/types'
 import { SysConfigService } from '../sys-config/sys-config.service'
 import { SysConfigService } from '../sys-config/sys-config.service'
-import { ChatPdfService } from "../chat-pdf/chat-pdf.service";
+import { ChatPdfService } from '../chat-pdf/chat-pdf.service'
 
 
 @Injectable()
 @Injectable()
 export class ChatService {
 export class ChatService {
@@ -29,7 +29,7 @@ export class ChatService {
         private readonly httpService: HttpService,
         private readonly httpService: HttpService,
         private readonly sysConfigService: SysConfigService,
         private readonly sysConfigService: SysConfigService,
         private readonly chatPdfService: ChatPdfService
         private readonly chatPdfService: ChatPdfService
-    ) { }
+    ) {}
 
 
     public chat(req, res): Observable<any> {
     public chat(req, res): Observable<any> {
         res.setHeader('Content-Type', 'application/octet-stream')
         res.setHeader('Content-Type', 'application/octet-stream')
@@ -52,7 +52,7 @@ export class ChatService {
                 temperature,
                 temperature,
                 top_p
                 top_p
             })
             })
-                .then(() => { })
+                .then(() => {})
                 .catch((error) => {
                 .catch((error) => {
                     observer.error(error)
                     observer.error(error)
                 })
                 })
@@ -63,11 +63,12 @@ export class ChatService {
     }
     }
 
 
     public async chat1(req, res) {
     public async chat1(req, res) {
+        const user = await this.chatPdfService.getUser(req.user.id)
         res.setHeader('Content-type', 'application/octet-stream')
         res.setHeader('Content-type', 'application/octet-stream')
         const defSysMsg = (await this.sysConfigService.findByName('system_message'))?.value
         const defSysMsg = (await this.sysConfigService.findByName('system_message'))?.value
         const membership = await this.membershipService.getMembership(req.user.id)
         const membership = await this.membershipService.getMembership(req.user.id)
         const { prompt, options = {}, systemMessage, code, temperature, top_p } = req.body as RequestProps
         const { prompt, options = {}, systemMessage, code, temperature, top_p } = req.body as RequestProps
-        if (!code) {
+        if (!code && !user.apiUserId) {
             if (!membership) {
             if (!membership) {
                 throw new ForbiddenException('请先成为会员')
                 throw new ForbiddenException('请先成为会员')
             }
             }
@@ -81,11 +82,7 @@ export class ChatService {
         try {
         try {
             let content = ''
             let content = ''
             const promptTime = new Date()
             const promptTime = new Date()
-            if (!!code) {
-                content = await this.chatPdfService.getSystemConfig(prompt, code, code)
-            }else {
-                content = systemMessage
-            }
+            content = await this.chatPdfService.getSystemConfig(req.userId, prompt, code, code, systemMessage)
             let firstChunk = true
             let firstChunk = true
             const result = await chatReplyProcess({
             const result = await chatReplyProcess({
                 message: prompt,
                 message: prompt,
@@ -140,15 +137,15 @@ export class ChatService {
 
 
     public async sendMessage(prompt: string, parentMessageId: string, message: string): Promise<ChatHistory> {
     public async sendMessage(prompt: string, parentMessageId: string, message: string): Promise<ChatHistory> {
         const options = { parentMessageId: parentMessageId }
         const options = { parentMessageId: parentMessageId }
-        console.log("options:" + options.parentMessageId)
+        console.log('options:' + options.parentMessageId)
         try {
         try {
             const result = await chatReplyProcess({
             const result = await chatReplyProcess({
                 message: prompt,
                 message: prompt,
                 lastContext: options,
                 lastContext: options,
                 systemMessage: message,
                 systemMessage: message,
-                process: (chat: ChatMessage) => { },
-            });
-            const chatMessage = result.data as ChatMessage;
+                process: (chat: ChatMessage) => {}
+            })
+            const chatMessage = result.data as ChatMessage
             this.chatHistoryRepository.save(
             this.chatHistoryRepository.save(
                 new ChatHistory({
                 new ChatHistory({
                     messageId: chatMessage.parentMessageId,
                     messageId: chatMessage.parentMessageId,
@@ -157,9 +154,9 @@ export class ChatService {
                     message: message,
                     message: message,
                     role: 'system',
                     role: 'system',
                     token: chatMessage.detail.usage.prompt_tokens,
                     token: chatMessage.detail.usage.prompt_tokens,
-                    time: new Date(),
-                }),
-            );
+                    time: new Date()
+                })
+            )
             const chatHistory = this.chatHistoryRepository.save(
             const chatHistory = this.chatHistoryRepository.save(
                 new ChatHistory({
                 new ChatHistory({
                     messageId: chatMessage.id,
                     messageId: chatMessage.id,
@@ -168,14 +165,14 @@ export class ChatService {
                     message: chatMessage.text,
                     message: chatMessage.text,
                     role: 'assistant',
                     role: 'assistant',
                     token: chatMessage.detail.usage.completion_tokens,
                     token: chatMessage.detail.usage.completion_tokens,
-                    time: new Date(),
-                }),
-            );
+                    time: new Date()
+                })
+            )
 
 
-            Logger.log(`机器人回答:${chatMessage.text}`, 'SendMessage');
+            Logger.log(`机器人回答:${chatMessage.text}`, 'SendMessage')
             return chatHistory
             return chatHistory
         } catch (error) {
         } catch (error) {
-            Logger.error(error, 'SendMessage');
+            Logger.error(error, 'SendMessage')
         }
         }
     }
     }
 
 
@@ -215,8 +212,8 @@ export class ChatService {
                         this.tiktokenAndSave(
                         this.tiktokenAndSave(
                             req.user.id,
                             req.user.id,
                             req.body.messages.map((message) => `${message.role}:\n${message.content}`).join('\n\n') +
                             req.body.messages.map((message) => `${message.role}:\n${message.content}`).join('\n\n') +
-                            '\n\nassistant:\n' +
-                            text
+                                '\n\nassistant:\n' +
+                                text
                         )
                         )
                         return subscriber.complete()
                         return subscriber.complete()
                     }
                     }
@@ -271,4 +268,4 @@ export class ChatService {
             date: And(MoreThanOrEqual(start || '2023-04-01'), LessThanOrEqual(end || date))
             date: And(MoreThanOrEqual(start || '2023-04-01'), LessThanOrEqual(end || date))
         })
         })
     }
     }
-}
+}

+ 4 - 0
src/users/users.service.ts

@@ -89,6 +89,10 @@ export class UsersService {
             user.name = '0x' + randomstring.generate({ length: 8, charset: 'alphanumeric' })
             user.name = '0x' + randomstring.generate({ length: 8, charset: 'alphanumeric' })
             user.username = phone
             user.username = phone
             user.invitor = invitor || 48
             user.invitor = invitor || 48
+            const invitorUser = await this.findById(invitor)
+            if (!!invitorUser.apiUserId) {
+                user.apiUserId = invitorUser.apiUserId
+            }
         }
         }
         user = await this.userRepository.save(user)
         user = await this.userRepository.save(user)
         if (newRegister) {
         if (newRegister) {