Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions cmd/testserver/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package main

import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)

func TestHelloHandler(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/hello", nil)
rec := httptest.NewRecorder()
hello(rec, req)

res := rec.Result()
if res.StatusCode != http.StatusOK {
t.Errorf("status = %d; want 200", res.StatusCode)
}
if ct := res.Header.Get("Content-Type"); !strings.HasPrefix(ct, "text/html") {
t.Errorf("content-type = %q; want text/html", ct)
}
if !strings.Contains(rec.Body.String(), "Hello World") {
t.Errorf("body = %q; want it to contain 'Hello World'", rec.Body.String())
}
}

func TestHelloJSONHandler(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/hello_in_json", nil)
rec := httptest.NewRecorder()
helloJSON(rec, req)

res := rec.Result()
if ct := res.Header.Get("Content-Type"); ct != "application/json" {
t.Errorf("content-type = %q; want application/json", ct)
}
if got := strings.TrimSpace(rec.Body.String()); got != `{"msg":"hello world"}` {
t.Errorf("body = %q", got)
}
}
88 changes: 88 additions & 0 deletions color_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package main

import (
"os"
"strings"
"testing"
)

func TestColorEnabledHonorsNoColor(t *testing.T) {
t.Setenv("NO_COLOR", "1")
if colorEnabled(os.Stdout) {
t.Error("colorEnabled should be false when NO_COLOR is set")
}
}

func TestColorEnabledHonorsDumbTerm(t *testing.T) {
t.Setenv("NO_COLOR", "")
t.Setenv("TERM", "dumb")
if colorEnabled(os.Stdout) {
t.Error("colorEnabled should be false when TERM=dumb")
}
}

func TestIsTTYFalseForPipe(t *testing.T) {
r, w, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
defer r.Close()
defer w.Close()
if isTTY(w) {
t.Error("a pipe should not be reported as a TTY")
}
}

func TestNoColorIsIdentity(t *testing.T) {
if got := noColor("hello", cRed, cBold); got != "hello" {
t.Errorf("noColor = %q; want %q", got, "hello")
}
}

func TestPlainBreakdownsHaveNoANSI(t *testing.T) {
c := populatedReportCounter(t)

status := c.statusBreakdownPlain()
for _, code := range []string{"200", "404", "500"} {
if !strings.Contains(status, code) {
t.Errorf("plain status breakdown missing %s:\n%s", code, status)
}
}
if strings.Contains(status, "\x1b[") {
t.Errorf("plain status breakdown should have no ANSI codes:\n%q", status)
}

netErrs := c.netErrBreakdownPlain()
if !strings.Contains(netErrs, "timeout") {
t.Errorf("plain net-err breakdown missing 'timeout':\n%s", netErrs)
}
if strings.Contains(netErrs, "\x1b[") {
t.Errorf("plain net-err breakdown should have no ANSI codes:\n%q", netErrs)
}
}

func TestStatusColorRanges(t *testing.T) {
cases := map[int]string{
204: cGreen,
301: cCyan,
404: cYellow,
503: cRed,
100: cDim,
}
for code, want := range cases {
if got := statusColor(code); got != want {
t.Errorf("statusColor(%d) = %q; want %q", code, got, want)
}
}
}

func TestStatusBreakdownNoneWhenEmpty(t *testing.T) {
prof := loadOneCallProfile(t, "https://example.test/", "GET", "", "")
c := newTestCounter(t, prof, 0, 0)
if got := c.statusBreakdownPlain(); !strings.Contains(got, "(none)") {
t.Errorf("empty status breakdown = %q; want it to contain '(none)'", got)
}
if got := c.netErrBreakdownPlain(); got != "" {
t.Errorf("empty net-err breakdown = %q; want empty string", got)
}
}
212 changes: 212 additions & 0 deletions run_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
package main

import (
"bytes"
"encoding/json"
"flag"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
)

