diff --git a/docs/stackit_auth_login.md b/docs/stackit_auth_login.md index 3cd888bd2..d23d9d23f 100644 --- a/docs/stackit_auth_login.md +++ b/docs/stackit_auth_login.md @@ -5,7 +5,7 @@ Logs in to the STACKIT CLI ### Synopsis Logs in to the STACKIT CLI using a user account. -The authentication is done via a web-based authorization flow, where the command will open a browser window in which you can login to your STACKIT account. +By default, the authentication uses a web-based authorization flow and opens a browser window where you can login to your STACKIT account. You can alternatively use the OAuth 2.0 device flow for environments where receiving a local callback is not possible. ``` stackit auth login [flags] @@ -16,14 +16,18 @@ stackit auth login [flags] ``` Login to the STACKIT CLI. This command will open a browser window where you can login to your STACKIT account $ stackit auth login + + Login to the STACKIT CLI using OAuth 2.0 device flow + $ stackit auth login --use-device-flow ``` ### Options ``` - -h, --help Help for "stackit auth login" - --port int The port on which the callback server will listen to. By default, it tries to bind a port between 8000 and 8020. - When a value is specified, it will only try to use the specified port. Valid values are within the range of 8000 to 8020. + -h, --help Help for "stackit auth login" + --port int The port on which the callback server will listen to. By default, it tries to bind a port between 8000 and 8020. + When a value is specified, it will only try to use the specified port. Valid values are within the range of 8000 to 8020. + --use-device-flow Use OAuth 2.0 device authorization grant (device flow) instead of the browser callback flow. ``` ### Options inherited from parent commands diff --git a/internal/cmd/auth/login/login.go b/internal/cmd/auth/login/login.go index 8a03d19af..ad1b59aa7 100644 --- a/internal/cmd/auth/login/login.go +++ b/internal/cmd/auth/login/login.go @@ -14,11 +14,13 @@ import ( ) const ( - portFlag = "port" + portFlag = "port" + useDeviceFlowFlag = "use-device-flow" ) type inputModel struct { - Port *int + Port *int + UseDeviceFlow bool } func NewCmd(params *types.CmdParams) *cobra.Command { @@ -27,12 +29,15 @@ func NewCmd(params *types.CmdParams) *cobra.Command { Short: "Logs in to the STACKIT CLI", Long: fmt.Sprintf("%s\n%s", "Logs in to the STACKIT CLI using a user account.", - "The authentication is done via a web-based authorization flow, where the command will open a browser window in which you can login to your STACKIT account."), + "By default, the authentication uses a web-based authorization flow and opens a browser window where you can login to your STACKIT account. You can alternatively use the OAuth 2.0 device flow for environments where receiving a local callback is not possible."), Args: args.NoArgs, Example: examples.Build( examples.NewExample( `Login to the STACKIT CLI. This command will open a browser window where you can login to your STACKIT account`, "$ stackit auth login"), + examples.NewExample( + `Login to the STACKIT CLI using OAuth 2.0 device flow`, + "$ stackit auth login --use-device-flow"), ), RunE: func(cmd *cobra.Command, args []string) error { model, err := parseInput(params.Printer, cmd, args) @@ -43,6 +48,7 @@ func NewCmd(params *types.CmdParams) *cobra.Command { err = auth.AuthorizeUser(params.Printer, auth.UserAuthConfig{ IsReauthentication: false, Port: model.Port, + UseDeviceFlow: model.UseDeviceFlow, }) if err != nil { return fmt.Errorf("authorization failed: %w", err) @@ -62,17 +68,21 @@ func configureFlags(cmd *cobra.Command) { "The port on which the callback server will listen to. By default, it tries to bind a port between 8000 and 8020.\n"+ "When a value is specified, it will only try to use the specified port. Valid values are within the range of 8000 to 8020.", ) + cmd.Flags().Bool(useDeviceFlowFlag, false, + "Use OAuth 2.0 device authorization grant (device flow) instead of the browser callback flow.") } func parseInput(p *print.Printer, cmd *cobra.Command, _ []string) (*inputModel, error) { port := flags.FlagToIntPointer(p, cmd, portFlag) + useDeviceFlow := flags.FlagToBoolValue(p, cmd, useDeviceFlowFlag) // For the CLI client only callback URLs with localhost:[8000-8020] are valid. Additional callbacks must be enabled in the backend. if port != nil && (*port < 8000 || 8020 < *port) { return nil, fmt.Errorf("port must be between 8000 and 8020") } model := inputModel{ - Port: port, + Port: port, + UseDeviceFlow: useDeviceFlow, } p.DebugInputModel(model) diff --git a/internal/cmd/auth/login/login_test.go b/internal/cmd/auth/login/login_test.go index 823fa863e..24a973348 100644 --- a/internal/cmd/auth/login/login_test.go +++ b/internal/cmd/auth/login/login_test.go @@ -9,7 +9,8 @@ import ( func fixtureFlagValues(mods ...func(flagValues map[string]string)) map[string]string { flagValues := map[string]string{ - portFlag: "8010", + portFlag: "8010", + useDeviceFlowFlag: "false", } for _, mod := range mods { mod(flagValues) @@ -19,7 +20,8 @@ func fixtureFlagValues(mods ...func(flagValues map[string]string)) map[string]st func fixtureInputModel(mods ...func(model *inputModel)) *inputModel { model := &inputModel{ - Port: utils.Ptr(8010), + Port: utils.Ptr(8010), + UseDeviceFlow: false, } for _, mod := range mods { mod(model) @@ -46,7 +48,19 @@ func TestParseInput(t *testing.T) { flagValues: map[string]string{}, isValid: true, expectedModel: &inputModel{ - Port: nil, + Port: nil, + UseDeviceFlow: false, + }, + }, + { + description: "device flow enabled", + flagValues: map[string]string{ + useDeviceFlowFlag: "true", + }, + isValid: true, + expectedModel: &inputModel{ + Port: nil, + UseDeviceFlow: true, }, }, { @@ -56,7 +70,8 @@ func TestParseInput(t *testing.T) { }, isValid: true, expectedModel: &inputModel{ - Port: utils.Ptr(8000), + Port: utils.Ptr(8000), + UseDeviceFlow: false, }, }, { @@ -73,7 +88,8 @@ func TestParseInput(t *testing.T) { }, isValid: true, expectedModel: &inputModel{ - Port: utils.Ptr(8020), + Port: utils.Ptr(8020), + UseDeviceFlow: false, }, }, { diff --git a/internal/pkg/auth/user_login.go b/internal/pkg/auth/user_login.go index cefde3868..12774fa8c 100644 --- a/internal/pkg/auth/user_login.go +++ b/internal/pkg/auth/user_login.go @@ -9,6 +9,7 @@ import ( "io" "net" "net/http" + "net/url" "os" "os/exec" "runtime" @@ -25,6 +26,7 @@ import ( const ( defaultWellKnownConfig = "https://accounts.stackit.cloud/.well-known/openid-configuration" defaultCLIClientID = "stackit-cli-0000-0000-000000000001" + scope = "openid groups offline_access email" loginSuccessPath = "/login-successful" @@ -32,6 +34,10 @@ const ( // so we configure a range of ports from 8000 to 8020 defaultPort = 8000 configuredPortRange = 20 + + deviceCodeGrantType = "urn:ietf:params:oauth:grant-type:device_code" + defaultDevicePollInterval = 5 * time.Second + devicePollSlowDownStep = 5 * time.Second ) //go:embed templates/login-successful.html @@ -50,13 +56,15 @@ type UserAuthConfig struct { IsReauthentication bool // Port defines which port should be used for the UserAuthFlow callback Port *int + // UseDeviceFlow defines if the login should use OAuth 2.0 device flow + UseDeviceFlow bool } type apiClient interface { Do(req *http.Request) (*http.Response, error) } -// AuthorizeUser implements the PKCE OAuth2 flow. +// AuthorizeUser performs user login using either PKCE or OAuth 2.0 device flow. func AuthorizeUser(p *print.Printer, authConfig UserAuthConfig) error { idpWellKnownConfig, err := retrieveIDPWellKnownConfig(p) if err != nil { @@ -82,6 +90,14 @@ func AuthorizeUser(p *print.Printer, authConfig UserAuthConfig) error { } } + if authConfig.UseDeviceFlow { + return authorizeUserWithDeviceFlow(p, idpWellKnownConfig, idpClientID) + } + + return authorizeUserWithPKCE(p, idpWellKnownConfig, idpClientID, authConfig) +} + +func authorizeUserWithPKCE(p *print.Printer, idpWellKnownConfig *wellKnownConfig, idpClientID string, authConfig UserAuthConfig) error { var redirectURL string var listener net.Listener var listenerErr error @@ -113,7 +129,7 @@ func AuthorizeUser(p *print.Printer, authConfig UserAuthConfig) error { Endpoint: oauth2.Endpoint{ AuthURL: idpWellKnownConfig.AuthorizationEndpoint, }, - Scopes: []string{"openid offline_access email"}, + Scopes: []string{scope}, RedirectURL: redirectURL, } @@ -268,6 +284,238 @@ func AuthorizeUser(p *print.Printer, authConfig UserAuthConfig) error { return nil } +func authorizeUserWithDeviceFlow(p *print.Printer, idpWellKnownConfig *wellKnownConfig, idpClientID string) error { + if idpWellKnownConfig.DeviceAuthorizationEndpoint == "" { + return fmt.Errorf("IDP does not provide a device authorization endpoint") + } + if len(idpWellKnownConfig.GrantTypesSupported) > 0 && !containsString(idpWellKnownConfig.GrantTypesSupported, deviceCodeGrantType) { + return fmt.Errorf("IDP does not advertise support for grant type %q", deviceCodeGrantType) + } + + p.Debug(print.DebugLevel, "using device authorization endpoint %s", idpWellKnownConfig.DeviceAuthorizationEndpoint) + p.Debug(print.DebugLevel, "using token endpoint %s", idpWellKnownConfig.TokenEndpoint) + p.Debug(print.DebugLevel, "using client ID %s for authentication ", idpClientID) + + deviceAuthorization, err := getDeviceAuthorizationData(idpWellKnownConfig.DeviceAuthorizationEndpoint, idpClientID) + if err != nil { + return fmt.Errorf("request device authorization: %w", err) + } + + verificationURL := deviceAuthorization.VerificationURIComplete + if verificationURL == "" { + verificationURL = deviceAuthorization.VerificationURI + } + + p.Info("To complete login with device flow:\n") + p.Info("1. Open this URL in your browser:\n") + p.Info("%s\n\n", verificationURL) + p.Info("2. Enter this code when prompted: %s\n\n", deviceAuthorization.UserCode) + + if verificationURL != "" { + err = openBrowser(verificationURL) + if err != nil { + p.Warn("Could not open browser automatically: %v\n", err) + } + } + + accessToken, refreshToken, err := waitForDeviceFlowTokens(idpWellKnownConfig.TokenEndpoint, idpClientID, deviceAuthorization) + if err != nil { + return fmt.Errorf("retrieve tokens: %w", err) + } + + sessionExpiresAtUnix, err := getStartingSessionExpiresAtUnix() + if err != nil { + return fmt.Errorf("compute session expiration timestamp: %w", err) + } + + err = SetAuthFlow(AUTH_FLOW_USER_TOKEN) + if err != nil { + return fmt.Errorf("set auth flow type: %w", err) + } + + email, err := getEmailFromToken(accessToken) + if err != nil { + return fmt.Errorf("get email from access token: %w", err) + } + + p.Debug(print.DebugLevel, "user %s logged in successfully via device flow", email) + + err = LoginUser(email, accessToken, refreshToken, sessionExpiresAtUnix) + if err != nil { + return fmt.Errorf("set in auth storage: %w", err) + } + + return nil +} + +type deviceAuthorizationResponse struct { + DeviceCode string `json:"device_code"` + UserCode string `json:"user_code"` + VerificationURI string `json:"verification_uri"` + VerificationURIComplete string `json:"verification_uri_complete"` + ExpiresIn int `json:"expires_in"` + Interval int `json:"interval"` +} + +func getDeviceAuthorizationData(deviceAuthorizationEndpoint, clientID string) (*deviceAuthorizationResponse, error) { + form := url.Values{} + form.Set("client_id", clientID) + form.Set("scope", scope) + + req, err := http.NewRequest("POST", deviceAuthorizationEndpoint, strings.NewReader(form.Encode())) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + req.Header.Add("content-type", "application/x-www-form-urlencoded") + + httpClient := &http.Client{} + res, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("call device authorization endpoint: %w", err) + } + defer func() { + _ = res.Body.Close() + }() + + body, err := io.ReadAll(res.Body) + if err != nil { + return nil, fmt.Errorf("read response body: %w", err) + } + + responseData := deviceAuthorizationResponse{} + err = json.Unmarshal(body, &responseData) + if err != nil { + return nil, fmt.Errorf("unmarshal response: %w", err) + } + if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("unexpected status code %d", res.StatusCode) + } + if responseData.DeviceCode == "" { + return nil, fmt.Errorf("found no device code") + } + if responseData.UserCode == "" { + return nil, fmt.Errorf("found no user code") + } + if responseData.VerificationURI == "" && responseData.VerificationURIComplete == "" { + return nil, fmt.Errorf("found no verification URI") + } + if responseData.ExpiresIn <= 0 { + return nil, fmt.Errorf("found invalid expiration") + } + + return &responseData, nil +} + +type tokenEndpointResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + Error string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +func waitForDeviceFlowTokens(tokenEndpoint, clientID string, deviceAuthorization *deviceAuthorizationResponse) (accessToken, refreshToken string, err error) { + pollInterval := defaultDevicePollInterval + if deviceAuthorization.Interval > 0 { + pollInterval = time.Duration(deviceAuthorization.Interval) * time.Second + } + deadline := time.Now().Add(time.Duration(deviceAuthorization.ExpiresIn) * time.Second) + + for { + if time.Now().After(deadline) { + return "", "", fmt.Errorf("device authorization expired before login was completed") + } + + accessToken, refreshToken, tokenError, tokenErrorDescription, err := getTokensWithDeviceCode(tokenEndpoint, clientID, deviceAuthorization.DeviceCode) + if err != nil { + return "", "", err + } + + switch tokenError { + case "": + if accessToken == "" { + return "", "", fmt.Errorf("found no access token") + } + if refreshToken == "" { + return "", "", fmt.Errorf("found no refresh token") + } + return accessToken, refreshToken, nil + case "authorization_pending": + // Keep polling until the user confirms authorization or the device code expires. + case "slow_down": + pollInterval += devicePollSlowDownStep + case "access_denied": + if tokenErrorDescription == "" { + return "", "", fmt.Errorf("device authorization was denied by the user") + } + return "", "", fmt.Errorf("device authorization denied: %s", tokenErrorDescription) + case "expired_token": + if tokenErrorDescription == "" { + return "", "", fmt.Errorf("device authorization expired") + } + return "", "", fmt.Errorf("device authorization expired: %s", tokenErrorDescription) + default: + if tokenErrorDescription == "" { + return "", "", fmt.Errorf("token endpoint returned error %q", tokenError) + } + return "", "", fmt.Errorf("token endpoint returned error %q: %s", tokenError, tokenErrorDescription) + } + + time.Sleep(pollInterval) + } +} + +func getTokensWithDeviceCode(tokenEndpoint, clientID, deviceCode string) (accessToken, refreshToken, tokenError, tokenErrorDescription string, err error) { + form := url.Values{} + form.Set("grant_type", deviceCodeGrantType) + form.Set("client_id", clientID) + form.Set("device_code", deviceCode) + + req, err := http.NewRequest("POST", tokenEndpoint, strings.NewReader(form.Encode())) + if err != nil { + return "", "", "", "", fmt.Errorf("create request: %w", err) + } + req.Header.Add("content-type", "application/x-www-form-urlencoded") + + httpClient := &http.Client{} + res, err := httpClient.Do(req) + if err != nil { + return "", "", "", "", fmt.Errorf("call access token endpoint: %w", err) + } + defer func() { + closeErr := res.Body.Close() + if closeErr != nil && err == nil { + err = fmt.Errorf("close response body: %w", closeErr) + } + }() + + body, err := io.ReadAll(res.Body) + if err != nil { + return "", "", "", "", fmt.Errorf("read response body: %w", err) + } + + responseData := tokenEndpointResponse{} + err = json.Unmarshal(body, &responseData) + if err != nil { + return "", "", "", "", fmt.Errorf("unmarshal response: %w", err) + } + if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices { + if responseData.Error == "" { + return "", "", "", "", fmt.Errorf("token endpoint returned status code %d", res.StatusCode) + } + } + + return responseData.AccessToken, responseData.RefreshToken, responseData.Error, responseData.ErrorDescription, nil +} + +func containsString(values []string, expected string) bool { + for _, value := range values { + if value == expected { + return true + } + } + return false +} + // getUserAccessAndRefreshTokens trades the authorization code retrieved from the first OAuth2 leg for an access token and a refresh token func getUserAccessAndRefreshTokens(idpWellKnownConfig *wellKnownConfig, clientID, codeVerifier, authorizationCode, callbackURL string) (accessToken, refreshToken string, err error) { // Set form-encoded data for the POST to the access token endpoint diff --git a/internal/pkg/auth/utils.go b/internal/pkg/auth/utils.go index a1be5a546..5585c585a 100644 --- a/internal/pkg/auth/utils.go +++ b/internal/pkg/auth/utils.go @@ -13,9 +13,11 @@ import ( ) type wellKnownConfig struct { - Issuer string `json:"issuer"` - AuthorizationEndpoint string `json:"authorization_endpoint"` - TokenEndpoint string `json:"token_endpoint"` + Issuer string `json:"issuer"` + AuthorizationEndpoint string `json:"authorization_endpoint"` + TokenEndpoint string `json:"token_endpoint"` + DeviceAuthorizationEndpoint string `json:"device_authorization_endpoint"` + GrantTypesSupported []string `json:"grant_types_supported"` } func getIDPWellKnownConfigURL() (wellKnownConfigURL string, err error) {