Skip to content

Commit 9f6eb9d

Browse files
Add secret source support for git and huggingface
1 parent 3891683 commit 9f6eb9d

10 files changed

Lines changed: 323 additions & 25 deletions

File tree

pkg/cmd/agent/submit.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ import (
3030

3131
func NewSubmitCmd(logger *log.Logger) *cobra.Command {
3232
dump := false
33-
submissionConfig := diambra.NewSubmissionConfig(logger)
33+
submissionConfig := diambra.SubmissionConfig{}
34+
submissionConfig.RegisterCredentialsProviders()
3435
c, err := diambra.NewConfig(logger)
3536
if err != nil {
3637
level.Error(logger).Log("msg", err.Error())
@@ -46,7 +47,7 @@ func NewSubmitCmd(logger *log.Logger) *cobra.Command {
4647
level.Error(logger).Log("msg", err.Error())
4748
os.Exit(1)
4849
}
49-
submission, err := submissionConfig.Submission(c.CredPath, args)
50+
submission, err := submissionConfig.Submission(c, args)
5051
if err != nil {
5152
level.Error(logger).Log("msg", "failed to configure manifest", "err", err.Error())
5253
os.Exit(1)

pkg/cmd/agent/test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ const (
2424
)
2525

2626
func NewTestCmd(logger *log.Logger) *cobra.Command {
27-
submissionConfig := diambra.NewSubmissionConfig(logger)
27+
submissionConfig := diambra.SubmissionConfig{}
28+
submissionConfig.RegisterCredentialsProviders()
2829
c, err := diambra.NewConfig(logger)
2930
if err != nil {
3031
level.Error(logger).Log("msg", err.Error())
@@ -37,7 +38,7 @@ func NewTestCmd(logger *log.Logger) *cobra.Command {
3738
Long: `This takes a docker image or submission manifest and runs it in the same way as it would be run when submitted
3839
to DIAMBRA. This is useful for testing your agent before submitting it. Optionally, you can pass in commands to run instead of the configured entrypoint.`,
3940
Run: func(cmd *cobra.Command, args []string) {
40-
submission, err := submissionConfig.Submission(c.CredPath, args)
41+
submission, err := submissionConfig.Submission(c, args)
4142
if err != nil {
4243
level.Error(logger).Log("msg", "failed to configure manifest", "err", err.Error())
4344
os.Exit(1)

pkg/container/docker.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func NewDockerRunner(logger log.Logger, client *client.Client, autoRemove bool)
6060
func (r *DockerRunner) Pull(c *Container, output *os.File) error {
6161
reader, err := r.Client.ImagePull(context.TODO(), c.Image, types.ImagePullOptions{})
6262
if err != nil {
63-
return fmt.Errorf("couldn't pull image %s: %w:\nTo disable pulling the image on start, retry with --images.pull=false", c.Image, err)
63+
return fmt.Errorf("couldn't pull image %s: %w:\nTo disable pulling the image on start, retry with --images.no-pull", c.Image, err)
6464
}
6565
defer reader.Close()
6666

pkg/diambra/config.go

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727

2828
"github.com/diambra/cli/pkg/container"
2929
"github.com/diambra/cli/pkg/diambra/client"
30+
"github.com/diambra/cli/pkg/secretsources"
3031
"github.com/diambra/init/initializer"
3132
"github.com/go-kit/log"
3233
"github.com/go-kit/log/level"
@@ -234,22 +235,28 @@ const (
234235
var ErrInvalidArgs = errors.New("either image, manifest path or submission id must be provided")
235236

236237
type SubmissionConfig struct {
237-
logger log.Logger
238-
239238
Mode string
240239
Difficulty string
241240
EnvVars map[string]string
242241
Sources map[string]string
243242
Secrets map[string]string
243+
SecretsFrom string
244244
ArgsIsCommand bool
245245
ManifestPath string
246246
SubmissionID int
247+
248+
credentialsProvider map[string]secretsources.CredentialProvider
247249
}
248250

249-
func NewSubmissionConfig(logger log.Logger) *SubmissionConfig {
250-
return &SubmissionConfig{
251-
logger: logger,
251+
func (c *SubmissionConfig) RegisterCredentialsProvider(name string, provider secretsources.CredentialProvider) {
252+
if c.credentialsProvider == nil {
253+
c.credentialsProvider = make(map[string]secretsources.CredentialProvider)
252254
}
255+
c.credentialsProvider[name] = provider
256+
}
257+
func (c *SubmissionConfig) RegisterCredentialsProviders() {
258+
c.RegisterCredentialsProvider("git", &secretsources.GitCredentials{})
259+
c.RegisterCredentialsProvider("huggingface", &secretsources.HuggingfaceCredentials{})
253260
}
254261

255262
func (c *SubmissionConfig) AddFlags(flags *pflag.FlagSet) {
@@ -258,20 +265,21 @@ func (c *SubmissionConfig) AddFlags(flags *pflag.FlagSet) {
258265
flags.StringToStringVarP(&c.EnvVars, "submission.env", "e", nil, "Environment variables to pass to the agent")
259266
flags.StringToStringVarP(&c.Sources, "submission.source", "u", nil, "Source urls to pass to the agent")
260267
flags.StringToStringVar(&c.Secrets, "submission.secret", nil, "Secrets to pass to the agent")
268+
flags.StringVar(&c.SecretsFrom, "submission.secrets-from", "", "Automatically add secrets. Supported values: git, huggingface")
261269
flags.StringVar(&c.ManifestPath, "submission.manifest", "", "Path to manifest file.")
262270
flags.IntVar(&c.SubmissionID, "submission.id", 0, "Submission ID to retrieve manifest from")
263271
flags.BoolVar(&c.ArgsIsCommand, "submission.set-command", false, "Treat positional arguments are command instead of entrypoint")
264272
}
265273

266-
func (c *SubmissionConfig) Submission(credPath string, args []string) (*client.Submission, error) {
274+
func (c *SubmissionConfig) Submission(config *EnvConfig, args []string) (*client.Submission, error) {
267275
var (
268276
nargs = len(args)
269277
manifest *client.Manifest
270278
)
271279

272280
switch {
273281
case c.SubmissionID != 0:
274-
cl, err := client.NewClient(c.logger, credPath)
282+
cl, err := client.NewClient(config.logger, config.CredPath)
275283
if err != nil {
276284
return nil, fmt.Errorf("failed to create client: %w", err)
277285
}
@@ -320,22 +328,62 @@ func (c *SubmissionConfig) Submission(credPath string, args []string) (*client.S
320328
}
321329

322330
if c.Sources != nil {
323-
level.Debug(c.logger).Log("msg", "Using sources", "sources", c.Sources)
331+
level.Debug(config.logger).Log("msg", "Using sources", "sources", c.Sources)
324332
manifest.Sources = make(map[string]string)
325333
for k, v := range c.Sources {
326334
manifest.Sources[k] = v
327335
}
328336
}
329337

330-
if manifest.Sources != nil {
331-
init, err := initializer.NewInitializer(c.logger, manifest.Sources, c.Secrets, map[string]string{}, "")
332-
if err != nil {
333-
return nil, err
338+
if c.SecretsFrom != "" {
339+
if c.Secrets == nil {
340+
c.Secrets = make(map[string]string)
334341
}
342+
}
335343

336-
if err := init.Validate(); err != nil {
337-
return nil, err
344+
if c.SecretsFrom != "" {
345+
ss, ok := c.credentialsProvider[c.SecretsFrom]
346+
if !ok {
347+
return nil, fmt.Errorf("invalid value for --submission.secrets-from: %s", c.SecretsFrom)
338348
}
349+
switch c.SecretsFrom {
350+
case "git":
351+
secrets, err := secretsources.CredentialsFill(ss, manifest.Sources)
352+
if err != nil {
353+
return nil, err
354+
}
355+
if manifest.Sources == nil {
356+
return nil, fmt.Errorf("sources are required to use --submission.secrets-from=git")
357+
}
358+
level.Debug(config.logger).Log("msg", "Adding git secrets")
359+
for k, v := range secrets {
360+
level.Info(config.logger).Log("msg", "Adding git secret", "key", k)
361+
c.Secrets[k] = v
362+
}
363+
case "huggingface":
364+
level.Debug(config.logger).Log("msg", "Adding huggingface secrets")
365+
secrets, err := ss.Credentials("")
366+
if err != nil {
367+
return nil, err
368+
}
369+
c.Secrets["HF_TOKEN"] = secrets["HF_TOKEN"]
370+
if manifest.Env == nil {
371+
manifest.Env = make(map[string]string)
372+
}
373+
manifest.Env["HF_TOKEN"] = "{{ .Secrets.HF_TOKEN }}"
374+
case "":
375+
default:
376+
return nil, fmt.Errorf("invalid value for --submission.secrets-from: %s", c.SecretsFrom)
377+
}
378+
}
379+
380+
init, err := initializer.NewInitializer(config.logger, manifest.Sources, c.Secrets, map[string]string{}, "")
381+
if err != nil {
382+
return nil, err
383+
}
384+
385+
if err := init.Validate(); err != nil {
386+
return nil, err
339387
}
340388

341389
return &client.Submission{

pkg/diambra/config_test.go

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,13 @@
1616
package diambra
1717

1818
import (
19+
"os"
20+
"path/filepath"
1921
"testing"
2022

2123
"github.com/diambra/cli/pkg/diambra/client"
24+
"github.com/diambra/cli/pkg/secretsources"
25+
"github.com/go-kit/log"
2226
"github.com/stretchr/testify/assert"
2327
)
2428

@@ -56,6 +60,13 @@ func TestAppArgs(t *testing.T) {
5660
}
5761

5862
func TestSubmissionConfig(t *testing.T) {
63+
envConfig := &EnvConfig{
64+
logger: log.NewNopLogger(),
65+
CredPath: "",
66+
}
67+
cwd, err := os.Getwd()
68+
assert.NoError(t, err)
69+
5970
for _, tc := range []struct {
6071
name string
6172
config SubmissionConfig
@@ -113,20 +124,60 @@ func TestSubmissionConfig(t *testing.T) {
113124
nil,
114125
},
115126
{
116-
"from args, with secrets",
117-
SubmissionConfig{},
118-
[]string{"diambra/agent-random-1:main", "--gameId", "doapp"},
127+
"from args with sources and secrets",
128+
SubmissionConfig{
129+
ManifestPath: "testdata/manifest.yaml",
130+
ArgsIsCommand: true,
131+
Sources: map[string]string{"model.zip": "https://user:{{ .Secrets.foo }}@example.com/model.zip"},
132+
Secrets: map[string]string{
133+
"foo": "bar",
134+
},
135+
},
136+
[]string{"python", "agent.py"},
119137
&client.Submission{
120138
Manifest: client.Manifest{
121-
Image: "diambra/agent-random-1:main",
122-
Args: []string{"--gameId", "doapp"},
139+
Image: "diambra/agent-random-1:main",
140+
Command: []string{"python", "agent.py"},
141+
Args: []string{"--gameId", "doapp"},
142+
Sources: map[string]string{
143+
"model.zip": "https://user:{{ .Secrets.foo }}@example.com/model.zip",
144+
},
145+
},
146+
Secrets: map[string]string{
147+
"foo": "bar",
148+
},
149+
},
150+
nil,
151+
},
152+
{
153+
"from args with sources and secrets from git",
154+
SubmissionConfig{
155+
ManifestPath: "testdata/manifest.yaml",
156+
ArgsIsCommand: true,
157+
Sources: map[string]string{"model.zip": "https://example.com/mode.zip"},
158+
SecretsFrom: "git",
159+
},
160+
[]string{"python", "agent.py"},
161+
&client.Submission{
162+
Manifest: client.Manifest{
163+
Image: "diambra/agent-random-1:main",
164+
Command: []string{"python", "agent.py"},
165+
Args: []string{"--gameId", "doapp"},
166+
Sources: map[string]string{
167+
"model.zip": "https://{{ .Secrets.git_username_1 }}:{{ .Secrets.git_password_1 }}@example.com/mode.zip",
168+
},
169+
},
170+
Secrets: map[string]string{
171+
"git_username_1": "user1",
172+
"git_password_1": "pass1",
123173
},
124174
},
125175
nil,
126176
},
127177
} {
128178
t.Run(tc.name, func(t *testing.T) {
129-
submission, err := tc.config.Submission("", tc.args)
179+
tc.config.RegisterCredentialsProvider("git", &secretsources.GitCredentials{Helper: filepath.Join(cwd, "../../test/mock-credential-helper.sh")})
180+
submission, err := tc.config.Submission(envConfig, tc.args)
130181
assert.Equal(t, tc.expectedErr, err)
131182
assert.Equal(t, tc.expected, submission)
132183
})

pkg/secretsources/credentials.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package secretsources
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"net/url"
7+
"os/exec"
8+
"strings"
9+
)
10+
11+
type CredentialProvider interface {
12+
Credentials(url string) (map[string]string, error)
13+
}
14+
15+
type GitCredentials struct {
16+
Helper string
17+
}
18+
19+
func (c *GitCredentials) Credentials(url string) (map[string]string, error) {
20+
args := []string{}
21+
if c.Helper != "" {
22+
args = append(args, "-c", fmt.Sprintf("credential.helper=%s", c.Helper))
23+
}
24+
args = append(args, "credential", "fill")
25+
cmd := exec.Command("git", args...)
26+
cmd.Stdin = strings.NewReader("url=" + url + "\n")
27+
28+
var stdout bytes.Buffer
29+
cmd.Stdout = &stdout
30+
if err := cmd.Run(); err != nil {
31+
return nil, fmt.Errorf("failed to run %v: %w", cmd, err)
32+
}
33+
34+
credentials := make(map[string]string)
35+
lines := strings.Split(stdout.String(), "\n")
36+
for _, line := range lines {
37+
parts := strings.SplitN(line, "=", 2)
38+
if len(parts) == 2 {
39+
credentials[parts[0]] = parts[1]
40+
}
41+
}
42+
43+
return credentials, nil
44+
}
45+
46+
// CredentialsFill calls the CredentialsProvider for each source and returns
47+
// a new source map with templating as well as a map of credentials for the templated values.
48+
func CredentialsFill(provider CredentialProvider, sources map[string]string) (map[string]string, error) {
49+
secrets := make(map[string]string)
50+
i := 0
51+
for k, v := range sources {
52+
i++
53+
u, err := url.Parse(v)
54+
if err != nil {
55+
return nil, fmt.Errorf("failed to parse url %s: %w", v, err)
56+
}
57+
credentials, err := provider.Credentials(v)
58+
if err != nil {
59+
return nil, err
60+
}
61+
if credentials["password"] == "" {
62+
continue
63+
}
64+
65+
if credentials["host"] != u.Host {
66+
return nil, fmt.Errorf("host %s does not match %s (this should never happend)", credentials["host"], u.Host)
67+
}
68+
69+
var (
70+
uservar = fmt.Sprintf("git_username_%d", i)
71+
passvar = fmt.Sprintf("git_password_%d", i)
72+
)
73+
74+
u.User = url.UserPassword(fmt.Sprintf("{{ %s }}", uservar), fmt.Sprintf("{{ %s }}", passvar))
75+
secrets[uservar] = credentials["username"]
76+
secrets[passvar] = credentials["password"]
77+
sources[k] = fmt.Sprintf("%s://{{ .Secrets.%s }}:{{ .Secrets.%s }}@%s%s", u.Scheme, uservar, passvar, u.Host, u.Path)
78+
}
79+
return secrets, nil
80+
}

0 commit comments

Comments
 (0)