Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AuthN support using Nested Cred #30

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions pkg/dataplane/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ var (
errInvalidRequest = errors.New("invalid request")
errNilMSI = errors.New("expected non-nil user-assigned managed identity")
errNumberOfMSIs = errors.New("returned MSIs does not match number of requested MSIs")
errGetNestedCreds = errors.New("failed to get nested credentials object")
)

// TODO - Add parameter to specify module name in azcore.NewClient()
Expand All @@ -73,7 +74,7 @@ func NewClient(cloud string, authenticator policy.Policy, clientOpts *policy.Cli
return &ManagedIdentityClient{swaggerClient: swaggerClient, cloud: cloud}, nil
}

func (c *ManagedIdentityClient) GetUserAssignedIdentities(ctx context.Context, request UserAssignedMSIRequest) (*UserAssignedIdentities, error) {
func (c *ManagedIdentityClient) getUserAssignedIdentities(ctx context.Context, request UserAssignedMSIRequest) (*CredentialsObject, error) {
validate := validator.New(validator.WithRequiredStructEnabled())
validate.RegisterValidation(resourceIDsTag, validateResourceIDs)
if err := validate.Struct(request); err != nil {
Expand Down Expand Up @@ -107,7 +108,25 @@ func (c *ManagedIdentityClient) GetUserAssignedIdentities(ctx context.Context, r
}

credentialsObject := CredentialsObject{CredentialsObject: creds.CredentialsObject}
return NewUserAssignedIdentities(credentialsObject, c.cloud)
return NewCredentialsObjectUAIdentities(credentialsObject, c.cloud)
}

func (c *ManagedIdentityClient) GetCredentialsObjectUserAssignedIdentities(ctx context.Context, request UserAssignedMSIRequest) (*CredentialsObject, error) {
return c.getUserAssignedIdentities(ctx, request)
}

func (c *ManagedIdentityClient) GetNestedCredentialsObjectUserAssignedIdentities(ctx context.Context, request UserAssignedMSIRequest) (*NestedCredentialsObject, error) {
ua, err := c.getUserAssignedIdentities(ctx, request)
if err != nil {
return nil, err
}

if len(ua.CredentialsObject.ExplicitIdentities) == 0 ||
ua.CredentialsObject.ExplicitIdentities[0] == nil {
return nil, errGetNestedCreds
}

return NewNestedCredentialsObjectUAIdentities(*ua.CredentialsObject.ExplicitIdentities[0], c.cloud)
}

func validateResourceIDs(fl validator.FieldLevel) bool {
Expand Down
2 changes: 1 addition & 1 deletion pkg/dataplane/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func TestGetUserAssignedIdentities(t *testing.T) {
tc.goMockCall(swaggerClient)

msiClient := &ManagedIdentityClient{swaggerClient: swaggerClient}
if _, err := msiClient.GetUserAssignedIdentities(context.Background(), tc.request); !errors.Is(err, tc.expectedErr) {
if _, err := msiClient.GetCredentialsObjectUserAssignedIdentities(context.Background(), tc.request); !errors.Is(err, tc.expectedErr) {
t.Errorf("expected error: `%s` but got: `%s`", tc.expectedErr, err)
}
})
Expand Down
48 changes: 39 additions & 9 deletions pkg/dataplane/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,27 @@ var (
// swagger.Credentials object can represent either system or user-assigned managed identity
type CredentialsObject struct {
swagger.CredentialsObject
cloud string
}

type UserAssignedIdentities struct {
CredentialsObject
// NestedCredentialsObject is a wrapper around the swagger.NestedCredentialsObject to add additional functionality
// swagger.NestedCredentials object can represent only user-assigned managed identity
type NestedCredentialsObject struct {
swagger.NestedCredentialsObject
cloud string
}

// Constructor for UserAssignedIdentities object
func NewUserAssignedIdentities(c CredentialsObject, cloud string) (*UserAssignedIdentities, error) {
// Constructor for Credentials Object UserAssignedIdentities
func NewCredentialsObjectUAIdentities(c CredentialsObject, cloud string) (*CredentialsObject, error) {
if !c.IsUserAssigned() {
return nil, errNoUserAssignedMSIs
}
return &UserAssignedIdentities{CredentialsObject: c, cloud: cloud}, nil
return &CredentialsObject{CredentialsObject: c.CredentialsObject, cloud: cloud}, nil
}

// Constructor for Nested Credentials Object UserAssignedIdentities
func NewNestedCredentialsObjectUAIdentities(c swagger.NestedCredentialsObject, cloud string) (*NestedCredentialsObject, error) {
return &NestedCredentialsObject{NestedCredentialsObject: c, cloud: cloud}, nil

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised to see an exported member field set by a constructor - either omit the constructor and export all the members or hide them and close over object initialization with the constructor.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to omit the constructor and export all the members.

}

// This method may be used by clients to check if they can use the object as a user-assigned managed identity
Expand All @@ -48,30 +56,52 @@ func (c CredentialsObject) IsUserAssigned() bool {
return len(c.ExplicitIdentities) > 0
}

// Get an AzIdentity credential for the given user-assigned identity resource ID
// Get an AzIdentity credential for the given credential object user-assigned identity resource ID
// Clients can use the credential to get a token for the user-assigned identity
func (u UserAssignedIdentities) GetCredential(requestedResourceID string) (*azidentity.ClientCertificateCredential, error) {
func (c CredentialsObject) GetCredential(requestedResourceID string) (*azidentity.ClientCertificateCredential, error) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's go ahead and implement this interface instead:

// TokenCredential represents a credential capable of providing an OAuth token.
// Exported as azcore.TokenCredential.
type TokenCredential interface {
	// GetToken requests an access token for the specified set of scopes.
	GetToken(ctx context.Context, options TokenRequestOptions) (AccessToken, error)
}

We know that reloading the value from disk is critically important and every client will need to do it. A method off of a credentials object that returns a static *azidentity.ClientCertificateCredential is insufficient.

Copy link

@stevekuznetsov stevekuznetsov Jan 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(The higher-level construct should take a path as input and can use the code you have here as the implementation for loading the credential. Look at the azidentity.WorkloadIdentityCredential as prior art here. )

Copy link
Author

@gouthamMN gouthamMN Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I already see an implementation of GetToken() func for Object azidentity.ClientCertificateCredential under azidentity.client_certificate_credential. Client could directly use that to get token? The implementation logic is exactly similar to azidentity.WorkloadIdentityCredential. Doesn't this suffice the ask?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We chatted about this out-of-band but just for posterity - no, it doesn't, since the code here is loading the credential once, and what clients will need is to refresh it from disk when it changes.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To recap requirements after talking with HyperShift and CS. It would be useful to have the following functions:

  • One to support authenticating as the NestedCredential directly. This shouldn't need to take any arguments and it should already have the NestedCredential embedded so just authenticate using the information. This doesn't need to refresh credentials at all as the auth will be short-lived. This will be used by Cluster Service at points during installation and updates by fetching the information directly from key vault and authenticating as the identity, hence it does not need to be long-lived.
  • One to support authenticating as a NestedCredential using a file on the file system. This will take a path to the file that's a json formatted NestedCredential object. It needs to be able to refresh the credential should the contents of the file change on disk. This will be used by HyperShift control plane components which use a secretsproviderclass to mount the contents of the key vault nestedcredential object on disk, hence reloading is necessary.

If you want to separate the implementation of these out into PRs that's fine, and if so I'd suggest working on the one for CS first as CS will leverage this functionality first.

cc @miguelsorianod and @bryan-cox

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this PR implements the first of the two you listed, so we can merge it in and tackle the second next. @gouthamMN does that sound like a good plan to you?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stevekuznetsov - Yes, it's better to have the 2nd functionality implemented in a separate.

requestedARMResourceID, err := arm.ParseResourceID(requestedResourceID)
if err != nil {
return nil, fmt.Errorf("%w for requested resource ID %s: %w", errParseResourceID, requestedResourceID, err)
}
requestedResourceID = requestedARMResourceID.String()

for _, id := range u.ExplicitIdentities {
for _, id := range c.ExplicitIdentities {
if id != nil && id.ResourceID != nil {
idARMResourceID, err := arm.ParseResourceID(*id.ResourceID)
if err != nil {
return nil, fmt.Errorf("%w for identity resource ID %s: %w", errParseResourceID, *id.ResourceID, err)
}
if requestedResourceID == idARMResourceID.String() {
return getClientCertificateCredential(*id, u.cloud)
return getClientCertificateCredential(*id, c.cloud)
}
}
}

return nil, errResourceIDNotFound
}

// Get an AzIdentity credential for the given nested credential object user-assigned identity resource ID
// Clients can use the credential to get a token for the user-assigned identity
func (n NestedCredentialsObject) GetCredential(requestedResourceID string) (*azidentity.ClientCertificateCredential, error) {
requestedARMResourceID, err := arm.ParseResourceID(requestedResourceID)
if err != nil {
return nil, fmt.Errorf("%w for requested resource ID %s: %w", errParseResourceID, requestedResourceID, err)
}
requestedResourceID = requestedARMResourceID.String()

if n.ResourceID != nil {
idARMResourceID, err := arm.ParseResourceID(*n.ResourceID)
if err != nil {
return nil, fmt.Errorf("%w for identity resource ID %s: %w", errParseResourceID, *n.ResourceID, err)
}
if requestedResourceID == idARMResourceID.String() {
return getClientCertificateCredential(n.NestedCredentialsObject, n.cloud)
}
}

return nil, errResourceIDNotFound
}

func getAzCoreCloud(cloud string) azcloud.Configuration {
switch cloud {
case AzureUSGovCloud:
Expand Down
70 changes: 30 additions & 40 deletions pkg/dataplane/identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestNewUserAssignedIdentities(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if _, err := NewUserAssignedIdentities(tc.c, test.Bogus); !errors.Is(err, tc.expectedErr) {
if _, err := NewCredentialsObjectUAIdentities(tc.c, test.Bogus); !errors.Is(err, tc.expectedErr) {
t.Errorf("expected error: `%s` but got: `%s`", tc.expectedErr, err)
}
})
Expand All @@ -101,39 +101,35 @@ func TestGetCredential(t *testing.T) {
validIdentity.AuthenticationEndpoint = test.StringPtr(test.ValidAuthenticationEndpoint)

testCases := []struct {
name string
uaIdentities UserAssignedIdentities
resourceID string
expectedErr error
name string
credentialsObject CredentialsObject
resourceID string
expectedErr error
}{
{
name: "empty resourceID",
uaIdentities: UserAssignedIdentities{
CredentialsObject: CredentialsObject{
CredentialsObject: swagger.CredentialsObject{
ExplicitIdentities: []*swagger.NestedCredentialsObject{
test.GetTestMSI(test.ValidResourceID),
},
credentialsObject: CredentialsObject{
CredentialsObject: swagger.CredentialsObject{
ExplicitIdentities: []*swagger.NestedCredentialsObject{
test.GetTestMSI(test.ValidResourceID),
},
},
},
resourceID: "",
expectedErr: errParseResourceID,
},
{
name: "no identities present",
uaIdentities: UserAssignedIdentities{},
resourceID: test.ValidResourceID,
expectedErr: errResourceIDNotFound,
name: "no identities present",
credentialsObject: CredentialsObject{},
resourceID: test.ValidResourceID,
expectedErr: errResourceIDNotFound,
},
{
name: "invalid requested resourceID",
uaIdentities: UserAssignedIdentities{
CredentialsObject: CredentialsObject{
CredentialsObject: swagger.CredentialsObject{
ExplicitIdentities: []*swagger.NestedCredentialsObject{
test.GetTestMSI(test.ValidResourceID),
},
credentialsObject: CredentialsObject{
CredentialsObject: swagger.CredentialsObject{
ExplicitIdentities: []*swagger.NestedCredentialsObject{
test.GetTestMSI(test.ValidResourceID),
},
},
},
Expand All @@ -142,12 +138,10 @@ func TestGetCredential(t *testing.T) {
},
{
name: "invalid identity resourceID",
uaIdentities: UserAssignedIdentities{
CredentialsObject: CredentialsObject{
CredentialsObject: swagger.CredentialsObject{
ExplicitIdentities: []*swagger.NestedCredentialsObject{
test.GetTestMSI(test.Bogus),
},
credentialsObject: CredentialsObject{
CredentialsObject: swagger.CredentialsObject{
ExplicitIdentities: []*swagger.NestedCredentialsObject{
test.GetTestMSI(test.Bogus),
},
},
},
Expand All @@ -156,12 +150,10 @@ func TestGetCredential(t *testing.T) {
},
{
name: "invalid client secret",
uaIdentities: UserAssignedIdentities{
CredentialsObject: CredentialsObject{
CredentialsObject: swagger.CredentialsObject{
ExplicitIdentities: []*swagger.NestedCredentialsObject{
test.GetTestMSI(test.ValidResourceID),
},
credentialsObject: CredentialsObject{
CredentialsObject: swagger.CredentialsObject{
ExplicitIdentities: []*swagger.NestedCredentialsObject{
test.GetTestMSI(test.ValidResourceID),
},
},
},
Expand All @@ -170,12 +162,10 @@ func TestGetCredential(t *testing.T) {
},
{
name: "success",
uaIdentities: UserAssignedIdentities{
CredentialsObject: CredentialsObject{
CredentialsObject: swagger.CredentialsObject{
ExplicitIdentities: []*swagger.NestedCredentialsObject{
validIdentity,
},
credentialsObject: CredentialsObject{
CredentialsObject: swagger.CredentialsObject{
ExplicitIdentities: []*swagger.NestedCredentialsObject{
validIdentity,
},
},
},
Expand All @@ -188,7 +178,7 @@ func TestGetCredential(t *testing.T) {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if _, err := tc.uaIdentities.GetCredential(tc.resourceID); !errors.Is(err, tc.expectedErr) {
if _, err := tc.credentialsObject.GetCredential(tc.resourceID); !errors.Is(err, tc.expectedErr) {
t.Errorf("expected error: `%s` but got: `%s`", tc.expectedErr, err)
}
})
Expand Down
6 changes: 5 additions & 1 deletion pkg/dataplane/stub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ func TestNewStub(t *testing.T) {
swagger.CredentialsObject{
ExplicitIdentities: []*swagger.NestedCredentialsObject{uaMSI},
},
AzurePublicCloud,
}
testStub := NewStub([]*CredentialsObject{credObject})
if testStub == nil {
Expand All @@ -37,6 +38,7 @@ func TestDo(t *testing.T) {
swagger.CredentialsObject{
ExplicitIdentities: []*swagger.NestedCredentialsObject{uaMSI},
},
AzurePublicCloud,
}
credRequest := &swagger.CredRequestDefinition{
IdentityIDs: []*string{test.StringPtr(test.ValidResourceID)},
Expand Down Expand Up @@ -118,6 +120,7 @@ func TestPost(t *testing.T) {
swagger.CredentialsObject{
ExplicitIdentities: []*swagger.NestedCredentialsObject{uaMSI},
},
AzurePublicCloud,
}

testCases := []struct {
Expand Down Expand Up @@ -179,6 +182,7 @@ func TestStubWithClient(t *testing.T) {
swagger.CredentialsObject{
ExplicitIdentities: []*swagger.NestedCredentialsObject{uaMSI},
},
AzurePublicCloud,
}
testStub := NewStub([]*CredentialsObject{credObject})
clientOpts := &policy.ClientOptions{
Expand All @@ -193,7 +197,7 @@ func TestStubWithClient(t *testing.T) {
IdentityURL: test.ValidIdentityURL,
TenantID: test.ValidTenantID,
}
identities, err := client.GetUserAssignedIdentities(context.Background(), request)
identities, err := client.GetCredentialsObjectUserAssignedIdentities(context.Background(), request)
if err != nil {
t.Fatalf("unable to get user assigned msi: %s", err)
}
Expand Down