diff --git a/pkg/restapi/cluster.go b/pkg/restapi/cluster.go index cd1cd7ae38..3be459e127 100644 --- a/pkg/restapi/cluster.go +++ b/pkg/restapi/cluster.go @@ -149,6 +149,9 @@ func (h clusterHandler) updateCluster(w http.ResponseWriter, r *http.Request) { return } newCluster.ID = c.ID + // Cluster.KnownHosts are not part of REST API definitions, + // so we need to fill them based on current cluster state. + newCluster.KnownHosts = c.KnownHosts if err := h.svc.PutCluster(r.Context(), newCluster); err != nil { respondError(w, r, errors.Wrapf(err, "update cluster %q", c.ID)) diff --git a/pkg/service/cluster/service.go b/pkg/service/cluster/service.go index 60a929bbae..e2167b0b44 100644 --- a/pkg/service/cluster/service.go +++ b/pkg/service/cluster/service.go @@ -442,6 +442,17 @@ func (s *Service) PutCluster(ctx context.Context, c *Cluster) (err error) { s.logger.Info(ctx, "Adding new cluster", "cluster_id", c.ID) } else { s.logger.Info(ctx, "Updating cluster", "cluster_id", c.ID) + // Putting cluster should theoretically just set cluster state to the provided one. + // The problem is that Cluster.KnownHosts are not part of REST API definitions, + // so it's possible that cluster passed to PutCluster doesn't have it set by accident. + // This shouldn't be the case, but since it never makes sense for Cluster.KnownHosts + // to be empty (they should at least contain resolved Cluster.Host), we can add additional + // safety net here and load them if they are missing. + if len(c.KnownHosts) == 0 { + if err := s.loadKnownHosts(c); err != nil && !errors.Is(err, gocql.ErrNotFound) { + return errors.Wrap(err, "load known hosts") + } + } } // Validate cluster model diff --git a/pkg/service/cluster/service_integration_test.go b/pkg/service/cluster/service_integration_test.go index 7a9264983a..571c4463e5 100644 --- a/pkg/service/cluster/service_integration_test.go +++ b/pkg/service/cluster/service_integration_test.go @@ -850,25 +850,45 @@ func TestServiceStorageIntegration(t *testing.T) { if err = s.PutCluster(ctx, &initialCluster); err != nil { t.Fatal(err) } + Print("Known hosts are set after cluster creation") + getCluster, err := s.GetClusterByID(t.Context(), initialCluster.ID) + if err != nil { + t.Fatal(err) + } + if len(getCluster.KnownHosts) != len(hosts) { + t.Fatalf("Expected %d known hosts, got %d", len(hosts), len(getCluster.KnownHosts)) + } - unavailableHost := hosts[1] - Print("Block connectivity to host: " + unavailableHost) - if err := RunIptablesCommand(t, unavailableHost, CmdBlockScyllaREST); err != nil { + Print("Block connectivity to host: " + clusterHost) + if err := RunIptablesCommand(t, clusterHost, CmdBlockScyllaREST); err != nil { t.Fatal(err) } - defer RunIptablesCommand(t, unavailableHost, CmdUnblockScyllaREST) + defer RunIptablesCommand(t, clusterHost, CmdUnblockScyllaREST) Print("Expect connectivity failure when adding new cluster") putCluster := *validCluster() putCluster.Host = clusterHost + // Simulate missing known hosts and expect that + // they won't overwrite existing known hosts. + putCluster.KnownHosts = nil if err := s.PutCluster(ctx, &putCluster); err == nil { t.Fatal("Expected connectivity failure when adding new cluster, got nil") } - newClusterHost := hosts[2] + Print("Known hosts are set after failed cluster update") + getCluster, err = s.GetClusterByID(t.Context(), initialCluster.ID) + if err != nil { + t.Fatal(err) + } + if len(getCluster.KnownHosts) != len(hosts) { + t.Fatalf("Expected %d known hosts, got %d", len(hosts), len(getCluster.KnownHosts)) + } + + newClusterHost := hosts[1] Print("Expect connectivity failure when updating existing cluster host param: " + newClusterHost) putCluster = initialCluster putCluster.Host = newClusterHost + putCluster.KnownHosts = nil if err := s.PutCluster(ctx, &putCluster); err == nil { t.Fatal("Expected connectivity failure when updating existing cluster host param, got nil") } @@ -876,9 +896,19 @@ func TestServiceStorageIntegration(t *testing.T) { Print("Expect success when updating existing cluster labels param") putCluster = initialCluster putCluster.Labels = map[string]string{"foo": "bar"} + putCluster.KnownHosts = nil if err := s.PutCluster(ctx, &putCluster); err != nil { t.Fatal(err) } + + Print("Known hosts are set after successful cluster update") + getCluster, err = s.GetClusterByID(t.Context(), initialCluster.ID) + if err != nil { + t.Fatal(err) + } + if len(getCluster.KnownHosts) != len(hosts) { + t.Fatalf("Expected %d known hosts, got %d", len(hosts), len(getCluster.KnownHosts)) + } }) t.Run("no --host in SM DB", func(t *testing.T) {