Skip to content

Commit d481f76

Browse files
committed
Consolidated limit and scroll logic across handlers for better code reuse.
Also added some tests to check that the scroll is properly exhausted.
1 parent 87fb2c2 commit d481f76

File tree

2 files changed

+131
-109
lines changed

2 files changed

+131
-109
lines changed

handlers.go

Lines changed: 75 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,33 @@ func dumpHttpErrorResponse(w http.ResponseWriter, err error) {
7171
dumpErrorResponse(w, status_code, err.Error())
7272
}
7373

74+
func getPageLimit(params url.Values, limit_name string, upper_limit int) (int, error) {
75+
if !params.Has(limit_name) {
76+
return upper_limit, nil
77+
}
78+
79+
limit, err := strconv.Atoi(params.Get(limit_name))
80+
if err != nil || limit <= 0 {
81+
return 0, newHttpError(http.StatusBadRequest, errors.New("invalid '" + limit_name + "'"))
82+
}
83+
84+
if limit > upper_limit {
85+
return upper_limit, nil
86+
}
87+
88+
return limit, nil
89+
}
90+
91+
func encodeNextParameters(scroll_name string, scroll_value string, params url.Values, other_names []string) string {
92+
gathered := url.Values{ scroll_name: []string{ scroll_value } }
93+
for _, p := range other_names {
94+
if params.Has(p) {
95+
gathered[p] = params[p]
96+
}
97+
}
98+
return gathered.Encode()
99+
}
100+
74101
func checkVerificationCode(path string, verifier *verificationRegistry, timeout time.Duration) (fs.FileInfo, error) {
75102
expected_code, ok := verifier.Pop(path)
76103
if !ok {
@@ -411,17 +438,12 @@ func newQueryHandler(db *sql.DB, tokenizer *unicodeTokenizer, wild_tokenizer *un
411438
options.Scroll = &queryScroll{ Time: time, Pid: pid }
412439
}
413440

414-
options.PageLimit = 100
415-
if params.Has("limit") {
416-
limit, err := strconv.Atoi(params.Get("limit"))
417-
if err != nil || limit <= 0 {
418-
dumpJsonResponse(w, http.StatusBadRequest, map[string]string{ "status": "ERROR", "reason": "invalid 'limit'" })
419-
return
420-
}
421-
if limit < options.PageLimit {
422-
options.PageLimit = limit
423-
}
441+
limit, err := getPageLimit(params, "limit", 100)
442+
if err != nil {
443+
dumpHttpErrorResponse(w, err)
444+
return
424445
}
446+
options.PageLimit = limit
425447

426448
if params.Has("metadata") {
427449
options.IncludeMetadata = (params.Get("metadata") != "false")
@@ -439,7 +461,7 @@ func newQueryHandler(db *sql.DB, tokenizer *unicodeTokenizer, wild_tokenizer *un
439461
query := &searchClause{}
440462
restricted := http.MaxBytesReader(w, r.Body, 1048576)
441463
dec := json.NewDecoder(restricted)
442-
err := dec.Decode(query)
464+
err = dec.Decode(query)
443465
if err != nil {
444466
dumpJsonResponse(w, http.StatusBadRequest, map[string]string{ "status": "ERROR", "reason": fmt.Sprintf("failed to parse request body; %v", err) })
445467
return
@@ -467,11 +489,12 @@ func newQueryHandler(db *sql.DB, tokenizer *unicodeTokenizer, wild_tokenizer *un
467489
respbody := map[string]interface{} { "results": res }
468490
if len(res) == options.PageLimit {
469491
last := &(res[options.PageLimit - 1])
470-
next := endpoint + "?scroll=" + strconv.FormatInt(last.Time, 10) + "," + strconv.FormatInt(last.Pid, 10)
471-
if translate {
472-
next += "&translate=true"
473-
}
474-
respbody["next"] = next
492+
respbody["next"] = endpoint + "?" + encodeNextParameters(
493+
"scroll",
494+
strconv.FormatInt(last.Time, 10) + "," + strconv.FormatInt(last.Pid, 10),
495+
params,
496+
[]string{ "limit", "metadata", "translate" },
497+
)
475498
}
476499

477500
dumpJsonResponse(w, http.StatusOK, respbody)
@@ -642,7 +665,7 @@ func newListFilesHandler(db *sql.DB, whitelist linkWhitelist) func(http.Response
642665

643666
func newListRegisteredDirectoriesHandler(db *sql.DB, endpoint string) func(http.ResponseWriter, *http.Request) {
644667
return func(w http.ResponseWriter, r *http.Request) {
645-
options := listRegisteredDirectoriesOptions{ PageLimit: 100 }
668+
options := listRegisteredDirectoriesOptions{}
646669
params := r.URL.Query()
647670

648671
if params.Has("user") {
@@ -682,16 +705,12 @@ func newListRegisteredDirectoriesHandler(db *sql.DB, endpoint string) func(http.
682705
options.Exists = &exists
683706
}
684707

685-
if params.Has("limit") {
686-
limit, err := strconv.Atoi(params.Get("limit"))
687-
if err != nil || limit <= 0 {
688-
dumpJsonResponse(w, http.StatusBadRequest, map[string]string{ "status": "ERROR", "reason": "invalid 'limit'" })
689-
return
690-
}
691-
if (limit < options.PageLimit) {
692-
options.PageLimit = limit
693-
}
708+
limit, err := getPageLimit(params, "limit", 100)
709+
if err != nil {
710+
dumpHttpErrorResponse(w, err)
711+
return
694712
}
713+
options.PageLimit = limit
695714

696715
if params.Has("scroll") {
697716
val := params.Get("scroll")
@@ -725,13 +744,12 @@ func newListRegisteredDirectoriesHandler(db *sql.DB, endpoint string) func(http.
725744
respbody := map[string]interface{} { "results": output }
726745
if len(output) == options.PageLimit {
727746
last := output[options.PageLimit - 1]
728-
next := endpoint + "?scroll=" + strconv.FormatInt(last.Time, 10) + "," + strconv.FormatInt(last.Did, 10)
729-
for _, extra := range []string { "user", "path_prefix", "within_path", "contains_path", "exists", "limit" } {
730-
if params.Has(extra) {
731-
next += "&" + extra + "=" + url.QueryEscape(params.Get(extra))
732-
}
733-
}
734-
respbody["next"] = next
747+
respbody["next"] = endpoint + "?" + encodeNextParameters(
748+
"scroll",
749+
strconv.FormatInt(last.Time, 10) + "," + strconv.FormatInt(last.Did, 10),
750+
params,
751+
[]string { "user", "path_prefix", "within_path", "contains_path", "exists", "limit" },
752+
)
735753
}
736754

737755
dumpJsonResponse(w, http.StatusOK, respbody)
@@ -742,7 +760,7 @@ func newListRegisteredDirectoriesHandler(db *sql.DB, endpoint string) func(http.
742760

743761
func newListFieldsHandler(db *sql.DB, endpoint string) func(http.ResponseWriter, *http.Request) {
744762
return func(w http.ResponseWriter, r *http.Request) {
745-
options := listFieldsOptions{ PageLimit: 1000 }
763+
options := listFieldsOptions{}
746764
params := r.URL.Query()
747765

748766
if params.Has("pattern") {
@@ -754,16 +772,12 @@ func newListFieldsHandler(db *sql.DB, endpoint string) func(http.ResponseWriter,
754772
options.Count = params.Get("count") == "true"
755773
}
756774

757-
if params.Has("limit") {
758-
limit, err := strconv.Atoi(params.Get("limit"))
759-
if err != nil || limit <= 0 {
760-
dumpJsonResponse(w, http.StatusBadRequest, map[string]string{ "status": "ERROR", "reason": "invalid 'limit'" })
761-
return
762-
}
763-
if (limit < options.PageLimit) {
764-
options.PageLimit = limit
765-
}
775+
limit, err := getPageLimit(params, "limit", 1000)
776+
if err != nil {
777+
dumpHttpErrorResponse(w, err)
778+
return
766779
}
780+
options.PageLimit = limit
767781

768782
if params.Has("scroll") {
769783
options.Scroll = &listFieldsScroll{ Field: params.Get("scroll") }
@@ -777,13 +791,12 @@ func newListFieldsHandler(db *sql.DB, endpoint string) func(http.ResponseWriter,
777791

778792
respbody := map[string]interface{} { "results": output }
779793
if len(output) == options.PageLimit {
780-
next := endpoint + "?scroll=" + output[options.PageLimit - 1].Field
781-
for _, extra := range []string { "pattern", "count", "limit" } {
782-
if params.Has(extra) {
783-
next += "&" + extra + "=" + url.QueryEscape(params.Get(extra))
784-
}
785-
}
786-
respbody["next"] = next
794+
respbody["next"] = endpoint + "?" + encodeNextParameters(
795+
"scroll",
796+
output[options.PageLimit - 1].Field,
797+
params,
798+
[]string { "pattern", "count", "limit" },
799+
)
787800
}
788801

789802
dumpJsonResponse(w, http.StatusOK, respbody)
@@ -792,7 +805,7 @@ func newListFieldsHandler(db *sql.DB, endpoint string) func(http.ResponseWriter,
792805

793806
func newListTokensHandler(db *sql.DB, endpoint string) func(http.ResponseWriter, *http.Request) {
794807
return func(w http.ResponseWriter, r *http.Request) {
795-
options := listTokensOptions{ PageLimit: 1000 }
808+
options := listTokensOptions{}
796809
params := r.URL.Query()
797810

798811
if params.Has("pattern") {
@@ -809,16 +822,12 @@ func newListTokensHandler(db *sql.DB, endpoint string) func(http.ResponseWriter,
809822
options.Count = params.Get("count") == "true"
810823
}
811824

812-
if params.Has("limit") {
813-
limit, err := strconv.Atoi(params.Get("limit"))
814-
if err != nil || limit <= 0 {
815-
dumpJsonResponse(w, http.StatusBadRequest, map[string]string{ "status": "ERROR", "reason": "invalid 'limit'" })
816-
return
817-
}
818-
if (limit < options.PageLimit) {
819-
options.PageLimit = limit
820-
}
825+
limit, err := getPageLimit(params, "limit", 1000)
826+
if err != nil {
827+
dumpHttpErrorResponse(w, err)
828+
return
821829
}
830+
options.PageLimit = limit
822831

823832
if params.Has("scroll") {
824833
options.Scroll = &listTokensScroll{ Token: params.Get("scroll") }
@@ -832,13 +841,12 @@ func newListTokensHandler(db *sql.DB, endpoint string) func(http.ResponseWriter,
832841

833842
respbody := map[string]interface{} { "results": output }
834843
if len(output) == options.PageLimit {
835-
next := endpoint + "?scroll=" + output[options.PageLimit - 1].Token
836-
for _, extra := range []string { "pattern", "field", "count", "limit" } {
837-
if params.Has(extra) {
838-
next += "&" + extra + "=" + url.QueryEscape(params.Get(extra))
839-
}
840-
}
841-
respbody["next"] = next
844+
respbody["next"] = endpoint + "?" + encodeNextParameters(
845+
"scroll",
846+
output[options.PageLimit - 1].Token,
847+
params,
848+
[]string { "pattern", "field", "count", "limit" },
849+
)
842850
}
843851

844852
dumpJsonResponse(w, http.StatusOK, respbody)

handlers_test.go

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1859,7 +1859,7 @@ func TestListFieldsHandler(t *testing.T) {
18591859
if len(r.Results) != 5 {
18601860
t.Errorf("unexpected results; %v", r)
18611861
}
1862-
if r.Next == "" || !strings.HasPrefix(r.Next, "/fields?scroll=") || !strings.HasSuffix(r.Next, "&limit=5") {
1862+
if r.Next == "" || !strings.Contains(r.Next, "scroll=") || !strings.Contains(r.Next, "limit=5") {
18631863
t.Errorf("expected a next string; %v", r)
18641864
}
18651865

@@ -1868,29 +1868,36 @@ func TestListFieldsHandler(t *testing.T) {
18681868
found[res.Field] = true
18691869
}
18701870

1871-
// Hitting up the scroll.
1872-
req, err = http.NewRequest("GET", r.Next, nil)
1873-
if err != nil {
1874-
t.Fatal(err)
1875-
}
1871+
// Hitting up the scroll until exhaustion.
1872+
for i := 0; i < 100; i++ {
1873+
req, err = http.NewRequest("GET", r.Next, nil)
1874+
if err != nil {
1875+
t.Fatal(err)
1876+
}
18761877

1877-
handler.ServeHTTP(rr, req)
1878-
if rr.Code != http.StatusOK {
1879-
t.Fatalf("should have succeeded (got %d)", rr.Code)
1880-
}
1878+
handler.ServeHTTP(rr, req)
1879+
if rr.Code != http.StatusOK {
1880+
t.Fatalf("should have succeeded (got %d)", rr.Code)
1881+
}
18811882

1882-
dec = json.NewDecoder(rr.Body)
1883-
err = dec.Decode(&r)
1884-
if err != nil {
1885-
t.Fatal(err)
1886-
}
1887-
if len(r.Results) != 5 {
1888-
t.Errorf("unexpected results; %v", r)
1889-
}
1883+
dec = json.NewDecoder(rr.Body)
1884+
err = dec.Decode(&r)
1885+
if err != nil {
1886+
t.Fatal(err)
1887+
}
18901888

1891-
for _, res := range r.Results {
1892-
if _, ok := found[res.Field]; ok {
1893-
t.Errorf("detected duplicate entries from scroll; %v", res.Field)
1889+
for _, res := range r.Results {
1890+
if _, ok := found[res.Field]; ok {
1891+
t.Errorf("detected duplicate entries from scroll; %v", res.Field)
1892+
}
1893+
}
1894+
1895+
if len(r.Results) != 5 {
1896+
if r.Next != "" {
1897+
t.Errorf("unexpected results; %v", r)
1898+
} else {
1899+
break
1900+
}
18941901
}
18951902
}
18961903
})
@@ -2039,7 +2046,7 @@ func TestListTokensHandler(t *testing.T) {
20392046
if len(r.Results) != 5 {
20402047
t.Errorf("unexpected results; %v", r)
20412048
}
2042-
if r.Next == "" || !strings.HasPrefix(r.Next, "/tokens?scroll=") || !strings.HasSuffix(r.Next, "&limit=5") {
2049+
if r.Next == "" || !strings.Contains(r.Next, "scroll=") || !strings.Contains(r.Next, "limit=5") {
20432050
t.Errorf("expected a next string; %v", r)
20442051
}
20452052

@@ -2048,29 +2055,36 @@ func TestListTokensHandler(t *testing.T) {
20482055
found[res.Token] = true
20492056
}
20502057

2051-
// Hitting up the scroll.
2052-
req, err = http.NewRequest("GET", r.Next, nil)
2053-
if err != nil {
2054-
t.Fatal(err)
2055-
}
2058+
// Hitting up the scroll until exhaustion.
2059+
for it := 0; it < 100; it++ {
2060+
req, err = http.NewRequest("GET", r.Next, nil)
2061+
if err != nil {
2062+
t.Fatal(err)
2063+
}
20562064

2057-
handler.ServeHTTP(rr, req)
2058-
if rr.Code != http.StatusOK {
2059-
t.Fatalf("should have succeeded (got %d)", rr.Code)
2060-
}
2065+
handler.ServeHTTP(rr, req)
2066+
if rr.Code != http.StatusOK {
2067+
t.Fatalf("should have succeeded (got %d)", rr.Code)
2068+
}
20612069

2062-
dec = json.NewDecoder(rr.Body)
2063-
err = dec.Decode(&r)
2064-
if err != nil {
2065-
t.Fatal(err)
2066-
}
2067-
if len(r.Results) != 5 {
2068-
t.Errorf("unexpected results; %v", r)
2069-
}
2070+
dec = json.NewDecoder(rr.Body)
2071+
err = dec.Decode(&r)
2072+
if err != nil {
2073+
t.Fatal(err)
2074+
}
20702075

2071-
for _, res := range r.Results {
2072-
if _, ok := found[res.Token]; ok {
2073-
t.Errorf("detected duplicate entries from scroll; %v", res.Token)
2076+
for _, res := range r.Results {
2077+
if _, ok := found[res.Token]; ok {
2078+
t.Errorf("detected duplicate entries from scroll; %v", res.Token)
2079+
}
2080+
}
2081+
2082+
if len(r.Results) != 5 {
2083+
if r.Next != "" {
2084+
t.Errorf("unexpected results; %v", r)
2085+
} else {
2086+
break
2087+
}
20742088
}
20752089
}
20762090
})

0 commit comments

Comments
 (0)