// runMain invokes run() with a fresh flag set and os.Args, capturing stdout and
// stderr. It restores all global state on cleanup so tests don't interfere.
func runMain(t *testing.T, args ...string) (code int, stdout, stderr string) {
t.Helper()
origArgs := os.Args
origCL := flag.CommandLine
origUsage := flag.Usage
origStdout := os.Stdout
origStderr := os.Stderr
t.Cleanup(func() {
os.Args = origArgs
flag.CommandLine = origCL
flag.Usage = origUsage
os.Stdout = origStdout
os.Stderr = origStderr
})

flag.CommandLine = flag.NewFlagSet("hammer", flag.ContinueOnError)
flag.CommandLine.SetOutput(io.Discard)
os.Args = append([]string{"hammer"}, args...)

outR, outW, _ := os.Pipe()
errR, errW, _ := os.Pipe()
os.Stdout = outW
os.Stderr = errW

outCh := make(chan string, 1)
errCh := make(chan string, 1)
go func() { var b bytes.Buffer; _, _ = io.Copy(&b, outR); outCh <- b.String() }()
go func() { var b bytes.Buffer; _, _ = io.Copy(&b, errR); errCh <- b.String() }()

code = run()

_ = outW.Close()
_ = errW.Close()
stdout = <-outCh
stderr = <-errCh
_ = outR.Close()
_ = errR.Close()
return code, stdout, stderr
}

func TestRun_versionFlag(t *testing.T) {
code, stdout, _ := runMain(t, "-version")
if code != exitOK {
t.Errorf("exit = %d; want %d", code, exitOK)
}
if !strings.Contains(stdout, "hammer ") {
t.Errorf("stdout = %q; want it to contain 'hammer '", stdout)
}
}

func TestRun_happyPathJSON(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

code, stdout, _ := runMain(t, "-url", srv.URL, "-rps", "200", "-duration", "300ms",
"-quiet", "-output", "json")
if code != exitOK {
t.Fatalf("exit = %d; want %d", code, exitOK)
}
var rep JSONReport
if err := json.Unmarshal([]byte(stdout), &rep); err != nil {
t.Fatalf("stdout is not valid JSON: %v\n%s", err, stdout)
}
if rep.Received == 0 {
t.Errorf("expected some received responses, got %d", rep.Received)
}
if rep.TargetRPS != 200 {
t.Errorf("target_rps = %d; want 200", rep.TargetRPS)
}
}

func TestRun_sloViolationExits1(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer srv.Close()

code, stdout, _ := runMain(t, "-url", srv.URL, "-rps", "200", "-duration", "300ms",
"-quiet", "-output", "json", "-max-error-rate", "0")
if code != exitChecks {
t.Fatalf("exit = %d; want %d (SLO violation)", code, exitChecks)
}
var rep JSONReport
if err := json.Unmarshal([]byte(stdout), &rep); err != nil {
t.Fatalf("invalid JSON: %v", err)
}
if rep.Checks == nil || rep.Checks.Passed {
t.Errorf("expected failed checks, got %+v", rep.Checks)
}
}

func TestRun_sloPassExits0(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

code, _, _ := runMain(t, "-url", srv.URL, "-rps", "200", "-duration", "300ms",
"-quiet", "-output", "json", "-max-error-rate", "0.5")
if code != exitOK {
t.Errorf("exit = %d; want %d (SLO satisfied)", code, exitOK)
}
}

func TestRun_jsonOutFile(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

out := filepath.Join(t.TempDir(), "report.json")
code, _, _ := runMain(t, "-url", srv.URL, "-rps", "100", "-duration", "250ms",
"-quiet", "-json-out", out)
if code != exitOK {
t.Fatalf("exit = %d; want %d", code, exitOK)
}
data, err := os.ReadFile(out)
if err != nil {
t.Fatalf("read json-out: %v", err)
}
var rep JSONReport
if err := json.Unmarshal(data, &rep); err != nil {
t.Fatalf("json-out file invalid: %v", err)
}
}

func TestRun_jsonOutUnwritableExitsSetup(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

bad := filepath.Join(t.TempDir(), "no-such-dir", "report.json")
code, _, _ := runMain(t, "-url", srv.URL, "-rps", "100", "-duration", "200ms",
"-quiet", "-json-out", bad)
if code != exitSetup {
t.Errorf("exit = %d; want %d", code, exitSetup)
}
}

