jolheiser
·
2025-04-07
pr.go
1package git
2
3import (
4 "database/sql"
5 "errors"
6 "fmt"
7 "io"
8 "strings"
9 "time"
10
11 "github.com/jmoiron/sqlx"
12)
13
14var ErrPatchExists = errors.New("patch already exists for patch request")
15
16type PatchsetOp int
17
18const (
19 OpNormal PatchsetOp = iota
20 OpReview
21 OpAccept
22 OpClose
23)
24
25type GitPatchRequest interface {
26 GetUsers() ([]*User, error)
27 GetUserByID(userID int64) (*User, error)
28 GetUserByName(name string) (*User, error)
29 GetUserByPubkey(pubkey string) (*User, error)
30 GetRepos() ([]*Repo, error)
31 GetRepoByID(repoID int64) (*Repo, error)
32 GetRepoByName(user *User, repoName string) (*Repo, error)
33 CreateRepo(user *User, repoName string) (*Repo, error)
34 UpsertUser(pubkey, name string) (*User, error)
35 IsBanned(pubkey, ipAddress string) error
36 SubmitPatchRequest(repoID int64, userID int64, patchset io.Reader) (*PatchRequest, error)
37 SubmitPatchset(prID, userID int64, op PatchsetOp, patchset io.Reader) ([]*Patch, error)
38 GetPatchRequestByID(prID int64) (*PatchRequest, error)
39 GetPatchRequests() ([]*PatchRequest, error)
40 GetPatchRequestsByRepoID(repoID int64) ([]*PatchRequest, error)
41 GetPatchRequestsByPubkey(pubkey string) ([]*PatchRequest, error)
42 GetPatchsetsByPrID(prID int64) ([]*Patchset, error)
43 GetPatchsetByID(patchsetID int64) (*Patchset, error)
44 GetLatestPatchsetByPrID(prID int64) (*Patchset, error)
45 GetPatchesByPatchsetID(prID int64) ([]*Patch, error)
46 UpdatePatchRequestStatus(prID, userID int64, status string) error
47 UpdatePatchRequestName(prID, userID int64, name string) error
48 DeletePatchsetByID(userID, prID int64, patchsetID int64) error
49 CreateEventLog(tx *sqlx.Tx, eventLog EventLog) error
50 GetEventLogs() ([]*EventLog, error)
51 GetEventLogsByRepoName(user *User, repoName string) ([]*EventLog, error)
52 GetEventLogsByPrID(prID int64) ([]*EventLog, error)
53 GetEventLogsByUserID(userID int64) ([]*EventLog, error)
54 DiffPatchsets(aset *Patchset, bset *Patchset) ([]*RangeDiffOutput, error)
55}
56
57type PrCmd struct {
58 Backend *Backend
59}
60
61var (
62 _ GitPatchRequest = PrCmd{}
63 _ GitPatchRequest = (*PrCmd)(nil)
64)
65
66func (pr PrCmd) IsBanned(pubkey, ipAddress string) error {
67 acl := []*Acl{}
68 err := pr.Backend.DB.Select(
69 &acl,
70 "SELECT * FROM acl WHERE permission='banned' AND (pubkey=? OR ip_address=?)",
71 pubkey,
72 ipAddress,
73 )
74 if len(acl) > 0 {
75 return fmt.Errorf("user has been banned")
76 }
77 return err
78}
79
80func (pr PrCmd) GetUsers() ([]*User, error) {
81 users := []*User{}
82 err := pr.Backend.DB.Select(&users, "SELECT * FROM app_users")
83 return users, err
84}
85
86func (pr PrCmd) GetUserByName(name string) (*User, error) {
87 var user User
88 err := pr.Backend.DB.Get(&user, "SELECT * FROM app_users WHERE name=?", name)
89 return &user, err
90}
91
92func (pr PrCmd) GetUserByID(id int64) (*User, error) {
93 var user User
94 err := pr.Backend.DB.Get(&user, "SELECT * FROM app_users WHERE id=?", id)
95 return &user, err
96}
97
98func (pr PrCmd) GetUserByPubkey(pubkey string) (*User, error) {
99 var user User
100 err := pr.Backend.DB.Get(&user, "SELECT * FROM app_users WHERE pubkey=?", pubkey)
101 return &user, err
102}
103
104func (pr PrCmd) computeUserName(name string) (string, error) {
105 var user User
106 err := pr.Backend.DB.Get(&user, "SELECT * FROM app_users WHERE name=?", name)
107 if err != nil {
108 return name, nil
109 }
110 // collision, generate random number and append
111 return fmt.Sprintf("%s%s", name, randSeq(4)), nil
112}
113
114func (pr PrCmd) CreateRepo(user *User, repoName string) (*Repo, error) {
115 var repoID int64
116 row := pr.Backend.DB.QueryRow(
117 "INSERT INTO repos (user_id, name) VALUES (?, ?) RETURNING id",
118 user.ID,
119 repoName,
120 )
121 err := row.Scan(&repoID)
122 if err != nil {
123 return nil, err
124 }
125
126 return pr.GetRepoByID(repoID)
127}
128
129func (pr PrCmd) GetRepoByID(repoID int64) (*Repo, error) {
130 var repo Repo
131 err := pr.Backend.DB.Get(&repo, "SELECT * FROM repos WHERE id=?", repoID)
132 return &repo, err
133}
134
135func (pr PrCmd) GetRepos() (repos []*Repo, err error) {
136 err = pr.Backend.DB.Select(
137 &repos,
138 "SELECT * from repos",
139 )
140 if err != nil {
141 return repos, err
142 }
143 if len(repos) == 0 {
144 return repos, fmt.Errorf("no repos found")
145 }
146 return repos, nil
147}
148
149func (pr PrCmd) GetRepoByName(user *User, repoName string) (*Repo, error) {
150 var repo Repo
151 var err error
152
153 if user == nil {
154 err = pr.Backend.DB.Get(&repo, "SELECT * FROM repos WHERE name=?", repoName)
155 } else {
156 err = pr.Backend.DB.Get(&repo, "SELECT * FROM repos WHERE user_id=? AND name=?", user.ID, repoName)
157 }
158
159 if err != nil {
160 return nil, fmt.Errorf("repo not found: %s", repoName)
161 }
162
163 return &repo, nil
164}
165
166func (pr PrCmd) createUser(pubkey, name string) (*User, error) {
167 if pubkey == "" {
168 return nil, fmt.Errorf("must provide pubkey when creating user")
169 }
170 if name == "" {
171 return nil, fmt.Errorf("must provide user name when creating user")
172 }
173
174 userName, err := pr.computeUserName(name)
175 if err != nil {
176 pr.Backend.Logger.Error("could not compute username", "err", err)
177 }
178
179 var userID int64
180 row := pr.Backend.DB.QueryRow(
181 "INSERT INTO app_users (pubkey, name) VALUES (?, ?) RETURNING id",
182 pubkey,
183 userName,
184 )
185 err = row.Scan(&userID)
186 if err != nil {
187 return nil, err
188 }
189 if userID == 0 {
190 return nil, fmt.Errorf("could not create user")
191 }
192
193 user, err := pr.GetUserByID(userID)
194 return user, err
195}
196
197func (pr PrCmd) UpsertUser(pubkey, name string) (*User, error) {
198 sanName := strings.ToLower(name)
199 if pubkey == "" {
200 return nil, fmt.Errorf("must provide pubkey during upsert")
201 }
202 user, err := pr.GetUserByPubkey(pubkey)
203 if err != nil {
204 user, err = pr.createUser(pubkey, sanName)
205 }
206 return user, err
207}
208
209func (pr PrCmd) GetPatchsetsByPrID(prID int64) ([]*Patchset, error) {
210 patchsets := []*Patchset{}
211 err := pr.Backend.DB.Select(
212 &patchsets,
213 "SELECT * FROM patchsets WHERE patch_request_id=? ORDER BY created_at ASC",
214 prID,
215 )
216 if err != nil {
217 return patchsets, err
218 }
219 if len(patchsets) == 0 {
220 return patchsets, fmt.Errorf("no patchsets found for patch request: %d", prID)
221 }
222 return patchsets, nil
223}
224
225func (pr PrCmd) GetPatchsetByID(patchsetID int64) (*Patchset, error) {
226 var patchset Patchset
227 err := pr.Backend.DB.Get(
228 &patchset,
229 "SELECT * FROM patchsets WHERE id=?",
230 patchsetID,
231 )
232 return &patchset, err
233}
234
235func (pr PrCmd) GetLatestPatchsetByPrID(prID int64) (*Patchset, error) {
236 patchsets, err := pr.GetPatchsetsByPrID(prID)
237 if err != nil {
238 return nil, err
239 }
240 if len(patchsets) == 0 {
241 return nil, fmt.Errorf("not patchsets found for patch request: %d", prID)
242 }
243 return patchsets[len(patchsets)-1], nil
244}
245
246func (pr PrCmd) GetPatchesByPatchsetID(patchsetID int64) ([]*Patch, error) {
247 patches := []*Patch{}
248 err := pr.Backend.DB.Select(
249 &patches,
250 "SELECT * FROM patches WHERE patchset_id=? ORDER BY created_at ASC, id ASC",
251 patchsetID,
252 )
253 return patches, err
254}
255
256func (cmd PrCmd) GetPatchRequests() ([]*PatchRequest, error) {
257 prs := []*PatchRequest{}
258 err := cmd.Backend.DB.Select(
259 &prs,
260 "SELECT * FROM patch_requests ORDER BY id DESC",
261 )
262 return prs, err
263}
264
265func (cmd PrCmd) GetPatchRequestsByRepoID(repoID int64) ([]*PatchRequest, error) {
266 prs := []*PatchRequest{}
267 err := cmd.Backend.DB.Select(
268 &prs,
269 "SELECT * FROM patch_requests WHERE repo_id=? ORDER BY id DESC",
270 repoID,
271 )
272 return prs, err
273}
274
275func (cmd PrCmd) GetPatchRequestsByPubkey(pubkey string) ([]*PatchRequest, error) {
276 prs := []*PatchRequest{}
277 err := cmd.Backend.DB.Select(
278 &prs,
279 "SELECT pr.* FROM patch_requests pr, app_users au WHERE pr.user_id=au.id AND au.pubkey=? ORDER BY id DESC",
280 pubkey,
281 )
282 return prs, err
283}
284
285func (cmd PrCmd) GetPatchRequestByID(prID int64) (*PatchRequest, error) {
286 pr := PatchRequest{}
287 err := cmd.Backend.DB.Get(
288 &pr,
289 "SELECT * FROM patch_requests WHERE id=? ORDER BY created_at DESC",
290 prID,
291 )
292 return &pr, err
293}
294
295// Status types: open, closed, accepted, reviewed.
296func (cmd PrCmd) UpdatePatchRequestStatus(prID int64, userID int64, status string) error {
297 tx, err := cmd.Backend.DB.Beginx()
298 if err != nil {
299 return err
300 }
301
302 defer func() {
303 _ = tx.Rollback()
304 }()
305
306 _, err = tx.Exec(
307 "UPDATE patch_requests SET status=? WHERE id=?",
308 status,
309 prID,
310 )
311 if err != nil {
312 return err
313 }
314
315 pr, err := cmd.GetPatchRequestByID(prID)
316 if err != nil {
317 return err
318 }
319
320 err = cmd.CreateEventLog(tx, EventLog{
321 UserID: userID,
322 RepoID: sql.NullInt64{Int64: pr.RepoID, Valid: true},
323 PatchRequestID: sql.NullInt64{Int64: prID, Valid: true},
324 Event: "pr_status_changed",
325 Data: fmt.Sprintf(`{"status":"%s"}`, status),
326 })
327 if err != nil {
328 return err
329 }
330
331 return tx.Commit()
332}
333
334func (cmd PrCmd) UpdatePatchRequestName(prID int64, userID int64, name string) error {
335 if name == "" {
336 return fmt.Errorf("must provide name or text in order to update patch request")
337 }
338
339 tx, err := cmd.Backend.DB.Beginx()
340 if err != nil {
341 return err
342 }
343
344 defer func() {
345 _ = tx.Rollback()
346 }()
347
348 _, err = tx.Exec(
349 "UPDATE patch_requests SET name=? WHERE id=?",
350 name,
351 prID,
352 )
353 if err != nil {
354 return err
355 }
356
357 pr, err := cmd.GetPatchRequestByID(prID)
358 if err != nil {
359 return err
360 }
361
362 err = cmd.CreateEventLog(tx, EventLog{
363 UserID: userID,
364 RepoID: sql.NullInt64{Int64: pr.RepoID, Valid: true},
365 PatchRequestID: sql.NullInt64{Int64: prID, Valid: true},
366 Event: "pr_name_changed",
367 Data: fmt.Sprintf(`{"name":"%s"}`, name),
368 })
369 if err != nil {
370 return err
371 }
372
373 return tx.Commit()
374}
375
376func (cmd PrCmd) CreateEventLog(tx *sqlx.Tx, eventLog EventLog) error {
377 if eventLog.RepoID.Valid && eventLog.PatchRequestID.Valid {
378 var pr PatchRequest
379 err := tx.Get(
380 &pr,
381 "SELECT repo_id FROM patch_requests WHERE id=?",
382 eventLog.PatchRequestID,
383 )
384 if err != nil {
385 cmd.Backend.Logger.Error(
386 "could not find pr when creating eventLog",
387 "err", err,
388 )
389 return nil
390 }
391 eventLog.RepoID = sql.NullInt64{Int64: pr.RepoID, Valid: true}
392 }
393
394 _, err := tx.Exec(
395 "INSERT INTO event_logs (user_id, repo_id, patch_request_id, patchset_id, event, data) VALUES (?, ?, ?, ?, ?, ?)",
396 eventLog.UserID,
397 eventLog.RepoID,
398 eventLog.PatchRequestID.Int64,
399 eventLog.PatchsetID.Int64,
400 eventLog.Event,
401 eventLog.Data,
402 )
403 if err != nil {
404 cmd.Backend.Logger.Error(
405 "could not create eventLog",
406 "err", err,
407 )
408 }
409 return err
410}
411
412func (cmd PrCmd) createPatch(tx *sqlx.Tx, patch *Patch) (int64, error) {
413 patchExists := []Patch{}
414 _ = cmd.Backend.DB.Select(&patchExists, "SELECT * FROM patches WHERE patchset_id=? AND content_sha=?", patch.PatchsetID, patch.ContentSha)
415 if len(patchExists) > 0 {
416 return 0, ErrPatchExists
417 }
418
419 var patchID int64
420 row := tx.QueryRow(
421 "INSERT INTO patches (user_id, patchset_id, author_name, author_email, author_date, title, body, body_appendix, commit_sha, content_sha, base_commit_sha, raw_text) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) RETURNING id",
422 patch.UserID,
423 patch.PatchsetID,
424 patch.AuthorName,
425 patch.AuthorEmail,
426 patch.AuthorDate,
427 patch.Title,
428 patch.Body,
429 patch.BodyAppendix,
430 patch.CommitSha,
431 patch.ContentSha,
432 patch.BaseCommitSha,
433 patch.RawText,
434 )
435 err := row.Scan(&patchID)
436 if err != nil {
437 return 0, err
438 }
439 if patchID == 0 {
440 return 0, fmt.Errorf("could not create patch request")
441 }
442 return patchID, err
443}
444
445func (cmd PrCmd) SubmitPatchRequest(repoID int64, userID int64, patchset io.Reader) (*PatchRequest, error) {
446 tx, err := cmd.Backend.DB.Beginx()
447 if err != nil {
448 return nil, err
449 }
450
451 defer func() {
452 _ = tx.Rollback()
453 }()
454
455 patches, err := ParsePatchset(patchset)
456 if err != nil {
457 return nil, err
458 }
459
460 if len(patches) == 0 {
461 return nil, fmt.Errorf("after parsing patchset we did't find any patches, did you send us an empty patchset?")
462 }
463
464 prName := ""
465 prText := ""
466 if len(patches) > 0 {
467 prName = patches[0].Title
468 prText = patches[0].Body
469 }
470
471 var prID int64
472 row := tx.QueryRow(
473 "INSERT INTO patch_requests (user_id, repo_id, name, text, status, updated_at) VALUES(?, ?, ?, ?, ?, ?) RETURNING id",
474 userID,
475 repoID,
476 prName,
477 prText,
478 "open",
479 time.Now(),
480 )
481 err = row.Scan(&prID)
482 if err != nil {
483 return nil, err
484 }
485 if prID == 0 {
486 return nil, fmt.Errorf("could not create patch request")
487 }
488
489 var patchsetID int64
490 row = tx.QueryRow(
491 "INSERT INTO patchsets (user_id, patch_request_id) VALUES(?, ?) RETURNING id",
492 userID,
493 prID,
494 )
495 err = row.Scan(&patchsetID)
496 if err != nil {
497 return nil, err
498 }
499 if patchsetID == 0 {
500 return nil, fmt.Errorf("could not create patchset")
501 }
502
503 for _, patch := range patches {
504 patch.UserID = userID
505 patch.PatchsetID = patchsetID
506 _, err = cmd.createPatch(tx, patch)
507 if err != nil {
508 return nil, err
509 }
510 }
511
512 err = cmd.CreateEventLog(tx, EventLog{
513 UserID: userID,
514 RepoID: sql.NullInt64{Int64: repoID, Valid: true},
515 PatchRequestID: sql.NullInt64{Int64: prID, Valid: true},
516 PatchsetID: sql.NullInt64{Int64: patchsetID, Valid: true},
517 Event: "pr_created",
518 })
519 if err != nil {
520 return nil, err
521 }
522
523 err = tx.Commit()
524 if err != nil {
525 return nil, err
526 }
527
528 var pr PatchRequest
529 err = cmd.Backend.DB.Get(&pr, "SELECT * FROM patch_requests WHERE id=?", prID)
530 return &pr, err
531}
532
533func (cmd PrCmd) SubmitPatchset(prID int64, userID int64, op PatchsetOp, patchset io.Reader) ([]*Patch, error) {
534 fin := []*Patch{}
535 tx, err := cmd.Backend.DB.Beginx()
536 if err != nil {
537 return fin, err
538 }
539
540 defer func() {
541 _ = tx.Rollback()
542 }()
543
544 patches, err := ParsePatchset(patchset)
545 if err != nil {
546 return fin, err
547 }
548
549 isReview := op == OpReview || op == OpAccept || op == OpClose
550 var patchsetID int64
551 row := tx.QueryRow(
552 "INSERT INTO patchsets (user_id, patch_request_id, review) VALUES(?, ?, ?) RETURNING id",
553 userID,
554 prID,
555 isReview,
556 )
557 err = row.Scan(&patchsetID)
558 if err != nil {
559 return nil, err
560 }
561 if patchsetID == 0 {
562 return nil, fmt.Errorf("could not create patchset")
563 }
564
565 for _, patch := range patches {
566 patch.UserID = userID
567 patch.PatchsetID = patchsetID
568 patchID, err := cmd.createPatch(tx, patch)
569 if err == nil {
570 patch.ID = patchID
571 fin = append(fin, patch)
572 } else {
573 if !errors.Is(ErrPatchExists, err) {
574 return fin, err
575 }
576 }
577 }
578
579 if len(fin) > 0 {
580 event := "pr_patchset_added"
581 if op == OpReview {
582 event = "pr_reviewed"
583 }
584
585 pr, err := cmd.GetPatchRequestByID(prID)
586 if err != nil {
587 return fin, err
588 }
589
590 err = cmd.CreateEventLog(tx, EventLog{
591 UserID: userID,
592 RepoID: sql.NullInt64{Int64: pr.RepoID, Valid: true},
593 PatchRequestID: sql.NullInt64{Int64: prID, Valid: true},
594 PatchsetID: sql.NullInt64{Int64: patchsetID, Valid: true},
595 Event: event,
596 })
597 if err != nil {
598 return fin, err
599 }
600 }
601
602 err = tx.Commit()
603 if err != nil {
604 return fin, err
605 }
606
607 return fin, err
608}
609
610func (cmd PrCmd) DeletePatchsetByID(userID int64, prID int64, patchsetID int64) error {
611 tx, err := cmd.Backend.DB.Beginx()
612 if err != nil {
613 return err
614 }
615
616 defer func() {
617 _ = tx.Rollback()
618 }()
619
620 _, err = tx.Exec(
621 "DELETE FROM patchsets WHERE id=?", patchsetID,
622 )
623 if err != nil {
624 return err
625 }
626
627 pr, err := cmd.GetPatchRequestByID(prID)
628 if err != nil {
629 return err
630 }
631
632 err = cmd.CreateEventLog(tx, EventLog{
633 UserID: userID,
634 RepoID: sql.NullInt64{Int64: pr.RepoID, Valid: true},
635 PatchRequestID: sql.NullInt64{Int64: prID, Valid: true},
636 PatchsetID: sql.NullInt64{Int64: patchsetID, Valid: true},
637 Event: "pr_patchset_deleted",
638 })
639 if err != nil {
640 return err
641 }
642
643 return tx.Commit()
644}
645
646func (cmd PrCmd) GetEventLogs() ([]*EventLog, error) {
647 eventLogs := []*EventLog{}
648 err := cmd.Backend.DB.Select(
649 &eventLogs,
650 "SELECT * FROM event_logs ORDER BY created_at DESC",
651 )
652 return eventLogs, err
653}
654
655func (cmd PrCmd) GetEventLogsByRepoName(user *User, repoName string) ([]*EventLog, error) {
656 repo, err := cmd.GetRepoByName(user, repoName)
657 if err != nil {
658 return nil, err
659 }
660
661 eventLogs := []*EventLog{}
662 err = cmd.Backend.DB.Select(
663 &eventLogs,
664 "SELECT * FROM event_logs WHERE repo_id=? ORDER BY created_at DESC",
665 repo.ID,
666 )
667 return eventLogs, err
668}
669
670func (cmd PrCmd) GetEventLogsByPrID(prID int64) ([]*EventLog, error) {
671 eventLogs := []*EventLog{}
672 err := cmd.Backend.DB.Select(
673 &eventLogs,
674 "SELECT * FROM event_logs WHERE patch_request_id=? ORDER BY created_at DESC",
675 prID,
676 )
677 return eventLogs, err
678}
679
680func (cmd PrCmd) GetEventLogsByUserID(userID int64) ([]*EventLog, error) {
681 eventLogs := []*EventLog{}
682 query := `SELECT * FROM event_logs
683 WHERE user_id=?
684 OR patch_request_id IN (
685 SELECT id FROM patch_requests WHERE user_id=?
686 )
687 ORDER BY created_at DESC`
688 err := cmd.Backend.DB.Select(
689 &eventLogs,
690 query,
691 userID,
692 userID,
693 )
694 return eventLogs, err
695}
696
697func (cmd PrCmd) DiffPatchsets(prev *Patchset, next *Patchset) ([]*RangeDiffOutput, error) {
698 output := []*RangeDiffOutput{}
699 patches, err := cmd.GetPatchesByPatchsetID(next.ID)
700 if err != nil {
701 return output, err
702 }
703
704 for idx, patch := range patches {
705 patchStr := patch.RawText
706 if idx > 0 {
707 patchStr = startOfPatch + patch.RawText
708 }
709 diffFiles, _, err := ParsePatch(patchStr)
710 if err != nil {
711 continue
712 }
713 patch.Files = diffFiles
714 }
715
716 if prev == nil {
717 return output, nil
718 }
719
720 prevPatches, err := cmd.GetPatchesByPatchsetID(prev.ID)
721 if err != nil {
722 return output, fmt.Errorf("cannot get previous patchset patches: %w", err)
723 }
724
725 for idx, patch := range prevPatches {
726 patchStr := patch.RawText
727 if idx > 0 {
728 patchStr = startOfPatch + patch.RawText
729 }
730 diffFiles, _, err := ParsePatch(patchStr)
731 if err != nil {
732 continue
733 }
734 patch.Files = diffFiles
735 }
736
737 return RangeDiff(prevPatches, patches), nil
738}