xiongzhu 2 anni fa
parent
commit
9e7c67efd6

+ 5 - 3
src/danmu/danmu.module.ts

@@ -6,14 +6,16 @@ import { Danmu } from './entities/danmu.enitity'
 import { Game } from '../game/entities/game.entity'
 import { AccessToken } from './entities/access-token.entity'
 import { DanmuUser } from './entities/danmu-user.entity'
-import { RoomModule } from 'src/room/room.module'
-import { FileModule } from 'src/file/file.module'
+import { RoomModule } from '../room/room.module'
+import { FileModule } from '../file/file.module'
+import { SysConfigModule } from '../sys-config/sys-config.module'
 
 @Module({
     imports: [
         TypeOrmModule.forFeature([Danmu, Game, AccessToken, DanmuUser]),
         forwardRef(() => RoomModule),
-        FileModule
+        FileModule,
+        SysConfigModule
     ],
     controllers: [DanmuController],
     providers: [DanmuService],

+ 12 - 3
src/danmu/danmu.service.ts

@@ -13,6 +13,7 @@ import { mkdtempSync, rmSync } from 'fs'
 import { FileService } from 'src/file/file.service'
 import path = require('path')
 import { HttpsProxyAgent } from 'https-proxy-agent'
+import { SysConfigService } from '../sys-config/sys-config.service'
 
 @Injectable()
 export class DanmuService implements OnModuleInit {
@@ -28,7 +29,8 @@ export class DanmuService implements OnModuleInit {
         private readonly danmuUserRepository: Repository<DanmuUser>,
         @Inject(forwardRef(() => RoomService))
         private readonly roomService: RoomService,
-        private readonly fileService: FileService
+        private readonly fileService: FileService,
+        private readonly sysConfigService: SysConfigService
     ) {}
 
     async onModuleInit() {
@@ -176,18 +178,25 @@ export class DanmuService implements OnModuleInit {
     }
 
     async pickCharactor(gameId: number, startTime: Date) {
+        const { value: search } = await this.sysConfigService.findByName('join_game_danmu')
         const res = await this.danmuRepository.find({
             where: {
                 gameId,
                 createdAt: MoreThan(startTime),
-                content: Like('我是%')
+                content: Like(search)
             }
         })
         if (res.length === 0) {
             return null
         }
         const index = Math.floor(Math.random() * res.length)
-        return res[index]
+        const danmuUser = await this.findDanmuUser(res[index].danmuUserId, true)
+        return {
+            name: danmuUser.name,
+            avatar: danmuUser.avatar,
+            danmuUserId: res[index].danmuUserId,
+            content: res[index].content.replace(search, '')
+        }
     }
 
     async findDanmuUser(danmuUserId: number, updateAvatar: boolean = false) {

+ 5 - 1
src/game/dto/update-game.dto.ts

@@ -1,4 +1,4 @@
-import { IsArray, IsNumber, IsOptional, IsString } from 'class-validator'
+import { IsArray, IsBoolean, IsNumber, IsOptional, IsString } from 'class-validator'
 import { Charactor } from '../models/charactor.model'
 import { Type } from 'class-transformer'
 
@@ -19,4 +19,8 @@ export class UpdateGameDto {
     @Type(() => Charactor)
     @IsOptional()
     charactors?: Charactor[]
+
+    @IsOptional()
+    @IsBoolean()
+    overridePrompt?: boolean
 }

+ 4 - 0
src/game/entities/game.entity.ts

@@ -20,6 +20,9 @@ export class Game {
     @Column()
     name: string
 
+    @Column({ default: '默认' })
+    type: string
+
     @Column()
     roomId: number
 
@@ -40,4 +43,5 @@ export class Game {
 
     @Column({ default: false })
     autoReset: boolean
+
 }

+ 2 - 7
src/game/game.controller.ts

@@ -42,8 +42,8 @@ export class GameController {
     }
 
     @Post('/genCharactor')
-    public async genCharactor(@Body() dto: { background: string; num: number }) {
-        return await this.gameService.genCharactor(dto.background, dto.num)
+    public async genCharactor(@Body() dto: { type: string; background: string; num: number }) {
+        return await this.gameService.genCharactor(dto.type, dto.background, dto.num)
     }
 
     @Post('/:id/init')
@@ -70,11 +70,6 @@ export class GameController {
         return await this.gameService.revert(Number(id))
     }
 
-    @Post('/:id/addCharactor')
-    public async addCharactor(@Param('id') id: string, @Body() body: { base: string }) {
-        return await this.gameService.createNewCharactor(Number(id), body.base)
-    }
-
     @Post('/:id/startRun')
     public async startRun(@Param('id') id: string) {
         return await this.gameService.startRun(Number(id))

+ 30 - 34
src/game/game.service.ts

@@ -222,7 +222,7 @@ export class GameService implements OnModuleInit {
         this.eventsGateway.emitEvent(id, event)
     }
 
-    async genCharactor(background: string, num: number) {
+    async genCharactor(type: string, background: string, num: number) {
         const parser = StructuredOutputParser.fromZodSchema(
             z.array(
                 z.object({
@@ -237,7 +237,7 @@ export class GameService implements OnModuleInit {
         )
         const formatInstructions = parser.getFormatInstructions()
         const prompt = new PromptTemplate({
-            template: await this.promptService.getPromptByName(PromptName.GenCharacter),
+            template: await this.promptService.getPromptByName(PromptName.GenCharacter, type),
             inputVariables: ['num', 'background'],
             partialVariables: { format_instructions: formatInstructions }
         })
@@ -350,7 +350,7 @@ export class GameService implements OnModuleInit {
             gameState.plot = initGame.plot
         } else {
             const prompt = new PromptTemplate({
-                template: await this.promptService.getPromptByName(PromptName.FirstPlot),
+                template: await this.promptService.getPromptByName(PromptName.FirstPlot, game.type),
                 inputVariables: ['background', 'charactors', 'datetime']
             })
             const input = await prompt.format({
@@ -414,7 +414,7 @@ export class GameService implements OnModuleInit {
 
         let summary = this.formatSummary(history)
         let prompt = new PromptTemplate({
-            template: await this.promptService.getPromptByName(PromptName.NextPlot),
+            template: await this.promptService.getPromptByName(PromptName.NextPlot, game.type),
             inputVariables: ['background', 'charactors', 'datetime', 'summary', 'choice', 'death', 'newCharactor']
         })
         const formatedCharactors = this.formatCharactors(charactors)
@@ -487,7 +487,10 @@ export class GameService implements OnModuleInit {
             const delay = voteDelay > 0 ? setTimeout(voteDelay * 1000) : Promise.resolve()
 
             options = (
-                await Promise.all([this.createOptions(charactors, `${summary}\n\n${formatedDatatime}\n${plot}`), delay])
+                await Promise.all([
+                    this.createOptions(game.type, charactors, `${summary}\n\n${formatedDatatime}\n${plot}`),
+                    delay
+                ])
             )[0]
             this.send(`${id}`, {
                 type: 'options',
@@ -498,7 +501,7 @@ export class GameService implements OnModuleInit {
         }
 
         const [newSummary, nextChoice] = await Promise.all([
-            this.refineSummary(history, this.formatDatetime(date, time) + '\n' + plot),
+            this.refineSummary(game.type, history, this.formatDatetime(date, time) + '\n' + plot),
             willEnd
                 ? Promise.resolve(null)
                 : this.makeVotes(id, options, new Date(), {
@@ -569,7 +572,7 @@ export class GameService implements OnModuleInit {
 
         let summary = this.formatSummary(history)
         let prompt = new PromptTemplate({
-            template: await this.promptService.getPromptByName(PromptName.Ending),
+            template: await this.promptService.getPromptByName(PromptName.Ending, game.type),
             inputVariables: ['background', 'summary']
         })
         const formatedDatatime = this.formatDatetime(date, time)
@@ -594,7 +597,7 @@ export class GameService implements OnModuleInit {
                 plot,
                 options: [],
                 charactors: [...lastState.charactors],
-                summary: await this.refineSummary(history, this.formatDatetime(date, time) + '\n' + plot),
+                summary: await this.refineSummary(game.type, history, this.formatDatetime(date, time) + '\n' + plot),
                 ending: true
             })
         )
@@ -751,7 +754,7 @@ export class GameService implements OnModuleInit {
         }
     }
 
-    async createOptions(charactors: Charactor[], summary: string) {
+    async createOptions(type: string, charactors: Charactor[], summary: string) {
         for (let i = 0; i < 5; i++) {
             try {
                 if (i > 0) await setTimeout(1000)
@@ -770,7 +773,7 @@ export class GameService implements OnModuleInit {
                 )
                 const formatInstructions = parser.getFormatInstructions()
                 const prompt = new PromptTemplate({
-                    template: await this.promptService.getPromptByName(PromptName.GenChoice),
+                    template: await this.promptService.getPromptByName(PromptName.GenChoice, type),
                     inputVariables: ['charactors', 'summary'],
                     partialVariables: { format_instructions: formatInstructions }
                 })
@@ -794,19 +797,16 @@ export class GameService implements OnModuleInit {
 
     async createCharactorFromDanmu(id: number, date: Date) {
         this.logger.log(`从弹幕中创建角色`)
-        const danmu = await this.danmuService.pickCharactor(id, date)
-        if (danmu) {
+        const game = await this.findById(id)
+        const c = await this.danmuService.pickCharactor(id, date)
+        if (c) {
             try {
-                const addCharactor = danmu.content.replace(/^我是/, '')
-                const danmuUser = await this.danmuService.findDanmuUser(danmu.danmuUserId)
-                if (danmuUser) {
-                    const charactor = await this.createNewCharactor(id, addCharactor)
-                    this.logger.log(`已创建角色: ${JSON.stringify(charactor)}`)
-                    charactor.name = danmuUser.name
-                    charactor.avatar = danmuUser.avatar
-                    charactor.danmuUserId = danmu.danmuUserId
-                    return charactor
-                }
+                const charactor = await this.createNewCharactor(game.type, c.content, game.background)
+                this.logger.log(`已创建角色: ${JSON.stringify(charactor)}`)
+                charactor.name = c.name
+                charactor.avatar = c.avatar
+                charactor.danmuUserId = c.danmuUserId
+                return charactor
             } catch (error) {
                 this.logger.error(error)
             }
@@ -832,12 +832,12 @@ export class GameService implements OnModuleInit {
         }
     }
 
-    async refineSummary(history: GameState[], newPlot: string) {
+    async refineSummary(type: string, history: GameState[], newPlot: string) {
         let i = history.map((i) => !!i.summary).lastIndexOf(true)
         let input
         if (i < 0) {
             const prompt = new PromptTemplate({
-                template: await this.promptService.getPromptByName(PromptName.Summarize),
+                template: await this.promptService.getPromptByName(PromptName.Summarize, type),
                 inputVariables: ['text']
             })
             input = await prompt.format({
@@ -848,7 +848,7 @@ export class GameService implements OnModuleInit {
             })
         } else {
             const prompt = new PromptTemplate({
-                template: await this.promptService.getPromptByName(PromptName.Refine),
+                template: await this.promptService.getPromptByName(PromptName.Refine, type),
                 inputVariables: ['existing_answer', 'text']
             })
             input = await prompt.format({
@@ -865,14 +865,10 @@ export class GameService implements OnModuleInit {
         return response.content
     }
 
-    async createNewCharactor(gameId: number, base: string) {
+    async createNewCharactor(type: string, base: string, background: string) {
         for (let i = 0; i < 5; i++) {
             try {
                 if (i > 0) await setTimeout(1000)
-                const game = await this.gameRepository.findOneBy({ id: gameId })
-                if (!game) {
-                    throw new NotFoundException(`game #${gameId} not found`)
-                }
                 const parser = StructuredOutputParser.fromZodSchema(
                     z.object({
                         name: z.string().describe('角色名称'),
@@ -885,13 +881,13 @@ export class GameService implements OnModuleInit {
                 )
                 const formatInstructions = parser.getFormatInstructions()
                 const prompt = new PromptTemplate({
-                    template: await this.promptService.getPromptByName(PromptName.NewCharacter),
+                    template: await this.promptService.getPromptByName(PromptName.NewCharacter, type),
                     inputVariables: ['base', 'background'],
                     partialVariables: { format_instructions: formatInstructions }
                 })
                 const input = await prompt.format({
                     base,
-                    background: game.background
+                    background
                 })
                 const response = await this.callLLM([
                     new SystemMessage('你是一个富有想象力的写作助手,你的任务是帮我想象一个小说里的角色。'),
@@ -906,7 +902,7 @@ export class GameService implements OnModuleInit {
         throw new InternalServerErrorException('无法生成新角色')
     }
 
-    async modifyHp(charactors: Charactor[], history: GameState[], newPlot: string) {
+    async modifyHp(type: string, charactors: Charactor[], history: GameState[], newPlot: string) {
         for (let i = 0; i < 5; i++) {
             try {
                 if (i > 0) await setTimeout(1000)
@@ -920,7 +916,7 @@ export class GameService implements OnModuleInit {
                 )
                 const formatInstructions = parser.getFormatInstructions()
                 const prompt = new PromptTemplate({
-                    template: await this.promptService.getPromptByName(PromptName.ModifyHp),
+                    template: await this.promptService.getPromptByName(PromptName.ModifyHp, type),
                     inputVariables: ['charactors', 'summary', 'text'],
                     partialVariables: { format_instructions: formatInstructions }
                 })

+ 9 - 2
src/prompt/entities/prompt.entity.ts

@@ -1,4 +1,4 @@
-import { Column, Entity, PrimaryColumn } from 'typeorm'
+import { Column, Entity, PrimaryColumn, PrimaryGeneratedColumn, Unique } from 'typeorm'
 
 export const PromptName = {
     GenCharacter: 'GenCharacter',
@@ -13,8 +13,12 @@ export const PromptName = {
 }
 
 @Entity()
+@Unique(['name', 'type'])
 export class Prompt {
-    @PrimaryColumn({ length: 90 })
+    @PrimaryGeneratedColumn()
+    id: number
+
+    @Column({ length: 90 })
     name: string
 
     @Column({ type: 'longtext' })
@@ -25,4 +29,7 @@ export class Prompt {
 
     @Column({ nullable: true })
     description: string
+
+    @Column({ default: '默认', length: 80 })
+    type: string
 }

+ 19 - 4
src/prompt/prompt.controller.ts

@@ -1,17 +1,32 @@
-import { Body, Controller, Get, Put } from '@nestjs/common'
+import { Body, Controller, Delete, Get, Param, Put } from '@nestjs/common'
 import { PromptService } from './prompt.service'
 
 @Controller('prompt')
 export class PromptController {
     constructor(private readonly promptService: PromptService) {}
 
-    @Get()
-    async getAllPrompts() {
-        return await this.promptService.getAllPrompts()
+    @Get('/types')
+    async findTypes() {
+        return await this.promptService.findTypes()
+    }
+
+    @Get('/:type')
+    async getAllPrompts(@Param('type') type: string) {
+        return await this.promptService.getAllPrompts(type)
     }
 
     @Put()
     async updatePrompt(@Body() prompt) {
         await this.promptService.updatePrompt(prompt)
     }
+
+    @Put('/types')
+    async createType(@Body() body: { type: string }) {
+        return await this.promptService.createType(body.type)
+    }
+
+    @Delete('/:type')
+    async deleteType(@Param('type') type: string) {
+        return await this.promptService.deleteType(type)
+    }
 }

+ 34 - 6
src/prompt/prompt.service.ts

@@ -1,4 +1,4 @@
-import { Injectable } from '@nestjs/common'
+import { Injectable, InternalServerErrorException } from '@nestjs/common'
 import { InjectRepository } from '@nestjs/typeorm'
 import { Prompt } from './entities/prompt.entity'
 import { Repository } from 'typeorm'
@@ -10,15 +10,43 @@ export class PromptService {
         private readonly promptRepository: Repository<Prompt>
     ) {}
 
-    async getPromptByName(name: string) {
-        return (await this.promptRepository.findOne({ where: { name } })).template
+    async getPromptByName(name: string, type: string = '默认') {
+        return (await this.promptRepository.findOne({ where: { name, type } })).template
     }
 
-    async getAllPrompts() {
-        return await this.promptRepository.find()
+    async getAllPrompts(type: string = '默认') {
+        return await this.promptRepository.findBy({ type })
     }
 
     async updatePrompt(prompt: Prompt) {
-        return await this.promptRepository.update(prompt.name, prompt)
+        return await this.promptRepository.update(prompt.id, prompt)
+    }
+
+    async findTypes() {
+        return ['默认'].concat(
+            (await this.promptRepository.query('select distinct type from prompt where type != "默认"')).map(
+                (i) => i.type
+            )
+        )
+    }
+
+    async createType(type: string) {
+        const types = await this.findTypes()
+        if (types.find((t) => t === type)) throw new InternalServerErrorException('已存在该类型')
+        const prompts = await this.promptRepository.findBy({ type: '默认' })
+        await this.promptRepository.save(
+            prompts.map((prompt) => ({
+                name: prompt.name,
+                template: prompt.template,
+                defaultTemplate: prompt.template,
+                description: prompt.description,
+                type
+            }))
+        )
+    }
+
+    async deleteType(type: string) {
+        if (type === '默认') throw new InternalServerErrorException('默认类型不可删除')
+        await this.promptRepository.delete({ type })
     }
 }

+ 8 - 0
src/sys-config/sys-config.service.ts

@@ -30,6 +30,14 @@ export class SysConfigService implements OnModuleInit {
                 remark: '投票等待时间'
             })
         }
+        if (!(await this.sysConfigRepository.findOneBy({ name: 'join_game_danmu' }))) {
+            await this.sysConfigRepository.save({
+                name: 'join_game_danmu',
+                value: '我是%',
+                type: SysConfigType.String,
+                remark: '加入游戏弹幕规则'
+            })
+        }
     }
 
     async findAll(req: PageRequest<SysConfig>) {