ponzi2

Testing Templates: Table-driven Tests

By Brian Muramatsu

2019-04-14

Overview

Table-driven tests are a type of test where "[e]ach table entry is a complete test case with inputs and expected results, and sometimes with additional information such as a test name to make the test output easily readable." This type of test "...simply iterates through all table entries and for each entry performs the necessary tests... [so, the] test code is written once and amortized over all table entries".

Since it is easy to add table entries, I find myself giggling with glee as I add a new test case. However, I can't help but groan everytime I have to write that anonymous struct of fields and format-specifier-ladden run loop. So, I've included a couple of templates here that I can copy and paste into my own code to eliminate that painful initial setup.

These test templates try to follow best practices:

What was sacrificed to make these templates copy and pastable:

Template: Return Value and Error

import (
    "testing"

    "github.com/google/go-cmp/cmp"
)

func TestMethodName(t *testing.T) {
    for _, tt := range []struct {
        desc    string
        input   struct{} // Replace with input type.
        want    struct{} // Replace with want type.
        wantErr bool
    }{
        {
            input:   struct{}{}, // Specify input.
            want:    struct{}{}, // Specify want.
            wantErr: true,
        },
    } {
        t.Run(tt.desc, func(t *testing.T) {
            got, gotErr := func(struct{}) (struct{}, error) { /* Replace with function under test. */ return struct{}{}, nil }(tt.input)

            if diff := cmp.Diff(tt.want, got); diff != "" {
                t.Errorf("diff (-want, +got)\n%s", diff)
            }

            if (gotErr != nil) != tt.wantErr {
                t.Errorf("got error: %v, wanted err: %t", gotErr, tt.wantErr)
            }
        })
    }
}
    

Example:

func TestModelOneDayChart(t *testing.T) {
    for _, tt := range []struct {
        desc    string
        input   *iex.Stock
        want    *model.Chart
        wantErr bool
    }{
        {
            input: &iex.Stock{
                Quote: &iex.Quote{CompanyName: "Apple Inc."},
                Chart: []*iex.ChartPoint{
                    {
                        Date:   time.Date(2018, time.September, 18, 15, 57, 0, 0, time.UTC),
                        Open:   218.44,
                        High:   218.49,
                        Low:    218.37,
                        Close:  218.49,
                        Volume: 2607,
                    },
                },
            },
            want: &model.Chart{
                Quote: &model.Quote{CompanyName: "Apple Inc."},
                Range: model.OneDay,
                TradingSessionSeries: &model.TradingSessionSeries{
                    TradingSessions: []*model.TradingSession{
                        {
                            Date:   time.Date(2018, time.September, 18, 15, 57, 0, 0, time.UTC),
                            Open:   218.44,
                            High:   218.49,
                            Low:    218.37,
                            Close:  218.49,
                            Volume: 2607,
                        },
                    },
                },
            },
        },
    } {
        t.Run(tt.desc, func(t *testing.T) {
            got, gotErr := modelOneDayChart(tt.input)

            if diff := cmp.Diff(tt.want, got); diff != "" {
                t.Errorf("diff (-want, +got)\n%s", diff)
            }

            if (gotErr != nil) != tt.wantErr {
                t.Errorf("got error: %v, wanted err: %t", gotErr, tt.wantErr)
            }
        })
    }
}
    

Template: Return Error

import (
    "testing"

    "github.com/google/go-cmp/cmp"
)

func TestMethodName(t *testing.T) {
    for _, tt := range []struct {
        desc    string
        input   struct{} // Replace with input type.
        wantErr bool
    }{
        {
            input:   struct{}{}, // Specify input.
            wantErr: true,
        },
    } {
        t.Run(tt.desc, func(t *testing.T) {
            gotErr := func(struct{}) error { /* Replace with function under test. */ return nil }(tt.input)

            if (gotErr != nil) != tt.wantErr {
                t.Errorf("got error: %v, wanted err: %t", gotErr, tt.wantErr)
            }
        })
    }
}
    

Example:

func TestValidateSymbol(t *testing.T) {
    for _, tt := range []struct {
        desc    string
        input   string
        wantErr bool
    }{
        {
            desc:  "valid three letter symbol",
            input: "SPY",
        },
        {
            desc:  "valid four letter symbol",
            input: "QQQQ",
        },
        {
            desc:    "lowercase not allowed",
            input:   "spy",
            wantErr: true,
        },
        {
            desc:    "spaces not allowed",
            input:   "S P Y",
            wantErr: true,
        },
        {
            desc:    "too long",
            input:   "SPYSPY",
            wantErr: true,
        },
        {
            desc:    "empty string not allowed",
            input:   "",
            wantErr: true,
        },
    } {
        t.Run(tt.desc, func(t *testing.T) {
            gotErr := ValidateSymbol(tt.input)

            if (gotErr != nil) != tt.wantErr {
                t.Errorf("got error: %v, wanted err: %t", gotErr, tt.wantErr)
            }
        })
    }
}
    

Support

Donate to keep development chugging (0 commits) or just star the repository like 0 others!

bitcoin:38vo2oWYmqBUXCxL3avpueye6dPRahX7gC

© 2020 Brian Muramatsu