parser.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. package parser
  2. import (
  3. "bufio"
  4. "bytes"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "strconv"
  9. "strings"
  10. )
  11. type Command struct {
  12. Name string
  13. Args string
  14. }
  15. type state int
  16. const (
  17. stateNil state = iota
  18. stateName
  19. stateValue
  20. stateParameter
  21. stateMessage
  22. stateComment
  23. )
  24. var (
  25. errMissingFrom = errors.New("no FROM line")
  26. errInvalidRole = errors.New("role must be one of \"system\", \"user\", or \"assistant\"")
  27. )
  28. func Parse(r io.Reader) (cmds []Command, err error) {
  29. var cmd Command
  30. var curr state
  31. var b bytes.Buffer
  32. var role string
  33. br := bufio.NewReader(r)
  34. for {
  35. r, _, err := br.ReadRune()
  36. if errors.Is(err, io.EOF) {
  37. break
  38. } else if err != nil {
  39. return nil, err
  40. }
  41. next, r, err := parseRuneForState(r, curr)
  42. if errors.Is(err, io.ErrUnexpectedEOF) {
  43. return nil, fmt.Errorf("%w: %s", err, b.String())
  44. } else if err != nil {
  45. return nil, err
  46. }
  47. // process the state transition, some transitions need to be intercepted and redirected
  48. if next != curr {
  49. switch curr {
  50. case stateName, stateParameter:
  51. // next state sometimes depends on the current buffer value
  52. switch s := strings.ToLower(b.String()); s {
  53. case "from":
  54. cmd.Name = "model"
  55. case "parameter":
  56. // transition to stateParameter which sets command name
  57. next = stateParameter
  58. case "message":
  59. // transition to stateMessage which validates the message role
  60. next = stateMessage
  61. fallthrough
  62. default:
  63. cmd.Name = s
  64. }
  65. case stateMessage:
  66. if !isValidMessageRole(b.String()) {
  67. return nil, errInvalidRole
  68. }
  69. role = b.String()
  70. case stateComment, stateNil:
  71. // pass
  72. case stateValue:
  73. s, ok := unquote(b.String())
  74. if !ok || isSpace(r) {
  75. if _, err := b.WriteRune(r); err != nil {
  76. return nil, err
  77. }
  78. continue
  79. }
  80. if role != "" {
  81. s = role + ": " + s
  82. role = ""
  83. }
  84. cmd.Args = s
  85. cmds = append(cmds, cmd)
  86. }
  87. b.Reset()
  88. curr = next
  89. }
  90. if strconv.IsPrint(r) {
  91. if _, err := b.WriteRune(r); err != nil {
  92. return nil, err
  93. }
  94. }
  95. }
  96. // flush the buffer
  97. switch curr {
  98. case stateComment, stateNil:
  99. // pass; nothing to flush
  100. case stateValue:
  101. s, ok := unquote(b.String())
  102. if !ok {
  103. return nil, io.ErrUnexpectedEOF
  104. }
  105. if role != "" {
  106. s = role + ": " + s
  107. }
  108. cmd.Args = s
  109. cmds = append(cmds, cmd)
  110. default:
  111. return nil, io.ErrUnexpectedEOF
  112. }
  113. for _, cmd := range cmds {
  114. if cmd.Name == "model" {
  115. return cmds, nil
  116. }
  117. }
  118. return nil, errMissingFrom
  119. }
  120. func parseRuneForState(r rune, cs state) (state, rune, error) {
  121. switch cs {
  122. case stateNil:
  123. switch {
  124. case r == '#':
  125. return stateComment, 0, nil
  126. case isSpace(r), isNewline(r):
  127. return stateNil, 0, nil
  128. default:
  129. return stateName, r, nil
  130. }
  131. case stateName:
  132. switch {
  133. case isAlpha(r):
  134. return stateName, r, nil
  135. case isSpace(r):
  136. return stateValue, 0, nil
  137. default:
  138. return stateNil, 0, errors.New("invalid")
  139. }
  140. case stateValue:
  141. switch {
  142. case isNewline(r):
  143. return stateNil, r, nil
  144. case isSpace(r):
  145. return stateNil, r, nil
  146. default:
  147. return stateValue, r, nil
  148. }
  149. case stateParameter:
  150. switch {
  151. case isAlpha(r), isNumber(r), r == '_':
  152. return stateParameter, r, nil
  153. case isSpace(r):
  154. return stateValue, 0, nil
  155. default:
  156. return stateNil, 0, io.ErrUnexpectedEOF
  157. }
  158. case stateMessage:
  159. switch {
  160. case isAlpha(r):
  161. return stateMessage, r, nil
  162. case isSpace(r):
  163. return stateValue, 0, nil
  164. default:
  165. return stateNil, 0, io.ErrUnexpectedEOF
  166. }
  167. case stateComment:
  168. switch {
  169. case isNewline(r):
  170. return stateNil, 0, nil
  171. default:
  172. return stateComment, 0, nil
  173. }
  174. default:
  175. return stateNil, 0, errors.New("")
  176. }
  177. }
  178. func unquote(s string) (string, bool) {
  179. if len(s) == 0 {
  180. return "", false
  181. }
  182. // TODO: single quotes
  183. if len(s) >= 3 && s[:3] == `"""` {
  184. if len(s) >= 6 && s[len(s)-3:] == `"""` {
  185. return s[3 : len(s)-3], true
  186. }
  187. return "", false
  188. }
  189. if len(s) >= 1 && s[0] == '"' {
  190. if len(s) >= 2 && s[len(s)-1] == '"' {
  191. return s[1 : len(s)-1], true
  192. }
  193. return "", false
  194. }
  195. return s, true
  196. }
  197. func isAlpha(r rune) bool {
  198. return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z'
  199. }
  200. func isNumber(r rune) bool {
  201. return r >= '0' && r <= '9'
  202. }
  203. func isSpace(r rune) bool {
  204. return r == ' ' || r == '\t'
  205. }
  206. func isNewline(r rune) bool {
  207. return r == '\r' || r == '\n'
  208. }
  209. func isValidMessageRole(role string) bool {
  210. return role == "system" || role == "user" || role == "assistant"
  211. }