server_test.go 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. package llm
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "testing"
  8. "github.com/ollama/ollama/api"
  9. "golang.org/x/sync/semaphore"
  10. )
  11. func TestLLMServerCompletionFormat(t *testing.T) {
  12. // This test was written to fix an already deployed issue. It is a bit
  13. // of a mess, and but it's good enough, until we can refactoring the
  14. // Completion method to be more testable.
  15. ctx, cancel := context.WithCancel(context.Background())
  16. s := &llmServer{
  17. sem: semaphore.NewWeighted(1), // required to prevent nil panic
  18. }
  19. checkInvalid := func(format string) {
  20. t.Helper()
  21. err := s.Completion(ctx, CompletionRequest{
  22. Options: new(api.Options),
  23. Format: []byte(format),
  24. }, nil)
  25. want := fmt.Sprintf("invalid format: %q; expected \"json\" or a valid JSON Schema", format)
  26. if err == nil || !strings.Contains(err.Error(), want) {
  27. t.Fatalf("err = %v; want %q", err, want)
  28. }
  29. }
  30. checkInvalid("X") // invalid format
  31. checkInvalid(`"X"`) // invalid JSON Schema
  32. cancel() // prevent further processing if request makes it past the format check
  33. checkCanceled := func(err error) {
  34. t.Helper()
  35. if !errors.Is(err, context.Canceled) {
  36. t.Fatalf("Completion: err = %v; expected context.Canceled", err)
  37. }
  38. }
  39. valids := []string{`"json"`, `{"type":"object"}`, ``, `""`, `null`}
  40. for _, valid := range valids {
  41. err := s.Completion(ctx, CompletionRequest{
  42. Options: new(api.Options),
  43. Format: []byte(valid),
  44. }, nil)
  45. checkCanceled(err)
  46. }
  47. err := s.Completion(ctx, CompletionRequest{
  48. Options: new(api.Options),
  49. Format: nil, // missing format
  50. }, nil)
  51. checkCanceled(err)
  52. }