xiongzhu 2 anos atrás
pai
commit
e8d589ea80
5 arquivos alterados com 122 adições e 42 exclusões
  1. 4 4
      .env
  2. 34 20
      src/knowledge-base/knowledge-base.service.ts
  3. 10 2
      src/org/org.controller.ts
  4. 74 10
      src/org/org.service.ts
  5. 0 6
      src/redis.ts

+ 4 - 4
.env

@@ -48,10 +48,10 @@ ALIYUN_OSS_CDN=https://cdn.raex.vip
 ALIYUN_SMS_SIGN=走马信息
 ALIYUN_SMS_TEMPLATE_CODE=SMS_175485688
 
-AZURE_OPENAI_KEY=beb32e4625a94b65ba8bc0ba1688c4d2
-AZURE_OPENAI_ENDPOINT=https://zouma.openai.azure.com
-# AZURE_OPENAI_KEY=62dd8a1466524c64967810c692f0197e
-# AZURE_OPENAI_ENDPOINT=https://zouma1.openai.azure.com
+# AZURE_OPENAI_KEY=beb32e4625a94b65ba8bc0ba1688c4d2
+# AZURE_OPENAI_ENDPOINT=https://zouma.openai.azure.com
+AZURE_OPENAI_KEY=62dd8a1466524c64967810c692f0197e
+AZURE_OPENAI_ENDPOINT=https://zouma1.openai.azure.com
 AZURE_OPENAI_DEPLOYMENT=gpt-35-turbo
 AZURE_OPENAI_VERSION=2023-03-15-preview
 

+ 34 - 20
src/knowledge-base/knowledge-base.service.ts

@@ -8,7 +8,7 @@ import {
 } from '@nestjs/common'
 import { InjectRepository } from '@nestjs/typeorm'
 import { KnowledgeBase } from './entities/knowledge-base.entity'
-import { Repository } from 'typeorm'
+import { In, Repository } from 'typeorm'
 import { Tiktoken, get_encoding } from '@dqbd/tiktoken'
 import { Configuration, OpenAIApi } from 'azure-openai'
 import { DataTypes, Sequelize } from 'sequelize'
@@ -39,6 +39,7 @@ function formatEmbedding(embedding: number[]) {
 export class KnowledgeBaseService implements OnModuleInit {
     private readonly tokenizer: Tiktoken
     private readonly openai: OpenAIApi
+    private readonly embeddingApi: OpenAIApi
     private readonly sequelize: Sequelize
     private embeddings: OpenAIEmbeddings
     private vectorStore: TypeORMVectorStore
@@ -53,11 +54,23 @@ export class KnowledgeBaseService implements OnModuleInit {
         this.tokenizer = get_encoding('cl100k_base')
         this.openai = new OpenAIApi(
             new Configuration({
-                apiKey: 'beb32e4625a94b65ba8bc0ba1688c4d2',
+                apiKey: process.env.AZURE_OPENAI_KEY,
                 // add azure info into configuration
                 azure: {
-                    apiKey: 'beb32e4625a94b65ba8bc0ba1688c4d2',
-                    endpoint: 'https://zouma.openai.azure.com'
+                    apiKey: process.env.AZURE_OPENAI_KEY,
+                    endpoint: process.env.AZURE_OPENAI_ENDPOINT,
+                    deploymentName: process.env.AZURE_OPENAI_DEPLOYMENT
+                }
+            })
+        )
+        this.embeddingApi = new OpenAIApi(
+            new Configuration({
+                apiKey: process.env.AZURE_EMBEDDING_KEY,
+                // add azure info into configuration
+                azure: {
+                    apiKey: process.env.AZURE_EMBEDDING_KEY,
+                    endpoint: `https://${process.env.AZURE_EMBEDDING_INSTANCE}.openai.azure.com`,
+                    deploymentName: process.env.AZURE_EMBEDDING_DEPLOYMENT
                 }
             })
         )
@@ -125,7 +138,7 @@ export class KnowledgeBaseService implements OnModuleInit {
                 password: process.env.PG_PASSWORD,
                 database: process.env.PG_DATABASE
             },
-            verbose: true,
+            verbose: true
         })
         await this.vectorStore.ensureTableInDatabase()
     }
