xiongzhu пре 2 година
родитељ
комит
2ddac47fb3
2 измењених фајлова са 30 додато и 8 уклоњено
  1. 28 6
      src/chat/chat.service.ts
  2. 2 2
      src/membership/membership.service.ts

+ 28 - 6
src/chat/chat.service.ts

@@ -13,6 +13,7 @@ import { MemberType } from '../membership/entities/membership.entity'
 import { get_encoding } from '@dqbd/tiktoken'
 import { fetchSSE } from 'src/chatapi/fetch-sse'
 import { HttpService } from '@nestjs/axios'
+import * as types from '../chatapi/types'
 
 @Injectable()
 export class ChatService {
@@ -134,6 +135,7 @@ export class ChatService {
                     'api-key': `${process.env.AZURE_OPENAI_KEY}`
                 }
             })
+            this.saveUsage(req.user.id, data.usage.total_tokens, false)
             return data
         } catch (e) {
             throw new InternalServerErrorException(e.response.data)
@@ -141,10 +143,10 @@ export class ChatService {
     }
 
     public streamChatProxy(req) {
-        Logger.log(req.body, 'CHAT PROXY')
         const url = `${process.env.AZURE_OPENAI_ENDPOINT}/openai/deployments/${process.env.AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version=${process.env.AZURE_OPENAI_VERSION}`
         req.body.stream = true
         return new Observable((subscriber) => {
+            let text = ''
             fetchSSE(url, {
                 body: JSON.stringify(req.body),
                 headers: {
@@ -152,11 +154,21 @@ export class ChatService {
                     'api-key': `${process.env.AZURE_OPENAI_KEY}`
                 },
                 method: 'POST',
-                onMessage: (msg) => {
-                    Logger.log(msg, 'CHAT PROXY')
+                onMessage: (msg: string) => {
                     subscriber.next(msg)
                     if ('[DONE]' === msg) {
-                        subscriber.complete()
+                        this.tiktokenAndSave(
+                            req.user.id,
+                            req.body.messages.map((message) => `${message.role}:\n${message.content}`).join('\n\n') +
+                                '\n\nassistant:\n' +
+                                text
+                        )
+                        return subscriber.complete()
+                    }
+                    const response: types.openai.CreateChatCompletionDeltaResponse = JSON.parse(msg)
+                    if (response.choices?.length) {
+                        const delta = response.choices[0].delta
+                        if (delta?.content) text += delta.content
                     }
                 },
                 onError: (err) => {
@@ -171,7 +183,15 @@ export class ChatService {
         })
     }
 
-    public async saveUsage(userId: number, usage: number) {
+    public async tiktokenAndSave(userId: number, text: string) {
+        // TODO: use a better fix in the tokenizer
+        text = text.replace(/<\|endoftext\|>/g, '')
+        const tokenizer = get_encoding('cl100k_base')
+        const token = tokenizer.encode(text).length
+        await this.saveUsage(userId, token, false)
+    }
+
+    public async saveUsage(userId: number, usage: number, member = true) {
         const date = format(new Date(), 'yyyy-MM-dd')
         const tokenUsage = await this.tokenUsageRepository.findOneBy({
             userId,
@@ -184,7 +204,9 @@ export class ChatService {
             await this.tokenUsageRepository.save(new TokenUsage({ userId, usage, date }))
         }
 
-        this.membershipService.saveUsage(userId, usage)
+        if (member) {
+            this.membershipService.saveUsage(userId, usage)
+        }
     }
 
     public async getUsage(userId: number, start?: string, end?: string): Promise<TokenUsage[]> {

+ 2 - 2
src/membership/membership.service.ts

@@ -41,9 +41,9 @@ export class MembershipService {
         }
         membership = new Membership()
         membership.userId = userId
-        membership.expireAt = addDays(new Date(), 7)
+        membership.expireAt = addDays(new Date(), 3)
         membership.memberType = MemberType.Trial
-        membership.tokenLeft = 10000
+        membership.tokenLeft = 2000
         return await this.memberShipRepository.save(membership)
     }