xiongzhu 2 лет назад
Родитель
Сommit
b304263a4e
2 измененных файлов с 109 добавлено и 7 удалено
  1. 17 7
      src/org/org.controller.ts
  2. 92 0
      src/org/org.service.ts

+ 17 - 7
src/org/org.controller.ts

@@ -1,4 +1,4 @@
-import { Body, Controller, ForbiddenException, Get, Param, Post, Put, Req } from '@nestjs/common'
+import { Body, Controller, ForbiddenException, Get, Param, Post, Put, Req, Res, Sse } from '@nestjs/common'
 import { Org } from './entities/org.entity'
 import { PageRequest } from '../common/dto/page-request'
 import { OrgService } from './org.service'
@@ -10,12 +10,6 @@ import { Public } from 'src/auth/public.decorator'
 export class OrgController {
     constructor(private readonly orgService: OrgService) {}
 
-    @Get('/:id')
-    @Public()
-    async get(@Param('id') id: string) {
-        return await this.orgService.findById(Number(id))
-    }
-
     @Get('/my')
     async my(@Req() req) {
         if (!req.user.orgId) {
@@ -24,6 +18,12 @@ export class OrgController {
         return await this.orgService.findById(req.user.orgId)
     }
 
+    @Get('/:id')
+    @Public()
+    async get(@Param('id') id: string) {
+        return await this.orgService.findById(Number(id))
+    }
+
     @Put('/my')
     async update(@Req() req, @Body() org: Org) {
         if (!req.user.orgId) {
@@ -45,4 +45,14 @@ export class OrgController {
         // }
         return await this.orgService.ask(body.question, orgId, body.knowledgeId, body.fileId)
     }
+
+    @Post('/:id/streamAsk')
+    @Sse()
+    async streamAsk(@Req() req, @Res() res, @Param('id') id: string, @Body() body: { prompt: string }) {
+        const orgId = Number(id)
+        // if (req.user.orgId != orgId) {
+        //     throw new ForbiddenException('You are not a member of this organization')
+        // }
+        await this.orgService.streamAsk(req, res, body.prompt, orgId)
+    }
 }

+ 92 - 0
src/org/org.service.ts

@@ -11,6 +11,9 @@ import { Role } from '../model/role.enum'
 import { KnowledgeBaseService } from 'src/knowledge-base/knowledge-base.service'
 import { OpenAIApi, Configuration } from 'azure-openai'
 import * as dedent from 'dedent'
+import { chatReplyProcess } from '../chat/chatgpt'
+import { fetchSSE } from '../chatapi/fetch-sse'
+import { v4 as uuidv4 } from 'uuid'
 
 @Injectable()
 export class OrgService {
@@ -122,4 +125,93 @@ export class OrgService {
             throw new InternalServerErrorException(error.message)
         }
     }
+
+    async streamAsk(req, res, question: string, orgId: number, knowledgeId?: number, fileId?: number) {
+        res.setHeader('Content-type', 'application/octet-stream')
+        try {
+            const org = await this.findById(orgId)
+            const context = await this.knowledgeService.askKnowledge(question, orgId, knowledgeId, fileId)
+            let content
+            if (org.contextTemplate) {
+                content = org.contextTemplate.replace('${context}', context.join('\n'))
+            } else {
+                content = dedent`You are a helpful AI article assistant. 
+            The following are the relevant article content fragments found from the article. 
+            The relevance is sorted from high to low. 
+            You can only answer according to the following content:
+            \`\`\`
+            ${context.join('\n')}
+            \`\`\`
+            You need to carefully consider your answer to ensure that it is based on the context. 
+            If the context does not mention the content or it is uncertain whether it is correct, 
+            please answer "Current context cannot provide effective information.
+            You must use Chinese to respond.`
+            }
+
+            try {
+                const url = `${process.env.AZURE_OPENAI_ENDPOINT}/openai/deployments/${process.env.AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version=${process.env.AZURE_OPENAI_VERSION}`
+                let firstChunk = true
+                const result: any = {
+                    role: 'assistant',
+                    id: uuidv4(),
+                    text: ''
+                }
+                await fetchSSE(url, {
+                    body: JSON.stringify({
+                        messages: [
+                            { role: 'system', content: org.systemPrompt },
+                            { role: 'user', content },
+                            { role: 'user', content: question }
+                        ],
+                        stream: true
+                    }),
+                    headers: {
+                        'Content-Type': 'application/json',
+                        'api-key': `${process.env.AZURE_OPENAI_KEY}`
+                    },
+                    method: 'POST',
+                    onMessage: (msg: string) => {
+                        if (msg === '[DONE]') return
+                        const response = JSON.parse(msg)
+                      
+                        result.id = response.id
+                        const delta = response.choices[0].delta
+                        result.delta = delta.content
+                        if (delta?.content) result.text += delta.content
+
+                        if (delta.role) {
+                            result.role = delta.role
+                        }
+
+                        result.detail = response
+                        res.write(firstChunk ? JSON.stringify(result) : `\n${JSON.stringify(result)}`)
+                        firstChunk = false
+                    },
+                    onError: (err) => {
+                        res.write(JSON.stringify(err))
+                    }
+                })
+
+                // const response = await this.openai.createChatCompletion({
+                //     model: 'gpt35',
+                //     messages: [
+                //         { role: 'system', content: org.systemPrompt },
+                //         { role: 'user', content },
+                //         { role: 'user', content: question }
+                //     ]
+                // })
+                // return { answer: response.data.choices[0].message.content }
+            } catch (error) {
+                Logger.error(error.message)
+                if (error.response) {
+                    Logger.error(error.response.data)
+                }
+                throw new InternalServerErrorException(error.message)
+            }
+        } catch (e) {
+            res.write(JSON.stringify(e))
+        } finally {
+            res.end()
+        }
+    }
 }