diff --git a/ai/README.md b/ai/README.md index ccd4401080..5dafe7fb7e 100644 --- a/ai/README.md +++ b/ai/README.md @@ -15,7 +15,3 @@ - `$env:AOAI_COMPLETIONS_ENDPOINT = Read-Host 'Enter AOAI_COMPLETIONS_ENDPOINT'` - `$env:AOAI_DEPLOYMENT_NAME = Read-Host 'Enter AOAI_DEPLOYMENT_NAME'` - `go run main.go` - -## Development - -Modify prompts in the folders within *pkg/analysis/* (e.g. *pkg/analysis/flows/prompt.go* or *analyzer.go*) diff --git a/ai/main.go b/ai/main.go index d058fb19a0..d891c05cc3 100644 --- a/ai/main.go +++ b/ai/main.go @@ -1,14 +1,11 @@ package main import ( - "context" - + "github.com/microsoft/retina/ai/pkg/chat" "github.com/microsoft/retina/ai/pkg/lm" - flowscenario "github.com/microsoft/retina/ai/pkg/scenarios/flows" "github.com/sirupsen/logrus" "k8s.io/client-go/kubernetes" - "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" ) @@ -43,26 +40,8 @@ func main() { } log.Info("initialized Azure OpenAI model") - handleChat(log, config, clientset, model) -} - -// pretend there's input from chat interface -func handleChat(log logrus.FieldLogger, config *rest.Config, clientset *kubernetes.Clientset, model lm.Model) { - question := "what's wrong with my app?" - var chat lm.ChatHistory - - h := flowscenario.NewHandler(log, config, clientset, model) - params := &flowscenario.ScenarioParams{ - Scenario: flowscenario.DropScenario, - Namespace1: "default", - Namespace2: "default", + bot := chat.NewBot(log, config, clientset, model) + if err := bot.Loop(); err != nil { + log.WithError(err).Fatal("error running chat loop") } - - ctx := context.TODO() - response, err := h.Handle(ctx, question, chat, params) - if err != nil { - log.WithError(err).Fatal("error running flows scenario") - } - - _ = response } diff --git a/ai/pkg/analysis/flows/analyzer.go b/ai/pkg/analysis/flows/analyzer.go deleted file mode 100644 index 7a30ab2450..0000000000 --- a/ai/pkg/analysis/flows/analyzer.go +++ /dev/null @@ -1,26 +0,0 @@ -package flows - -import ( - "context" - "fmt" - - "github.com/microsoft/retina/ai/pkg/lm" - "github.com/sirupsen/logrus" -) - -type Analyzer struct { - log logrus.FieldLogger - model lm.Model -} - -func NewAnalyzer(log logrus.FieldLogger, model lm.Model) *Analyzer { - return &Analyzer{ - log: logrus.WithField("component", "flow-analyzer"), - model: model, - } -} - -func (a *Analyzer) Analyze(ctx context.Context, query string, chat lm.ChatHistory, summary FlowSummary) (string, error) { - message := fmt.Sprintf(messagePromptTemplate, query, summary.FormatForLM()) - return a.model.Generate(ctx, systemPrompt, chat, message) -} diff --git a/ai/pkg/analysis/flows/types.go b/ai/pkg/analysis/flows/types.go deleted file mode 100644 index 3fa5a4f66b..0000000000 --- a/ai/pkg/analysis/flows/types.go +++ /dev/null @@ -1,78 +0,0 @@ -package flows - -import ( - "errors" - "fmt" - "strings" - - flowpb "github.com/cilium/cilium/api/v1/flow" -) - -var ErrNoEndpointName = errors.New("no endpoint name") - -type Connection struct { - Pod1 string - Pod2 string - Key string - Flows []*flowpb.Flow -} - -type FlowSummary map[string]*Connection - -func (fs FlowSummary) FormatForLM() string { - // FIXME hacky right now - forwards := fs.connStrings(flowpb.Verdict_FORWARDED) - drops := fs.connStrings(flowpb.Verdict_DROPPED) - other := fs.connStrings(flowpb.Verdict_VERDICT_UNKNOWN) - - return fmt.Sprintf("SUCCESSFUL CONNECTIONS:\n%s\n\nDROPPED CONNECTIONS:\n%s\n\nOTHER CONNECTIONS:\n%s", forwards, drops, other) -} - -func (fs FlowSummary) connStrings(verdict flowpb.Verdict) string { - connStrings := make([]string, 0, len(fs)) - for _, conn := range fs { - match := false - for _, f := range conn.Flows { - // FIXME hacky right now - if f.GetVerdict() == verdict || (verdict == flowpb.Verdict_VERDICT_UNKNOWN && f.GetVerdict() != flowpb.Verdict_FORWARDED && f.GetVerdict() != flowpb.Verdict_DROPPED) { - match = true - break - } - } - - if !match { - continue - } - - connString := "" - if verdict == flowpb.Verdict_FORWARDED && conn.Flows[0].L4.GetTCP() != nil { - successful := false - rst := false - for _, f := range conn.Flows { - if f.GetVerdict() == flowpb.Verdict_FORWARDED && f.L4.GetTCP().GetFlags().GetSYN() && f.L4.GetTCP().GetFlags().GetACK() { - successful = true - continue - } - - if f.GetVerdict() == flowpb.Verdict_FORWARDED && f.L4.GetTCP().GetFlags().GetRST() { - rst = true - continue - } - } - _ = successful - connString = fmt.Sprintf("Connection: %s -> %s, Number of Flows: %d. Was Reset: %v", conn.Pod1, conn.Pod2, len(conn.Flows), rst) - - } else { - - connString = fmt.Sprintf("Connection: %s -> %s, Number of Flows: %d", conn.Pod1, conn.Pod2, len(conn.Flows)) - } - - connStrings = append(connStrings, connString) - } - - if len(connStrings) == 0 { - return "none" - } - - return strings.Join(connStrings, "\n") -} diff --git a/ai/pkg/chat/chat.go b/ai/pkg/chat/chat.go new file mode 100644 index 0000000000..55e2d5daa1 --- /dev/null +++ b/ai/pkg/chat/chat.go @@ -0,0 +1,91 @@ +package chat + +import ( + "context" + "fmt" + + "github.com/microsoft/retina/ai/pkg/lm" + flowretrieval "github.com/microsoft/retina/ai/pkg/retrieval/flows" + "github.com/microsoft/retina/ai/pkg/scenarios" + "github.com/microsoft/retina/ai/pkg/scenarios/dns" + "github.com/microsoft/retina/ai/pkg/scenarios/drops" + + "github.com/sirupsen/logrus" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" +) + +var ( + definitions = []*scenarios.Definition{ + drops.Definition, + dns.Definition, + } +) + +type Bot struct { + log logrus.FieldLogger + config *rest.Config + clientset *kubernetes.Clientset + model lm.Model +} + +// input log, config, clientset, model +func NewBot(log logrus.FieldLogger, config *rest.Config, clientset *kubernetes.Clientset, model lm.Model) *Bot { + return &Bot{ + log: log.WithField("component", "chat"), + config: config, + clientset: clientset, + model: model, + } +} + +func (b *Bot) Loop() error { + var history lm.ChatHistory + flowRetriever := flowretrieval.NewRetriever(b.log, b.config, b.clientset) + + for { + // TODO get user input + question := "what's wrong with my app?" + + // select scenario and get parameters + definition, params, err := b.selectScenario(question, history) + if err != nil { + return fmt.Errorf("error selecting scenario: %w", err) + } + + // cfg.FlowRetriever.UseFile() + + cfg := &scenarios.Config{ + Log: b.log, + Config: b.config, + Clientset: b.clientset, + Model: b.model, + FlowRetriever: flowRetriever, + } + + ctx := context.TODO() + response, err := definition.Handle(ctx, cfg, params, question, history) + if err != nil { + return fmt.Errorf("error handling scenario: %w", err) + } + + fmt.Println(response) + + // TODO keep chat loop going + break + } + + return nil +} + +func (b *Bot) selectScenario(question string, history lm.ChatHistory) (*scenarios.Definition, map[string]string, error) { + // TODO use chat interface + // FIXME hard-coding the scenario and params for now + d := definitions[0] + params := map[string]string{ + scenarios.Namespace1.Name: "default", + scenarios.Namespace2.Name: "default", + } + + return d, params, nil +} diff --git a/ai/pkg/lm/azure-openai.go b/ai/pkg/lm/azure-openai.go index 0eb43e5186..0fd846c205 100644 --- a/ai/pkg/lm/azure-openai.go +++ b/ai/pkg/lm/azure-openai.go @@ -65,11 +65,11 @@ func NewAzureOpenAI() (*AzureOpenAI, error) { return aoai, nil } -func (m *AzureOpenAI) Generate(ctx context.Context, systemPrompt string, chat ChatHistory, message string) (string, error) { +func (m *AzureOpenAI) Generate(ctx context.Context, systemPrompt string, history ChatHistory, message string) (string, error) { messages := []azopenai.ChatRequestMessageClassification{ &azopenai.ChatRequestSystemMessage{Content: to.Ptr(systemPrompt)}, } - for _, pair := range chat { + for _, pair := range history { messages = append(messages, &azopenai.ChatRequestUserMessage{Content: azopenai.NewChatRequestUserMessageContent(pair.User)}) messages = append(messages, &azopenai.ChatRequestAssistantMessage{Content: to.Ptr(pair.Assistant)}) } diff --git a/ai/pkg/lm/echo.go b/ai/pkg/lm/echo.go index 24f53b37b4..e4c1143212 100644 --- a/ai/pkg/lm/echo.go +++ b/ai/pkg/lm/echo.go @@ -13,9 +13,9 @@ func NewEchoModel() *EchoModel { return &EchoModel{} } -func (m *EchoModel) Generate(ctx context.Context, systemPrompt string, chat ChatHistory, message string) (string, error) { - chatStrings := make([]string, 0, len(chat)) - for _, pair := range chat { +func (m *EchoModel) Generate(ctx context.Context, systemPrompt string, history ChatHistory, message string) (string, error) { + chatStrings := make([]string, 0, len(history)) + for _, pair := range history { chatStrings = append(chatStrings, fmt.Sprintf("USER: %s\nASSISTANT: %s\n", pair.User, pair.Assistant)) } resp := fmt.Sprintf("systemPrompt: %s\nhistory: %s\nmessage: %s", systemPrompt, strings.Join(chatStrings, "\n"), message) diff --git a/ai/pkg/lm/model.go b/ai/pkg/lm/model.go index b848e30fb7..7e9eaca3d9 100644 --- a/ai/pkg/lm/model.go +++ b/ai/pkg/lm/model.go @@ -10,5 +10,5 @@ type MessagePair struct { type ChatHistory []MessagePair type Model interface { - Generate(ctx context.Context, systemPrompt string, chat ChatHistory, message string) (string, error) + Generate(ctx context.Context, systemPrompt string, history ChatHistory, message string) (string, error) } diff --git a/ai/pkg/analysis/flows/parser.go b/ai/pkg/parse/flows/parser.go similarity index 83% rename from ai/pkg/analysis/flows/parser.go rename to ai/pkg/parse/flows/parser.go index 2862e5a0c6..bea8789791 100644 --- a/ai/pkg/analysis/flows/parser.go +++ b/ai/pkg/parse/flows/parser.go @@ -9,19 +9,19 @@ import ( ) type Parser struct { - log logrus.FieldLogger - summary FlowSummary + log logrus.FieldLogger + connections Connections } func NewParser(log logrus.FieldLogger) *Parser { return &Parser{ - log: log.WithField("component", "flow-parser"), - summary: make(map[string]*Connection), + log: log.WithField("component", "flow-parser"), + connections: make(map[string]*Connection), } } -func (p *Parser) Summary() FlowSummary { - return p.summary +func (p *Parser) Connections() Connections { + return p.connections } func (p *Parser) Parse(flows []*flowpb.Flow) { @@ -57,7 +57,7 @@ func (p *Parser) addFlow(f *flowpb.Flow) error { pod1, pod2 := pods[0], pods[1] key := pod1 + "#" + pod2 - conn, exists := p.summary[key] + conn, exists := p.connections[key] if !exists { conn = &Connection{ Pod1: pod1, @@ -65,7 +65,7 @@ func (p *Parser) addFlow(f *flowpb.Flow) error { Key: key, Flows: []*flowpb.Flow{}, } - p.summary[key] = conn + p.connections[key] = conn } conn.Flows = append(conn.Flows, f) diff --git a/ai/pkg/parse/flows/types.go b/ai/pkg/parse/flows/types.go new file mode 100644 index 0000000000..02f6a8e833 --- /dev/null +++ b/ai/pkg/parse/flows/types.go @@ -0,0 +1,149 @@ +package flows + +import ( + "errors" + + flowpb "github.com/cilium/cilium/api/v1/flow" +) + +var ( + ErrNoEndpointName = errors.New("no endpoint name") + ErrNilEndpoint = errors.New("nil endpoint") +) + +type Connection struct { + Pod1 string + Pod2 string + Key string + + // UDP *UdpSummary + // TCP *TcpSummary + Flows []*flowpb.Flow +} + +type Connections map[string]*Connection + +// func + +// type UdpSummary struct { +// MinLatency time.Duration +// MaxLatency time.Duration +// AvgLatency time.Duration +// TotalPackets int +// TotalBytes int +// } + +// type TcpSummary struct { +// MinLatency time.Duration +// MaxLatency time.Duration +// AvgLatency time.Duration +// TotalPackets int +// TotalBytes int +// *TcpFlagSummary +// } + +// type TcpFlagSummary struct { +// SynCount int +// AckCount int +// SynAckCount int +// FinCount int +// RstCount int +// } + +// type FlowSummary map[string]*Connection + +// func (fs FlowSummary) Aggregate() { +// for _, conn := range fs { +// udpTimestamps := make(map[string][]*timestamppb.Timestamp) +// tcpTimestamps := make(map[string][]*timestamppb.Timestamp) +// for _, f := range conn.Flows { +// l4 := f.GetL4() +// if l4 == nil { +// continue +// } + +// udp := l4.GetUDP() +// if udp != nil { +// if conn.UDP == nil { +// conn.UDP = &UdpSummary{} +// } + +// conn.UDP.TotalPackets += 1 + +// src, err := endpointName(f.GetSource()) +// if err != nil { +// // FIXME warn and continue +// log.Fatalf("bad src endpoint while aggregating: %w", err) +// } +// dst, err := endpointName(f.GetDestination()) +// if err != nil { +// // FIXME warn and continue +// log.Fatalf("bad dst endpoint while aggregating: %w", err) +// } + +// tuple := fmt.Sprintf("%s:%d -> %s:%d", src, udp.GetSourcePort(), dst, udp.GetDestinationPort()) + +// time := f.GetTime() +// if time == nil { +// // FIXME warn and continue +// log.Fatalf("nil time while aggregating") +// } + +// udpTimestamps[tuple] = append(udpTimestamps[tuple], f.GetTime()) +// } + +// tcp := l4.GetTCP() +// if tcp != nil { +// if conn.TCP == nil { +// conn.TCP = &TcpSummary{} +// } + +// conn.TCP.TotalPackets += 1 + +// if conn.TCP.TcpFlagSummary == nil { +// conn.TCP.TcpFlagSummary = &TcpFlagSummary{} +// } + +// flags := tcp.GetFlags() +// if flags == nil { +// // FIXME warn and continue +// log.Fatalf("nil flags while aggregating") +// } + +// switch { +// case flags.SYN && flags.ACK: +// conn.TCP.TcpFlagSummary.SynAckCount += 1 +// case flags.SYN: +// conn.TCP.TcpFlagSummary.SynCount += 1 +// case flags.ACK: +// conn.TCP.TcpFlagSummary.AckCount += 1 +// case flags.FIN: +// conn.TCP.TcpFlagSummary.FinCount += 1 +// case flags.RST: +// conn.TCP.TcpFlagSummary.RstCount += 1 +// } + +// src, err := endpointName(f.GetSource()) +// if err != nil { +// // FIXME warn and continue +// log.Fatalf("bad src endpoint while aggregating: %w", err) +// } +// dst, err := endpointName(f.GetDestination()) +// if err != nil { +// // FIXME warn and continue +// log.Fatalf("bad dst endpoint while aggregating: %w", err) +// } + +// tuple := fmt.Sprintf("%s:%d -> %s:%d", src, udp.GetSourcePort(), dst, udp.GetDestinationPort()) + +// time := f.GetTime() +// if time == nil { +// // FIXME warn and continue +// log.Fatalf("nil time while aggregating") +// } + +// tcpTimestamps[tuple] = append(tcpTimestamps[tuple], f.GetTime()) +// } +// } +// } +// } diff --git a/ai/pkg/retrieval/flows/retriever.go b/ai/pkg/retrieval/flows/retriever.go index aeb0922c0b..bb1cb18191 100644 --- a/ai/pkg/retrieval/flows/retriever.go +++ b/ai/pkg/retrieval/flows/retriever.go @@ -20,6 +20,8 @@ import ( "k8s.io/client-go/rest" ) +const MaxFlowsFromHubbleRelay = 30000 + type Retriever struct { log logrus.FieldLogger config *rest.Config @@ -90,7 +92,6 @@ func (r *Retriever) Observe(ctx context.Context, req *observerpb.GetFlowsRequest observeCtx, observeCancel := context.WithTimeout(ctx, 30*time.Second) defer observeCancel() - // FIXME don't use maxFlows anymore? check for EOF? then remove this constant: MaxFlowsToAnalyze maxFlows := req.Number flows, err := r.observeFlowsGRPC(observeCtx, req, int(maxFlows)) if err != nil { diff --git a/ai/pkg/scenarios/common.go b/ai/pkg/scenarios/common.go new file mode 100644 index 0000000000..4ec8713f9c --- /dev/null +++ b/ai/pkg/scenarios/common.go @@ -0,0 +1,49 @@ +package scenarios + +import "regexp" + +// common parameters +var ( + k8sNameRegex = regexp.MustCompile(`^[a-zA-Z][-a-zA-Z0-9]*$`) + nodesRegex = regexp.MustCompile(`^\[[a-zA-Z][-a-zA-Z0-9_,]*\]$`) + + Namespace1 = &ParameterSpec{ + Name: "namespace1", + DataType: "string", + Description: "Namespace 1", + Optional: false, + Regex: k8sNameRegex, + } + + PodPrefix1 = &ParameterSpec{ + Name: "podPrefix1", + DataType: "string", + Description: "Pod prefix 1", + Optional: true, + Regex: k8sNameRegex, + } + + Namespace2 = &ParameterSpec{ + Name: "namespace2", + DataType: "string", + Description: "Namespace 2", + Optional: false, + Regex: k8sNameRegex, + } + + PodPrefix2 = &ParameterSpec{ + Name: "podPrefix2", + DataType: "string", + Description: "Pod prefix 2", + Optional: true, + Regex: k8sNameRegex, + } + + Nodes = &ParameterSpec{ + Name: "nodes", + DataType: "[]string", + Description: "Nodes", + Optional: true, + Regex: nodesRegex, + } +) diff --git a/ai/pkg/scenarios/dns/dns.go b/ai/pkg/scenarios/dns/dns.go new file mode 100644 index 0000000000..36ab7424a0 --- /dev/null +++ b/ai/pkg/scenarios/dns/dns.go @@ -0,0 +1,39 @@ +package dns + +import ( + "context" + + "github.com/microsoft/retina/ai/pkg/lm" + "github.com/microsoft/retina/ai/pkg/scenarios" +) + +var ( + Definition = scenarios.NewDefinition("DNS", "DNS", parameterSpecs, &handler{}) + + dnsQuery = &scenarios.ParameterSpec{ + Name: "dnsQuery", + DataType: "string", + Description: "DNS query", + Optional: true, + } + + parameterSpecs = []*scenarios.ParameterSpec{ + scenarios.Namespace1, + scenarios.Namespace2, + dnsQuery, + } +) + +// mirrored with parameterSpecs +type params struct { + Namespace1 string + Namespace2 string + DNSQuery string +} + +type handler struct{} + +func (h *handler) Handle(ctx context.Context, cfg *scenarios.Config, typedParams map[string]any, question string, history lm.ChatHistory) (string, error) { + // TODO + return "", nil +} diff --git a/ai/pkg/scenarios/drops/drops.go b/ai/pkg/scenarios/drops/drops.go new file mode 100644 index 0000000000..8373748508 --- /dev/null +++ b/ai/pkg/scenarios/drops/drops.go @@ -0,0 +1,251 @@ +package drops + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/microsoft/retina/ai/pkg/lm" + flowparsing "github.com/microsoft/retina/ai/pkg/parse/flows" + flowretrieval "github.com/microsoft/retina/ai/pkg/retrieval/flows" + "github.com/microsoft/retina/ai/pkg/scenarios" + + flowpb "github.com/cilium/cilium/api/v1/flow" + observerpb "github.com/cilium/cilium/api/v1/observer" +) + +var ( + Definition = scenarios.NewDefinition("DROPS", "DROPS", parameterSpecs, &handler{}) + + parameterSpecs = []*scenarios.ParameterSpec{ + scenarios.Namespace1, + scenarios.PodPrefix1, + scenarios.Namespace2, + scenarios.PodPrefix2, + scenarios.Nodes, + } +) + +// mirrored with parameterSpecs +type params struct { + Namespace1 string + PodPrefix1 string + Namespace2 string + PodPrefix2 string + Nodes []string +} + +type handler struct{} + +func (h *handler) Handle(ctx context.Context, cfg *scenarios.Config, typedParams map[string]any, question string, history lm.ChatHistory) (string, error) { + l := cfg.Log.WithField("scenario", "drops") + l.Info("handling drops scenario...") + + if err := cfg.FlowRetriever.Init(); err != nil { + return "", fmt.Errorf("error initializing flow retriever: %w", err) + } + + params := ¶ms{ + Namespace1: anyToString(typedParams[scenarios.Namespace1.Name]), + PodPrefix1: anyToString(typedParams[scenarios.PodPrefix1.Name]), + Namespace2: anyToString(typedParams[scenarios.Namespace2.Name]), + PodPrefix2: anyToString(typedParams[scenarios.PodPrefix2.Name]), + Nodes: anyToStringSlice(typedParams[scenarios.Nodes.Name]), + } + + req := flowsRequest(params) + flows, err := cfg.FlowRetriever.Observe(ctx, req) + if err != nil { + return "", fmt.Errorf("error observing flows: %w", err) + } + l.Info("observed flows") + + // analyze flows + p := flowparsing.NewParser(l) + p.Parse(flows) + connections := p.Connections() + + formattedFlowLogs := formatFlowLogs(connections) + + message := fmt.Sprintf(messagePromptTemplate, question, formattedFlowLogs) + analyzeCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + resp, err := cfg.Model.Generate(analyzeCtx, systemPrompt, history, message) + if err != nil { + return "", fmt.Errorf("error analyzing flows: %w", err) + } + l.Info("analyzed flows") + + return resp, nil +} + +// cast to string without nil panics +func anyToString(a any) string { + if a == nil { + return "" + } + return a.(string) +} + +// cast to []string without nil panics +func anyToStringSlice(a any) []string { + if a == nil { + return nil + } + return a.([]string) +} + +func flowsRequest(params *params) *observerpb.GetFlowsRequest { + req := &observerpb.GetFlowsRequest{ + Number: flowretrieval.MaxFlowsFromHubbleRelay, + Follow: true, + } + + protocol := []string{"TCP", "UDP"} + + if params.Namespace1 == "" && params.PodPrefix1 == "" && params.Namespace2 == "" && params.PodPrefix2 == "" { + req.Whitelist = []*flowpb.FlowFilter{ + { + NodeName: params.Nodes, + Protocol: protocol, + }, + } + + return req + } + + var prefix1 []string + if params.Namespace1 != "" || params.PodPrefix1 != "" { + prefix1 = append(prefix1, fmt.Sprintf("%s/%s", params.Namespace1, params.PodPrefix1)) + } + + var prefix2 []string + if params.Namespace2 != "" || params.PodPrefix2 != "" { + prefix2 = append(prefix2, fmt.Sprintf("%s/%s", params.Namespace2, params.PodPrefix2)) + } + + filterDirection1 := &flowpb.FlowFilter{ + NodeName: params.Nodes, + SourcePod: prefix1, + DestinationPod: prefix2, + Protocol: protocol, + } + + filterDirection2 := &flowpb.FlowFilter{ + NodeName: params.Nodes, + SourcePod: prefix2, + DestinationPod: prefix1, + Protocol: protocol, + } + + // filterPod1ToIP := &flowpb.FlowFilter{ + // NodeName: params.Nodes, + // SourcePod: prefix1, + // DestinationIp: []string{"10.224.1.214"}, + // Protocol: protocol, + // } + + // filterPod1FromIP := &flowpb.FlowFilter{ + // NodeName: params.Nodes, + // SourceIp: []string{"10.224.1.214"}, + // DestinationPod: prefix1, + // Protocol: protocol, + // } + + // includes services + // world := []string{"reserved:world"} + + // filterPod1ToWorld := &flowpb.FlowFilter{ + // NodeName: params.Nodes, + // SourcePod: prefix1, + // DestinationLabel: world, + // Protocol: protocol, + // } + + // filterPod1FromWorld := &flowpb.FlowFilter{ + // NodeName: params.Nodes, + // SourceLabel: world, + // DestinationPod: prefix1, + // Protocol: protocol, + // } + + req.Whitelist = []*flowpb.FlowFilter{ + filterDirection1, + filterDirection2, + // filterPod1FromIP, + // filterPod1ToIP, + } + + req.Whitelist = nil + + req.Blacklist = []*flowpb.FlowFilter{ + { + SourcePod: []string{"kube-system/"}, + }, + { + DestinationPod: []string{"kube-system/"}, + }, + } + + return req +} + +func formatFlowLogs(connections flowparsing.Connections) string { + // FIXME hacky right now + forwards := connStrings(connections, flowpb.Verdict_FORWARDED) + + drops := connStrings(connections, flowpb.Verdict_DROPPED) + other := connStrings(connections, flowpb.Verdict_VERDICT_UNKNOWN) + + return fmt.Sprintf("SUCCESSFUL CONNECTIONS:\n%s\n\nDROPPED CONNECTIONS:\n%s\n\nOTHER CONNECTIONS:\n%s", forwards, drops, other) +} + +func connStrings(connections flowparsing.Connections, verdict flowpb.Verdict) string { + connStrings := make([]string, 0, len(connections)) + for _, conn := range connections { + match := false + for _, f := range conn.Flows { + // FIXME hacky right now + if f.GetVerdict() == verdict || (verdict == flowpb.Verdict_VERDICT_UNKNOWN && f.GetVerdict() != flowpb.Verdict_FORWARDED && f.GetVerdict() != flowpb.Verdict_DROPPED) { + match = true + break + } + } + + if !match { + continue + } + + connString := "" + if verdict == flowpb.Verdict_FORWARDED && conn.Flows[0].L4.GetTCP() != nil { + successful := false + rst := false + for _, f := range conn.Flows { + if f.GetVerdict() == flowpb.Verdict_FORWARDED && f.L4.GetTCP().GetFlags().GetSYN() && f.L4.GetTCP().GetFlags().GetACK() { + successful = true + continue + } + + if f.GetVerdict() == flowpb.Verdict_FORWARDED && f.L4.GetTCP().GetFlags().GetRST() { + rst = true + continue + } + } + _ = successful + connString = fmt.Sprintf("Connection: %s -> %s, Number of Flows: %d. Was Reset: %v", conn.Pod1, conn.Pod2, len(conn.Flows), rst) + + } else { + + connString = fmt.Sprintf("Connection: %s -> %s, Number of Flows: %d", conn.Pod1, conn.Pod2, len(conn.Flows)) + } + + connStrings = append(connStrings, connString) + } + + if len(connStrings) == 0 { + return "none" + } + + return strings.Join(connStrings, "\n") +} diff --git a/ai/pkg/analysis/flows/prompt.go b/ai/pkg/scenarios/drops/prompt.go similarity index 96% rename from ai/pkg/analysis/flows/prompt.go rename to ai/pkg/scenarios/drops/prompt.go index fac5113207..a6d0fd1f03 100644 --- a/ai/pkg/analysis/flows/prompt.go +++ b/ai/pkg/scenarios/drops/prompt.go @@ -1,4 +1,4 @@ -package flows +package drops const ( systemPrompt = `You are an assistant with expertise in Kubernetes Networking. The user is debugging networking issues on their Pods and/or Nodes. Provide a succinct summary identifying any issues in the "summary of network flow logs" provided by the user.` diff --git a/ai/pkg/scenarios/flows/handler.go b/ai/pkg/scenarios/flows/handler.go deleted file mode 100644 index 71d64fd5e0..0000000000 --- a/ai/pkg/scenarios/flows/handler.go +++ /dev/null @@ -1,196 +0,0 @@ -package flows - -import ( - "context" - "fmt" - "time" - - flowanalysis "github.com/microsoft/retina/ai/pkg/analysis/flows" - "github.com/microsoft/retina/ai/pkg/lm" - flowretrieval "github.com/microsoft/retina/ai/pkg/retrieval/flows" - "github.com/microsoft/retina/ai/pkg/util" - - flowpb "github.com/cilium/cilium/api/v1/flow" - observerpb "github.com/cilium/cilium/api/v1/observer" - "github.com/sirupsen/logrus" - "k8s.io/client-go/kubernetes" - "k8s.io/client-go/rest" -) - -type FlowScenario string - -const ( - AnyScenario FlowScenario = "Any" - DropScenario FlowScenario = "Drops" - DnsScenario FlowScenario = "DNS" -) - -type ScenarioParams struct { - Scenario FlowScenario - - // parameters (all optional?) - DnsQuery string - Nodes []string - Namespace1 string - PodPrefix1 string - Namespace2 string - PodPrefix2 string -} - -type Handler struct { - log logrus.FieldLogger - r *flowretrieval.Retriever - p *flowanalysis.Parser - a *flowanalysis.Analyzer -} - -func NewHandler(log logrus.FieldLogger, config *rest.Config, clientset *kubernetes.Clientset, model lm.Model) *Handler { - return &Handler{ - log: log.WithField("component", "flow-handler"), - r: flowretrieval.NewRetriever(log, config, clientset), - p: flowanalysis.NewParser(log), - a: flowanalysis.NewAnalyzer(log, model), - } -} - -func (h *Handler) Handle(ctx context.Context, question string, chat lm.ChatHistory, params *ScenarioParams) (string, error) { - h.log.Info("handling flows scenario...") - - // get flows - // h.r.UseFile() - - if err := h.r.Init(); err != nil { - return "", fmt.Errorf("error initializing flow retriever: %w", err) - } - - req := flowsRequest(params) - flows, err := h.r.Observe(ctx, req) - if err != nil { - return "", fmt.Errorf("error observing flows: %w", err) - } - h.log.Info("observed flows") - - // analyze flows - h.p.Parse(flows) - summary := h.p.Summary() - - analyzeCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - resp, err := h.a.Analyze(analyzeCtx, question, chat, summary) - if err != nil { - return "", fmt.Errorf("error analyzing flows: %w", err) - } - h.log.Info("analyzed flows") - - // temporary printing - fmt.Println("flow summary:") - fmt.Println(summary.FormatForLM()) - fmt.Println() - fmt.Println("response:") - fmt.Println(resp) - - return resp, nil -} - -// TODO DNS should not have a destination Pod (except maybe a specific coredns pod) -func flowsRequest(params *ScenarioParams) *observerpb.GetFlowsRequest { - req := &observerpb.GetFlowsRequest{ - Number: util.MaxFlowsFromHubbleRelay, - Follow: true, - } - - if len(params.Nodes) == 0 { - params.Nodes = nil - } - - protocol := []string{"TCP", "UDP"} - if params.Scenario == DnsScenario { - protocol = []string{"DNS"} - } - - if params.Namespace1 == "" && params.PodPrefix1 == "" && params.Namespace2 == "" && params.PodPrefix2 == "" { - req.Whitelist = []*flowpb.FlowFilter{ - { - NodeName: params.Nodes, - Protocol: protocol, - }, - } - - return req - } - - var prefix1 []string - if params.Namespace1 != "" || params.PodPrefix1 != "" { - prefix1 = append(prefix1, fmt.Sprintf("%s/%s", params.Namespace1, params.PodPrefix1)) - } - - var prefix2 []string - if params.Namespace2 != "" || params.PodPrefix2 != "" { - prefix2 = append(prefix2, fmt.Sprintf("%s/%s", params.Namespace2, params.PodPrefix2)) - } - - filterDirection1 := &flowpb.FlowFilter{ - NodeName: params.Nodes, - SourcePod: prefix1, - DestinationPod: prefix2, - Protocol: protocol, - } - - filterDirection2 := &flowpb.FlowFilter{ - NodeName: params.Nodes, - SourcePod: prefix2, - DestinationPod: prefix1, - Protocol: protocol, - } - - // filterPod1ToIP := &flowpb.FlowFilter{ - // NodeName: params.Nodes, - // SourcePod: prefix1, - // DestinationIp: []string{"10.224.1.214"}, - // Protocol: protocol, - // } - - // filterPod1FromIP := &flowpb.FlowFilter{ - // NodeName: params.Nodes, - // SourceIp: []string{"10.224.1.214"}, - // DestinationPod: prefix1, - // Protocol: protocol, - // } - - // includes services - // world := []string{"reserved:world"} - - // filterPod1ToWorld := &flowpb.FlowFilter{ - // NodeName: params.Nodes, - // SourcePod: prefix1, - // DestinationLabel: world, - // Protocol: protocol, - // } - - // filterPod1FromWorld := &flowpb.FlowFilter{ - // NodeName: params.Nodes, - // SourceLabel: world, - // DestinationPod: prefix1, - // Protocol: protocol, - // } - - req.Whitelist = []*flowpb.FlowFilter{ - filterDirection1, - filterDirection2, - // filterPod1FromIP, - // filterPod1ToIP, - } - - req.Whitelist = nil - - req.Blacklist = []*flowpb.FlowFilter{ - { - SourcePod: []string{"kube-system/"}, - }, - { - DestinationPod: []string{"kube-system/"}, - }, - } - - return req -} diff --git a/ai/pkg/scenarios/types.go b/ai/pkg/scenarios/types.go new file mode 100644 index 0000000000..0ae48fe17c --- /dev/null +++ b/ai/pkg/scenarios/types.go @@ -0,0 +1,95 @@ +package scenarios + +import ( + "context" + "fmt" + "regexp" + "strconv" + "strings" + + "github.com/microsoft/retina/ai/pkg/lm" + flowretrieval "github.com/microsoft/retina/ai/pkg/retrieval/flows" + + "github.com/sirupsen/logrus" + "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" +) + +type Definition struct { + Name string + Description string + Specs []*ParameterSpec + handler +} + +func NewDefinition(name, description string, specs []*ParameterSpec, handler handler) *Definition { + return &Definition{ + Name: name, + Description: description, + Specs: specs, + handler: handler, + } +} + +func (d *Definition) Handle(ctx context.Context, cfg *Config, rawParams map[string]string, question string, history lm.ChatHistory) (string, error) { + typedParams := make(map[string]any) + + // validate params + for _, p := range d.Specs { + raw, ok := rawParams[p.Name] + if !ok { + if !p.Optional { + return "", fmt.Errorf("missing required parameter %s", p.Name) + } + + continue + } + + if p.Regex != nil && !p.Regex.MatchString(raw) { + return "", fmt.Errorf("parameter %s does not match regex format", p.Name) + } + + switch p.DataType { + case "string": + typedParams[p.Name] = raw + case "int": + i, err := strconv.Atoi(raw) + if err != nil { + return "", fmt.Errorf("parameter %s is not an integer", p.Name) + } + typedParams[p.Name] = i + case "[]string": + // make sure the format is like [a,b,c] + if raw == "" || raw[0] != '[' || raw[len(raw)-1] != ']' || strings.Count(raw, "[") != 1 || strings.Count(raw, "]") != 1 { + return "", fmt.Errorf("invalid array format for parameter %s", p.Name) + } + // remove brackets + raw = raw[1 : len(raw)-1] + typedParams[p.Name] = strings.Split(raw, ",") + default: + return "", fmt.Errorf("unsupported data type %s", p.DataType) + } + } + + return d.handler.Handle(ctx, cfg, typedParams, question, history) +} + +type ParameterSpec struct { + Name string + DataType string + Description string + Optional bool + Regex *regexp.Regexp +} + +type handler interface { + Handle(ctx context.Context, cfg *Config, typedParams map[string]any, question string, history lm.ChatHistory) (string, error) +} + +type Config struct { + Log logrus.FieldLogger + Config *rest.Config + Clientset *kubernetes.Clientset + Model lm.Model + FlowRetriever *flowretrieval.Retriever +}