server_test.go 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. package llm
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "strings"
  7. "testing"
  8. "github.com/ollama/ollama/api"
  9. "golang.org/x/sync/semaphore"
  10. )
  11. func TestLLMServerCompletionFormat(t *testing.T) {
  12. // This test was written to fix an already deployed issue. It is a bit
  13. // of a mess, and but it's good enough, until we can refactoring the
  14. // Completion method to be more testable.
  15. ctx, cancel := context.WithCancel(context.Background())
  16. s := &llmServer{
  17. sem: semaphore.NewWeighted(1), // required to prevent nil panic
  18. }
  19. checkInvalid := func(format string) {
  20. t.Helper()
  21. err := s.Completion(ctx, CompletionRequest{
  22. Options: new(api.Options),
  23. Format: []byte(format),
  24. }, nil)
  25. want := fmt.Sprintf("invalid format: %q; expected \"json\" or a valid JSON Schema", format)
  26. if err == nil || !strings.Contains(err.Error(), want) {
  27. t.Fatalf("err = %v; want %q", err, want)
  28. }
  29. }
  30. checkInvalid("X") // invalid format
  31. checkInvalid(`"X"`) // invalid JSON Schema
  32. cancel() // prevent further processing if request makes it past the format check
  33. checkValid := func(err error) {
  34. t.Helper()
  35. if !errors.Is(err, context.Canceled) {
  36. t.Fatalf("Completion: err = %v; expected context.Canceled", err)
  37. }
  38. }
  39. valids := []string{
  40. // "missing"
  41. ``,
  42. `""`,
  43. `null`,
  44. // JSON
  45. `"json"`,
  46. `{"type":"object"}`,
  47. }
  48. for _, valid := range valids {
  49. err := s.Completion(ctx, CompletionRequest{
  50. Options: new(api.Options),
  51. Format: []byte(valid),
  52. }, nil)
  53. checkValid(err)
  54. }
  55. err := s.Completion(ctx, CompletionRequest{
  56. Options: new(api.Options),
  57. Format: nil, // missing format
  58. }, nil)
  59. checkValid(err)
  60. }