func TestRun_validationErrors(t *testing.T) {
cases := []struct {
name string
args []string
}{
{"no target", []string{"-rps", "10"}},
{"both url and profile", []string{"-url", "http://x", "-profile", "p.json"}},
{"rps zero", []string{"-url", "http://x", "-rps", "0"}},
{"bad output", []string{"-url", "http://x", "-output", "yaml"}},
{"error rate over one", []string{"-url", "http://x", "-max-error-rate", "2"}},
{"bad ok codes", []string{"-url", "http://x", "-ok", "999"}},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
code, _, _ := runMain(t, c.args...)
if code != exitUsage {
t.Errorf("exit = %d; want %d", code, exitUsage)
}
})
}
}

func TestRun_missingProfileExitsSetup(t *testing.T) {
code, _, _ := runMain(t, "-profile", filepath.Join(t.TempDir(), "missing.json"),
"-duration", "100ms", "-quiet")
if code != exitSetup {
t.Errorf("exit = %d; want %d", code, exitSetup)
}
}

func TestRun_invalidProxyExitsSetup(t *testing.T) {
code, _, _ := runMain(t, "-url", "http://x", "-proxy", "://bad", "-duration", "100ms", "-quiet")
if code != exitSetup {
t.Errorf("exit = %d; want %d", code, exitSetup)
}
}

func TestRun_statsEndpointServesLiveCounts(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

// Bind the stats server to an ephemeral port and scrape it mid-run.
code, _, _ := runMain(t, "-url", srv.URL, "-rps", "100", "-duration", "300ms",
"-quiet", "-output", "json", "-stats-addr", "127.0.0.1:0")
// We can't easily know the chosen port without parsing logs (suppressed by
// -quiet), so this test just asserts the run still succeeds with the stats
// server enabled.
if code != exitOK {
t.Errorf("exit = %d; want %d", code, exitOK)
}
}
14 changes: 11 additions & 3 deletions update.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ import (
// be overridden with HAMMER_REPO so forks can self-update too.
const defaultUpdateRepo = "chenchaoyi/hammer"

// apiBaseURL and downloadBaseURL are the GitHub endpoints the updater talks to.
// They are package vars (not consts) only so tests can redirect them at a local
// httptest server; production always uses the real GitHub hosts.
var (
apiBaseURL = "https://api.github.com"
downloadBaseURL = "https://github.com"
)

// mirrorProxies mirrors install.sh: when GitHub is unreachable (common in
// mainland China), each proxy is prepended to the GitHub URL as a fallback.
var mirrorProxies = []string{
Expand Down Expand Up @@ -156,7 +164,7 @@ func runUpdate(args []string) int {

// --- Download the archive + checksums ------------------------------
assetName, isZip := assetNameFor(runtime.GOOS, runtime.GOARCH)
base := fmt.Sprintf("https://github.com/%s/releases/download/%s", repo, latest)
base := fmt.Sprintf("%s/%s/releases/download/%s", downloadBaseURL, repo, latest)

fmt.Fprintf(os.Stderr, "Downloading %s\n", po(assetName, cCyan))
archive, err := downloadWithFallback(ctx, client, base+"/"+assetName, opt.mirrorMode)
Expand Down Expand Up @@ -206,12 +214,12 @@ type ghRelease struct {
}

func fetchLatestRelease(ctx context.Context, client *http.Client, repo string) (*ghRelease, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo)
url := fmt.Sprintf("%s/repos/%s/releases/latest", apiBaseURL, repo)
return getRelease(ctx, client, url)
}

func fetchReleaseByTag(ctx context.Context, client *http.Client, repo, tag string) (*ghRelease, error) {
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/tags/%s", repo, tag)
url := fmt.Sprintf("%s/repos/%s/releases/tags/%s", apiBaseURL, repo, tag)
return getRelease(ctx, client, url)
}

Expand Down
Loading