Browse Source

add llama.cpp go bindings

Jeffrey Morgan 1 năm trước cách đây
mục cha
commit
6093a88c1a
18 tập tin đã thay đổi với 841 bổ sung79 xóa
  1. 8 1
      README.md
  2. 1 1
      api/client.go
  3. 4 5
      cmd/cmd.go
  4. 2 4
      go.mod
  5. 8 10
      go.sum
  6. 0 1
      lib/.gitignore
  7. 0 10
      lib/README.md
  8. 0 41
      lib/binding.h
  9. 1 0
      llama/.gitignore
  10. 9 3
      llama/CMakeLists.txt
  11. 22 0
      llama/binding/binding.cpp
  12. 71 0
      llama/binding/binding.h
  13. 302 0
      llama/llama.go
  14. 9 0
      llama/llama_cublas.go
  15. 9 0
      llama/llama_openblas.go
  16. 392 0
      llama/options.go
  17. 1 1
      main.go
  18. 2 2
      server/routes.go

+ 8 - 1
README.md

@@ -1,6 +1,6 @@
 # Ollama
 
-A fast runtime for large language models, powered by [llama.cpp](https://github.com/ggerganov/llama.cpp).
+An easy, fast runtime for large language models, powered by `llama.cpp`.
 
 > _Note: this project is a work in progress. Certain models that can be run with `ollama` are intended for research and/or non-commercial use only._
 
@@ -38,6 +38,13 @@ Or directly via downloaded model files:
 ollama run ~/Downloads/orca-mini-13b.ggmlv3.q4_0.bin
 ```
 
+## Building
+
+```
+go generate ./...
+go build .
+```
+
 ## Documentation
 
 - [Development](docs/development.md)

+ 1 - 1
api/client.go

@@ -8,7 +8,7 @@ import (
 	"io"
 	"net/http"
 
-	"github.com/ollama/ollama/signature"
+	"github.com/jmorganca/ollama/signature"
 )
 
 type Client struct {

+ 4 - 5
cmd/cmd.go

@@ -3,7 +3,6 @@ package cmd
 import (
 	"context"
 	"fmt"
-	"io/ioutil"
 	"log"
 	"net"
 	"net/http"
@@ -13,8 +12,8 @@ import (
 
 	"github.com/spf13/cobra"
 
-	"github.com/ollama/ollama/api"
-	"github.com/ollama/ollama/server"
+	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/server"
 )
 
 func NewAPIClient(cmd *cobra.Command) (*api.Client, error) {
@@ -36,7 +35,7 @@ func NewAPIClient(cmd *cobra.Command) (*api.Client, error) {
 
 	if k != "" {
 		fn := path.Join(home, ".ollama/keys/", k)
-		rawKey, err = ioutil.ReadFile(fn)
+		rawKey, err = os.ReadFile(fn)
 		if err != nil {
 			return nil, err
 		}
@@ -59,7 +58,7 @@ func NewCLI() *cobra.Command {
 	log.SetFlags(log.LstdFlags | log.Lshortfile)
 
 	rootCmd := &cobra.Command{
-		Use:   "gollama",
+		Use:   "ollama",
 		Short: "Run any large language model on any machine.",
 		CompletionOptions: cobra.CompletionOptions{
 			DisableDefaultCmd: true,

+ 2 - 4
go.mod

@@ -1,11 +1,9 @@
-module github.com/ollama/ollama
+module github.com/jmorganca/ollama
 
 go 1.20
 
 require (
 	github.com/gin-gonic/gin v1.9.1
-	github.com/go-skynet/go-llama.cpp v0.0.0-20230630201504-ecd358d2f144
-	github.com/r3labs/sse v0.0.0-20210224172625-26fe804710bc
 	github.com/spf13/cobra v1.7.0
 	golang.org/x/crypto v0.10.0
 )
@@ -19,6 +17,7 @@ require (
 	github.com/go-playground/universal-translator v0.18.1 // indirect
 	github.com/go-playground/validator/v10 v10.14.0 // indirect
 	github.com/goccy/go-json v0.10.2 // indirect
+	github.com/google/go-cmp v0.5.9 // indirect
 	github.com/inconshreveable/mousetrap v1.1.0 // indirect
 	github.com/json-iterator/go v1.1.12 // indirect
 	github.com/klauspost/cpuid/v2 v2.2.4 // indirect
@@ -35,6 +34,5 @@ require (
 	golang.org/x/sys v0.9.0 // indirect
 	golang.org/x/text v0.10.0 // indirect
 	google.golang.org/protobuf v1.30.0 // indirect
-	gopkg.in/cenkalti/backoff.v1 v1.1.0 // indirect
 	gopkg.in/yaml.v3 v3.0.1 // indirect
 )

+ 8 - 10
go.sum

@@ -6,6 +6,7 @@ github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhD
 github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
 github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
 github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
@@ -13,18 +14,19 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE
 github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
 github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
 github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
+github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
 github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
 github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
 github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
 github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
 github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
 github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
-github.com/go-skynet/go-llama.cpp v0.0.0-20230630201504-ecd358d2f144 h1:fszkmZG3pW9/bqhuWB6sfJMArJPx1RPzjZSqNdhuSQ0=
-github.com/go-skynet/go-llama.cpp v0.0.0-20230630201504-ecd358d2f144/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40=
 github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
 github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
 github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
 github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
+github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
+github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
 github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
@@ -44,8 +46,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
 github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
 github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
 github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
-github.com/r3labs/sse v0.0.0-20210224172625-26fe804710bc/go.mod h1:S8xSOnV3CgpNrWd0GQ/OoQfMtlg2uPRSuTzcSGrzwK8=
 github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
 github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
 github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
@@ -55,12 +57,12 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
 github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
 github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
-github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
 github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
 github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
 github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
 github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
+github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
 github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
 github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
 github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
@@ -69,27 +71,23 @@ github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
 golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
 golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
 golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
-golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM=
 golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I=
-golang.org/x/net v0.0.0-20191116160921-f9c825593386/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
 golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
 golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
-golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
 golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/term v0.9.0 h1:GRRCnKYhdQrD8kfRAdQ6Zcw1P0OcELxGLKJvtjVMZ28=
 golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58=
 golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
 google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
 google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
-gopkg.in/cenkalti/backoff.v1 v1.1.0/go.mod h1:J6Vskwqd+OMVJl8C33mmtxTBs2gyzfv7UDAkHu8BrjI=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
-gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

+ 0 - 1
lib/.gitignore

@@ -1 +0,0 @@
-build

+ 0 - 10
lib/README.md

@@ -1,10 +0,0 @@
-# Bindings
-
-These are Llama.cpp bindings
-
-## Build
-
-```
-cmake -S . -B build
-cmake --build build
-```

+ 0 - 41
lib/binding.h

@@ -1,41 +0,0 @@
-#ifdef __cplusplus
-#include <vector>
-#include <string>
-extern "C" {
-#endif
-
-#include <stdbool.h>
-
-extern unsigned char tokenCallback(void *, char *);
-
-int load_state(void *ctx, char *statefile, char*modes);
-
-int eval(void* params_ptr, void *ctx, char*text);
-
-void save_state(void *ctx, char *dst, char*modes);
-
-void* load_model(const char *fname, int n_ctx, int n_seed, bool memory_f16, bool mlock, bool embeddings, bool mmap, bool low_vram, bool vocab_only, int n_gpu, int n_batch, const char *maingpu, const char *tensorsplit, bool numa);
-
-int get_embeddings(void* params_ptr, void* state_pr, float * res_embeddings);
-
-int get_token_embeddings(void* params_ptr, void* state_pr,  int *tokens, int tokenSize, float * res_embeddings);
-
-void* llama_allocate_params(const char *prompt, int seed, int threads, int tokens,
-                            int top_k, float top_p, float temp, float repeat_penalty, 
-                            int repeat_last_n, bool ignore_eos, bool memory_f16, 
-                            int n_batch, int n_keep, const char** antiprompt, int antiprompt_count,
-                            float tfs_z, float typical_p, float frequency_penalty, float presence_penalty, int mirostat, float mirostat_eta, float mirostat_tau, bool penalize_nl, const char *logit_bias, const char *session_file, bool prompt_cache_all, bool mlock, bool mmap, const char *maingpu, const char *tensorsplit , bool prompt_cache_ro);
-
-void llama_free_params(void* params_ptr);
-
-void llama_binding_free_model(void* state);
-
-int llama_predict(void* params_ptr, void* state_pr, char* result, bool debug);
-
-#ifdef __cplusplus
-}
-
-
-std::vector<std::string> create_vector(const char** strings, int count);
-void delete_vector(std::vector<std::string>* vec);
-#endif

+ 1 - 0
llama/.gitignore

@@ -0,0 +1 @@
+build

+ 9 - 3
lib/CMakeLists.txt → llama/CMakeLists.txt

@@ -9,13 +9,19 @@ FetchContent_Declare(
 
 FetchContent_MakeAvailable(llama_cpp)
 
-project(binding)
+if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
+    set(LLAMA_METAL ON)
+    add_compile_definitions(GGML_USE_METAL)
+endif()
 
-set(LLAMA_METAL ON CACHE BOOL "Enable Llama Metal by default on macOS")
+project(binding)
 
-add_library(binding binding.cpp ${llama_cpp_SOURCE_DIR}/examples/common.cpp)
+add_library(binding ${CMAKE_CURRENT_SOURCE_DIR}/binding/binding.cpp ${llama_cpp_SOURCE_DIR}/examples/common.cpp)
 target_compile_features(binding PRIVATE cxx_std_11)
 target_include_directories(binding PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
 target_include_directories(binding PRIVATE ${llama_cpp_SOURCE_DIR})
 target_include_directories(binding PRIVATE ${llama_cpp_SOURCE_DIR}/examples)
 target_link_libraries(binding llama ggml_static)
+
+configure_file(${llama_cpp_BINARY_DIR}/libllama.a ${CMAKE_CURRENT_BINARY_DIR}/libllama.a COPYONLY)
+configure_file(${llama_cpp_BINARY_DIR}/libggml_static.a ${CMAKE_CURRENT_BINARY_DIR}/libggml_static.a COPYONLY)

+ 22 - 0
lib/binding.cpp → llama/binding/binding.cpp

@@ -1,3 +1,25 @@
+// MIT License
+
+// Copyright (c) 2023 go-skynet authors
+
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+
+// The above copyright notice and this permission notice shall be included in
+// all copies or substantial portions of the Software.
+
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+
 #include "common.h"
 #include "llama.h"
 

+ 71 - 0
llama/binding/binding.h

@@ -0,0 +1,71 @@
+// MIT License
+
+// Copyright (c) 2023 go-skynet authors
+
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+
+// The above copyright notice and this permission notice shall be included in
+// all copies or substantial portions of the Software.
+
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+
+#ifdef __cplusplus
+#include <string>
+#include <vector>
+extern "C" {
+#endif
+
+#include <stdbool.h>
+
+extern unsigned char tokenCallback(void *, char *);
+
+int load_state(void *ctx, char *statefile, char *modes);
+
+int eval(void *params_ptr, void *ctx, char *text);
+
+void save_state(void *ctx, char *dst, char *modes);
+
+void *load_model(const char *fname, int n_ctx, int n_seed, bool memory_f16,
+                 bool mlock, bool embeddings, bool mmap, bool low_vram,
+                 bool vocab_only, int n_gpu, int n_batch, const char *maingpu,
+                 const char *tensorsplit, bool numa);
+
+int get_embeddings(void *params_ptr, void *state_pr, float *res_embeddings);
+
+int get_token_embeddings(void *params_ptr, void *state_pr, int *tokens,
+                         int tokenSize, float *res_embeddings);
+
+void *llama_allocate_params(
+    const char *prompt, int seed, int threads, int tokens, int top_k,
+    float top_p, float temp, float repeat_penalty, int repeat_last_n,
+    bool ignore_eos, bool memory_f16, int n_batch, int n_keep,
+    const char **antiprompt, int antiprompt_count, float tfs_z, float typical_p,
+    float frequency_penalty, float presence_penalty, int mirostat,
+    float mirostat_eta, float mirostat_tau, bool penalize_nl,
+    const char *logit_bias, const char *session_file, bool prompt_cache_all,
+    bool mlock, bool mmap, const char *maingpu, const char *tensorsplit,
+    bool prompt_cache_ro);
+
+void llama_free_params(void *params_ptr);
+
+void llama_binding_free_model(void *state);
+
+int llama_predict(void *params_ptr, void *state_pr, char *result, bool debug);
+
+#ifdef __cplusplus
+}
+
+std::vector<std::string> create_vector(const char **strings, int count);
+void delete_vector(std::vector<std::string> *vec);
+#endif

+ 302 - 0
llama/llama.go

@@ -0,0 +1,302 @@
+// MIT License
+
+// Copyright (c) 2023 go-skynet authors
+
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+
+//go:generate cmake -S . -B build
+//go:generate cmake --build build
+package llama
+
+// #cgo LDFLAGS: -Lbuild -lbinding -lllama -lggml_static -lstdc++
+// #cgo darwin LDFLAGS: -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
+// #cgo darwin CXXFLAGS: -std=c++11
+// #include "binding/binding.h"
+import "C"
+import (
+	"fmt"
+	"os"
+	"strings"
+	"sync"
+	"unsafe"
+)
+
+type LLama struct {
+	state       unsafe.Pointer
+	embeddings  bool
+	contextSize int
+}
+
+func New(model string, opts ...ModelOption) (*LLama, error) {
+	mo := NewModelOptions(opts...)
+	modelPath := C.CString(model)
+	result := C.load_model(modelPath, C.int(mo.ContextSize), C.int(mo.Seed), C.bool(mo.F16Memory), C.bool(mo.MLock), C.bool(mo.Embeddings), C.bool(mo.MMap), C.bool(mo.LowVRAM), C.bool(mo.VocabOnly), C.int(mo.NGPULayers), C.int(mo.NBatch), C.CString(mo.MainGPU), C.CString(mo.TensorSplit), C.bool(mo.NUMA))
+	if result == nil {
+		return nil, fmt.Errorf("failed loading model")
+	}
+
+	ll := &LLama{state: result, contextSize: mo.ContextSize, embeddings: mo.Embeddings}
+
+	return ll, nil
+}
+
+func (l *LLama) Free() {
+	C.llama_binding_free_model(l.state)
+}
+
+func (l *LLama) LoadState(state string) error {
+	d := C.CString(state)
+	w := C.CString("rb")
+
+	result := C.load_state(l.state, d, w)
+	if result != 0 {
+		return fmt.Errorf("error while loading state")
+	}
+
+	return nil
+}
+
+func (l *LLama) SaveState(dst string) error {
+	d := C.CString(dst)
+	w := C.CString("wb")
+
+	C.save_state(l.state, d, w)
+
+	_, err := os.Stat(dst)
+	return err
+}
+
+// Token Embeddings
+func (l *LLama) TokenEmbeddings(tokens []int, opts ...PredictOption) ([]float32, error) {
+	if !l.embeddings {
+		return []float32{}, fmt.Errorf("model loaded without embeddings")
+	}
+
+	po := NewPredictOptions(opts...)
+
+	outSize := po.Tokens
+	if po.Tokens == 0 {
+		outSize = 9999999
+	}
+
+	floats := make([]float32, outSize)
+
+	myArray := (*C.int)(C.malloc(C.size_t(len(tokens)) * C.sizeof_int))
+
+	// Copy the values from the Go slice to the C array
+	for i, v := range tokens {
+		(*[1<<31 - 1]int32)(unsafe.Pointer(myArray))[i] = int32(v)
+	}
+
+	params := C.llama_allocate_params(C.CString(""), C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
+		C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat),
+		C.bool(po.IgnoreEOS), C.bool(po.F16KV),
+		C.int(po.Batch), C.int(po.NKeep), nil, C.int(0),
+		C.float(po.TailFreeSamplingZ), C.float(po.TypicalP), C.float(po.FrequencyPenalty), C.float(po.PresencePenalty),
+		C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), C.CString(po.LogitBias),
+		C.CString(po.PathPromptCache), C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap),
+		C.CString(po.MainGPU), C.CString(po.TensorSplit),
+		C.bool(po.PromptCacheRO),
+	)
+	ret := C.get_token_embeddings(params, l.state, myArray, C.int(len(tokens)), (*C.float)(&floats[0]))
+	if ret != 0 {
+		return floats, fmt.Errorf("embedding inference failed")
+	}
+	return floats, nil
+}
+
+// Embeddings
+func (l *LLama) Embeddings(text string, opts ...PredictOption) ([]float32, error) {
+	if !l.embeddings {
+		return []float32{}, fmt.Errorf("model loaded without embeddings")
+	}
+
+	po := NewPredictOptions(opts...)
+
+	input := C.CString(text)
+	if po.Tokens == 0 {
+		po.Tokens = 99999999
+	}
+	floats := make([]float32, po.Tokens)
+	reverseCount := len(po.StopPrompts)
+	reversePrompt := make([]*C.char, reverseCount)
+	var pass **C.char
+	for i, s := range po.StopPrompts {
+		cs := C.CString(s)
+		reversePrompt[i] = cs
+		pass = &reversePrompt[0]
+	}
+
+	params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
+		C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat),
+		C.bool(po.IgnoreEOS), C.bool(po.F16KV),
+		C.int(po.Batch), C.int(po.NKeep), pass, C.int(reverseCount),
+		C.float(po.TailFreeSamplingZ), C.float(po.TypicalP), C.float(po.FrequencyPenalty), C.float(po.PresencePenalty),
+		C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), C.CString(po.LogitBias),
+		C.CString(po.PathPromptCache), C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap),
+		C.CString(po.MainGPU), C.CString(po.TensorSplit),
+		C.bool(po.PromptCacheRO),
+	)
+
+	ret := C.get_embeddings(params, l.state, (*C.float)(&floats[0]))
+	if ret != 0 {
+		return floats, fmt.Errorf("embedding inference failed")
+	}
+
+	return floats, nil
+}
+
+func (l *LLama) Eval(text string, opts ...PredictOption) error {
+	po := NewPredictOptions(opts...)
+
+	input := C.CString(text)
+	if po.Tokens == 0 {
+		po.Tokens = 99999999
+	}
+
+	reverseCount := len(po.StopPrompts)
+	reversePrompt := make([]*C.char, reverseCount)
+	var pass **C.char
+	for i, s := range po.StopPrompts {
+		cs := C.CString(s)
+		reversePrompt[i] = cs
+		pass = &reversePrompt[0]
+	}
+
+	params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
+		C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat),
+		C.bool(po.IgnoreEOS), C.bool(po.F16KV),
+		C.int(po.Batch), C.int(po.NKeep), pass, C.int(reverseCount),
+		C.float(po.TailFreeSamplingZ), C.float(po.TypicalP), C.float(po.FrequencyPenalty), C.float(po.PresencePenalty),
+		C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), C.CString(po.LogitBias),
+		C.CString(po.PathPromptCache), C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap),
+		C.CString(po.MainGPU), C.CString(po.TensorSplit),
+		C.bool(po.PromptCacheRO),
+	)
+	ret := C.eval(params, l.state, input)
+	if ret != 0 {
+		return fmt.Errorf("inference failed")
+	}
+
+	C.llama_free_params(params)
+
+	return nil
+}
+
+func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) {
+	po := NewPredictOptions(opts...)
+
+	if po.TokenCallback != nil {
+		setCallback(l.state, po.TokenCallback)
+	}
+
+	input := C.CString(text)
+	if po.Tokens == 0 {
+		po.Tokens = 99999999
+	}
+	out := make([]byte, po.Tokens)
+
+	reverseCount := len(po.StopPrompts)
+	reversePrompt := make([]*C.char, reverseCount)
+	var pass **C.char
+	for i, s := range po.StopPrompts {
+		cs := C.CString(s)
+		reversePrompt[i] = cs
+		pass = &reversePrompt[0]
+	}
+
+	params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
+		C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat),
+		C.bool(po.IgnoreEOS), C.bool(po.F16KV),
+		C.int(po.Batch), C.int(po.NKeep), pass, C.int(reverseCount),
+		C.float(po.TailFreeSamplingZ), C.float(po.TypicalP), C.float(po.FrequencyPenalty), C.float(po.PresencePenalty),
+		C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), C.CString(po.LogitBias),
+		C.CString(po.PathPromptCache), C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap),
+		C.CString(po.MainGPU), C.CString(po.TensorSplit),
+		C.bool(po.PromptCacheRO),
+	)
+	ret := C.llama_predict(params, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.bool(po.DebugMode))
+	if ret != 0 {
+		return "", fmt.Errorf("inference failed")
+	}
+	res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))
+
+	res = strings.TrimPrefix(res, " ")
+	res = strings.TrimPrefix(res, text)
+	res = strings.TrimPrefix(res, "\n")
+
+	for _, s := range po.StopPrompts {
+		res = strings.TrimRight(res, s)
+	}
+
+	C.llama_free_params(params)
+
+	if po.TokenCallback != nil {
+		setCallback(l.state, nil)
+	}
+
+	return res, nil
+}
+
+// CGo only allows us to use static calls from C to Go, we can't just dynamically pass in func's.
+// This is the next best thing, we register the callbacks in this map and call tokenCallback from
+// the C code. We also attach a finalizer to LLama, so it will unregister the callback when the
+// garbage collection frees it.
+
+// SetTokenCallback registers a callback for the individual tokens created when running Predict. It
+// will be called once for each token. The callback shall return true as long as the model should
+// continue predicting the next token. When the callback returns false the predictor will return.
+// The tokens are just converted into Go strings, they are not trimmed or otherwise changed. Also
+// the tokens may not be valid UTF-8.
+// Pass in nil to remove a callback.
+//
+// It is save to call this method while a prediction is running.
+func (l *LLama) SetTokenCallback(callback func(token string) bool) {
+	setCallback(l.state, callback)
+}
+
+var (
+	m         sync.Mutex
+	callbacks = map[uintptr]func(string) bool{}
+)
+
+//export tokenCallback
+func tokenCallback(statePtr unsafe.Pointer, token *C.char) bool {
+	m.Lock()
+	defer m.Unlock()
+
+	if callback, ok := callbacks[uintptr(statePtr)]; ok {
+		return callback(C.GoString(token))
+	}
+
+	return true
+}
+
+// setCallback can be used to register a token callback for LLama. Pass in a nil callback to
+// remove the callback.
+func setCallback(statePtr unsafe.Pointer, callback func(string) bool) {
+	m.Lock()
+	defer m.Unlock()
+
+	if callback == nil {
+		delete(callbacks, uintptr(statePtr))
+	} else {
+		callbacks[uintptr(statePtr)] = callback
+	}
+}

+ 9 - 0
llama/llama_cublas.go

@@ -0,0 +1,9 @@
+//go:build cublas
+// +build cublas
+
+package llama
+
+/*
+#cgo LDFLAGS: -lcublas -lcudart -L/usr/local/cuda/lib64/
+*/
+import "C"

+ 9 - 0
llama/llama_openblas.go

@@ -0,0 +1,9 @@
+//go:build openblas
+// +build openblas
+
+package llama
+
+/*
+#cgo LDFLAGS: -lopenblas
+*/
+import "C"

+ 392 - 0
llama/options.go

@@ -0,0 +1,392 @@
+// MIT License
+
+// Copyright (c) 2023 go-skynet authors
+
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+
+// The above copyright notice and this permission notice shall be included in all
+// copies or substantial portions of the Software.
+
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+// SOFTWARE.
+
+package llama
+
+type ModelOptions struct {
+	ContextSize int
+	Seed        int
+	NBatch      int
+	F16Memory   bool
+	MLock       bool
+	MMap        bool
+	VocabOnly   bool
+	LowVRAM     bool
+	Embeddings  bool
+	NUMA        bool
+	NGPULayers  int
+	MainGPU     string
+	TensorSplit string
+}
+
+type PredictOptions struct {
+	Seed, Threads, Tokens, TopK, Repeat, Batch, NKeep int
+	TopP, Temperature, Penalty                        float64
+	F16KV                                             bool
+	DebugMode                                         bool
+	StopPrompts                                       []string
+	IgnoreEOS                                         bool
+
+	TailFreeSamplingZ float64
+	TypicalP          float64
+	FrequencyPenalty  float64
+	PresencePenalty   float64
+	Mirostat          int
+	MirostatETA       float64
+	MirostatTAU       float64
+	PenalizeNL        bool
+	LogitBias         string
+	TokenCallback     func(string) bool
+
+	PathPromptCache             string
+	MLock, MMap, PromptCacheAll bool
+	PromptCacheRO               bool
+	MainGPU                     string
+	TensorSplit                 string
+}
+
+type PredictOption func(p *PredictOptions)
+
+type ModelOption func(p *ModelOptions)
+
+var DefaultModelOptions ModelOptions = ModelOptions{
+	ContextSize: 512,
+	Seed:        0,
+	F16Memory:   false,
+	MLock:       false,
+	Embeddings:  false,
+	MMap:        true,
+	LowVRAM:     false,
+}
+
+var DefaultOptions PredictOptions = PredictOptions{
+	Seed:              -1,
+	Threads:           4,
+	Tokens:            128,
+	Penalty:           1.1,
+	Repeat:            64,
+	Batch:             512,
+	NKeep:             64,
+	TopK:              40,
+	TopP:              0.95,
+	TailFreeSamplingZ: 1.0,
+	TypicalP:          1.0,
+	Temperature:       0.8,
+	FrequencyPenalty:  0.0,
+	PresencePenalty:   0.0,
+	Mirostat:          0,
+	MirostatTAU:       5.0,
+	MirostatETA:       0.1,
+	MMap:              true,
+}
+
+// SetContext sets the context size.
+func SetContext(c int) ModelOption {
+	return func(p *ModelOptions) {
+		p.ContextSize = c
+	}
+}
+
+func SetModelSeed(c int) ModelOption {
+	return func(p *ModelOptions) {
+		p.Seed = c
+	}
+}
+
+// SetContext sets the context size.
+func SetMMap(b bool) ModelOption {
+	return func(p *ModelOptions) {
+		p.MMap = b
+	}
+}
+
+// SetNBatch sets the  n_Batch
+func SetNBatch(n_batch int) ModelOption {
+	return func(p *ModelOptions) {
+		p.NBatch = n_batch
+	}
+}
+
+// Set sets the tensor split for the GPU
+func SetTensorSplit(maingpu string) ModelOption {
+	return func(p *ModelOptions) {
+		p.TensorSplit = maingpu
+	}
+}
+
+// SetMainGPU sets the main_gpu
+func SetMainGPU(maingpu string) ModelOption {
+	return func(p *ModelOptions) {
+		p.MainGPU = maingpu
+	}
+}
+
+// SetPredictionTensorSplit sets the tensor split for the GPU
+func SetPredictionTensorSplit(maingpu string) PredictOption {
+	return func(p *PredictOptions) {
+		p.TensorSplit = maingpu
+	}
+}
+
+// SetPredictionMainGPU sets the main_gpu
+func SetPredictionMainGPU(maingpu string) PredictOption {
+	return func(p *PredictOptions) {
+		p.MainGPU = maingpu
+	}
+}
+
+var VocabOnly ModelOption = func(p *ModelOptions) {
+	p.VocabOnly = true
+}
+
+var EnabelLowVRAM ModelOption = func(p *ModelOptions) {
+	p.LowVRAM = true
+}
+
+var EnableNUMA ModelOption = func(p *ModelOptions) {
+	p.NUMA = true
+}
+
+var EnableEmbeddings ModelOption = func(p *ModelOptions) {
+	p.Embeddings = true
+}
+
+var EnableF16Memory ModelOption = func(p *ModelOptions) {
+	p.F16Memory = true
+}
+
+var EnableF16KV PredictOption = func(p *PredictOptions) {
+	p.F16KV = true
+}
+
+var Debug PredictOption = func(p *PredictOptions) {
+	p.DebugMode = true
+}
+
+var EnablePromptCacheAll PredictOption = func(p *PredictOptions) {
+	p.PromptCacheAll = true
+}
+
+var EnablePromptCacheRO PredictOption = func(p *PredictOptions) {
+	p.PromptCacheRO = true
+}
+
+var EnableMLock ModelOption = func(p *ModelOptions) {
+	p.MLock = true
+}
+
+// Create a new PredictOptions object with the given options.
+func NewModelOptions(opts ...ModelOption) ModelOptions {
+	p := DefaultModelOptions
+	for _, opt := range opts {
+		opt(&p)
+	}
+	return p
+}
+
+var IgnoreEOS PredictOption = func(p *PredictOptions) {
+	p.IgnoreEOS = true
+}
+
+// SetMlock sets the memory lock.
+func SetMlock(b bool) PredictOption {
+	return func(p *PredictOptions) {
+		p.MLock = b
+	}
+}
+
+// SetMemoryMap sets memory mapping.
+func SetMemoryMap(b bool) PredictOption {
+	return func(p *PredictOptions) {
+		p.MMap = b
+	}
+}
+
+// SetGPULayers sets the number of GPU layers to use to offload computation
+func SetGPULayers(n int) ModelOption {
+	return func(p *ModelOptions) {
+		p.NGPULayers = n
+	}
+}
+
+// SetTokenCallback sets the prompts that will stop predictions.
+func SetTokenCallback(fn func(string) bool) PredictOption {
+	return func(p *PredictOptions) {
+		p.TokenCallback = fn
+	}
+}
+
+// SetStopWords sets the prompts that will stop predictions.
+func SetStopWords(stop ...string) PredictOption {
+	return func(p *PredictOptions) {
+		p.StopPrompts = stop
+	}
+}
+
+// SetSeed sets the random seed for sampling text generation.
+func SetSeed(seed int) PredictOption {
+	return func(p *PredictOptions) {
+		p.Seed = seed
+	}
+}
+
+// SetThreads sets the number of threads to use for text generation.
+func SetThreads(threads int) PredictOption {
+	return func(p *PredictOptions) {
+		p.Threads = threads
+	}
+}
+
+// SetTokens sets the number of tokens to generate.
+func SetTokens(tokens int) PredictOption {
+	return func(p *PredictOptions) {
+		p.Tokens = tokens
+	}
+}
+
+// SetTopK sets the value for top-K sampling.
+func SetTopK(topk int) PredictOption {
+	return func(p *PredictOptions) {
+		p.TopK = topk
+	}
+}
+
+// SetTopP sets the value for nucleus sampling.
+func SetTopP(topp float64) PredictOption {
+	return func(p *PredictOptions) {
+		p.TopP = topp
+	}
+}
+
+// SetTemperature sets the temperature value for text generation.
+func SetTemperature(temp float64) PredictOption {
+	return func(p *PredictOptions) {
+		p.Temperature = temp
+	}
+}
+
+// SetPathPromptCache sets the session file to store the prompt cache.
+func SetPathPromptCache(f string) PredictOption {
+	return func(p *PredictOptions) {
+		p.PathPromptCache = f
+	}
+}
+
+// SetPenalty sets the repetition penalty for text generation.
+func SetPenalty(penalty float64) PredictOption {
+	return func(p *PredictOptions) {
+		p.Penalty = penalty
+	}
+}
+
+// SetRepeat sets the number of times to repeat text generation.
+func SetRepeat(repeat int) PredictOption {
+	return func(p *PredictOptions) {
+		p.Repeat = repeat
+	}
+}
+
+// SetBatch sets the batch size.
+func SetBatch(size int) PredictOption {
+	return func(p *PredictOptions) {
+		p.Batch = size
+	}
+}
+
+// SetKeep sets the number of tokens from initial prompt to keep.
+func SetNKeep(n int) PredictOption {
+	return func(p *PredictOptions) {
+		p.NKeep = n
+	}
+}
+
+// Create a new PredictOptions object with the given options.
+func NewPredictOptions(opts ...PredictOption) PredictOptions {
+	p := DefaultOptions
+	for _, opt := range opts {
+		opt(&p)
+	}
+	return p
+}
+
+// SetTailFreeSamplingZ sets the tail free sampling, parameter z.
+func SetTailFreeSamplingZ(tfz float64) PredictOption {
+	return func(p *PredictOptions) {
+		p.TailFreeSamplingZ = tfz
+	}
+}
+
+// SetTypicalP sets the typicality parameter, p_typical.
+func SetTypicalP(tp float64) PredictOption {
+	return func(p *PredictOptions) {
+		p.TypicalP = tp
+	}
+}
+
+// SetFrequencyPenalty sets the frequency penalty parameter, freq_penalty.
+func SetFrequencyPenalty(fp float64) PredictOption {
+	return func(p *PredictOptions) {
+		p.FrequencyPenalty = fp
+	}
+}
+
+// SetPresencePenalty sets the presence penalty parameter, presence_penalty.
+func SetPresencePenalty(pp float64) PredictOption {
+	return func(p *PredictOptions) {
+		p.PresencePenalty = pp
+	}
+}
+
+// SetMirostat sets the mirostat parameter.
+func SetMirostat(m int) PredictOption {
+	return func(p *PredictOptions) {
+		p.Mirostat = m
+	}
+}
+
+// SetMirostatETA sets the mirostat ETA parameter.
+func SetMirostatETA(me float64) PredictOption {
+	return func(p *PredictOptions) {
+		p.MirostatETA = me
+	}
+}
+
+// SetMirostatTAU sets the mirostat TAU parameter.
+func SetMirostatTAU(mt float64) PredictOption {
+	return func(p *PredictOptions) {
+		p.MirostatTAU = mt
+	}
+}
+
+// SetPenalizeNL sets whether to penalize newlines or not.
+func SetPenalizeNL(pnl bool) PredictOption {
+	return func(p *PredictOptions) {
+		p.PenalizeNL = pnl
+	}
+}
+
+// SetLogitBias sets the logit bias parameter.
+func SetLogitBias(lb string) PredictOption {
+	return func(p *PredictOptions) {
+		p.LogitBias = lb
+	}
+}

+ 1 - 1
main.go

@@ -1,7 +1,7 @@
 package main
 
 import (
-	"github.com/ollama/ollama/cmd"
+	"github.com/jmorganca/ollama/cmd"
 )
 
 func main() {

+ 2 - 2
server/routes.go

@@ -9,9 +9,9 @@ import (
 	"runtime"
 
 	"github.com/gin-gonic/gin"
-	llama "github.com/go-skynet/go-llama.cpp"
+	llama "github.com/jmorganca/ollama/llama"
 
-	"github.com/ollama/ollama/api"
+	"github.com/jmorganca/ollama/api"
 )
 
 func Serve(ln net.Listener) error {