@@ -203,8 +216,8 @@ export class KnowledgeBaseService implements OnModuleInit {
                 this.processExcelKnowledgeFile(knowledgeFile, buffer)
                 break
             case 'application/pdf':
-                //this.processPdfKnowledgeFile(knowledgeFile, buffer)
-                this.processPdfKnowledgeFile1(knowledgeFile, buffer)
+                this.processPdfKnowledgeFile(knowledgeFile, buffer)
+                // this.processPdfKnowledgeFile1(knowledgeFile, buffer)
                 break
         }
         return knowledgeFile
@@ -384,7 +397,7 @@ export class KnowledgeBaseService implements OnModuleInit {
 
     async getEmbedding(content: string, retry = 0) {
         try {
-            const response = await this.openai.createEmbedding({
+            const response = await this.embeddingApi.createEmbedding({
                 model: 'embedding',
                 input: content
             })
@@ -407,7 +420,7 @@ export class KnowledgeBaseService implements OnModuleInit {
     async getKeywords(text: string) {
         try {
             const res = await this.openai.createChatCompletion({
-                model: 'gpt35',
+                model: 'gpt-35-turbo',
                 messages: [
                     {
                         role: 'user',
@@ -425,16 +438,6 @@ export class KnowledgeBaseService implements OnModuleInit {
         }
     }
 
-    async searchEmbedding(name: string, embedding: number[]) {
-        return await KnowledgeEmbedding.findAll({
-            where: {
-                name
-            },
-            order: this.sequelize.literal(`embedding <-> '${formatEmbedding(embedding)}'`),
-            limit: 100
-        })
-    }
-
     cutContext(context: string[]) {
         if (!context || !context.length) return []
         let max = 4096 - 1024
@@ -457,7 +460,18 @@ export class KnowledgeBaseService implements OnModuleInit {
             order: this.sequelize.literal(`embedding <-> '${formatEmbedding(keywordEmbedding)}'`),
             limit: 100
         })
-        const context = await this.cutContext(relatedEmbeddings.map((item) => item.text))
+        const files = await this.knowledgeFileRepository.findBy({
+            fileHash: In([...new Set(relatedEmbeddings.map((item) => item.fileHash))])
+        })
+        const context = await this.cutContext(
+            relatedEmbeddings.map((item) => {
+                const file = files.find((i) => i.fileHash === item.fileHash)
+                if (file) {
+                    return `***\n[${file.fileName}](${file.fileUrl}):\n\n${item.text}`
+                }
+                return item.text
+            })
+        )
         return context
     }
 }

+ 10 - 2
src/org/org.controller.ts

@@ -43,13 +43,21 @@ export class OrgController {
     async ask(
         @Req() req,
         @Param('id') id: string,
-        @Body() body: { question: string; knowledgeId?: number; fileId?: number }
+        @Body()
+        body: {
+            prompt: string
+            options: {
+                parentMessageId?: string
+            }
+            knowledgeId?: number
+            fileId?: number
+        }
     ) {
         const orgId = Number(id)
         // if (req.user.orgId != orgId) {
         //     throw new ForbiddenException('You are not a member of this organization')
         // }
-        return await this.orgService.ask(body.question, orgId, body.knowledgeId, body.fileId)
+        return await this.orgService.ask(body.prompt, orgId, body.options?.parentMessageId, body.knowledgeId, body.fileId)
     }
 
     @Post('/:id/streamAsk')

+ 74 - 10
src/org/org.service.ts

@@ -14,10 +14,12 @@ import * as dedent from 'dedent'
 import { chatReplyProcess } from '../chat/chatgpt'
 import { fetchSSE } from '../chatapi/fetch-sse'
 import { v4 as uuidv4 } from 'uuid'
+import Redis from 'ioredis'
 
 @Injectable()
 export class OrgService {
     private readonly openai: OpenAIApi
+    private readonly redis: Redis
     constructor(
         @InjectRepository(Org)
         private readonly orgRepository: Repository<Org>,
@@ -28,14 +30,16 @@ export class OrgService {
     ) {
         this.openai = new OpenAIApi(
             new Configuration({
-                apiKey: 'beb32e4625a94b65ba8bc0ba1688c4d2',
+                apiKey: process.env.AZURE_OPENAI_KEY,
                 // add azure info into configuration
                 azure: {
-                    apiKey: 'beb32e4625a94b65ba8bc0ba1688c4d2',
-                    endpoint: 'https://zouma.openai.azure.com'
+                    apiKey: process.env.AZURE_OPENAI_KEY,
+                    endpoint: process.env.AZURE_OPENAI_ENDPOINT,
+                    deploymentName: 'gpt-35-turbo-16k'
                 }
             })
         )
+        this.redis = new Redis(process.env.REDIS_URI)
     }
 
     async findByUrl(url: string): Promise<Org> {
@@ -105,7 +109,13 @@ export class OrgService {
         await this.userService.updateUser(orgUser.userId, { orgId: null })
     }
 
-    async buildMessages(question: string, orgId: number, knowledgeId?: number, fileId?: number): Promise<any[]> {
+    async buildMessages(
+        question: string,
+        orgId: number,
+        parentMessageId?: string,
+        knowledgeId?: number,
+        fileId?: number
+    ): Promise<any[]> {
         const org = await this.findById(orgId)
         const context = await this.knowledgeService.askKnowledge(question, orgId, knowledgeId, fileId)
         if (org.questionTemplate) {
@@ -118,6 +128,13 @@ export class OrgService {
                 content: org.systemPrompt.replace('${context}', context.join('\n')).replace('${question}', question)
             })
         }
+        if (parentMessageId) {
+            let history = (await this.getChatHistory(parentMessageId)).map((i) => ({
+                role: i.role,
+                content: i.text
+            }))
+            messages.push(...history)
+        }
         if (!/\$\{context\}/.test(org.systemPrompt)) {
             messages.push({
                 role: 'user',
@@ -148,13 +165,52 @@ export class OrgService {
         return messages
     }
 
-    async ask(question: string, orgId: number, knowledgeId?: number, fileId?: number) {
+    async getChatHistory(parentMessageId, history = []) {
+        let parent: any = await this.redis.get(parentMessageId)
+        if (!parent) {
+            return history
+        }
+        parent = JSON.parse(parent)
+        history.unshift(parent)
+        if (parent.parentMessageId) {
+            return await this.getChatHistory(parent.parentMessageId, history)
+        }
+        return history
+    }
+
+    async ask(question: string, orgId: number, parentMessageId?: string, knowledgeId?: number, fileId?: number) {
         try {
             const response = await this.openai.createChatCompletion({
-                model: 'gpt35',
-                messages: await this.buildMessages(question, orgId, knowledgeId, fileId)
+                model: 'gpt-35-turbo-16k',
+                messages: await this.buildMessages(question, orgId, parentMessageId, knowledgeId, fileId),
+                temperature: 0.1,
             })
-            return { answer: response.data.choices[0].message.content }
+            const id = uuidv4()
+            await this.redis.set(
+                id,
+                JSON.stringify({
+                    role: 'user',
+                    id,
+                    parentMessageId,
+                    text: question
+                })
+            )
+            await this.redis.set(
+                response.data.id,
+                JSON.stringify({
+                    role: 'assistant',
+                    id: response.data.id,
+                    parentMessageId: id,
+                    text: response.data.choices[0].message.content
+                })
+            )
+            return {
+                role: 'assistant',
+                id: response.data.id,
+                text: response.data.choices[0].message.content,
+                detail: response.data
+            }
+            // return { answer: response.data.choices[0].message.content }
         } catch (error) {
             Logger.error(error.message)
             if (error.response) {
@@ -164,7 +220,15 @@ export class OrgService {
         }
     }
 
-    async streamAsk(req, res, question: string, orgId: number, knowledgeId?: number, fileId?: number) {
+    async streamAsk(
+        req,
+        res,
+        question: string,
+        orgId: number,
+        parentMessageId?: string,
+        knowledgeId?: number,
+        fileId?: number
+    ) {
         res.setHeader('Content-type', 'application/octet-stream')
         try {
             try {
@@ -177,7 +241,7 @@ export class OrgService {
                 }
                 await fetchSSE(url, {
                     body: JSON.stringify({
-                        messages: await this.buildMessages(question, orgId, knowledgeId, fileId),
+                        messages: await this.buildMessages(question, orgId, parentMessageId, knowledgeId, fileId),
                         stream: true
                     }),
                     headers: {

+ 0 - 6
src/redis.ts

@@ -1,6 +0,0 @@
-const Redis = require("ioredis");
-
-// Create a Redis instance.
-// By default, it will connect to localhost:6379.
-// We are going to cover how to specify connection options soon.
-const redis = new Redis();