images_test.go 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. package server
  2. import (
  3. "strings"
  4. "testing"
  5. "github.com/jmorganca/ollama/api"
  6. )
  7. func TestChat(t *testing.T) {
  8. tests := []struct {
  9. name string
  10. template string
  11. msgs []api.Message
  12. want string
  13. wantErr string
  14. }{
  15. {
  16. name: "Single Message",
  17. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  18. msgs: []api.Message{
  19. {
  20. Role: "system",
  21. Content: "You are a Wizard.",
  22. },
  23. {
  24. Role: "user",
  25. Content: "What are the potion ingredients?",
  26. },
  27. },
  28. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
  29. },
  30. {
  31. name: "Message History",
  32. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  33. msgs: []api.Message{
  34. {
  35. Role: "system",
  36. Content: "You are a Wizard.",
  37. },
  38. {
  39. Role: "user",
  40. Content: "What are the potion ingredients?",
  41. },
  42. {
  43. Role: "assistant",
  44. Content: "sugar",
  45. },
  46. {
  47. Role: "user",
  48. Content: "Anything else?",
  49. },
  50. },
  51. want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
  52. },
  53. {
  54. name: "Assistant Only",
  55. template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
  56. msgs: []api.Message{
  57. {
  58. Role: "assistant",
  59. Content: "everything nice",
  60. },
  61. },
  62. want: "[INST] [/INST]everything nice",
  63. },
  64. {
  65. name: "Invalid Role",
  66. msgs: []api.Message{
  67. {
  68. Role: "not-a-role",
  69. Content: "howdy",
  70. },
  71. },
  72. wantErr: "invalid role: not-a-role",
  73. },
  74. }
  75. for _, tt := range tests {
  76. m := Model{
  77. Template: tt.template,
  78. }
  79. t.Run(tt.name, func(t *testing.T) {
  80. got, err := m.ChatPrompt(tt.msgs)
  81. if tt.wantErr != "" {
  82. if err == nil {
  83. t.Errorf("ChatPrompt() expected error, got nil")
  84. }
  85. if !strings.Contains(err.Error(), tt.wantErr) {
  86. t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
  87. }
  88. }
  89. if got != tt.want {
  90. t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
  91. }
  92. })
  93. }
  94. }