chat.service.ts 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import { ForbiddenException, Injectable, InternalServerErrorException, Logger } from '@nestjs/common'
  2. import { Observable } from 'rxjs'
  3. import { ChatGPTAPI, ChatMessage } from '../chatapi'
  4. import type { RequestProps } from './types'
  5. import { chatReplyProcess } from './chatgpt'
  6. import { Repository, MoreThanOrEqual, LessThanOrEqual, And } from 'typeorm'
  7. import { ChatHistory } from './entities/chat.entity'
  8. import { InjectRepository } from '@nestjs/typeorm'
  9. import { TokenUsage } from './entities/token-usage.entity'
  10. import { format } from 'date-fns'
  11. import { MembershipService } from '../membership/membership.service'
  12. import { MemberType } from '../membership/entities/membership.entity'
  13. import { get_encoding } from '@dqbd/tiktoken'
  14. import { fetchSSE } from 'src/chatapi/fetch-sse'
  15. import { HttpService } from '@nestjs/axios'
  16. import * as types from '../chatapi/types'
  17. import { SysConfigService } from 'src/sys-config/sys-config.service'
  18. @Injectable()
  19. export class ChatService {
  20. tokenizer = get_encoding('cl100k_base')
  21. constructor(
  22. @InjectRepository(ChatHistory)
  23. private readonly chatHistoryRepository: Repository<ChatHistory>,
  24. @InjectRepository(TokenUsage)
  25. private readonly tokenUsageRepository: Repository<TokenUsage>,
  26. private readonly membershipService: MembershipService,
  27. private readonly httpService: HttpService,
  28. private readonly sysConfigService: SysConfigService
  29. ) {}
  30. public chat(req, res): Observable<any> {
  31. res.setHeader('Content-Type', 'application/octet-stream')
  32. return new Observable((observer) => {
  33. const { prompt, options = {}, systemMessage, temperature, top_p } = req.body as RequestProps
  34. let firstChunk = true
  35. chatReplyProcess({
  36. message: prompt,
  37. lastContext: options,
  38. process: (chat: ChatMessage) => {
  39. // observer.next(
  40. // new MessageEvent('message', {
  41. // data: firstChunk ? JSON.stringify(chat) : `\n${JSON.stringify(chat)}`
  42. // })
  43. // )
  44. observer.next(firstChunk ? JSON.stringify(chat) : `\n${JSON.stringify(chat)}`)
  45. firstChunk = false
  46. },
  47. systemMessage,
  48. temperature,
  49. top_p
  50. })
  51. .then(() => {})
  52. .catch((error) => {
  53. observer.error(error)
  54. })
  55. .finally(() => {
  56. observer.complete()
  57. })
  58. })
  59. }
  60. public async chat1(req, res) {
  61. res.setHeader('Content-type', 'application/octet-stream')
  62. const defSysMsg = (await this.sysConfigService.findByName('system_message'))?.value
  63. const membership = await this.membershipService.getMembership(req.user.id)
  64. if (!membership) {
  65. throw new ForbiddenException('请先成为会员')
  66. }
  67. if (membership.memberType == MemberType.Trial && membership.tokenLeft <= 0) {
  68. throw new ForbiddenException('您的试用额度已用完,请升级会员')
  69. }
  70. if (membership.isExpired) {
  71. throw new ForbiddenException('您的会员已过期,请即时续费')
  72. }
  73. try {
  74. const promptTime = new Date()
  75. const { prompt, options = {}, systemMessage, temperature, top_p } = req.body as RequestProps
  76. let firstChunk = true
  77. const result = await chatReplyProcess({
  78. message: prompt,
  79. lastContext: options,
  80. process: (chat: ChatMessage) => {
  81. res.write(firstChunk ? JSON.stringify(chat) : `\n${JSON.stringify(chat)}`)
  82. firstChunk = false
  83. },
  84. systemMessage: systemMessage || defSysMsg,
  85. temperature,
  86. top_p
  87. })
  88. let chatMessage = result.data as ChatMessage
  89. this.chatHistoryRepository
  90. .save(
  91. new ChatHistory({
  92. messageId: chatMessage.parentMessageId,
  93. parentMessageId: options.parentMessageId,
  94. userId: req.user.id,
  95. message: prompt,
  96. role: 'user',
  97. token: chatMessage.detail.usage.prompt_tokens,
  98. time: promptTime
  99. })
  100. )
  101. .catch((e) => {
  102. Logger.error(e, 'SAVE CHAT HISTORY')
  103. })
  104. this.chatHistoryRepository
  105. .save(
  106. new ChatHistory({
  107. messageId: chatMessage.id,
  108. parentMessageId: chatMessage.parentMessageId,
  109. userId: req.user.id,
  110. message: chatMessage.text,
  111. role: 'assistant',
  112. token: chatMessage.detail.usage.completion_tokens,
  113. time: new Date()
  114. })
  115. )
  116. .catch((e) => {
  117. Logger.error(e, 'SAVE CHAT HISTORY')
  118. })
  119. this.saveUsage(req.user.id, chatMessage.detail.usage.total_tokens)
  120. } catch (error) {
  121. res.write(JSON.stringify(error))
  122. } finally {
  123. res.end()
  124. }
  125. }
  126. public async chatProxy(req) {
  127. const url = `${process.env.AZURE_OPENAI_ENDPOINT}/openai/deployments/${process.env.AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version=${process.env.AZURE_OPENAI_VERSION}`
  128. req.body.stream = false
  129. try {
  130. const { data } = await this.httpService.axiosRef.post(url, req.body, {
  131. headers: {
  132. 'Content-Type': 'application/json',
  133. 'api-key': `${process.env.AZURE_OPENAI_KEY}`
  134. }
  135. })
  136. this.saveUsage(req.user.id, data.usage.total_tokens, false)
  137. return data
  138. } catch (e) {
  139. throw new InternalServerErrorException(e.response.data)
  140. }
  141. }
  142. public streamChatProxy(req) {
  143. const url = `${process.env.AZURE_OPENAI_ENDPOINT}/openai/deployments/${process.env.AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version=${process.env.AZURE_OPENAI_VERSION}`
  144. req.body.stream = true
  145. return new Observable((subscriber) => {
  146. let text = ''
  147. fetchSSE(url, {
  148. body: JSON.stringify(req.body),
  149. headers: {
  150. 'Content-Type': 'application/json',
  151. 'api-key': `${process.env.AZURE_OPENAI_KEY}`
  152. },
  153. method: 'POST',
  154. onMessage: (msg: string) => {
  155. subscriber.next(msg)
  156. if ('[DONE]' === msg) {
  157. Logger.log('done', 'CHAT PROXY')
  158. this.tiktokenAndSave(
  159. req.user.id,
  160. req.body.messages.map((message) => `${message.role}:\n${message.content}`).join('\n\n') +
  161. '\n\nassistant:\n' +
  162. text
  163. )
  164. return subscriber.complete()
  165. }
  166. const response: types.openai.CreateChatCompletionDeltaResponse = JSON.parse(msg)
  167. if (response.choices?.length) {
  168. const delta = response.choices[0].delta
  169. if (delta?.content) text += delta.content
  170. }
  171. },
  172. onError: (err) => {
  173. Logger.error(err, 'CHAT PROXY')
  174. subscriber.error(err)
  175. subscriber.complete()
  176. }
  177. }).catch((e) => {
  178. Logger.error(e, 'CHAT PROXY')
  179. subscriber.error(e)
  180. })
  181. })
  182. }
  183. public async tiktokenAndSave(userId: number, text: string) {
  184. // TODO: use a better fix in the tokenizer
  185. text = text.replace(/<\|endoftext\|>/g, '')
  186. const tokenizer = get_encoding('cl100k_base')
  187. const token = tokenizer.encode(text).length
  188. await this.saveUsage(userId, token, false)
  189. }
  190. public async saveUsage(userId: number, usage: number, member = true) {
  191. const date = format(new Date(), 'yyyy-MM-dd')
  192. const tokenUsage = await this.tokenUsageRepository.findOneBy({
  193. userId,
  194. date
  195. })
  196. if (tokenUsage) {
  197. tokenUsage.usage += usage
  198. await this.tokenUsageRepository.save(tokenUsage)
  199. } else {
  200. await this.tokenUsageRepository.save(new TokenUsage({ userId, usage, date }))
  201. }
  202. if (member) {
  203. this.membershipService.saveUsage(userId, usage)
  204. }
  205. }
  206. public async getUsage(userId: number, start?: string, end?: string): Promise<TokenUsage[]> {
  207. const date = format(new Date(), 'yyyy-MM-dd')
  208. return await this.tokenUsageRepository.findBy({
  209. userId,
  210. date: And(MoreThanOrEqual(start || '2023-04-01'), LessThanOrEqual(end || date))
  211. })
  212. }
  213. }