Skip to content

Commit dd7cdd2

Browse files
committed
Introduce discoballed directive and granular enforcement of discoballing
1 parent 6f3f5f3 commit dd7cdd2

8 files changed

Lines changed: 328 additions & 30 deletions

File tree

cmd/action/comment.go

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package action
22

33
import (
4+
"encoding/json"
45
"fmt"
56
"os"
67

@@ -27,27 +28,56 @@ func (c *Comment) Run(cc *cli.Context) (err error) {
2728
}
2829

2930
var t *raymond.Template
30-
t, err = templates.LoadTemplate(templateName)
31+
switch templateName {
32+
case "disco/combined":
33+
t, err = templates.LoadDiscoCombinedTemplate()
34+
case "disco/unpatched":
35+
t, err = templates.LoadDiscoUnpatchedTemplate()
36+
default:
37+
t, err = templates.LoadTemplate(templateName)
38+
}
3139

3240
if err != nil {
3341
err = fmt.Errorf("error loading template %s: %w", templateName, err)
3442
return
3543
}
3644

3745
templateContext := make(map[string]any)
38-
templateContext["patch_url"] = githubactions.GetInput("patch_url")
3946

40-
patchPath := githubactions.GetInput("patch_path")
47+
templateDataStr := githubactions.GetInput("template_data")
48+
if templateDataStr != "" {
49+
var templateData map[string]any
50+
err = json.Unmarshal([]byte(templateDataStr), &templateData)
51+
if err != nil {
52+
err = fmt.Errorf("error parsing template_data JSON: %w", err)
53+
return
54+
}
55+
for k, v := range templateData {
56+
templateContext[k] = v
57+
}
58+
}
4159

42-
if patchPath != "" {
43-
var patch []byte
44-
patch, err = os.ReadFile(patchPath)
60+
fileDataStr := githubactions.GetInput("file_data")
61+
if fileDataStr != "" {
62+
var fileData map[string]string
63+
err = json.Unmarshal([]byte(fileDataStr), &fileData)
4564
if err != nil {
46-
err = fmt.Errorf("error loading patch file: %w", err)
65+
err = fmt.Errorf("error parsing file_data JSON: %w", err)
4766
return
4867
}
49-
if len(patch) < 60000 {
50-
templateContext["diff"] = string(patch)
68+
for k, path := range fileData {
69+
if path == "" {
70+
continue
71+
}
72+
var content []byte
73+
content, err = os.ReadFile(path)
74+
if err != nil {
75+
err = fmt.Errorf("error reading file %s for key %s: %w", path, k, err)
76+
return
77+
}
78+
if len(content) < 60000 {
79+
templateContext[k] = string(content)
80+
}
5181
}
5282
}
5383

cmd/action/disco.go

Lines changed: 84 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@ package action
33
import (
44
"bytes"
55
"fmt"
6+
"io"
67
"log/slog"
78
"os"
89
"path/filepath"
10+
"strings"
911

1012
"github.com/project-chip/alchemy/asciidoc"
1113
"github.com/project-chip/alchemy/cmd/action/github"
@@ -18,6 +20,14 @@ import (
1820
"github.com/sethvargo/go-githubactions"
1921
)
2022

23+
type EnforcementLevel string
24+
25+
const (
26+
// EnforcementMandatory applies to new files with suggestions or existing discoballed files with suggestions.
27+
EnforcementMandatory EnforcementLevel = "discoball-mandatory"
28+
EnforcementOptional EnforcementLevel = "discoball-optional"
29+
)
30+
2131
type Disco struct {
2232
}
2333

@@ -40,8 +50,8 @@ func (c *Disco) Run(cc *cli.Context) (err error) {
4050
if pr == nil {
4151
return nil
4252
}
43-
var changedFiles []string
44-
changedFiles, err = github.GetPRChangedFiles(cc, githubContext, action, pr)
53+
var changedFiles map[string]github.FileStatus
54+
changedFiles, err = github.GetPRChangedFilesWithStatus(cc, githubContext, action, pr)
4555
if err != nil {
4656
return fmt.Errorf("failed on getting pull request changes: %w", err)
4757
}
@@ -52,9 +62,26 @@ func (c *Disco) Run(cc *cli.Context) (err error) {
5262
}
5363

5464
var changedDocs []string
55-
for _, path := range changedFiles {
65+
fileEnforcementLevel := make(map[string]EnforcementLevel)
66+
for path, status := range changedFiles {
5667
if filepath.Ext(path) == ".adoc" {
5768
changedDocs = append(changedDocs, path)
69+
if status == github.FileStatusAdded {
70+
fileEnforcementLevel[path] = EnforcementMandatory
71+
continue
72+
}
73+
fullPath := filepath.Join(githubContext.Workspace, path)
74+
b, err := os.ReadFile(fullPath)
75+
if err != nil {
76+
slog.Warn("failed to read original file to check for discoballed marker", "path", fullPath, "error", err)
77+
fileEnforcementLevel[path] = EnforcementOptional
78+
continue
79+
}
80+
if strings.Contains(string(b), ":alchemy-discoballed:") {
81+
fileEnforcementLevel[path] = EnforcementMandatory
82+
} else {
83+
fileEnforcementLevel[path] = EnforcementOptional
84+
}
5885
}
5986
}
6087

@@ -66,8 +93,15 @@ func (c *Disco) Run(cc *cli.Context) (err error) {
6693

6794
pipelineOptions := pipeline.ProcessingOptions{NoProgress: true}
6895

69-
var out bytes.Buffer
70-
writer := files.NewPatcher[string]("Generating patch file...", &out)
96+
var outMandatory bytes.Buffer
97+
var outOptional bytes.Buffer
98+
writer := files.NewPatcher[string]("Generating patch file...", &outMandatory)
99+
writer.GetWriter = func(path string) io.Writer {
100+
if fileEnforcementLevel[path] == EnforcementMandatory {
101+
return &outMandatory
102+
}
103+
return &outOptional
104+
}
71105
writer.Root = githubContext.Workspace
72106

73107
parserOptions := spec.ParserOptions{
@@ -90,17 +124,54 @@ func (c *Disco) Run(cc *cli.Context) (err error) {
90124
return fmt.Errorf("failed disco-balling: %s", message)
91125
}
92126

93-
if out.Len() > 0 {
94-
slog.Info("Setting disco_status to patched")
95-
action.SetOutput("disco_status", "patched")
127+
hasMandatory := outMandatory.Len() > 0
128+
hasOptional := outOptional.Len() > 0
129+
130+
if hasMandatory {
131+
action.SetOutput("has_violations", "true")
132+
var violations []string
133+
for _, path := range writer.ModifiedFiles {
134+
if fileEnforcementLevel[path] == EnforcementMandatory {
135+
status := changedFiles[path]
136+
var msg string
137+
if status == github.FileStatusAdded {
138+
msg = fmt.Sprintf("a new file is added to PR and it has alchemy discoball suggestions: %s", path)
139+
} else {
140+
msg = fmt.Sprintf("a file that already had :alchemy-discoballed: has fixes suggested: %s", path)
141+
}
142+
violations = append(violations, msg)
143+
}
144+
}
145+
for _, v := range violations {
146+
action.Errorf("%s", v)
147+
}
148+
action.SetOutput("violation_reason", fmt.Sprintf("Found %d files with violations. See logs for details.", len(violations)))
149+
} else {
150+
action.SetOutput("has_violations", "false")
151+
}
96152

97-
err = os.WriteFile("disco.patch", out.Bytes(), os.ModeAppend|0644)
153+
if hasMandatory {
154+
slog.Info("Setting mandatory patch outputs")
155+
err = os.WriteFile("disco-mandatory.patch", outMandatory.Bytes(), os.ModeAppend|0644)
98156
if err != nil {
99-
return fmt.Errorf("failed saving patch: %v", err)
157+
return fmt.Errorf("failed saving mandatory patch: %v", err)
100158
}
101-
action.SetOutput("patch_name", "disco_patch")
102-
action.SetOutput("patch_path", "disco.patch")
103-
action.SetOutput("template_name", "disco/patched")
159+
action.SetOutput("mandatory_patch_name", "disco_mandatory_patch")
160+
action.SetOutput("mandatory_patch_path", "disco-mandatory.patch")
161+
}
162+
163+
if hasOptional {
164+
slog.Info("Setting optional patch outputs")
165+
err = os.WriteFile("disco-optional.patch", outOptional.Bytes(), os.ModeAppend|0644)
166+
if err != nil {
167+
return fmt.Errorf("failed saving optional patch: %v", err)
168+
}
169+
action.SetOutput("optional_patch_name", "disco_optional_patch")
170+
action.SetOutput("optional_patch_path", "disco-optional.patch")
171+
}
172+
173+
if hasMandatory || hasOptional {
174+
action.SetOutput("template_name", "disco/combined")
104175
} else {
105176
slog.Info("Setting disco_status to unpatched")
106177
action.SetOutput("disco_status", "unpatched")

cmd/action/github/pr.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ import (
1616
"github.com/walle/targz"
1717
)
1818

19+
type FileStatus string
20+
21+
const (
22+
FileStatusAdded FileStatus = "added"
23+
FileStatusModified FileStatus = "modified"
24+
FileStatusDeleted FileStatus = "deleted"
25+
)
26+
1927
func GetPRChangedFiles(cxt context.Context, githubContext *githubactions.GitHubContext, action *githubactions.Action, pr *github.PullRequest) (changedFiles []string, err error) {
2028

2129
token := action.Getenv("GITHUB_AUTH_TOKEN")
@@ -53,6 +61,44 @@ func GetPRChangedFiles(cxt context.Context, githubContext *githubactions.GitHubC
5361
return
5462
}
5563

64+
func GetPRChangedFilesWithStatus(cxt context.Context, githubContext *githubactions.GitHubContext, action *githubactions.Action, pr *github.PullRequest) (changedFiles map[string]FileStatus, err error) {
65+
66+
token := action.Getenv("GITHUB_AUTH_TOKEN")
67+
if token == "" {
68+
return nil, fmt.Errorf("missing github token")
69+
}
70+
client := github.NewClient(nil).WithAuthToken(token)
71+
72+
owner, repo := githubContext.Repo()
73+
74+
lo := github.ListOptions{
75+
PerPage: 100,
76+
}
77+
changedFiles = make(map[string]FileStatus)
78+
for {
79+
action.Infof("Fetching PR from: %s/%s; page %d\n", owner, repo, lo.Page)
80+
var files []*github.CommitFile
81+
var resp *github.Response
82+
files, resp, err = client.PullRequests.ListFiles(cxt, owner, repo, pr.GetNumber(), &lo)
83+
if err != nil {
84+
err = fmt.Errorf("failed listing files in PR: %w", err)
85+
return
86+
}
87+
for _, file := range files {
88+
if *file.Status == "deleted" {
89+
continue
90+
}
91+
slog.Info("changed file", "file", *file.Filename, "status", *file.Status)
92+
changedFiles[*file.Filename] = FileStatus(*file.Status)
93+
}
94+
if resp.NextPage == 0 {
95+
break
96+
}
97+
lo.Page = resp.NextPage
98+
}
99+
return
100+
}
101+
56102
func GetPR(cxt context.Context, githubContext *githubactions.GitHubContext, action *githubactions.Action, pr *github.PullRequest) (pullRequest *github.PullRequest, err error) {
57103

58104
token := action.Getenv("GITHUB_AUTH_TOKEN")

cmd/action/github/templates/disco.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@ import (
1111
//go:embed disco
1212
var discoTemplateFiles embed.FS
1313

14-
var discoPatchedTemplate pipeline.Once[*raymond.Template]
14+
var discoCombinedTemplate pipeline.Once[*raymond.Template]
1515

16-
func LoadDiscoPatchedTemplate() (*raymond.Template, error) {
17-
t, err := discoPatchedTemplate.Do(func() (*raymond.Template, error) {
16+
func LoadDiscoCombinedTemplate() (*raymond.Template, error) {
17+
t, err := discoCombinedTemplate.Do(func() (*raymond.Template, error) {
1818

19-
t, err := handlebars.LoadTemplate("{{> disco/patched}}", discoTemplateFiles)
19+
t, err := handlebars.LoadTemplate("{{> disco/combined}}", discoTemplateFiles)
2020
if err != nil {
2121
return nil, err
2222
}

disco/baller.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,29 @@ func (b *Baller) disco(cxt context.Context, doc *asciidoc.Document) error {
134134
}
135135

136136
func (b *Baller) discoBallTopLevelSection(dc *discoContext, top *asciidoc.Section, docType matter.DocType) error {
137+
if b.options.AddDiscoballed {
138+
var existingIndex = -1
139+
var existingAE *asciidoc.AttributeEntry
140+
for i, el := range dc.doc.Elements {
141+
if ae, ok := el.(*asciidoc.AttributeEntry); ok && ae.Name == "alchemy-discoballed" {
142+
existingIndex = i
143+
existingAE = ae
144+
break
145+
}
146+
}
147+
148+
if existingIndex > 0 {
149+
// Remove it
150+
dc.doc.Elements = append(dc.doc.Elements[:existingIndex], dc.doc.Elements[existingIndex+1:]...)
151+
// Prepend it
152+
dc.doc.Elements = append(asciidoc.Elements{existingAE, &asciidoc.NewLine{}}, dc.doc.Elements...)
153+
} else if existingIndex == -1 {
154+
// Create and prepend
155+
ae := asciidoc.NewAttributeEntry("alchemy-discoballed")
156+
dc.doc.Elements = append(asciidoc.Elements{ae, &asciidoc.NewLine{}}, dc.doc.Elements...)
157+
}
158+
}
159+
137160
if b.options.XrefStyleOnlyInRoot {
138161
// Logic:
139162
// 1. For non-root, no xrefstyle at all anywhere in doc.

0 commit comments

Comments
 (0)