11package cmd
22
33import (
4+ "bytes"
5+ "errors"
46 "fmt"
57 "html/template"
8+ "io"
69 "os"
710 "path/filepath"
811 "strings"
912
1013 "github.com/go-git/go-git/v5"
1114 "github.com/spf13/cobra"
1215 "github.com/spf13/viper"
16+ "gopkg.in/yaml.v3"
1317
14- "github.com/revanite-io/sci/pkg/layer2"
18+ "github.com/ossf/gemara/layer2"
19+ sdkutils "github.com/privateerproj/privateer-sdk/utils"
1520)
1621
1722type CatalogData struct {
1823 layer2.Catalog
19- ServiceName string
20- TestSuites map [string ][]string
24+ ServiceName string
25+ Requirements []string
26+ ApplicabilityCategories []string
27+ StrippedName string
2128}
2229
23- var TemplatesDir string
24- var SourcePath string
25- var OutputDir string
26-
27- // versionCmd represents the version command
28- var genPluginCmd = & cobra.Command {
29- Use : "generate-plugin" ,
30- Short : "Generate a new plugin" ,
31- Run : func (cmd * cobra.Command , args []string ) {
32- generatePlugin ()
33- },
34- }
30+ var (
31+ TemplatesDir string
32+ SourcePath string
33+ OutputDir string
34+ ServiceName string
35+
36+ // versionCmd represents the version command
37+ genPluginCmd = & cobra.Command {
38+ Use : "generate-plugin" ,
39+ Short : "Generate a new plugin" ,
40+ Run : func (cmd * cobra.Command , args []string ) {
41+ generatePlugin ()
42+ },
43+ }
44+ )
3545
3646func init () {
3747 genPluginCmd .PersistentFlags ().StringP ("source-path" , "p" , "" , "The source file to generate the plugin from." )
@@ -53,14 +63,18 @@ func generatePlugin() {
5363 logger .Error (err .Error ())
5464 return
5565 }
56- data , err := readData ()
66+ data := CatalogData {}
67+ data .ServiceName = ServiceName
68+
69+ err = data .LoadFile ("file://" + SourcePath )
5770 if err != nil {
5871 logger .Error (err .Error ())
5972 return
6073 }
61- data .ServiceName = viper .GetString ("service-name" )
62- if data .ServiceName == "" {
63- logger .Error ("--service-name is required to generate a plugin." )
74+
75+ err = data .getAssessmentRequirements ()
76+ if err != nil {
77+ logger .Error (err .Error ())
6478 return
6579 }
6680
@@ -83,6 +97,11 @@ func generatePlugin() {
8397 if err != nil {
8498 logger .Error ("Error walking through templates directory: %s" , err )
8599 }
100+
101+ err = writeCatalogFile (& data .Catalog )
102+ if err != nil {
103+ logger .Error ("Failed to write catalog to file: %s" , err )
104+ }
86105}
87106
88107func setupTemplatingEnvironment () error {
@@ -91,6 +110,11 @@ func setupTemplatingEnvironment() error {
91110 return fmt .Errorf ("--source-path is required to generate a plugin from a control set from local file or URL" )
92111 }
93112
113+ ServiceName = viper .GetString ("service-name" )
114+ if ServiceName == "" {
115+ return fmt .Errorf ("--service-name is required to generate a plugin." )
116+ }
117+
94118 if viper .GetString ("local-templates" ) != "" {
95119 TemplatesDir = viper .GetString ("local-templates" )
96120 } else {
@@ -130,26 +154,36 @@ func generateFileFromTemplate(data CatalogData, templatePath, OutputDir string)
130154 return fmt .Errorf ("error reading template file %s: %w" , templatePath , err )
131155 }
132156
157+ // Determine relative path from templates dir so we can preserve subdirs in output
158+ relativePath , err := filepath .Rel (TemplatesDir , templatePath )
159+ if err != nil {
160+ return fmt .Errorf ("error calculating relative path for %s: %w" , templatePath , err )
161+ }
162+
163+ // If the template is not a text template, copy it over as-is (preserve mode)
164+ if filepath .Ext (templatePath ) != ".txt" {
165+ return copyNonTemplateFile (templatePath , filepath .Join (OutputDir , relativePath ))
166+ }
167+
133168 tmpl , err := template .New ("plugin" ).Funcs (template.FuncMap {
134- "as_text" : func (s string ) template.HTML {
135- s = strings .TrimSpace (strings .ReplaceAll (s , "\n " , " " ))
136- return template .HTML (s )
169+ "as_text" : func (in string ) template.HTML {
170+ return template .HTML (
171+ strings .TrimSpace (
172+ strings .ReplaceAll (in , "\n " , " " )))
137173 },
138- "as_id" : func (s string ) string {
139- return strings .TrimSpace (
140- strings .ReplaceAll (
141- strings .ReplaceAll (s , "." , "_" ), "-" , "_" ))
174+ "default" : func (in string , out string ) string {
175+ if in != "" {
176+ return in
177+ }
178+ return out
142179 },
180+ "snake_case" : snakeCase ,
181+ "simplifiedName" : simplifiedName ,
143182 }).Parse (string (templateContent ))
144183 if err != nil {
145184 return fmt .Errorf ("error parsing template file %s: %w" , templatePath , err )
146185 }
147186
148- relativePath , err := filepath .Rel (TemplatesDir , templatePath )
149- if err != nil {
150- return err
151- }
152-
153187 outputPath := filepath .Join (OutputDir , strings .TrimSuffix (relativePath , ".txt" ))
154188
155189 err = os .MkdirAll (filepath .Dir (outputPath ), os .ModePerm )
@@ -177,26 +211,89 @@ func generateFileFromTemplate(data CatalogData, templatePath, OutputDir string)
177211 return nil
178212}
179213
180- func readData () (data CatalogData , err error ) {
181- err = data .LoadControlFamiliesFile (SourcePath )
182- if err != nil {
183- return
184- }
185-
186- data .TestSuites = make (map [string ][]string )
187-
188- for i , family := range data .ControlFamilies {
189- for j := range family .Controls {
190- for _ , testReq := range data .ControlFamilies [i ].Controls [j ].Requirements {
191- // Add the test ID to the TestSuites map for each TLP level
192- for _ , tlpLevel := range testReq .Applicability {
193- if data .TestSuites [tlpLevel ] == nil {
194- data .TestSuites [tlpLevel ] = []string {}
214+ func (c * CatalogData ) getAssessmentRequirements () error {
215+ for _ , family := range c .ControlFamilies {
216+ for _ , control := range family .Controls {
217+ for _ , requirement := range control .AssessmentRequirements {
218+ c .Requirements = append (c .Requirements , requirement .Id )
219+ // Add applicability categories if unique
220+ for _ , a := range requirement .Applicability {
221+ if ! sdkutils .StringSliceContains (c .ApplicabilityCategories , a ) {
222+ c .ApplicabilityCategories = append (c .ApplicabilityCategories , a )
195223 }
196- data .TestSuites [tlpLevel ] = append (data .TestSuites [tlpLevel ], testReq .ID )
197224 }
198225 }
199226 }
200227 }
201- return
228+ if len (c .Requirements ) == 0 {
229+ return errors .New ("No requirements retrieved from catalog" )
230+ }
231+ return nil
232+ }
233+
234+ func writeCatalogFile (catalog * layer2.Catalog ) error {
235+ var b bytes.Buffer
236+ yamlEncoder := yaml .NewEncoder (& b )
237+ yamlEncoder .SetIndent (2 ) // this is the line that sets the indentation
238+ err := yamlEncoder .Encode (catalog )
239+ if err != nil {
240+ return fmt .Errorf ("error marshaling YAML: %w" , err )
241+ }
242+
243+ dirPath := filepath .Join (OutputDir , "data" , simplifiedName (catalog .Metadata .Id , catalog .Metadata .Version ))
244+ filePath := filepath .Join (dirPath , "catalog.yaml" )
245+
246+ err = os .MkdirAll (dirPath , os .ModePerm )
247+ if err != nil {
248+ return fmt .Errorf ("error creating directories for %s: %w" , filePath , err )
249+ }
250+
251+ if err := os .WriteFile (filePath , b .Bytes (), 0644 ); err != nil {
252+ return fmt .Errorf ("error writing YAML file: %w" , err )
253+ }
254+
255+ return nil
256+ }
257+
258+ func snakeCase (in string ) string {
259+ return strings .TrimSpace (
260+ strings .ReplaceAll (
261+ strings .ReplaceAll (in , "." , "_" ), "-" , "_" ))
262+ }
263+
264+ func simplifiedName (catalogId string , catalogVersion string ) string {
265+ return fmt .Sprintf ("%s_%s" , snakeCase (catalogId ), snakeCase (catalogVersion ))
266+ }
267+
268+ func copyNonTemplateFile (templatePath , relativePath string ) error {
269+ outputPath := filepath .Join (OutputDir , relativePath )
270+ if err := os .MkdirAll (filepath .Dir (outputPath ), os .ModePerm ); err != nil {
271+ return fmt .Errorf ("error creating directories for %s: %w" , outputPath , err )
272+ }
273+
274+ // Copy file contents
275+ srcFile , err := os .Open (templatePath )
276+ if err != nil {
277+ return fmt .Errorf ("error opening source file %s: %w" , templatePath , err )
278+ }
279+ defer srcFile .Close ()
280+
281+ dstFile , err := os .Create (outputPath )
282+ if err != nil {
283+ return fmt .Errorf ("error creating destination file %s: %w" , outputPath , err )
284+ }
285+ defer func () {
286+ _ = dstFile .Close ()
287+ }()
288+
289+ if _ , err := io .Copy (dstFile , srcFile ); err != nil {
290+ return fmt .Errorf ("error copying file to %s: %w" , outputPath , err )
291+ }
292+
293+ // Try to preserve file mode from source
294+ if fi , err := os .Stat (templatePath ); err == nil {
295+ _ = os .Chmod (outputPath , fi .Mode ())
296+ }
297+
298+ return nil
202299}
0 commit comments