Roy Han 10 månader sedan
förälder
incheckning
00a4cb26ca
6 ändrade filer med 27 tillägg och 27 borttagningar
  1. 7 7
      api/types.go
  2. 4 4
      format/normalize.go
  3. 8 8
      format/normalize_test.go
  4. 3 3
      llm/server.go
  5. 3 3
      server/routes.go
  6. 2 2
      server/sched_test.go

+ 7 - 7
api/types.go

@@ -210,7 +210,7 @@ type EmbedRequest struct {
 	Model string `json:"model"`
 	Model string `json:"model"`
 
 
 	// Input is the input to embed.
 	// Input is the input to embed.
-	Input any `json:"input,omitempty"`
+	Input any `json:"input"`
 
 
 	// KeepAlive controls how long the model will stay loaded in memory following
 	// KeepAlive controls how long the model will stay loaded in memory following
 	// this request.
 	// this request.
@@ -222,6 +222,12 @@ type EmbedRequest struct {
 	Options map[string]interface{} `json:"options"`
 	Options map[string]interface{} `json:"options"`
 }
 }
 
 
+// EmbedResponse is the response from [Client.Embed].
+type EmbedResponse struct {
+	Model      string      `json:"model"`
+	Embeddings [][]float32 `json:"embeddings,omitempty"`
+}
+
 // EmbeddingRequest is the request passed to [Client.Embeddings].
 // EmbeddingRequest is the request passed to [Client.Embeddings].
 type EmbeddingRequest struct {
 type EmbeddingRequest struct {
 	// Model is the model name.
 	// Model is the model name.
@@ -238,12 +244,6 @@ type EmbeddingRequest struct {
 	Options map[string]interface{} `json:"options"`
 	Options map[string]interface{} `json:"options"`
 }
 }
 
 
-// EmbedResponse is the response from [Client.Embed].
-type EmbedResponse struct {
-	Model      string      `json:"model"`
-	Embeddings [][]float64 `json:"embeddings,omitempty"`
-}
-
 // EmbeddingResponse is the response from [Client.Embeddings].
 // EmbeddingResponse is the response from [Client.Embeddings].
 type EmbeddingResponse struct {
 type EmbeddingResponse struct {
 	Embedding []float64 `json:"embedding"`
 	Embedding []float64 `json:"embedding"`

+ 4 - 4
format/normalize.go

@@ -2,18 +2,18 @@ package format
 
 
 import "math"
 import "math"
 
 
-func Normalize(vec []float64) []float64 {
+func Normalize(vec []float32) []float32 {
 	var sum float64
 	var sum float64
 	for _, v := range vec {
 	for _, v := range vec {
-		sum += v * v
+		sum += float64(v * v)
 	}
 	}
 
 
 	sum = math.Sqrt(sum)
 	sum = math.Sqrt(sum)
 
 
-	var norm float64
+	var norm float32
 
 
 	if sum > 0 {
 	if sum > 0 {
-		norm = 1.0 / sum
+		norm = float32(1.0 / sum)
 	} else {
 	} else {
 		norm = 0.0
 		norm = 0.0
 	}
 	}

+ 8 - 8
format/normalize_test.go

@@ -7,21 +7,21 @@ import (
 
 
 func TestNormalize(t *testing.T) {
 func TestNormalize(t *testing.T) {
 	type testCase struct {
 	type testCase struct {
-		input []float64
+		input []float32
 	}
 	}
 
 
 	testCases := []testCase{
 	testCases := []testCase{
-		{input: []float64{1}},
-		{input: []float64{0, 1, 2, 3}},
-		{input: []float64{0.1, 0.2, 0.3}},
-		{input: []float64{-0.1, 0.2, 0.3, -0.4}},
-		{input: []float64{0, 0, 0}},
+		{input: []float32{1}},
+		{input: []float32{0, 1, 2, 3}},
+		{input: []float32{0.1, 0.2, 0.3}},
+		{input: []float32{-0.1, 0.2, 0.3, -0.4}},
+		{input: []float32{0, 0, 0}},
 	}
 	}
 
 
-	assertNorm := func(vec []float64) (res bool) {
+	assertNorm := func(vec []float32) (res bool) {
 		sum := 0.0
 		sum := 0.0
 		for _, v := range vec {
 		for _, v := range vec {
-			sum += v * v
+			sum += float64(v * v)
 		}
 		}
 		if math.Abs(sum-1) > 1e-6 {
 		if math.Abs(sum-1) > 1e-6 {
 			return sum == 0
 			return sum == 0

+ 3 - 3
llm/server.go

@@ -34,7 +34,7 @@ type LlamaServer interface {
 	WaitUntilRunning(ctx context.Context) error
 	WaitUntilRunning(ctx context.Context) error
 	Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
 	Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
 	Embedding(ctx context.Context, prompt string) ([]float64, error)
 	Embedding(ctx context.Context, prompt string) ([]float64, error)
-	Embed(ctx context.Context, input []string) ([][]float64, error)
+	Embed(ctx context.Context, input []string) ([][]float32, error)
 	Tokenize(ctx context.Context, content string) ([]int, error)
 	Tokenize(ctx context.Context, content string) ([]int, error)
 	Detokenize(ctx context.Context, tokens []int) (string, error)
 	Detokenize(ctx context.Context, tokens []int) (string, error)
 	Close() error
 	Close() error
@@ -847,10 +847,10 @@ type EmbedRequest struct {
 }
 }
 
 
 type EmbedResponse struct {
 type EmbedResponse struct {
-	Embedding [][]float64 `json:"embedding"`
+	Embedding [][]float32 `json:"embedding"`
 }
 }
 
 
-func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float64, error) {
+func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
 	if err := s.sem.Acquire(ctx, 1); err != nil {
 	if err := s.sem.Acquire(ctx, 1); err != nil {
 		slog.Error("Failed to acquire semaphore", "error", err)
 		slog.Error("Failed to acquire semaphore", "error", err)
 		return nil, err
 		return nil, err

+ 3 - 3
server/routes.go

@@ -414,12 +414,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		return s, nil
 		return s, nil
 	}
 	}
 
 
-	embeddings := [][]float64{}
+	embeddings := [][]float32{}
 
 
 	switch reqEmbed := req.Input.(type) {
 	switch reqEmbed := req.Input.(type) {
 	case string:
 	case string:
 		if reqEmbed == "" {
 		if reqEmbed == "" {
-			c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
+			c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
 			return
 			return
 		}
 		}
 		reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
 		reqEmbed, err = checkFit(reqEmbed, *req.Truncate)
@@ -430,7 +430,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
 		embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
 		embeddings, err = runner.llama.Embed(c.Request.Context(), []string{reqEmbed})
 	case []any:
 	case []any:
 		if reqEmbed == nil {
 		if reqEmbed == nil {
-			c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float64{}})
+			c.JSON(http.StatusOK, api.EmbedResponse{Embeddings: [][]float32{}})
 			return
 			return
 		}
 		}
 
 

+ 2 - 2
server/sched_test.go

@@ -610,7 +610,7 @@ type mockLlm struct {
 	completionResp     error
 	completionResp     error
 	embeddingResp      []float64
 	embeddingResp      []float64
 	embeddingRespErr   error
 	embeddingRespErr   error
-	embedResp          [][]float64
+	embedResp          [][]float32
 	embedRespErr       error
 	embedRespErr       error
 	tokenizeResp       []int
 	tokenizeResp       []int
 	tokenizeRespErr    error
 	tokenizeRespErr    error
@@ -631,7 +631,7 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
 func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
 func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
 	return s.embeddingResp, s.embeddingRespErr
 	return s.embeddingResp, s.embeddingRespErr
 }
 }
-func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float64, error) {
+func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) {
 	return s.embedResp, s.embedRespErr
 	return s.embedResp, s.embedRespErr
 }
 }
 func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
 func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {