소스 검색

feat: pipe async support

Timothy J. Baek 10 달 전
부모
커밋
4370f233a1
1개의 변경된 파일11개의 추가작업 그리고 5개의 파일을 삭제
  1. 11 5
      backend/main.py

+ 11 - 5
backend/main.py

@@ -843,7 +843,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
             "role": user.role,
         }
 
-        def job():
+        async def job():
             pipe_id = form_data["model"]
             if "." in pipe_id:
                 pipe_id, sub_pipe_id = pipe_id.split(".", 1)
@@ -852,8 +852,11 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
             pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
             if form_data["stream"]:
 
-                def stream_content():
-                    res = pipe(body=form_data)
+                async def stream_content():
+                    if inspect.iscoroutinefunction(pipe):
+                        res = await pipe(body=form_data)
+                    else:
+                        res = pipe(body=form_data)
 
                     if isinstance(res, str):
                         message = stream_message_template(form_data["model"], res)
@@ -898,7 +901,10 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
                     stream_content(), media_type="text/event-stream"
                 )
             else:
-                res = pipe(body=form_data)
+                if inspect.iscoroutinefunction(pipe):
+                    res = await pipe(body=form_data)
+                else:
+                    res = pipe(body=form_data)
 
                 if isinstance(res, dict):
                     return res
@@ -930,7 +936,7 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
                         ],
                     }
 
-        return await run_in_threadpool(job)
+        return await job()
     if model["owned_by"] == "ollama":
         return await generate_ollama_chat_completion(form_data, user=user)
     else: