|
@@ -12,6 +12,7 @@ import (
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
"github.com/ollama/ollama/api"
|
|
|
+ "github.com/ollama/ollama/types/model"
|
|
|
)
|
|
|
|
|
|
type Error struct {
|
|
@@ -85,6 +86,18 @@ type ChatCompletionChunk struct {
|
|
|
Choices []ChunkChoice `json:"choices"`
|
|
|
}
|
|
|
|
|
|
+type Model struct {
|
|
|
+ Id string `json:"id"`
|
|
|
+ Object string `json:"object"`
|
|
|
+ Created int64 `json:"created"`
|
|
|
+ OwnedBy string `json:"owned_by"`
|
|
|
+}
|
|
|
+
|
|
|
+type ListCompletion struct {
|
|
|
+ Object string `json:"object"`
|
|
|
+ Data []Model `json:"data"`
|
|
|
+}
|
|
|
+
|
|
|
func NewError(code int, message string) ErrorResponse {
|
|
|
var etype string
|
|
|
switch code {
|
|
@@ -145,7 +158,33 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func fromRequest(r ChatCompletionRequest) api.ChatRequest {
|
|
|
+func toListCompletion(r api.ListResponse) ListCompletion {
|
|
|
+ var data []Model
|
|
|
+ for _, m := range r.Models {
|
|
|
+ data = append(data, Model{
|
|
|
+ Id: m.Name,
|
|
|
+ Object: "model",
|
|
|
+ Created: m.ModifiedAt.Unix(),
|
|
|
+ OwnedBy: model.ParseName(m.Name).Namespace,
|
|
|
+ })
|
|
|
+ }
|
|
|
+
|
|
|
+ return ListCompletion{
|
|
|
+ Object: "list",
|
|
|
+ Data: data,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func toModel(r api.ShowResponse, m string) Model {
|
|
|
+ return Model{
|
|
|
+ Id: m,
|
|
|
+ Object: "model",
|
|
|
+ Created: r.ModifiedAt.Unix(),
|
|
|
+ OwnedBy: model.ParseName(m).Namespace,
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
|
|
|
var messages []api.Message
|
|
|
for _, msg := range r.Messages {
|
|
|
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
|
|
@@ -208,13 +247,26 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-type writer struct {
|
|
|
+type BaseWriter struct {
|
|
|
+ gin.ResponseWriter
|
|
|
+}
|
|
|
+
|
|
|
+type ChatWriter struct {
|
|
|
stream bool
|
|
|
id string
|
|
|
- gin.ResponseWriter
|
|
|
+ BaseWriter
|
|
|
}
|
|
|
|
|
|
-func (w *writer) writeError(code int, data []byte) (int, error) {
|
|
|
+type ListWriter struct {
|
|
|
+ BaseWriter
|
|
|
+}
|
|
|
+
|
|
|
+type RetrieveWriter struct {
|
|
|
+ BaseWriter
|
|
|
+ model string
|
|
|
+}
|
|
|
+
|
|
|
+func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
|
|
|
var serr api.StatusError
|
|
|
err := json.Unmarshal(data, &serr)
|
|
|
if err != nil {
|
|
@@ -230,7 +282,7 @@ func (w *writer) writeError(code int, data []byte) (int, error) {
|
|
|
return len(data), nil
|
|
|
}
|
|
|
|
|
|
-func (w *writer) writeResponse(data []byte) (int, error) {
|
|
|
+func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
|
|
var chatResponse api.ChatResponse
|
|
|
err := json.Unmarshal(data, &chatResponse)
|
|
|
if err != nil {
|
|
@@ -270,7 +322,7 @@ func (w *writer) writeResponse(data []byte) (int, error) {
|
|
|
return len(data), nil
|
|
|
}
|
|
|
|
|
|
-func (w *writer) Write(data []byte) (int, error) {
|
|
|
+func (w *ChatWriter) Write(data []byte) (int, error) {
|
|
|
code := w.ResponseWriter.Status()
|
|
|
if code != http.StatusOK {
|
|
|
return w.writeError(code, data)
|
|
@@ -279,7 +331,92 @@ func (w *writer) Write(data []byte) (int, error) {
|
|
|
return w.writeResponse(data)
|
|
|
}
|
|
|
|
|
|
-func Middleware() gin.HandlerFunc {
|
|
|
+func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
|
|
+ var listResponse api.ListResponse
|
|
|
+ err := json.Unmarshal(data, &listResponse)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
|
+ err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return len(data), nil
|
|
|
+}
|
|
|
+
|
|
|
+func (w *ListWriter) Write(data []byte) (int, error) {
|
|
|
+ code := w.ResponseWriter.Status()
|
|
|
+ if code != http.StatusOK {
|
|
|
+ return w.writeError(code, data)
|
|
|
+ }
|
|
|
+
|
|
|
+ return w.writeResponse(data)
|
|
|
+}
|
|
|
+
|
|
|
+func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
|
|
+ var showResponse api.ShowResponse
|
|
|
+ err := json.Unmarshal(data, &showResponse)
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ // retrieve completion
|
|
|
+ w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
|
|
+ err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
|
|
|
+ if err != nil {
|
|
|
+ return 0, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return len(data), nil
|
|
|
+}
|
|
|
+
|
|
|
+func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
|
|
+ code := w.ResponseWriter.Status()
|
|
|
+ if code != http.StatusOK {
|
|
|
+ return w.writeError(code, data)
|
|
|
+ }
|
|
|
+
|
|
|
+ return w.writeResponse(data)
|
|
|
+}
|
|
|
+
|
|
|
+func ListMiddleware() gin.HandlerFunc {
|
|
|
+ return func(c *gin.Context) {
|
|
|
+ w := &ListWriter{
|
|
|
+ BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
|
+ }
|
|
|
+
|
|
|
+ c.Writer = w
|
|
|
+
|
|
|
+ c.Next()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func RetrieveMiddleware() gin.HandlerFunc {
|
|
|
+ return func(c *gin.Context) {
|
|
|
+ var b bytes.Buffer
|
|
|
+ if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
|
|
+ c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ c.Request.Body = io.NopCloser(&b)
|
|
|
+
|
|
|
+ // response writer
|
|
|
+ w := &RetrieveWriter{
|
|
|
+ BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
|
+ model: c.Param("model"),
|
|
|
+ }
|
|
|
+
|
|
|
+ c.Writer = w
|
|
|
+
|
|
|
+ c.Next()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func ChatMiddleware() gin.HandlerFunc {
|
|
|
return func(c *gin.Context) {
|
|
|
var req ChatCompletionRequest
|
|
|
err := c.ShouldBindJSON(&req)
|
|
@@ -294,17 +431,17 @@ func Middleware() gin.HandlerFunc {
|
|
|
}
|
|
|
|
|
|
var b bytes.Buffer
|
|
|
- if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil {
|
|
|
+ if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
|
|
|
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
|
|
return
|
|
|
}
|
|
|
|
|
|
c.Request.Body = io.NopCloser(&b)
|
|
|
|
|
|
- w := &writer{
|
|
|
- ResponseWriter: c.Writer,
|
|
|
- stream: req.Stream,
|
|
|
- id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
|
|
+ w := &ChatWriter{
|
|
|
+ BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
|
|
+ stream: req.Stream,
|
|
|
+ id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
|
|
}
|
|
|
|
|
|
c.Writer = w
|