Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions server/api/region.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ func (h *regionHandler) GetRegionByID(w http.ResponseWriter, r *http.Request) {
}

regionInfo := rc.GetRegion(regionID)
failpoint.Inject("RejectGetRegionByIDWhenAccessLeader", func() {
if h.svr.GetMember().IsServing() {
regionInfo = nil
}
})
if regionInfo == nil {
h.rd.JSON(w, http.StatusNotFound, errs.ErrRegionNotFound.FastGenByArgs(regionID).Error())
return
Expand Down
18 changes: 16 additions & 2 deletions tools/pd-ctl/pdctl/command/global.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,11 @@ func doRequest(cmd *cobra.Command, prefix string, method string, customHeader ht
o(b)
}
var resp string
header := buildDirectHeader(cmd, customHeader)

endpoints := getEndpoints(cmd)
err := tryURLs(cmd, endpoints, func(endpoint string) error {
return do(endpoint, prefix, method, &resp, customHeader, b)
return do(endpoint, prefix, method, &resp, header, b)
})
return resp, err
}
Expand All @@ -181,13 +182,26 @@ func doRequestSingleEndpoint(cmd *cobra.Command, endpoint, prefix, method string
o(b)
}
var resp string
header := buildDirectHeader(cmd, customHeader)

err := requestURL(cmd, endpoint, func(endpoint string) error {
return do(endpoint, prefix, method, &resp, customHeader, b)
return do(endpoint, prefix, method, &resp, header, b)
})
return resp, err
}

func buildDirectHeader(cmd *cobra.Command, customHeader http.Header) http.Header {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please update other naming together

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

header := customHeader.Clone()
if header == nil {
header = http.Header{}
}

if direct, err := cmd.Flags().GetBool("no-forward"); err == nil && direct {
header.Set(apiutil.PDAllowFollowerHandleHeader, "true")
}
return header
}

func dial(req *http.Request) (string, error) {
resp, err := dialClient.Do(req)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions tools/pd-ctl/pdctl/command/region_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ func NewRegionCommand() *cobra.Command {
r.AddCommand(scanRegion)

r.Flags().String("jq", "", "jq query")
r.PersistentFlags().Bool("no-forward", false, "direct request to the endpoint instead of forwarding by PD leader if the flag exists")

return r
}
Expand Down
55 changes: 55 additions & 0 deletions tools/pd-ctl/tests/global_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"sync"
"testing"

"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -85,3 +86,57 @@ func TestSendAndGetComponent(t *testing.T) {
re.NoError(err)
re.Equal(fmt.Sprintf("%s\n", command.PDControlCallerID), string(output))
}

