checker.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. package module
  2. import (
  3. "reflect"
  4. "regexp"
  5. "time"
  6. "unicode/utf8"
  7. "github.com/envoyproxy/protoc-gen-validate/validate"
  8. "github.com/lyft/protoc-gen-star"
  9. "google.golang.org/protobuf/proto"
  10. "google.golang.org/protobuf/types/known/durationpb"
  11. "google.golang.org/protobuf/types/known/timestamppb"
  12. )
  13. var unknown = ""
  14. var httpHeaderName = "^:?[0-9a-zA-Z!#$%&'*+-.^_|~\x60]+$"
  15. var httpHeaderValue = "^[^\u0000-\u0008\u000A-\u001F\u007F]*$"
  16. var headerString = "^[^\u0000\u000A\u000D]*$" // For non-strict validation.
  17. // Map from well known regex to regex pattern.
  18. var regex_map = map[string]*string{
  19. "UNKNOWN": &unknown,
  20. "HTTP_HEADER_NAME": &httpHeaderName,
  21. "HTTP_HEADER_VALUE": &httpHeaderValue,
  22. "HEADER_STRING": &headerString,
  23. }
  24. type FieldType interface {
  25. ProtoType() pgs.ProtoType
  26. Embed() pgs.Message
  27. }
  28. type Repeatable interface {
  29. IsRepeated() bool
  30. }
  31. func (m *Module) CheckRules(msg pgs.Message) {
  32. m.Push("msg: " + msg.Name().String())
  33. defer m.Pop()
  34. var disabled bool
  35. _, err := msg.Extension(validate.E_Disabled, &disabled)
  36. m.CheckErr(err, "unable to read validation extension from message")
  37. if disabled {
  38. m.Debug("validation disabled, skipping checks")
  39. return
  40. }
  41. for _, f := range msg.Fields() {
  42. m.Push(f.Name().String())
  43. var rules validate.FieldRules
  44. _, err = f.Extension(validate.E_Rules, &rules)
  45. m.CheckErr(err, "unable to read validation rules from field")
  46. if rules.GetMessage() != nil {
  47. m.MustType(f.Type(), pgs.MessageT, pgs.UnknownWKT)
  48. m.CheckMessage(f, &rules)
  49. }
  50. m.CheckFieldRules(f.Type(), &rules)
  51. m.Pop()
  52. }
  53. }
  54. func (m *Module) CheckFieldRules(typ FieldType, rules *validate.FieldRules) {
  55. if rules == nil {
  56. return
  57. }
  58. switch r := rules.Type.(type) {
  59. case *validate.FieldRules_Float:
  60. m.MustType(typ, pgs.FloatT, pgs.FloatValueWKT)
  61. m.CheckFloat(r.Float)
  62. case *validate.FieldRules_Double:
  63. m.MustType(typ, pgs.DoubleT, pgs.DoubleValueWKT)
  64. m.CheckDouble(r.Double)
  65. case *validate.FieldRules_Int32:
  66. m.MustType(typ, pgs.Int32T, pgs.Int32ValueWKT)
  67. m.CheckInt32(r.Int32)
  68. case *validate.FieldRules_Int64:
  69. m.MustType(typ, pgs.Int64T, pgs.Int64ValueWKT)
  70. m.CheckInt64(r.Int64)
  71. case *validate.FieldRules_Uint32:
  72. m.MustType(typ, pgs.UInt32T, pgs.UInt32ValueWKT)
  73. m.CheckUInt32(r.Uint32)
  74. case *validate.FieldRules_Uint64:
  75. m.MustType(typ, pgs.UInt64T, pgs.UInt64ValueWKT)
  76. m.CheckUInt64(r.Uint64)
  77. case *validate.FieldRules_Sint32:
  78. m.MustType(typ, pgs.SInt32, pgs.UnknownWKT)
  79. m.CheckSInt32(r.Sint32)
  80. case *validate.FieldRules_Sint64:
  81. m.MustType(typ, pgs.SInt64, pgs.UnknownWKT)
  82. m.CheckSInt64(r.Sint64)
  83. case *validate.FieldRules_Fixed32:
  84. m.MustType(typ, pgs.Fixed32T, pgs.UnknownWKT)
  85. m.CheckFixed32(r.Fixed32)
  86. case *validate.FieldRules_Fixed64:
  87. m.MustType(typ, pgs.Fixed64T, pgs.UnknownWKT)
  88. m.CheckFixed64(r.Fixed64)
  89. case *validate.FieldRules_Sfixed32:
  90. m.MustType(typ, pgs.SFixed32, pgs.UnknownWKT)
  91. m.CheckSFixed32(r.Sfixed32)
  92. case *validate.FieldRules_Sfixed64:
  93. m.MustType(typ, pgs.SFixed64, pgs.UnknownWKT)
  94. m.CheckSFixed64(r.Sfixed64)
  95. case *validate.FieldRules_Bool:
  96. m.MustType(typ, pgs.BoolT, pgs.BoolValueWKT)
  97. case *validate.FieldRules_String_:
  98. m.MustType(typ, pgs.StringT, pgs.StringValueWKT)
  99. m.CheckString(r.String_)
  100. case *validate.FieldRules_Bytes:
  101. m.MustType(typ, pgs.BytesT, pgs.BytesValueWKT)
  102. m.CheckBytes(r.Bytes)
  103. case *validate.FieldRules_Enum:
  104. m.MustType(typ, pgs.EnumT, pgs.UnknownWKT)
  105. m.CheckEnum(typ, r.Enum)
  106. case *validate.FieldRules_Repeated:
  107. m.CheckRepeated(typ, r.Repeated)
  108. case *validate.FieldRules_Map:
  109. m.CheckMap(typ, r.Map)
  110. case *validate.FieldRules_Any:
  111. m.CheckAny(typ, r.Any)
  112. case *validate.FieldRules_Duration:
  113. m.CheckDuration(typ, r.Duration)
  114. case *validate.FieldRules_Timestamp:
  115. m.CheckTimestamp(typ, r.Timestamp)
  116. case nil: // noop
  117. default:
  118. m.Failf("unknown rule type (%T)", rules.Type)
  119. }
  120. }
  121. func (m *Module) MustType(typ FieldType, pt pgs.ProtoType, wrapper pgs.WellKnownType) {
  122. if emb := typ.Embed(); emb != nil && emb.IsWellKnown() && emb.WellKnownType() == wrapper {
  123. m.MustType(emb.Fields()[0].Type(), pt, pgs.UnknownWKT)
  124. return
  125. }
  126. if typ, ok := typ.(Repeatable); ok {
  127. m.Assert(!typ.IsRepeated(),
  128. "repeated rule should be used for repeated fields")
  129. }
  130. m.Assert(typ.ProtoType() == pt,
  131. " expected rules for ",
  132. typ.ProtoType().Proto(),
  133. " but got ",
  134. pt.Proto(),
  135. )
  136. }
  137. func (m *Module) CheckFloat(r *validate.FloatRules) {
  138. m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
  139. }
  140. func (m *Module) CheckDouble(r *validate.DoubleRules) {
  141. m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
  142. }
  143. func (m *Module) CheckInt32(r *validate.Int32Rules) {
  144. m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
  145. }
  146. func (m *Module) CheckInt64(r *validate.Int64Rules) {
  147. m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
  148. }
  149. func (m *Module) CheckUInt32(r *validate.UInt32Rules) {
  150. m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
  151. }
  152. func (m *Module) CheckUInt64(r *validate.UInt64Rules) {
  153. m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
  154. }
  155. func (m *Module) CheckSInt32(r *validate.SInt32Rules) {
  156. m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
  157. }
  158. func (m *Module) CheckSInt64(r *validate.SInt64Rules) {
  159. m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
  160. }
  161. func (m *Module) CheckFixed32(r *validate.Fixed32Rules) {
  162. m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
  163. }
  164. func (m *Module) CheckFixed64(r *validate.Fixed64Rules) {
  165. m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
  166. }
  167. func (m *Module) CheckSFixed32(r *validate.SFixed32Rules) {
  168. m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
  169. }
  170. func (m *Module) CheckSFixed64(r *validate.SFixed64Rules) {
  171. m.checkNums(len(r.In), len(r.NotIn), r.Const, r.Lt, r.Lte, r.Gt, r.Gte)
  172. }
  173. func (m *Module) CheckString(r *validate.StringRules) {
  174. m.checkLen(r.Len, r.MinLen, r.MaxLen)
  175. m.checkLen(r.LenBytes, r.MinBytes, r.MaxBytes)
  176. m.checkMinMax(r.MinLen, r.MaxLen)
  177. m.checkMinMax(r.MinBytes, r.MaxBytes)
  178. m.checkIns(len(r.In), len(r.NotIn))
  179. m.checkWellKnownRegex(r.GetWellKnownRegex(), r)
  180. m.checkPattern(r.Pattern, len(r.In))
  181. if r.MaxLen != nil {
  182. max := int(r.GetMaxLen())
  183. m.Assert(utf8.RuneCountInString(r.GetPrefix()) <= max, "`prefix` length exceeds the `max_len`")
  184. m.Assert(utf8.RuneCountInString(r.GetSuffix()) <= max, "`suffix` length exceeds the `max_len`")
  185. m.Assert(utf8.RuneCountInString(r.GetContains()) <= max, "`contains` length exceeds the `max_len`")
  186. m.Assert(
  187. r.MaxBytes == nil || r.GetMaxBytes() >= r.GetMaxLen(),
  188. "`max_len` cannot exceed `max_bytes`")
  189. }
  190. if r.MaxBytes != nil {
  191. max := int(r.GetMaxBytes())
  192. m.Assert(len(r.GetPrefix()) <= max, "`prefix` length exceeds the `max_bytes`")
  193. m.Assert(len(r.GetSuffix()) <= max, "`suffix` length exceeds the `max_bytes`")
  194. m.Assert(len(r.GetContains()) <= max, "`contains` length exceeds the `max_bytes`")
  195. }
  196. }
  197. func (m *Module) CheckBytes(r *validate.BytesRules) {
  198. m.checkMinMax(r.MinLen, r.MaxLen)
  199. m.checkIns(len(r.In), len(r.NotIn))
  200. m.checkPattern(r.Pattern, len(r.In))
  201. if r.MaxLen != nil {
  202. max := int(r.GetMaxLen())
  203. m.Assert(len(r.GetPrefix()) <= max, "`prefix` length exceeds the `max_len`")
  204. m.Assert(len(r.GetSuffix()) <= max, "`suffix` length exceeds the `max_len`")
  205. m.Assert(len(r.GetContains()) <= max, "`contains` length exceeds the `max_len`")
  206. }
  207. }
  208. func (m *Module) CheckEnum(ft FieldType, r *validate.EnumRules) {
  209. m.checkIns(len(r.In), len(r.NotIn))
  210. if r.GetDefinedOnly() && len(r.In) > 0 {
  211. typ, ok := ft.(interface {
  212. Enum() pgs.Enum
  213. })
  214. if !ok {
  215. m.Failf("unexpected field type (%T)", ft)
  216. }
  217. defined := typ.Enum().Values()
  218. vals := make(map[int32]struct{}, len(defined))
  219. for _, val := range defined {
  220. vals[val.Value()] = struct{}{}
  221. }
  222. for _, in := range r.In {
  223. if _, ok = vals[in]; !ok {
  224. m.Failf("undefined `in` value (%d) conflicts with `defined_only` rule")
  225. }
  226. }
  227. }
  228. }
  229. func (m *Module) CheckMessage(f pgs.Field, rules *validate.FieldRules) {
  230. m.Assert(f.Type().IsEmbed(), "field is not embedded but got message rules")
  231. emb := f.Type().Embed()
  232. if emb != nil && emb.IsWellKnown() {
  233. switch emb.WellKnownType() {
  234. case pgs.AnyWKT:
  235. m.Failf("Any rules should be used for Any fields")
  236. case pgs.DurationWKT:
  237. m.Failf("Duration rules should be used for Duration fields")
  238. case pgs.TimestampWKT:
  239. m.Failf("Timestamp rules should be used for Timestamp fields")
  240. }
  241. }
  242. if rules.Type != nil && rules.GetMessage().GetSkip() {
  243. m.Failf("Skip should not be used with WKT scalar rules")
  244. }
  245. }
  246. func (m *Module) CheckRepeated(ft FieldType, r *validate.RepeatedRules) {
  247. typ := m.mustFieldType(ft)
  248. m.Assert(typ.IsRepeated(), "field is not repeated but got repeated rules")
  249. m.checkMinMax(r.MinItems, r.MaxItems)
  250. if r.GetUnique() {
  251. m.Assert(
  252. !typ.Element().IsEmbed(),
  253. "unique rule is only applicable for scalar types")
  254. }
  255. m.Push("items")
  256. m.CheckFieldRules(typ.Element(), r.Items)
  257. m.Pop()
  258. }
  259. func (m *Module) CheckMap(ft FieldType, r *validate.MapRules) {
  260. typ := m.mustFieldType(ft)
  261. m.Assert(typ.IsMap(), "field is not a map but got map rules")
  262. m.checkMinMax(r.MinPairs, r.MaxPairs)
  263. if r.GetNoSparse() {
  264. m.Assert(
  265. typ.Element().IsEmbed(),
  266. "no_sparse rule is only applicable for embedded message types",
  267. )
  268. }
  269. m.Push("keys")
  270. m.CheckFieldRules(typ.Key(), r.Keys)
  271. m.Pop()
  272. m.Push("values")
  273. m.CheckFieldRules(typ.Element(), r.Values)
  274. m.Pop()
  275. }
  276. func (m *Module) CheckAny(ft FieldType, r *validate.AnyRules) {
  277. m.checkIns(len(r.In), len(r.NotIn))
  278. }
  279. func (m *Module) CheckDuration(ft FieldType, r *validate.DurationRules) {
  280. m.checkNums(
  281. len(r.GetIn()),
  282. len(r.GetNotIn()),
  283. m.checkDur(r.GetConst()),
  284. m.checkDur(r.GetLt()),
  285. m.checkDur(r.GetLte()),
  286. m.checkDur(r.GetGt()),
  287. m.checkDur(r.GetGte()))
  288. for _, v := range r.GetIn() {
  289. m.Assert(v != nil, "cannot have nil values in `in`")
  290. m.checkDur(v)
  291. }
  292. for _, v := range r.GetNotIn() {
  293. m.Assert(v != nil, "cannot have nil values in `not_in`")
  294. m.checkDur(v)
  295. }
  296. }
  297. func (m *Module) CheckTimestamp(ft FieldType, r *validate.TimestampRules) {
  298. m.checkNums(0, 0,
  299. m.checkTS(r.GetConst()),
  300. m.checkTS(r.GetLt()),
  301. m.checkTS(r.GetLte()),
  302. m.checkTS(r.GetGt()),
  303. m.checkTS(r.GetGte()))
  304. m.Assert(
  305. (r.LtNow == nil && r.GtNow == nil) || (r.Lt == nil && r.Lte == nil && r.Gt == nil && r.Gte == nil),
  306. "`now` rules cannot be mixed with absolute `lt/gt` rules")
  307. m.Assert(
  308. r.Within == nil || (r.Lt == nil && r.Lte == nil && r.Gt == nil && r.Gte == nil),
  309. "`within` rule cannot be used with absolute `lt/gt` rules")
  310. m.Assert(
  311. r.LtNow == nil || r.GtNow == nil,
  312. "both `now` rules cannot be used together")
  313. dur := m.checkDur(r.Within)
  314. m.Assert(
  315. dur == nil || *dur > 0,
  316. "`within` rule must be positive and non-zero")
  317. }
  318. func (m *Module) mustFieldType(ft FieldType) pgs.FieldType {
  319. typ, ok := ft.(pgs.FieldType)
  320. if !ok {
  321. m.Failf("unexpected field type (%T)", ft)
  322. }
  323. return typ
  324. }
  325. func (m *Module) checkNums(in, notIn int, ci, lti, ltei, gti, gtei interface{}) {
  326. m.checkIns(in, notIn)
  327. c := reflect.ValueOf(ci)
  328. lt, lte := reflect.ValueOf(lti), reflect.ValueOf(ltei)
  329. gt, gte := reflect.ValueOf(gti), reflect.ValueOf(gtei)
  330. m.Assert(
  331. c.IsNil() ||
  332. in == 0 && notIn == 0 &&
  333. lt.IsNil() && lte.IsNil() &&
  334. gt.IsNil() && gte.IsNil(),
  335. "`const` can be the only rule on a field",
  336. )
  337. m.Assert(
  338. in == 0 ||
  339. lt.IsNil() && lte.IsNil() &&
  340. gt.IsNil() && gte.IsNil(),
  341. "cannot have both `in` and range constraint rules on the same field",
  342. )
  343. m.Assert(
  344. lt.IsNil() || lte.IsNil(),
  345. "cannot have both `lt` and `lte` rules on the same field",
  346. )
  347. m.Assert(
  348. gt.IsNil() || gte.IsNil(),
  349. "cannot have both `gt` and `gte` rules on the same field",
  350. )
  351. if !lt.IsNil() {
  352. m.Assert(gt.IsNil() || !reflect.DeepEqual(lti, gti),
  353. "cannot have equal `gt` and `lt` rules on the same field")
  354. m.Assert(gte.IsNil() || !reflect.DeepEqual(lti, gtei),
  355. "cannot have equal `gte` and `lt` rules on the same field")
  356. } else if !lte.IsNil() {
  357. m.Assert(gt.IsNil() || !reflect.DeepEqual(ltei, gti),
  358. "cannot have equal `gt` and `lte` rules on the same field")
  359. m.Assert(gte.IsNil() || !reflect.DeepEqual(ltei, gtei),
  360. "use `const` instead of equal `lte` and `gte` rules")
  361. }
  362. }
  363. func (m *Module) checkIns(in, notIn int) {
  364. m.Assert(
  365. in == 0 || notIn == 0,
  366. "cannot have both `in` and `not_in` rules on the same field")
  367. }
  368. func (m *Module) checkMinMax(min, max *uint64) {
  369. if min == nil || max == nil {
  370. return
  371. }
  372. m.Assert(
  373. *min <= *max,
  374. "`min` value is greater than `max` value")
  375. }
  376. func (m *Module) checkLen(len, min, max *uint64) {
  377. if len == nil {
  378. return
  379. }
  380. m.Assert(
  381. min == nil,
  382. "cannot have both `len` and `min_len` rules on the same field")
  383. m.Assert(
  384. max == nil,
  385. "cannot have both `len` and `max_len` rules on the same field")
  386. }
  387. func (m *Module) checkWellKnownRegex(wk validate.KnownRegex, r *validate.StringRules) {
  388. if wk != 0 {
  389. m.Assert(r.Pattern == nil, "regex `well_known_regex` and regex `pattern` are incompatible")
  390. var non_strict = r.Strict != nil && *r.Strict == false
  391. if (wk.String() == "HTTP_HEADER_NAME" || wk.String() == "HTTP_HEADER_VALUE") && non_strict {
  392. // Use non-strict header validation.
  393. r.Pattern = regex_map["HEADER_STRING"]
  394. } else {
  395. r.Pattern = regex_map[wk.String()]
  396. }
  397. }
  398. }
  399. func (m *Module) checkPattern(p *string, in int) {
  400. if p != nil {
  401. m.Assert(in == 0, "regex `pattern` and `in` rules are incompatible")
  402. _, err := regexp.Compile(*p)
  403. m.CheckErr(err, "unable to parse regex `pattern`")
  404. }
  405. }
  406. func (m *Module) checkDur(d *durationpb.Duration) *time.Duration {
  407. if d == nil {
  408. return nil
  409. }
  410. dur, err := d.AsDuration(), d.CheckValid()
  411. m.CheckErr(err, "could not resolve duration")
  412. return &dur
  413. }
  414. func (m *Module) checkTS(ts *timestamppb.Timestamp) *int64 {
  415. if ts == nil {
  416. return nil
  417. }
  418. t, err := ts.AsTime(), ts.CheckValid()
  419. m.CheckErr(err, "could not resolve timestamp")
  420. return proto.Int64(t.UnixNano())
  421. }