...

Source file src/net/http/httptest/recorder_test.go

Documentation: net/http/httptest

		 1  // Copyright 2012 The Go Authors. All rights reserved.
		 2  // Use of this source code is governed by a BSD-style
		 3  // license that can be found in the LICENSE file.
		 4  
		 5  package httptest
		 6  
		 7  import (
		 8  	"fmt"
		 9  	"io"
		10  	"net/http"
		11  	"testing"
		12  )
		13  
		14  func TestRecorder(t *testing.T) {
		15  	type checkFunc func(*ResponseRecorder) error
		16  	check := func(fns ...checkFunc) []checkFunc { return fns }
		17  
		18  	hasStatus := func(wantCode int) checkFunc {
		19  		return func(rec *ResponseRecorder) error {
		20  			if rec.Code != wantCode {
		21  				return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
		22  			}
		23  			return nil
		24  		}
		25  	}
		26  	hasResultStatus := func(want string) checkFunc {
		27  		return func(rec *ResponseRecorder) error {
		28  			if rec.Result().Status != want {
		29  				return fmt.Errorf("Result().Status = %q; want %q", rec.Result().Status, want)
		30  			}
		31  			return nil
		32  		}
		33  	}
		34  	hasResultStatusCode := func(wantCode int) checkFunc {
		35  		return func(rec *ResponseRecorder) error {
		36  			if rec.Result().StatusCode != wantCode {
		37  				return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
		38  			}
		39  			return nil
		40  		}
		41  	}
		42  	hasResultContents := func(want string) checkFunc {
		43  		return func(rec *ResponseRecorder) error {
		44  			contentBytes, err := io.ReadAll(rec.Result().Body)
		45  			if err != nil {
		46  				return err
		47  			}
		48  			contents := string(contentBytes)
		49  			if contents != want {
		50  				return fmt.Errorf("Result().Body = %s; want %s", contents, want)
		51  			}
		52  			return nil
		53  		}
		54  	}
		55  	hasContents := func(want string) checkFunc {
		56  		return func(rec *ResponseRecorder) error {
		57  			if rec.Body.String() != want {
		58  				return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
		59  			}
		60  			return nil
		61  		}
		62  	}
		63  	hasFlush := func(want bool) checkFunc {
		64  		return func(rec *ResponseRecorder) error {
		65  			if rec.Flushed != want {
		66  				return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
		67  			}
		68  			return nil
		69  		}
		70  	}
		71  	hasOldHeader := func(key, want string) checkFunc {
		72  		return func(rec *ResponseRecorder) error {
		73  			if got := rec.HeaderMap.Get(key); got != want {
		74  				return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
		75  			}
		76  			return nil
		77  		}
		78  	}
		79  	hasHeader := func(key, want string) checkFunc {
		80  		return func(rec *ResponseRecorder) error {
		81  			if got := rec.Result().Header.Get(key); got != want {
		82  				return fmt.Errorf("final header %s = %q; want %q", key, got, want)
		83  			}
		84  			return nil
		85  		}
		86  	}
		87  	hasNotHeaders := func(keys ...string) checkFunc {
		88  		return func(rec *ResponseRecorder) error {
		89  			for _, k := range keys {
		90  				v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
		91  				if ok {
		92  					return fmt.Errorf("unexpected header %s with value %q", k, v)
		93  				}
		94  			}
		95  			return nil
		96  		}
		97  	}
		98  	hasTrailer := func(key, want string) checkFunc {
		99  		return func(rec *ResponseRecorder) error {
	 100  			if got := rec.Result().Trailer.Get(key); got != want {
	 101  				return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
	 102  			}
	 103  			return nil
	 104  		}
	 105  	}
	 106  	hasNotTrailers := func(keys ...string) checkFunc {
	 107  		return func(rec *ResponseRecorder) error {
	 108  			trailers := rec.Result().Trailer
	 109  			for _, k := range keys {
	 110  				_, ok := trailers[http.CanonicalHeaderKey(k)]
	 111  				if ok {
	 112  					return fmt.Errorf("unexpected trailer %s", k)
	 113  				}
	 114  			}
	 115  			return nil
	 116  		}
	 117  	}
	 118  	hasContentLength := func(length int64) checkFunc {
	 119  		return func(rec *ResponseRecorder) error {
	 120  			if got := rec.Result().ContentLength; got != length {
	 121  				return fmt.Errorf("ContentLength = %d; want %d", got, length)
	 122  			}
	 123  			return nil
	 124  		}
	 125  	}
	 126  
	 127  	for _, tt := range [...]struct {
	 128  		name	 string
	 129  		h			func(w http.ResponseWriter, r *http.Request)
	 130  		checks []checkFunc
	 131  	}{
	 132  		{
	 133  			"200 default",
	 134  			func(w http.ResponseWriter, r *http.Request) {},
	 135  			check(hasStatus(200), hasContents("")),
	 136  		},
	 137  		{
	 138  			"first code only",
	 139  			func(w http.ResponseWriter, r *http.Request) {
	 140  				w.WriteHeader(201)
	 141  				w.WriteHeader(202)
	 142  				w.Write([]byte("hi"))
	 143  			},
	 144  			check(hasStatus(201), hasContents("hi")),
	 145  		},
	 146  		{
	 147  			"write sends 200",
	 148  			func(w http.ResponseWriter, r *http.Request) {
	 149  				w.Write([]byte("hi first"))
	 150  				w.WriteHeader(201)
	 151  				w.WriteHeader(202)
	 152  			},
	 153  			check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
	 154  		},
	 155  		{
	 156  			"write string",
	 157  			func(w http.ResponseWriter, r *http.Request) {
	 158  				io.WriteString(w, "hi first")
	 159  			},
	 160  			check(
	 161  				hasStatus(200),
	 162  				hasContents("hi first"),
	 163  				hasFlush(false),
	 164  				hasHeader("Content-Type", "text/plain; charset=utf-8"),
	 165  			),
	 166  		},
	 167  		{
	 168  			"flush",
	 169  			func(w http.ResponseWriter, r *http.Request) {
	 170  				w.(http.Flusher).Flush() // also sends a 200
	 171  				w.WriteHeader(201)
	 172  			},
	 173  			check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
	 174  		},
	 175  		{
	 176  			"Content-Type detection",
	 177  			func(w http.ResponseWriter, r *http.Request) {
	 178  				io.WriteString(w, "<html>")
	 179  			},
	 180  			check(hasHeader("Content-Type", "text/html; charset=utf-8")),
	 181  		},
	 182  		{
	 183  			"no Content-Type detection with Transfer-Encoding",
	 184  			func(w http.ResponseWriter, r *http.Request) {
	 185  				w.Header().Set("Transfer-Encoding", "some encoding")
	 186  				io.WriteString(w, "<html>")
	 187  			},
	 188  			check(hasHeader("Content-Type", "")), // no header
	 189  		},
	 190  		{
	 191  			"no Content-Type detection if set explicitly",
	 192  			func(w http.ResponseWriter, r *http.Request) {
	 193  				w.Header().Set("Content-Type", "some/type")
	 194  				io.WriteString(w, "<html>")
	 195  			},
	 196  			check(hasHeader("Content-Type", "some/type")),
	 197  		},
	 198  		{
	 199  			"Content-Type detection doesn't crash if HeaderMap is nil",
	 200  			func(w http.ResponseWriter, r *http.Request) {
	 201  				// Act as if the user wrote new(httptest.ResponseRecorder)
	 202  				// rather than using NewRecorder (which initializes
	 203  				// HeaderMap)
	 204  				w.(*ResponseRecorder).HeaderMap = nil
	 205  				io.WriteString(w, "<html>")
	 206  			},
	 207  			check(hasHeader("Content-Type", "text/html; charset=utf-8")),
	 208  		},
	 209  		{
	 210  			"Header is not changed after write",
	 211  			func(w http.ResponseWriter, r *http.Request) {
	 212  				hdr := w.Header()
	 213  				hdr.Set("Key", "correct")
	 214  				w.WriteHeader(200)
	 215  				hdr.Set("Key", "incorrect")
	 216  			},
	 217  			check(hasHeader("Key", "correct")),
	 218  		},
	 219  		{
	 220  			"Trailer headers are correctly recorded",
	 221  			func(w http.ResponseWriter, r *http.Request) {
	 222  				w.Header().Set("Non-Trailer", "correct")
	 223  				w.Header().Set("Trailer", "Trailer-A")
	 224  				w.Header().Add("Trailer", "Trailer-B")
	 225  				w.Header().Add("Trailer", "Trailer-C")
	 226  				io.WriteString(w, "<html>")
	 227  				w.Header().Set("Non-Trailer", "incorrect")
	 228  				w.Header().Set("Trailer-A", "valuea")
	 229  				w.Header().Set("Trailer-C", "valuec")
	 230  				w.Header().Set("Trailer-NotDeclared", "should be omitted")
	 231  				w.Header().Set("Trailer:Trailer-D", "with prefix")
	 232  			},
	 233  			check(
	 234  				hasStatus(200),
	 235  				hasHeader("Content-Type", "text/html; charset=utf-8"),
	 236  				hasHeader("Non-Trailer", "correct"),
	 237  				hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"),
	 238  				hasTrailer("Trailer-A", "valuea"),
	 239  				hasTrailer("Trailer-C", "valuec"),
	 240  				hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
	 241  				hasTrailer("Trailer-D", "with prefix"),
	 242  			),
	 243  		},
	 244  		{
	 245  			"Header set without any write", // Issue 15560
	 246  			func(w http.ResponseWriter, r *http.Request) {
	 247  				w.Header().Set("X-Foo", "1")
	 248  
	 249  				// Simulate somebody using
	 250  				// new(ResponseRecorder) instead of
	 251  				// using the constructor which sets
	 252  				// this to 200
	 253  				w.(*ResponseRecorder).Code = 0
	 254  			},
	 255  			check(
	 256  				hasOldHeader("X-Foo", "1"),
	 257  				hasStatus(0),
	 258  				hasHeader("X-Foo", "1"),
	 259  				hasResultStatus("200 OK"),
	 260  				hasResultStatusCode(200),
	 261  			),
	 262  		},
	 263  		{
	 264  			"HeaderMap vs FinalHeaders", // more for Issue 15560
	 265  			func(w http.ResponseWriter, r *http.Request) {
	 266  				h := w.Header()
	 267  				h.Set("X-Foo", "1")
	 268  				w.Write([]byte("hi"))
	 269  				h.Set("X-Foo", "2")
	 270  				h.Set("X-Bar", "2")
	 271  			},
	 272  			check(
	 273  				hasOldHeader("X-Foo", "2"),
	 274  				hasOldHeader("X-Bar", "2"),
	 275  				hasHeader("X-Foo", "1"),
	 276  				hasNotHeaders("X-Bar"),
	 277  			),
	 278  		},
	 279  		{
	 280  			"setting Content-Length header",
	 281  			func(w http.ResponseWriter, r *http.Request) {
	 282  				body := "Some body"
	 283  				contentLength := fmt.Sprintf("%d", len(body))
	 284  				w.Header().Set("Content-Length", contentLength)
	 285  				io.WriteString(w, body)
	 286  			},
	 287  			check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
	 288  		},
	 289  		{
	 290  			"nil ResponseRecorder.Body", // Issue 26642
	 291  			func(w http.ResponseWriter, r *http.Request) {
	 292  				w.(*ResponseRecorder).Body = nil
	 293  				io.WriteString(w, "hi")
	 294  			},
	 295  			check(hasResultContents("")), // check we don't crash reading the body
	 296  
	 297  		},
	 298  	} {
	 299  		t.Run(tt.name, func(t *testing.T) {
	 300  			r, _ := http.NewRequest("GET", "http://foo.com/", nil)
	 301  			h := http.HandlerFunc(tt.h)
	 302  			rec := NewRecorder()
	 303  			h.ServeHTTP(rec, r)
	 304  			for _, check := range tt.checks {
	 305  				if err := check(rec); err != nil {
	 306  					t.Error(err)
	 307  				}
	 308  			}
	 309  		})
	 310  	}
	 311  }
	 312  
	 313  // issue 39017 - disallow Content-Length values such as "+3"
	 314  func TestParseContentLength(t *testing.T) {
	 315  	tests := []struct {
	 316  		cl	 string
	 317  		want int64
	 318  	}{
	 319  		{
	 320  			cl:	 "3",
	 321  			want: 3,
	 322  		},
	 323  		{
	 324  			cl:	 "+3",
	 325  			want: -1,
	 326  		},
	 327  		{
	 328  			cl:	 "-3",
	 329  			want: -1,
	 330  		},
	 331  		{
	 332  			// max int64, for safe conversion before returning
	 333  			cl:	 "9223372036854775807",
	 334  			want: 9223372036854775807,
	 335  		},
	 336  		{
	 337  			cl:	 "9223372036854775808",
	 338  			want: -1,
	 339  		},
	 340  	}
	 341  
	 342  	for _, tt := range tests {
	 343  		if got := parseContentLength(tt.cl); got != tt.want {
	 344  			t.Errorf("%q:\n\tgot=%d\n\twant=%d", tt.cl, got, tt.want)
	 345  		}
	 346  	}
	 347  }
	 348  
	 349  // Ensure that httptest.Recorder panics when given a non-3 digit (XXX)
	 350  // status HTTP code. See https://golang.org/issues/45353
	 351  func TestRecorderPanicsOnNonXXXStatusCode(t *testing.T) {
	 352  	badCodes := []int{
	 353  		-100, 0, 99, 1000, 20000,
	 354  	}
	 355  	for _, badCode := range badCodes {
	 356  		badCode := badCode
	 357  		t.Run(fmt.Sprintf("Code=%d", badCode), func(t *testing.T) {
	 358  			defer func() {
	 359  				if r := recover(); r == nil {
	 360  					t.Fatal("Expected a panic")
	 361  				}
	 362  			}()
	 363  
	 364  			handler := func(rw http.ResponseWriter, _ *http.Request) {
	 365  				rw.WriteHeader(badCode)
	 366  			}
	 367  			r, _ := http.NewRequest("GET", "http://example.org/", nil)
	 368  			rw := NewRecorder()
	 369  			handler(rw, r)
	 370  		})
	 371  	}
	 372  }
	 373  

View as plain text