file.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. // Package model implements the Modelfile and Path formats.
  2. package model
  3. import (
  4. "bufio"
  5. "io"
  6. "iter"
  7. "strings"
  8. )
  9. type ParameterPragma struct {
  10. Key string
  11. Value string
  12. }
  13. type MessagePragma struct {
  14. Role string
  15. Content string
  16. }
  17. type File struct {
  18. // From is a required pragma that specifies the source of the model,
  19. // either on disk, or by reference (see blob.ParseRef).
  20. From string
  21. // Optional
  22. Params []ParameterPragma
  23. Template string
  24. System string
  25. Adapter string
  26. Messages []MessagePragma
  27. License string
  28. }
  29. type FileError struct {
  30. Pragma string
  31. Message string
  32. }
  33. func (e *FileError) Error() string {
  34. return e.Pragma + ": " + e.Message
  35. }
  36. // Pragma represents a single pragma in a Modelfile.
  37. type Pragma struct {
  38. // The pragma name
  39. Name string
  40. // Args contains the user-defined arguments for the pragma. If no
  41. // arguments were provided, it is nil.
  42. Args []string
  43. }
  44. func (p Pragma) Arg(i int) string {
  45. if i >= len(p.Args) {
  46. return ""
  47. }
  48. return p.Args[i]
  49. }
  50. func FilePragmas(r io.Reader) iter.Seq2[Pragma, error] {
  51. return func(yield func(Pragma, error) bool) {
  52. sc := bufio.NewScanner(r)
  53. for sc.Scan() {
  54. line := sc.Text()
  55. // TODO(bmizerany): set a max num fields/args to
  56. // prevent mem bloat
  57. args := strings.Fields(line)
  58. if len(args) == 0 {
  59. continue
  60. }
  61. p := Pragma{
  62. Name: strings.ToUpper(args[0]),
  63. }
  64. if p.Name == "MESSAGE" {
  65. // handle special case where message content
  66. // is space separated on the _rest_ of the
  67. // line like: `MESSAGE user Is Ontario in
  68. // Canada?`
  69. panic("TODO")
  70. }
  71. if len(args) > 1 {
  72. p.Args = args[1:]
  73. }
  74. if !yield(p, nil) {
  75. return
  76. }
  77. }
  78. if sc.Err() != nil {
  79. yield(Pragma{}, sc.Err())
  80. }
  81. }
  82. }
  83. func ParseFile(r io.Reader) (File, error) {
  84. var f File
  85. for p, err := range FilePragmas(r) {
  86. if err != nil {
  87. return File{}, err
  88. }
  89. switch p.Name {
  90. case "FROM":
  91. f.From = p.Arg(0)
  92. case "PARAMETER":
  93. f.Params = append(f.Params, ParameterPragma{
  94. Key: strings.ToLower(p.Arg(0)),
  95. Value: p.Arg(1),
  96. })
  97. case "TEMPLATE":
  98. f.Template = p.Arg(0)
  99. case "SYSTEM":
  100. f.System = p.Arg(0)
  101. case "ADAPTER":
  102. f.Adapter = p.Arg(0)
  103. case "MESSAGE":
  104. f.Messages = append(f.Messages, MessagePragma{
  105. Role: p.Arg(0),
  106. Content: p.Arg(1),
  107. })
  108. case "LICENSE":
  109. f.License = p.Arg(0)
  110. }
  111. }
  112. return f, nil
  113. }