file.go 5.7 KB


  1. package model
  2. import (
  3. "bufio"
  4. "bytes"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "strconv"
  9. "strings"
  10. )
  11. type File struct {
  12. Commands []Command
  13. }
  14. func (f File) String() string {
  15. var sb strings.Builder
  16. for _, cmd := range f.Commands {
  17. fmt.Fprintln(&sb, cmd.String())
  18. }
  19. return sb.String()
  20. }
  21. type Command struct {
  22. Name string
  23. Args string
  24. }
  25. func (c Command) String() string {
  26. var sb strings.Builder
  27. switch c.Name {
  28. case "model":
  29. fmt.Fprintf(&sb, "FROM %s", c.Args)
  30. case "license", "template", "system", "adapter":
  31. fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
  32. case "message":
  33. role, message, _ := strings.Cut(c.Args, ": ")
  34. fmt.Fprintf(&sb, "MESSAGE %s %s", role, quote(message))
  35. default:
  36. fmt.Fprintf(&sb, "PARAMETER %s %s", c.Name, quote(c.Args))
  37. }
  38. return sb.String()
  39. }
  40. type state int
  41. const (
  42. stateNil state = iota
  43. stateName
  44. stateValue
  45. stateParameter
  46. stateMessage
  47. stateComment
  48. )
  49. var (
  50. errMissingFrom = errors.New("no FROM line")
  51. errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
  52. errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
  53. )
  54. func ParseFile(r io.Reader) (*File, error) {
  55. var cmd Command
  56. var curr state
  57. var b bytes.Buffer
  58. var role string
  59. var f File
  60. br := bufio.NewReader(r)
  61. for {
  62. r, _, err := br.ReadRune()
  63. if errors.Is(err, io.EOF) {
  64. break
  65. } else if err != nil {
  66. return nil, err
  67. }
  68. next, r, err := parseRuneForState(r, curr)
  69. if errors.Is(err, io.ErrUnexpectedEOF) {
  70. return nil, fmt.Errorf("%w: %s", err, b.String())
  71. } else if err != nil {
  72. return nil, err
  73. }
  74. // process the state transition, some transitions need to be intercepted and redirected
  75. if next != curr {
  76. switch curr {
  77. case stateName:
  78. if !isValidCommand(b.String()) {
  79. return nil, errInvalidCommand
  80. }
  81. // next state sometimes depends on the current buffer value
  82. switch s := strings.ToLower(b.String()); s {
  83. case "from":
  84. cmd.Name = "model"
  85. case "parameter":
  86. // transition to stateParameter which sets command name
  87. next = stateParameter
  88. case "message":
  89. // transition to stateMessage which validates the message role
  90. next = stateMessage
  91. fallthrough
  92. default:
  93. cmd.Name = s
  94. }
  95. case stateParameter:
  96. cmd.Name = b.String()
  97. case stateMessage:
  98. if !isValidMessageRole(b.String()) {
  99. return nil, errInvalidMessageRole
  100. }
  101. role = b.String()
  102. case stateComment, stateNil:
  103. // pass
  104. case stateValue:
  105. s, ok := unquote(b.String())
  106. if !ok || isSpace(r) {
  107. if _, err := b.WriteRune(r); err != nil {
  108. return nil, err
  109. }
  110. continue
  111. }
  112. if role != "" {
  113. s = role + ": " + s
  114. role = ""
  115. }
  116. cmd.Args = s
  117. f.Commands = append(f.Commands, cmd)
  118. }
  119. b.Reset()
  120. curr = next
  121. }
  122. if strconv.IsPrint(r) {
  123. if _, err := b.WriteRune(r); err != nil {
  124. return nil, err
  125. }
  126. }
  127. }
  128. // flush the buffer
  129. switch curr {
  130. case stateComment, stateNil:
  131. // pass; nothing to flush
  132. case stateValue:
  133. s, ok := unquote(b.String())
  134. if !ok {
  135. return nil, io.ErrUnexpectedEOF
  136. }
  137. if role != "" {
  138. s = role + ": " + s
  139. }
  140. cmd.Args = s
  141. f.Commands = append(f.Commands, cmd)
  142. default:
  143. return nil, io.ErrUnexpectedEOF
  144. }
  145. for _, cmd := range f.Commands {
  146. if cmd.Name == "model" {
  147. return &f, nil
  148. }
  149. }
  150. return nil, errMissingFrom
  151. }
  152. func parseRuneForState(r rune, cs state) (state, rune, error) {
  153. switch cs {
  154. case stateNil:
  155. switch {
  156. case r == '#':
  157. return stateComment, 0, nil
  158. case isSpace(r), isNewline(r):
  159. return stateNil, 0, nil
  160. default:
  161. return stateName, r, nil
  162. }
  163. case stateName:
  164. switch {
  165. case isAlpha(r):
  166. return stateName, r, nil
  167. case isSpace(r):
  168. return stateValue, 0, nil
  169. default:
  170. return stateNil, 0, errInvalidCommand
  171. }
  172. case stateValue:
  173. switch {
  174. case isNewline(r):
  175. return stateNil, r, nil
  176. case isSpace(r):
  177. return stateNil, r, nil
  178. default:
  179. return stateValue, r, nil
  180. }
  181. case stateParameter:
  182. switch {
  183. case isAlpha(r), isNumber(r), r == '_':
  184. return stateParameter, r, nil
  185. case isSpace(r):
  186. return stateValue, 0, nil
  187. default:
  188. return stateNil, 0, io.ErrUnexpectedEOF
  189. }
  190. case stateMessage:
  191. switch {
  192. case isAlpha(r):
  193. return stateMessage, r, nil
  194. case isSpace(r):
  195. return stateValue, 0, nil
  196. default:
  197. return stateNil, 0, io.ErrUnexpectedEOF
  198. }
  199. case stateComment:
  200. switch {
  201. case isNewline(r):
  202. return stateNil, 0, nil
  203. default:
  204. return stateComment, 0, nil
  205. }
  206. default:
  207. return stateNil, 0, errors.New("")
  208. }
  209. }
  210. func quote(s string) string {
  211. if strings.Contains(s, "\n") || strings.HasPrefix(s, " ") || strings.HasSuffix(s, " ") {
  212. if strings.Contains(s, "\"") {
  213. return `"""` + s + `"""`
  214. }
  215. return `"` + s + `"`
  216. }
  217. return s
  218. }
  219. func unquote(s string) (string, bool) {
  220. if len(s) == 0 {
  221. return "", false
  222. }
  223. // TODO: single quotes
  224. if len(s) >= 3 && s[:3] == `"""` {
  225. if len(s) >= 6 && s[len(s)-3:] == `"""` {
  226. return s[3 : len(s)-3], true
  227. }
  228. return "", false
  229. }
  230. if len(s) >= 1 && s[0] == '"' {
  231. if len(s) >= 2 && s[len(s)-1] == '"' {
  232. return s[1 : len(s)-1], true
  233. }
  234. return "", false
  235. }
  236. return s, true
  237. }
  238. func isAlpha(r rune) bool {
  239. return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z'
  240. }
  241. func isNumber(r rune) bool {
  242. return r >= '0' && r <= '9'
  243. }
  244. func isSpace(r rune) bool {
  245. return r == ' ' || r == '\t'
  246. }
  247. func isNewline(r rune) bool {
  248. return r == '\r' || r == '\n'
  249. }
  250. func isValidMessageRole(role string) bool {
  251. return role == "system" || role == "user" || role == "assistant"
  252. }
  253. func isValidCommand(cmd string) bool {
  254. switch strings.ToLower(cmd) {
  255. case "from", "license", "template", "system", "adapter", "parameter", "message":
  256. return true
  257. default:
  258. return false
  259. }
  260. }