llama_test.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. package llm
  2. import (
  3. "bytes"
  4. "testing"
  5. )
  6. func TestCheckStopConditions(t *testing.T) {
  7. tests := map[string]struct {
  8. b string
  9. stop []string
  10. wantB string
  11. wantStop bool
  12. wantEndsWithStopPrefix bool
  13. }{
  14. "not present": {
  15. b: "abc",
  16. stop: []string{"x"},
  17. wantStop: false,
  18. wantEndsWithStopPrefix: false,
  19. },
  20. "exact": {
  21. b: "abc",
  22. stop: []string{"abc"},
  23. wantStop: true,
  24. wantEndsWithStopPrefix: false,
  25. },
  26. "substring": {
  27. b: "abc",
  28. stop: []string{"b"},
  29. wantB: "a",
  30. wantStop: true,
  31. wantEndsWithStopPrefix: false,
  32. },
  33. "prefix 1": {
  34. b: "abc",
  35. stop: []string{"abcd"},
  36. wantStop: false,
  37. wantEndsWithStopPrefix: true,
  38. },
  39. "prefix 2": {
  40. b: "abc",
  41. stop: []string{"bcd"},
  42. wantStop: false,
  43. wantEndsWithStopPrefix: true,
  44. },
  45. "prefix 3": {
  46. b: "abc",
  47. stop: []string{"cd"},
  48. wantStop: false,
  49. wantEndsWithStopPrefix: true,
  50. },
  51. "no prefix": {
  52. b: "abc",
  53. stop: []string{"bx"},
  54. wantStop: false,
  55. wantEndsWithStopPrefix: false,
  56. },
  57. }
  58. for name, test := range tests {
  59. t.Run(name, func(t *testing.T) {
  60. var b bytes.Buffer
  61. b.WriteString(test.b)
  62. stop, endsWithStopPrefix := handleStopSequences(&b, test.stop)
  63. if test.wantB != "" {
  64. gotB := b.String()
  65. if gotB != test.wantB {
  66. t.Errorf("got b %q, want %q", gotB, test.wantB)
  67. }
  68. }
  69. if stop != test.wantStop {
  70. t.Errorf("got stop %v, want %v", stop, test.wantStop)
  71. }
  72. if endsWithStopPrefix != test.wantEndsWithStopPrefix {
  73. t.Errorf("got endsWithStopPrefix %v, want %v", endsWithStopPrefix, test.wantEndsWithStopPrefix)
  74. }
  75. })
  76. }
  77. }