use testify asserts

This commit is contained in:
Jay
2025-10-23 13:20:55 -04:00
parent 417c97f168
commit e0c669351c
11 changed files with 142 additions and 294 deletions

View File

@@ -2,8 +2,7 @@ package roots
import (
"encoding/json"
//"fmt"
"reflect"
"github.com/stretchr/testify/assert"
"testing"
)
@@ -586,17 +585,15 @@ func TestFilterMarshalJSON(t *testing.T) {
for _, tc := range marshalTestCases {
t.Run(tc.name, func(t *testing.T) {
result, err := tc.filter.MarshalJSON()
expectOk(t, err)
assert.NoError(t, err)
var expectedMap, resultMap map[string]interface{}
var expectedMap, actualMap map[string]interface{}
err = json.Unmarshal([]byte(tc.expected), &expectedMap)
expectOk(t, err)
err = json.Unmarshal(result, &resultMap)
expectOk(t, err)
assert.NoError(t, err)
err = json.Unmarshal(result, &actualMap)
assert.NoError(t, err)
if !reflect.DeepEqual(expectedMap, resultMap) {
t.Errorf("marshal mismatch: got %s, want %s", result, tc.expected)
}
assert.Equal(t, expectedMap, actualMap)
})
}
}
@@ -606,7 +603,7 @@ func TestFilterUnmarshalJSON(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
var result Filter
err := result.UnmarshalJSON([]byte(tc.input))
expectOk(t, err)
assert.NoError(t, err)
expectEqualFilters(t, result, tc.expected)
})
@@ -617,11 +614,11 @@ func TestFilterRoundTrip(t *testing.T) {
for _, tc := range roundTripTestCases {
t.Run(tc.name, func(t *testing.T) {
jsonBytes, err := tc.filter.MarshalJSON()
expectOk(t, err)
assert.NoError(t, err)
var result Filter
err = result.UnmarshalJSON(jsonBytes)
expectOk(t, err)
assert.NoError(t, err)
expectEqualFilters(t, result, tc.filter)
})
@@ -632,120 +629,29 @@ func TestFilterRoundTrip(t *testing.T) {
// Helpers
func expectEqualFilters(t *testing.T, got, want Filter) {
// Compare IDs
if got.IDs == nil && want.IDs == nil {
// pass
} else if got.IDs == nil || want.IDs == nil {
t.Errorf("mismatched ids: got %v, want %v", got.IDs, want.IDs)
} else {
expectEqualStringSlices(t, got.IDs, want.IDs)
}
assert.Equal(t, want.IDs, got.IDs)
assert.Equal(t, want.Authors, got.Authors)
assert.Equal(t, want.Kinds, got.Kinds)
assert.Equal(t, want.Since, got.Since)
assert.Equal(t, want.Until, got.Until)
assert.Equal(t, want.Limit, got.Limit)
assert.Equal(t, want.Tags, got.Tags)
// Compare Authors
if got.Authors == nil && want.Authors == nil {
// pass
} else if got.Authors == nil || want.Authors == nil {
t.Errorf("mismatched authors: got %v, want %v", got.Authors, want.Authors)
} else {
expectEqualStringSlices(t, got.Authors, want.Authors)
if want.Extensions == nil && got.Extensions == nil {
return
}
assert.NotNil(t, got.Extensions)
assert.NotNil(t, want.Extensions)
// Compare Kinds
if got.Kinds == nil && want.Kinds == nil {
// pass
} else if got.Kinds == nil || want.Kinds == nil {
t.Errorf("mismatched kinds: got %v, want %v", got.Kinds, want.Kinds)
} else {
expectEqualIntSlices(t, got.Kinds, want.Kinds)
}
assert.Equal(t, len(want.Extensions), len(got.Extensions))
for key, wantValue := range want.Extensions {
gotValue, ok := got.Extensions[key]
assert.True(t, ok, "expected key %s", key)
// Compare Timestamps
if got.Since == nil && want.Since == nil {
// pass
} else if got.Since == nil || want.Since == nil {
t.Errorf("mismatched since pointers: got %v, want %v", got.Since, want.Since)
} else {
expectEqualIntPointers(t, got.Since, want.Since)
}
if got.Until == nil && want.Until == nil {
// pass
} else if got.Until == nil || want.Until == nil {
t.Errorf("mismatched until pointers: got %v, want %v", got.Until, want.Until)
} else {
expectEqualIntPointers(t, got.Until, want.Until)
}
// Compare Limit
if got.Limit == nil && want.Limit == nil {
// pass
} else if got.Limit == nil || want.Limit == nil {
t.Errorf("mismatched limit pointers: got %v, want %v", got.Limit, want.Limit)
} else {
expectEqualIntPointers(t, got.Limit, want.Limit)
}
// Compare Tags
if got.Tags == nil && want.Tags == nil {
// pass
} else if got.Tags == nil || want.Tags == nil {
t.Errorf("mismatched tags: got %v, want %v", got.Tags, want.Tags)
} else {
expectEqualTags(t, got.Tags, want.Tags)
}
// Compare Extensions
if got.Extensions == nil && want.Extensions == nil {
// pass
} else if got.Extensions == nil || want.Extensions == nil {
t.Errorf("mismatched extensions: got %v, want %v", got.Extensions, want.Extensions)
} else {
expectEqualExtensions(t, got.Extensions, want.Extensions)
}
}
func expectEqualTags(t *testing.T, got, want map[string][]string) {
if len(got) != len(want) {
t.Errorf("length mismatch: got %d, want %d", len(got), len(want))
}
for key, wantValues := range want {
gotValues, exists := got[key]
if !exists {
t.Fatalf("expected key %q to exist", key)
}
if len(wantValues) != len(gotValues) {
t.Errorf(
"key %q: length mismatch: got %d, want %d",
key, len(gotValues), len(wantValues))
}
for i := range wantValues {
if gotValues[i] != wantValues[i] {
t.Errorf(
"key %q: index %d: got %s, want %s",
key, i, gotValues[i], wantValues[i])
}
}
}
}
func expectEqualExtensions(t *testing.T, got, want map[string]json.RawMessage) {
if len(got) != len(want) {
t.Errorf("length mismatch: got %d, want %d", len(got), len(want))
}
for key, wantValue := range want {
gotValue, ok := got[key]
if !ok {
t.Errorf("expected key %s, got nil", key)
}
var gotJSON, wantJSON interface{}
if err := json.Unmarshal(wantValue, &wantJSON); err != nil {
t.Errorf("key %q: failed to unmarshal 'want' value: %s", key, wantValue)
}
if err := json.Unmarshal(gotValue, &gotJSON); err != nil {
t.Errorf("key %q: failed to unmarshal 'got' value: %s", key, wantValue)
}
if !reflect.DeepEqual(gotJSON, wantJSON) {
t.Errorf("key %q: got %s, want %s", key, gotValue, wantValue)
}
assert.NoError(t, json.Unmarshal(wantValue, &wantJSON))
assert.NoError(t, json.Unmarshal(gotValue, &gotJSON))
assert.Equal(t, wantJSON, gotJSON)
}
}