func TestRegionNoProxyHeader(t *testing.T) {
re := require.New(t)
var (
mu sync.Mutex
count = 0
)
handler := func(context.Context, *server.Server) (http.Handler, apiutil.APIServiceGroup, error) {
mux := http.NewServeMux()
mux.HandleFunc("/pd/api/v1/regions", func(w http.ResponseWriter, r *http.Request) {
mu.Lock()
if vals := r.Header.Values(apiutil.PDAllowFollowerHandleHeader); len(vals) > 0 {
count++
}
mu.Unlock()
fmt.Fprint(w, `{}`)
})
info := apiutil.APIServiceGroup{IsCore: true}
return mux, info, nil
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

cluster, err := tests.NewTestClusterWithHandlers(ctx, 1, []server.HandlerBuilder{handler})
re.NoError(err)
defer cluster.Destroy()

err = cluster.RunInitialServers()
re.NoError(err)

leaderName := cluster.WaitLeader()
re.NotEmpty(leaderName)
pdAddr := cluster.GetLeaderServer().GetAddr()

cmd := cmd.GetRootCmd()
_, err = ExecuteCommand(cmd, "-u", pdAddr, "region")
re.NoError(err)
re.Equal(0, count)

// PD-Allow-follower-handle is only added when --no-forward=true.
_, err = ExecuteCommand(cmd, "-u", pdAddr, "region", "--no-forward")
re.NoError(err)
re.Equal(1, count)

// --no-forward=false should not add PD-Allow-follower-handle.
_, err = ExecuteCommand(cmd, "-u", pdAddr, "region", "--no-forward=false")
re.NoError(err)
re.Equal(1, count)

_, err = ExecuteCommand(cmd, "-u", pdAddr, "region", "--no-forward=true")
re.NoError(err)
re.Equal(2, count)
}
68 changes: 68 additions & 0 deletions tools/pd-ctl/tests/region/region_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"fmt"
"strconv"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -47,6 +48,10 @@ func TestRegionTestSuite(t *testing.T) {

func (suite *regionTestSuite) SetupSuite() {
suite.env = pdTests.NewSchedulingTestEnvironment(suite.T())
suite.env.PDCount = 3
suite.env.RunFunc(func(cluster *pdTests.TestCluster) {
suite.NotEmpty(cluster.WaitLeader())
})
}

func (suite *regionTestSuite) TearDownSuite() {
Expand All @@ -64,6 +69,7 @@ func (suite *regionTestSuite) TestRegionKeyFormat() {
func (suite *regionTestSuite) checkRegionKeyFormat(cluster *pdTests.TestCluster) {
re := suite.Require()
url := cluster.GetConfig().GetClientURL()
cluster.WaitLeader()
store := &metapb.Store{
Id: 1,
State: metapb.StoreState_Up,
Expand Down Expand Up @@ -497,3 +503,65 @@ func (suite *regionTestSuite) checkPatrolWithLimit(cluster *pdTests.TestCluster)
re.Empty(res["results"])
}
}

func (suite *regionTestSuite) TestFollowerDirect() {
suite.env.RunTestInNonMicroserviceEnv(suite.followerDirect)
}

func (suite *regionTestSuite) followerDirect(cluster *pdTests.TestCluster) {
re := suite.Require()
re.NotEmpty(cluster.WaitLeader())
if err := cluster.GetLeaderServer().BootstrapCluster(); err != nil {
re.Contains(err.Error(), "already bootstrapped")
}
cmd := ctl.GetRootCmd()
stores := []*metapb.Store{
{
Id: 1,
State: metapb.StoreState_Up,
},
{
Id: 2,
State: metapb.StoreState_Up,
},
{
Id: 3,
State: metapb.StoreState_Up,
},
}

for i := range stores {
pdTests.MustPutStore(re, cluster, stores[i])
}
metaRegion := &metapb.Region{
Id: 100,
StartKey: []byte(""),
EndKey: []byte(""),
Peers: []*metapb.Peer{
{Id: 1, StoreId: 1},
{Id: 5, StoreId: 2},
{Id: 6, StoreId: 3}},
RegionEpoch: &metapb.RegionEpoch{ConfVer: 1, Version: 1},
}
region := core.NewRegionInfo(metaRegion, metaRegion.Peers[0])
re.NoError(cluster.HandleRegionHeartbeat(region))
pdAddr := cluster.GetLeaderServer().GetAddr()
re.NoError(failpoint.Enable("github.com/tikv/pd/server/api/RejectGetRegionByIDWhenAccessLeader", "return(true)"))
defer func() {
re.NoError(failpoint.Disable("github.com/tikv/pd/server/api/RejectGetRegionByIDWhenAccessLeader"))
}()
for _, server := range cluster.GetServers() {
serverAddr := server.GetAddr()
// leader reject any region info request with --no-forward, followers should work normally.
if serverAddr == pdAddr {
output, err := tests.ExecuteCommand(cmd, "-u", serverAddr, "region", "100", "--no-forward")
re.NoError(err)
re.Contains(string(output), "Failed to get region")
} else {
output, err := tests.ExecuteCommand(cmd, "-u", serverAddr, "region", "100", "--no-forward")
re.NoError(err)
outputStr := string(output)
re.True(strings.Contains(outputStr, "\"id\":100") || strings.Contains(outputStr, "TiKV cluster not bootstrapped"), outputStr)
}
}
}
Loading