Skip to content
Snippets Groups Projects
Commit fcc9d9b5 authored by Jens Drenhaus's avatar Jens Drenhaus
Browse files

Add trust.NewPolicy()

parent 2cae1098
Branches
Tags
No related merge requests found
......@@ -16,6 +16,20 @@ type Policy struct {
FetchMethod ospkg.FetchMethod `json:"ospkg_fetch_method"`
}
// NewPolicy creates a Policy from template.
// If the template is not a valid Policy, the returned error wrapps ErrInvalidPolicy.
func NewPolicy(template Policy) (Policy, error) {
var ret Policy
if err := template.validate(); err != nil {
return ret, fmt.Errorf("%w: %v", ErrInvalidPolicy, err)
}
ret.SignatureThreshold = template.SignatureThreshold
ret.FetchMethod = template.FetchMethod
return ret, nil
}
// policy is used as an alias in Policy.UnmarshalJSON.
type policy struct {
SignatureThreshold int `json:"ospkg_signature_threshold"`
......
......@@ -9,6 +9,93 @@ import (
"system-transparency.org/stboot/ospkg"
)
func TestPolicyNew(t *testing.T) {
validtests := []struct {
name string
template Policy
want Policy
}{
{
name: "All set",
template: Policy{
SignatureThreshold: 1,
FetchMethod: ospkg.FetchFromNetwork,
},
want: Policy{
SignatureThreshold: 1,
FetchMethod: ospkg.FetchFromNetwork,
},
},
}
invalidtests := []struct {
name string
template Policy
}{
{
name: "Empty Policy",
template: Policy{},
},
{
name: "SignaturesThreshold missing",
template: Policy{
FetchMethod: ospkg.FetchFromNetwork,
},
},
{
name: "SignaturesThreshold 0",
template: Policy{
SignatureThreshold: 0,
FetchMethod: ospkg.FetchFromNetwork,
},
},
{
name: "SignaturesThreshold negative",
template: Policy{
SignatureThreshold: -1,
FetchMethod: ospkg.FetchFromNetwork,
},
},
{
name: "FetchMethod missing",
template: Policy{
SignatureThreshold: 1,
},
},
{
name: "FetchMethod unknown",
template: Policy{
SignatureThreshold: 1,
FetchMethod: 100,
},
},
}
for _, tt := range validtests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewPolicy(tt.template)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("got %+v, want %+v", got, tt.want)
}
})
}
for _, tt := range invalidtests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewPolicy(tt.template)
if err == nil {
t.Fatalf("expect an error")
}
if !errors.Is(err, ErrInvalidPolicy) {
t.Errorf("expect error to wrap ErrInvalidPolicy")
}
})
}
}
func TestPolicyUnmarshalJSON(t *testing.T) {
validtests := []struct {
name string
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment