wangqifan 3 years ago
parent
commit
1bdf49da71

+ 6 - 0
src/api-users/api-user.service.ts

@@ -48,6 +48,12 @@ export class ApiUserService {
         return user
     }
 
+    public async findByUserId(id: number): Promise<ApiUser> {
+        const user = await this.apiUserRepository.findOneBy({
+            userId: +id
+        })
+        return user
+    }
 
     public async findByCode(code: string): Promise<ApiUser> {
         const apiUser = await this.apiUserRepository.findOneBy({

+ 9 - 9
src/chat-pdf/chat-pdf.controller.ts

@@ -1,4 +1,4 @@
-import { Body, Controller, Get, Post, Render, UploadedFile, UseInterceptors } from '@nestjs/common'
+import { Body, Controller, Get, Post, Render, Req, UploadedFile, UseInterceptors } from '@nestjs/common'
 import { Public } from '../auth/public.decorator'
 import { FileInterceptor } from '@nestjs/platform-express'
 import { ChatPdfService } from './chat-pdf.service'
@@ -10,16 +10,16 @@ export class ChatPdfController {
     @Public()
     @Post('upload')
     @UseInterceptors(FileInterceptor('file'))
-    public async uploadFile(@UploadedFile() file: Express.Multer.File) {
-        return await this.chatPdfService.upload(file)
+    public async uploadFile(@UploadedFile() file: Express.Multer.File, @Req() req) {
+        return await this.chatPdfService.upload(file,req)
     }
 
-    @Public()
-    @Post('apiUpload')
-    @UseInterceptors(FileInterceptor('file'))
-    public async apiUpload(@UploadedFile() file: Express.Multer.File,userId: number) {
-        return await this.chatPdfService.upload(file)
-    }
+    // @Public()
+    // @Post('apiUpload')
+    // @UseInterceptors(FileInterceptor('file'))
+    // public async apiUpload(@UploadedFile() file: Express.Multer.File, userId: number) {
+    //     return await this.chatPdfService.upload(file)
+    // }
 
     @Public()
     @Post('ask')

+ 28 - 23
src/chat-pdf/chat-pdf.service.ts

@@ -1,16 +1,16 @@
-import {BadRequestException, Injectable, InternalServerErrorException, Logger} from '@nestjs/common'
+import { BadRequestException, Injectable, InternalServerErrorException, Logger } from '@nestjs/common'
 import * as PdfParse from '@cyber2024/pdf-parse-fixed'
-import {createHash} from 'crypto'
-import {get_encoding, Tiktoken} from '@dqbd/tiktoken'
-import {Configuration, OpenAIApi} from 'azure-openai'
+import { createHash } from 'crypto'
+import { get_encoding, Tiktoken } from '@dqbd/tiktoken'
+import { Configuration, OpenAIApi } from 'azure-openai'
 import * as queue from 'fastq'
-import {setTimeout} from 'timers/promises'
+import { setTimeout } from 'timers/promises'
 import * as dedent from 'dedent'
-import {DataTypes, Sequelize} from 'sequelize'
-import {ConfigService} from '@nestjs/config'
-import {ChatEmbedding} from './entities/chat-embedding.entity'
-import {VECTOR} from './pgvector'
-import {ApiUserService} from '../api-users/api-user.service'
+import { DataTypes, Sequelize } from 'sequelize'
+import { ConfigService } from '@nestjs/config'
+import { ChatEmbedding } from './entities/chat-embedding.entity'
+import { VECTOR } from './pgvector'
+import { ApiUserService } from '../api-users/api-user.service'
 
 function formatEmbedding(embedding: number[]) {
     return `[${embedding.join(', ')}]`
@@ -73,18 +73,23 @@ export class ChatPdfService {
         this.sequelize.sync()
     }
 
-    public async upload(file: Express.Multer.File) {
+    public async upload(file: Express.Multer.File, req) {
         const { originalname, buffer, mimetype } = file
-        const md5 = this.calculateMD5(buffer)
-        const res = await ChatEmbedding.findAll({
-            where: {
-                name: md5
-            }
-        })
-        if (res.length) {
-            return {
-                name: md5
+        let md5 = this.calculateMD5(buffer)
+        const user = await this.apiUserService.findByUserId(req.user.id)
+        if (!user) {
+            const res = await ChatEmbedding.findAll({
+                where: {
+                    name: md5
+                }
+            })
+            if (res.length) {
+                return {
+                    name: md5
+                }
             }
+        } else {
+            md5 = user.code
         }
         const pdf = await PdfParse(buffer)
         const contents = []
@@ -231,7 +236,7 @@ export class ChatPdfService {
         根据相关性从高到低排序。 '
         这是用户提出的问题:
         ${q}
-        你只能根据用户的问题,以下面的内容为准进行回答:
+        你只能根据用户的问题,以下面的内容和之前的聊天记录为准进行回答:
         \`\`\`
         ${context.join('\n')}
         \`\`\`
@@ -279,8 +284,8 @@ export class ChatPdfService {
 
     async customerAsk(q: string, name: string) {
         let apiUser = await this.apiUserService.findByCode(name)
-        if(!apiUser) {
-            throw new BadRequestException("not a enabled api user")
+        if (!apiUser) {
+            throw new BadRequestException('not a enabled api user')
         }
         // const defSysMsg = (await this.sysConfigService.findByName('customer_system_message'))?.value
         const keywords = await this.getKeywords(q)