xiongzhu 2 лет назад
Родитель
Сommit
4b5c9f5e65

+ 10 - 17
src/chat/chat.service.ts

@@ -66,28 +66,21 @@ export class ChatService {
     }
 
     public async chat1(req, res) {
-        const user = await this.chatPdfService.getUser(req.user.id)
         res.setHeader('Content-type', 'application/octet-stream')
         const defSysMsg = (await this.sysConfigService.findByName('system_message'))?.value
         const membership = await this.membershipService.getMembership(req.user.id)
-        const { prompt, options = {}, systemMessage, code, temperature, top_p } = req.body as RequestProps
-        if (!user.apiUserId) {
-            if (!membership) {
-                throw new ForbiddenException('请先成为会员')
-            }
-            if (membership.memberType == MemberType.Trial && membership.tokenLeft <= 0) {
-                throw new ForbiddenException('您的试用额度已用完,请升级会员')
-            }
-            if (membership.isExpired) {
-                throw new ForbiddenException('您的会员已过期,请即时续费')
-            }
+        const { prompt, options = {}, systemMessage, temperature, top_p } = req.body as RequestProps
+        if (!membership) {
+            throw new ForbiddenException('请先成为会员')
+        }
+        if (membership.memberType == MemberType.Trial && membership.tokenLeft <= 0) {
+            throw new ForbiddenException('您的试用额度已用完,请升级会员')
+        }
+        if (membership.isExpired) {
+            throw new ForbiddenException('您的会员已过期,请即时续费')
         }
         try {
-            let content = ''
             const promptTime = new Date()
-            if (!!user.apiUserId) {
-                content = await this.chatPdfService.getSystemConfig(req.user.id, prompt, code, code, systemMessage)
-            }
             let firstChunk = true
             const result = await chatReplyProcess({
                 message: prompt,
@@ -96,7 +89,7 @@ export class ChatService {
                     res.write(firstChunk ? JSON.stringify(chat) : `\n${JSON.stringify(chat)}`)
                     firstChunk = false
                 },
-                systemMessage: content || defSysMsg,
+                systemMessage: systemMessage || defSysMsg,
                 temperature,
                 top_p
             })

+ 4 - 0
src/knowledge-base/entities/knowledge-embedding.entity.ts

@@ -3,6 +3,8 @@ import { Model } from 'sequelize'
 export class KnowledgeEmbedding extends Model {
     id: number
 
+    orgId: number
+
     knowledgeId: number
 
     fileId: number
@@ -16,6 +18,7 @@ export class KnowledgeEmbedding extends Model {
     index: number
 
     constructor(model?: {
+        orgId: number
         knowledgeId: number
         fileId: number
         fileHash: string
@@ -25,6 +28,7 @@ export class KnowledgeEmbedding extends Model {
     }) {
         super()
         if (model) {
+            this.orgId = model.orgId
             this.knowledgeId = model.knowledgeId
             this.fileId = model.fileId
             this.fileHash = model.fileHash

+ 4 - 4
src/knowledge-base/enums/file-status.enum.ts

@@ -1,6 +1,6 @@
 export enum FileStatus {
-    PENDING = 'PENDING',
-    PROCESSING = 'PROCESSING',
-    DONE = 'DONE',
-    FAILED = 'FAILED'
+    PENDING = 'pending',
+    PROCESSING = 'processing',
+    DONE = 'done',
+    FAILED = 'failed'
 }

+ 5 - 0
src/knowledge-base/knowledge-base.controller.ts

@@ -42,4 +42,9 @@ export class KnowledgeBaseController {
     public async uploadFile(@UploadedFile() file: Express.Multer.File, @Param('id') id: string) {
         return await this.knowledgeBaseService.uploadKnowledgeFile(file, Number(id))
     }
+
+    @Delete('/file/:id')
+    async deleteKnowledgeFile(@Param('id') id: string) {
+        return await this.knowledgeBaseService.deleteKnowledgeFile(Number(id))
+    }
 }

+ 13 - 5
src/knowledge-base/knowledge-base.service.ts

@@ -69,6 +69,9 @@ export class KnowledgeBaseService {
                     autoIncrement: true,
                     type: DataTypes.INTEGER
                 },
+                orgId: {
+                    type: DataTypes.INTEGER 
+                },
                 knowledgeId: {
                     type: DataTypes.INTEGER
                 },
@@ -141,12 +144,13 @@ export class KnowledgeBaseService {
     public async uploadKnowledgeFile(file: Express.Multer.File, knowledgeId: number) {
         const knowledgeBase = await this.getKnowledgeBaseById(knowledgeId)
         const { originalname, buffer, mimetype, size } = file
+        const fileName = Buffer.from(originalname, 'latin1').toString('utf8')
         let fileHash = this.calculateMD5(buffer)
         let knowledgeFile = await this.knowledgeFileRepository.findOneBy({
             fileHash
         })
         if (knowledgeFile) {
-            throw new ConflictException(`File ${originalname} already exists`)
+            throw new ConflictException(`File ${fileName} already exists`)
         }
         const { url: fileUrl } = await this.fileService.uploadBuffer(
             buffer,
@@ -158,7 +162,7 @@ export class KnowledgeBaseService {
         knowledgeFile.knowledgeId = knowledgeId
         knowledgeFile.fileHash = fileHash
         knowledgeFile.fileType = mimetype
-        knowledgeFile.fileName = originalname
+        knowledgeFile.fileName = fileName
         knowledgeFile.size = size
         knowledgeFile.fileUrl = fileUrl
         await this.knowledgeFileRepository.save(knowledgeFile)
@@ -201,6 +205,7 @@ export class KnowledgeBaseService {
             for (const item of embeddings) {
                 try {
                     await KnowledgeEmbedding.create({
+                        orgId: knowledgeFile.orgId,
                         knowledgeId: knowledgeFile.knowledgeId,
                         fileId: knowledgeFile.id,
                         fileHash: knowledgeFile.fileHash,
@@ -314,14 +319,17 @@ export class KnowledgeBaseService {
         return context
     }
 
-    async searchKnowledge(question: string, orgId: number, knowledgeId?: number, fileId?: number) {
+    async askKnowledge(question: string, orgId: number, knowledgeId?: number, fileId?: number) {
         const keywords = await this.getKeywords(question)
         const { embedding: keywordEmbedding } = await this.getEmbedding(keywords)
-        const context = await KnowledgeEmbedding.findAll({
-            where: { orgId, knowledgeId, fileId },
+        const where = { orgId, knowledgeId, fileId }
+        Object.keys(where).forEach((key) => (where[key] === undefined ? delete where[key] : {}))
+        const relatedEmbeddings = await KnowledgeEmbedding.findAll({
+            where,
             order: this.sequelize.literal(`embedding <-> '${formatEmbedding(keywordEmbedding)}'`),
             limit: 100
         })
+        const context = await this.cutContext(relatedEmbeddings.map((item) => item.text))
         return context
     }
 }

+ 22 - 0
src/org/entities/org-user.entity.ts

@@ -0,0 +1,22 @@
+import { Column, CreateDateColumn, Entity, PrimaryGeneratedColumn } from 'typeorm'
+
+@Entity()
+export class OrgUser {
+    @PrimaryGeneratedColumn()
+    id: number
+
+    @CreateDateColumn()
+    createdAt: Date
+
+    @Column()
+    orgId: number
+
+    @Column()
+    userId: number
+
+    @Column()
+    phone: string
+
+    @Column()
+    name: string
+}

+ 1 - 1
src/org/entities/org.entity.ts

@@ -24,5 +24,5 @@ export class Org {
     systemPrompt: string
 
     @Column({ type: 'text' })
-    questionTemplate: string
+    contextTemplate: string
 }

+ 19 - 3
src/org/org.admin.controller.ts

@@ -1,18 +1,34 @@
-import { Body, Controller, Post, Put } from '@nestjs/common'
+import { Body, Controller, Delete, Param, Post, Put } from '@nestjs/common'
 import { Org } from './entities/org.entity'
 import { PageRequest } from '../common/dto/page-request'
 import { OrgService } from './org.service'
 import { HasAnyRoles, HasRoles } from '../auth/roles.decorator'
 import { Role } from '../model/role.enum'
 import { ApiBearerAuth, ApiTags } from '@nestjs/swagger'
+import { OrgUser } from './entities/org-user.entity'
 
-@ApiTags('users.admin')
+@ApiTags('org.admin')
 @Controller('/admin/org')
 @ApiBearerAuth()
-@HasAnyRoles(Role.Admin, Role.Api)
+@HasAnyRoles(Role.Admin, Role.Org)
 export class OrgAdminController {
     constructor(private readonly orgService: OrgService) {}
 
+    @Post('/users')
+    async orgUsers(@Body() page: PageRequest<OrgUser>) {
+        return await this.orgService.findAllUsers(page)
+    }
+
+    @Put('/users')
+    async addUsers(@Body() orgUsers: OrgUser[]) {
+        return await this.orgService.addUsers(orgUsers)
+    }
+
+    @Delete('/users/:id')
+    async removeUser(@Param('id') id) {
+        return await this.orgService.removeUser(id)
+    }
+
     @Post()
     async list(@Body() page: PageRequest<Org>) {
         return await this.orgService.findAll(page)

+ 22 - 2
src/org/org.controller.ts

@@ -1,16 +1,23 @@
-import { Body, Controller, ForbiddenException, Get, Post, Put, Req } from '@nestjs/common'
+import { Body, Controller, ForbiddenException, Get, Param, Post, Put, Req } from '@nestjs/common'
 import { Org } from './entities/org.entity'
 import { PageRequest } from '../common/dto/page-request'
 import { OrgService } from './org.service'
 import { HasRoles } from '../auth/roles.decorator'
 import { Role } from '../model/role.enum'
+import { Public } from 'src/auth/public.decorator'
 
 @Controller('org')
 export class OrgController {
     constructor(private readonly orgService: OrgService) {}
 
+    @Get('/:id')
+    @Public()
+    async get(@Param('id') id: string) {
+        return await this.orgService.findById(Number(id))
+    }
+
     @Get('/my')
-    async get(@Req() req) {
+    async my(@Req() req) {
         if (!req.user.orgId) {
             throw new ForbiddenException('You are not a member of any organization')
         }
@@ -25,4 +32,17 @@ export class OrgController {
         org.id = req.user.orgId
         return await this.orgService.update(org)
     }
+
+    @Post('/:id/ask')
+    async ask(
+        @Req() req,
+        @Param('id') id: string,
+        @Body() body: { question: string; knowledgeId?: number; fileId?: number }
+    ) {
+        const orgId = Number(id)
+        // if (req.user.orgId != orgId) {
+        //     throw new ForbiddenException('You are not a member of this organization')
+        // }
+        return await this.orgService.ask(body.question, orgId, body.knowledgeId, body.fileId)
+    }
 }

+ 3 - 1
src/org/org.module.ts

@@ -5,9 +5,11 @@ import { OrgController } from './org.controller'
 import { OrgService } from './org.service'
 import { OrgAdminController } from './org.admin.controller'
 import { UsersModule } from 'src/users/users.module'
+import { OrgUser } from './entities/org-user.entity'
+import { KnowledgeBaseModule } from 'src/knowledge-base/knowledge-base.module'
 
 @Module({
-    imports: [TypeOrmModule.forFeature([Org]), UsersModule],
+    imports: [TypeOrmModule.forFeature([Org, OrgUser]), UsersModule, KnowledgeBaseModule],
     controllers: [OrgController, OrgAdminController],
     providers: [OrgService]
 })

+ 94 - 3
src/org/org.service.ts

@@ -1,16 +1,39 @@
-import { Injectable } from '@nestjs/common'
+import { Injectable, InternalServerErrorException, Logger } from '@nestjs/common'
 import { InjectRepository } from '@nestjs/typeorm'
 import { Org } from './entities/org.entity'
 import { Repository } from 'typeorm'
 import { PageRequest } from '../common/dto/page-request'
 import { Pagination, paginate } from 'nestjs-typeorm-paginate'
+import { OrgUser } from './entities/org-user.entity'
+import { UsersService } from 'src/users/users.service'
+import * as randomstring from 'randomstring'
+import { Role } from '../model/role.enum'
+import { KnowledgeBaseService } from 'src/knowledge-base/knowledge-base.service'
+import { OpenAIApi, Configuration } from 'azure-openai'
+import * as dedent from 'dedent'
 
 @Injectable()
 export class OrgService {
+    private readonly openai: OpenAIApi
     constructor(
         @InjectRepository(Org)
-        private readonly orgRepository: Repository<Org>
-    ) {}
+        private readonly orgRepository: Repository<Org>,
+        @InjectRepository(OrgUser)
+        private readonly orgUserRepository: Repository<OrgUser>,
+        private readonly userService: UsersService,
+        private readonly knowledgeService: KnowledgeBaseService
+    ) {
+        this.openai = new OpenAIApi(
+            new Configuration({
+                apiKey: 'beb32e4625a94b65ba8bc0ba1688c4d2',
+                // add azure info into configuration
+                azure: {
+                    apiKey: 'beb32e4625a94b65ba8bc0ba1688c4d2',
+                    endpoint: 'https://zouma.openai.azure.com'
+                }
+            })
+        )
+    }
 
     async findById(orgId: number): Promise<Org> {
         return await this.orgRepository.findOneOrFail({
@@ -31,4 +54,72 @@ export class OrgService {
     async update(org: Org): Promise<Org> {
         return await this.orgRepository.save(org)
     }
+
+    async findAllUsers(req: PageRequest<OrgUser>): Promise<Pagination<OrgUser>> {
+        return await paginate<OrgUser>(this.orgUserRepository, req.page, req.search)
+    }
+
+    async addUsers(orgUsers: OrgUser[]): Promise<void> {
+        const users = await this.userService.findByPhone(orgUsers.map((x) => x.phone))
+        for (let x of orgUsers) {
+            let user = users.find((y) => y.phone === x.phone)
+            if (!user) {
+                const name = '0x' + randomstring.generate({ length: 8, charset: 'alphanumeric' })
+                user = await this.userService.create({
+                    name: name,
+                    username: name,
+                    phone: x.phone,
+                    roles: [Role.User]
+                })
+            }
+            await this.userService.updateUser(user.id, { orgId: x.orgId })
+            x.userId = user.id
+            await this.orgUserRepository.save(x)
+        }
+    }
+
+    async removeUser(orgUserId): Promise<void> {
+        const orgUser = await this.orgUserRepository.findOneOrFail({ where: { id: orgUserId } })
+        await this.orgUserRepository.delete(orgUserId)
+        await this.userService.updateUser(orgUser.userId, { orgId: null })
+    }
+
+    async ask(question: string, orgId: number, knowledgeId?: number, fileId?: number) {
+        const org = await this.findById(orgId)
+        const context = await this.knowledgeService.askKnowledge(question, orgId, knowledgeId, fileId)
+        let content
+        if (org.contextTemplate) {
+            content = org.contextTemplate.replace('${context}', context.join('\n'))
+        } else {
+            content = dedent`You are a helpful AI article assistant. 
+            The following are the relevant article content fragments found from the article. 
+            The relevance is sorted from high to low. 
+            You can only answer according to the following content:
+            \`\`\`
+            ${context.join('\n')}
+            \`\`\`
+            You need to carefully consider your answer to ensure that it is based on the context. 
+            If the context does not mention the content or it is uncertain whether it is correct, 
+            please answer "Current context cannot provide effective information.
+            You must use Chinese to respond.`
+        }
+
+        try {
+            const response = await this.openai.createChatCompletion({
+                model: 'gpt35',
+                messages: [
+                    { role: 'system', content: org.systemPrompt },
+                    { role: 'user', content },
+                    { role: 'user', content: question }
+                ]
+            })
+            return { answer: response.data.choices[0].message.content }
+        } catch (error) {
+            Logger.error(error.message)
+            if (error.response) {
+                Logger.error(error.response.data)
+            }
+            throw new InternalServerErrorException(error.message)
+        }
+    }
 }

+ 7 - 2
src/users/dto/user-create.dto.ts

@@ -26,10 +26,15 @@ export class UserCreateDto {
     @IsNotEmpty()
     @IsString()
     @MaxLength(60)
-    password: string
+    @IsOptional()
+    password?: string
 
-    apiUserId: number
+    @IsOptional()
+    apiUserId?: number
 
     @IsArray()
     readonly roles: Role[]
+
+    @IsOptional()
+    orgId?: number
 }

+ 2 - 2
src/users/dto/user-update.dto.ts

@@ -20,12 +20,12 @@ export class UserUpdateDto {
     @IsString()
     @MaxLength(30)
     @IsOptional()
-    readonly name: string
+    readonly name?: string
 
     @IsString()
     @MaxLength(40)
     @IsOptional()
-    readonly username: string
+    readonly username?: string
 
     @IsString()
     @IsOptional()

+ 7 - 1
src/users/users.service.ts

@@ -122,7 +122,7 @@ export class UsersService {
         return user
     }
 
-    public async create(userDto: UserCreateDto): Promise<IUsers> {
+    public async create(userDto: UserCreateDto) {
         try {
             if (userDto.password) {
                 userDto.password = await this.hashingService.hash(userDto.password)
@@ -271,4 +271,10 @@ export class UsersService {
             return item.id
         })
     }
+
+    public async findByPhone(phone: string[]) {
+        return await this.userRepository.findBy({
+            phone: In(phone)
+        })
+    }
 }