Bruce MacDonald 1 年之前
父節點
當前提交
7454900733
共有 3 個文件被更改,包括 62 次插入36 次删除
  1. 17 9
      ollama/cmd/cli.py
  2. 43 23
      ollama/cmd/server.py
  3. 2 4
      requirements.txt

+ 17 - 9
ollama/cmd/cli.py

@@ -8,24 +8,28 @@ from ollama.cmd import server
 
 def main():
     parser = ArgumentParser()
-    parser.add_argument('--models-home', default=Path.home() / '.ollama' / 'models')
+    parser.add_argument("--models-home", default=Path.home() / ".ollama" / "models")
 
     subparsers = parser.add_subparsers()
 
-    server.set_parser(subparsers.add_parser('serve'))
+    server.set_parser(subparsers.add_parser("serve"))
 
-    list_parser = subparsers.add_parser('list')
+    list_parser = subparsers.add_parser("list")
     list_parser.set_defaults(fn=list)
 
-    generate_parser = subparsers.add_parser('generate')
-    generate_parser.add_argument('model')
-    generate_parser.add_argument('prompt')
+    generate_parser = subparsers.add_parser("generate")
+    generate_parser.add_argument("model")
+    generate_parser.add_argument("prompt")
     generate_parser.set_defaults(fn=generate)
 
+    add_parser = subparsers.add_parser("add")
+    add_parser.add_argument("model_file")
+    generate_parser.set_defaults(fn=add)
+
     args = parser.parse_args()
     args = vars(args)
 
-    fn = args.pop('fn')
+    fn = args.pop("fn")
     fn(**args)
 
 
@@ -38,6 +42,10 @@ def generate(*args, **kwargs):
     for output in engine.generate(*args, **kwargs):
         output = json.loads(output)
 
-        choices = output.get('choices', [])
+        choices = output.get("choices", [])
         if len(choices) > 0:
-            print(choices[0].get('text', ''), end='')
+            print(choices[0].get("text", ""), end="")
+
+
+def add(*args, **kwargs):
+    model.add(*args, **kwargs)

+ 43 - 23
ollama/cmd/server.py

@@ -1,39 +1,59 @@
 from aiohttp import web
+import aiohttp_cors
 
 from ollama import engine
 
 
 def set_parser(parser):
-    parser.add_argument('--host', default='127.0.0.1')
-    parser.add_argument('--port', default=7734)
+    parser.add_argument("--host", default="127.0.0.1")
+    parser.add_argument("--port", default=7734)
     parser.set_defaults(fn=serve)
 
 
-def serve(models_home='.', *args, **kwargs):
+def serve(models_home=".", *args, **kwargs):
     app = web.Application()
-    app.add_routes([
-        web.post('/load', load),
-        web.post('/unload', unload),
-        web.post('/generate', generate),
-    ])
 
-    app.update({
-        'llms': {},
-        'models_home': models_home,
-    })
+    cors = aiohttp_cors.setup(
+        app,
+        defaults={
+            "*": aiohttp_cors.ResourceOptions(
+                allow_credentials=True,
+                expose_headers="*",
+                allow_headers="*",
+            )
+        },
+    )
+
+    app.add_routes(
+        [
+            web.post("/load", load),
+            web.post("/unload", unload),
+            web.post("/generate", generate),
+        ]
+    )
+
+    for route in app.router.routes():
+        cors.add(route)
+
+    app.update(
+        {
+            "llms": {},
+            "models_home": models_home,
+        }
+    )
 
     web.run_app(app, **kwargs)
 
 
 async def load(request):
     body = await request.json()
-    model = body.get('model')
+    model = body.get("model")
     if not model:
         raise web.HTTPBadRequest()
 
     kwargs = {
-        'llms': request.app.get('llms'),
-        'models_home': request.app.get('models_home'),
+        "llms": request.app.get("llms"),
+        "models_home": request.app.get("models_home"),
     }
 
     engine.load(model, **kwargs)
@@ -42,21 +62,21 @@ async def load(request):
 
 async def unload(request):
     body = await request.json()
-    model = body.get('model')
+    model = body.get("model")
     if not model:
         raise web.HTTPBadRequest()
 
-    engine.unload(model, llms=request.app.get('llms'))
+    engine.unload(model, llms=request.app.get("llms"))
     return web.Response()
 
 
 async def generate(request):
     body = await request.json()
-    model = body.get('model')
+    model = body.get("model")
     if not model:
         raise web.HTTPBadRequest()
 
-    prompt = body.get('prompt')
+    prompt = body.get("prompt")
     if not prompt:
         raise web.HTTPBadRequest()
 
@@ -64,12 +84,12 @@ async def generate(request):
     await response.prepare(request)
 
     kwargs = {
-        'llms': request.app.get('llms'),
-        'models_home': request.app.get('models_home'),
+        "llms": request.app.get("llms"),
+        "models_home": request.app.get("models_home"),
     }
 
     for output in engine.generate(model, prompt, **kwargs):
-        await response.write(output.encode('utf-8'))
-        await response.write(b'\n')
+        await response.write(output.encode("utf-8"))
+        await response.write(b"\n")
 
     return response

+ 2 - 4
requirements.txt

@@ -1,7 +1,5 @@
-click==8.1.3
-Flask==2.3.2
-Flask_Cors==3.0.10
+aiohttp==3.8.4
+aiohttp_cors==0.7.0
 llama_cpp_python==0.1.65
 pyinstaller==5.13.0
 setuptools==65.6.3
-tqdm==4.65.0