diff --git a/packages/api/model.go b/packages/api/model.go index 92879031..a17441f5 100644 --- a/packages/api/model.go +++ b/packages/api/model.go @@ -852,24 +852,31 @@ type RegisterGatewayResponse struct { } type PAMAccessRequest struct { - Duration string `json:"duration,omitempty"` + // New path-based fields (PAM revamp) + Path string `json:"path,omitempty"` + + // Legacy resource-based fields (kept so we can temporarily keep the other proxy files for furtherr revamp) ResourceName string `json:"resourceName,omitempty"` AccountName string `json:"accountName,omitempty"` ProjectId string `json:"projectId,omitempty"` MfaSessionId string `json:"mfaSessionId,omitempty"` - Reason string `json:"reason,omitempty"` + + // Common fields + Duration string `json:"duration,omitempty"` + Reason string `json:"reason,omitempty"` } type PAMAccessResponse struct { SessionId string `json:"sessionId"` + AccountType string `json:"accountType"` ResourceType string `json:"resourceType"` + RelayHost string `json:"relayHost"` RelayClientCertificate string `json:"relayClientCertificate"` RelayClientPrivateKey string `json:"relayClientPrivateKey"` RelayServerCertificateChain string `json:"relayServerCertificateChain"` GatewayClientCertificate string `json:"gatewayClientCertificate"` GatewayClientPrivateKey string `json:"gatewayClientPrivateKey"` GatewayServerCertificateChain string `json:"gatewayServerCertificateChain"` - RelayHost string `json:"relayHost"` Metadata map[string]string `json:"metadata,omitempty"` } diff --git a/packages/cmd/pam.go b/packages/cmd/pam.go index 5f9d99b0..d40d4b6b 100644 --- a/packages/cmd/pam.go +++ b/packages/cmd/pam.go @@ -1,36 +1,13 @@ package cmd import ( - "os" "time" pam "github.com/Infisical/infisical-merge/packages/pam/local" "github.com/Infisical/infisical-merge/packages/util" - "github.com/mattn/go-isatty" - "github.com/rs/zerolog/log" "github.com/spf13/cobra" ) -func readReasonFlag(cmd *cobra.Command) string { - reason, _ := cmd.Flags().GetString("reason") - return reason -} - -func resolveReason(cmd *cobra.Command) string { - if cmd.Flags().Changed("reason") { - reason, _ := cmd.Flags().GetString("reason") - return reason - } - if !isatty.IsTerminal(os.Stdin.Fd()) { - return "" - } - reason, err := pam.PromptForReason(false) - if err != nil { - return "" - } - return reason -} - var pamCmd = &cobra.Command{ Use: "pam", Short: "PAM-related commands", @@ -39,338 +16,22 @@ var pamCmd = &cobra.Command{ Args: cobra.NoArgs, } -// ==================== Database Commands ==================== - -var pamDbCmd = &cobra.Command{ - Use: "db", - Short: "Database-related PAM commands", - Long: "Database-related PAM commands for Infisical", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, -} - -var pamDbAccessCmd = &cobra.Command{ - Use: "access", - Short: "Access PAM database accounts", - Long: "Access PAM database accounts for Infisical. This starts a local database proxy server that you can use to connect to databases directly.", - Example: "infisical pam db access --resource infisical-shared-cloud-instances --account infisical --project-id b38bef10-2685-43c4-9a2c-635206d60bec --duration 4h", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - util.RequireLogin() - - resourceName, _ := cmd.Flags().GetString("resource") - accountName, _ := cmd.Flags().GetString("account") - - if resourceName == "" || accountName == "" { - util.PrintErrorMessageAndExit("Both --resource and --account flags are required") - } - - projectID, err := cmd.Flags().GetString("project-id") - if err != nil { - util.HandleError(err, "Unable to parse project-id flag") - } - - if projectID == "" { - workspaceFile, err := util.GetWorkSpaceFromFile() - if err != nil { - util.PrintErrorMessageAndExit("Please either run infisical init to connect to a project or pass in project id with --project-id flag") - } - projectID = workspaceFile.WorkspaceId - } - - durationStr, err := cmd.Flags().GetString("duration") - if err != nil { - util.HandleError(err, "Unable to parse duration flag") - } - - _, err = time.ParseDuration(durationStr) - if err != nil { - util.HandleError(err, "Invalid duration format. Use formats like '1h', '30m', '2h30m'") - } - - port, err := cmd.Flags().GetInt("port") - if err != nil { - util.HandleError(err, "Unable to parse port flag") - } - - reason := resolveReason(cmd) - - log.Debug().Msg("PAM Database Access: Trying to fetch secrets using logged in details") - - loggedInUserDetails, err := util.GetCurrentLoggedInUserDetails(true) - isConnected := util.ValidateInfisicalAPIConnection() - - if isConnected { - log.Debug().Msg("PAM Database Access: Connected to Infisical instance, checking logged in creds") - } - - if err != nil { - util.HandleError(err, "Unable to get logged in user details") - } - - if isConnected && loggedInUserDetails.LoginExpired { - loggedInUserDetails = util.EstablishUserLoginSession() - } - - pam.StartDatabaseLocalProxy(loggedInUserDetails.UserCredentials.JTWToken, pam.PAMAccessParams{ - ResourceName: resourceName, - AccountName: accountName, - Reason: reason, - }, projectID, durationStr, port) - }, -} - -// ==================== SSH Commands ==================== - -var pamSshCmd = &cobra.Command{ - Use: "ssh", - Short: "SSH-related PAM commands", - Long: "SSH-related PAM commands for Infisical", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, -} - -var pamSshAccessCmd = &cobra.Command{ - Use: "access", - Short: "Start interactive SSH session to PAM account", - Long: "Start an interactive SSH session to a PAM-managed SSH account. This command automatically launches an SSH client connected through the Infisical Gateway.", - Example: "infisical pam ssh access --resource prod-servers --account root --project-id --duration 1h", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - runSSHCommand(cmd, args, pam.SSHAccessOptions{}) - }, -} - -var pamSshExecCmd = &cobra.Command{ - Use: "exec [command]", - Short: "Execute a command on a PAM SSH account", - Long: `Execute a single command on a PAM-managed SSH account and return the output. -This is useful for CI/CD pipelines and scripting where interactive sessions are not needed.`, - Example: ` # Run a command and get output - infisical pam ssh exec "ls -la /var/log" --resource prod-servers --account root --project-id - - # Use in a script - OUTPUT=$(infisical pam ssh exec "cat /etc/hostname" --resource prod-servers --account root --project-id )`, +var pamAccessCmd = &cobra.Command{ + Use: "access ", + Short: "Launch a PAM session for the account at the given path", + Long: `Launch a PAM session for the account at the given path. +The path format is: /folder/account-name (leading slash optional)`, + Example: "infisical pam access /production/postgres-main --duration 2h", DisableFlagsInUseLine: true, Args: cobra.ExactArgs(1), - Run: func(cmd *cobra.Command, args []string) { - runSSHCommand(cmd, args, pam.SSHAccessOptions{ - ExecCommand: args[0], - }) - }, -} - -var pamSshProxyCmd = &cobra.Command{ - Use: "proxy", - Short: "Start SSH proxy for SCP, SFTP, or rsync", - Long: `Start an SSH proxy without launching an interactive session. -This is useful for file transfers using SCP, SFTP, rsync, or other SSH-based tools. -The proxy prints connection details and waits until terminated with Ctrl+C.`, - Example: ` # Start the proxy - infisical pam ssh proxy --resource prod-servers --account root --project-id - - # Then in another terminal, use SCP: - scp -P -o StrictHostKeyChecking=no local-file.txt root@127.0.0.1:/remote/path/ - - # Or use rsync: - rsync -e "ssh -p -o StrictHostKeyChecking=no" local-dir/ root@127.0.0.1:/remote/path/`, - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - runSSHCommand(cmd, args, pam.SSHAccessOptions{ - ProxyOnly: true, - }) - }, -} - -// runSSHCommand is the shared implementation for all SSH subcommands -func runSSHCommand(cmd *cobra.Command, args []string, options pam.SSHAccessOptions) { - util.RequireLogin() - - resourceName, _ := cmd.Flags().GetString("resource") - accountName, _ := cmd.Flags().GetString("account") - - if resourceName == "" || accountName == "" { - util.PrintErrorMessageAndExit("Both --resource and --account flags are required") - } - - durationStr, err := cmd.Flags().GetString("duration") - if err != nil { - util.HandleError(err, "Unable to parse duration flag") - } - - _, err = time.ParseDuration(durationStr) - if err != nil { - util.HandleError(err, "Invalid duration format. Use formats like '1h', '30m', '2h30m'") - } - - projectID, err := cmd.Flags().GetString("project-id") - if err != nil { - util.HandleError(err, "Unable to parse project-id flag") - } - - if projectID == "" { - workspaceFile, err := util.GetWorkSpaceFromFile() - if err != nil { - util.PrintErrorMessageAndExit("Please either run infisical init to connect to a project or pass in project id with --project-id flag") - } - projectID = workspaceFile.WorkspaceId - } - - var reason string - if options.ExecCommand != "" { - reason = readReasonFlag(cmd) - } else { - reason = resolveReason(cmd) - } - - log.Debug().Msg("PAM SSH: Trying to fetch credentials using logged in details") - - loggedInUserDetails, err := util.GetCurrentLoggedInUserDetails(true) - isConnected := util.ValidateInfisicalAPIConnection() - - if isConnected { - log.Debug().Msg("PAM SSH: Connected to Infisical instance, checking logged in creds") - } - - if err != nil { - util.HandleError(err, "Unable to get logged in user details") - } - - if isConnected && loggedInUserDetails.LoginExpired { - loggedInUserDetails = util.EstablishUserLoginSession() - } - - pam.StartSSHLocalProxy(loggedInUserDetails.UserCredentials.JTWToken, pam.PAMAccessParams{ - ResourceName: resourceName, - AccountName: accountName, - Reason: reason, - }, projectID, durationStr, options) -} - -// ==================== Kubernetes Commands ==================== - -var pamKubernetesCmd = &cobra.Command{ - Use: "kubernetes", - Aliases: []string{"k8s"}, - Short: "Kubernetes-related PAM commands", - Long: "Kubernetes-related PAM commands for Infisical", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, -} - -var pamKubernetesAccessCmd = &cobra.Command{ - Use: "access", - Short: "Access Kubernetes PAM account", - Long: "Access Kubernetes via a PAM-managed Kubernetes account. This command automatically launches a proxy connected to your Kubernetes cluster through the Infisical Gateway.", - Example: "infisical pam kubernetes access --resource prod-cluster --account developer --project-id b38bef10-2685-43c4-9a2c-635206d60bec --duration 4h", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - util.RequireLogin() - - resourceName, _ := cmd.Flags().GetString("resource") - accountName, _ := cmd.Flags().GetString("account") - - if resourceName == "" || accountName == "" { - util.PrintErrorMessageAndExit("Both --resource and --account flags are required") - } - - durationStr, err := cmd.Flags().GetString("duration") - if err != nil { - util.HandleError(err, "Unable to parse duration flag") - } - - _, err = time.ParseDuration(durationStr) - if err != nil { - util.HandleError(err, "Invalid duration format. Use formats like '1h', '30m', '2h30m'") - } - - port, err := cmd.Flags().GetInt("port") - if err != nil { - util.HandleError(err, "Unable to parse port flag") - } - - projectID, err := cmd.Flags().GetString("project-id") - if err != nil { - util.HandleError(err, "Unable to parse project-id flag") - } - - if projectID == "" { - workspaceFile, err := util.GetWorkSpaceFromFile() - if err != nil { - util.PrintErrorMessageAndExit("Please either run infisical init to connect to a project or pass in project id with --project-id flag") - } - projectID = workspaceFile.WorkspaceId - } - - reason := resolveReason(cmd) - - log.Debug().Msg("PAM Kubernetes Access: Trying to fetch credentials using logged in details") - - loggedInUserDetails, err := util.GetCurrentLoggedInUserDetails(true) - isConnected := util.ValidateInfisicalAPIConnection() - - if isConnected { - log.Debug().Msg("PAM Kubernetes Access: Connected to Infisical instance, checking logged in creds") - } - - if err != nil { - util.HandleError(err, "Unable to get logged in user details") - } - - if isConnected && loggedInUserDetails.LoginExpired { - loggedInUserDetails = util.EstablishUserLoginSession() - } - - pam.StartKubernetesLocalProxy(loggedInUserDetails.UserCredentials.JTWToken, pam.PAMAccessParams{ - ResourceName: resourceName, - AccountName: accountName, - Reason: reason, - }, projectID, durationStr, port) - }, -} - -// ==================== Redis Commands ==================== - -var pamRedisCmd = &cobra.Command{ - Use: "redis", - Short: "Redis-related PAM commands", - Long: "Redis-related PAM commands for Infisical", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, -} - -var pamRedisAccessCmd = &cobra.Command{ - Use: "access", - Short: "Access PAM Redis accounts", - Long: "Access PAM Redis accounts for Infisical. This starts a local Redis proxy server that you can use to connect to Redis directly.", - Example: "infisical pam redis access --resource my-redis-resource --account redis-admin --duration 4h --port 6379 --project-id ", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, Run: func(cmd *cobra.Command, args []string) { util.RequireLogin() - resourceName, _ := cmd.Flags().GetString("resource") - accountName, _ := cmd.Flags().GetString("account") + path := args[0] - if resourceName == "" || accountName == "" { - util.PrintErrorMessageAndExit("Both --resource and --account flags are required") - } - - projectID, err := cmd.Flags().GetString("project-id") + reason, err := cmd.Flags().GetString("reason") if err != nil { - util.HandleError(err, "Unable to parse project-id flag") - } - - if projectID == "" { - workspaceFile, err := util.GetWorkSpaceFromFile() - if err != nil { - util.PrintErrorMessageAndExit("Please either run infisical init to connect to a project or pass in project id with --project-id flag") - } - projectID = workspaceFile.WorkspaceId + util.HandleError(err, "Unable to parse reason flag") } durationStr, err := cmd.Flags().GetString("duration") @@ -388,190 +49,25 @@ var pamRedisAccessCmd = &cobra.Command{ util.HandleError(err, "Unable to parse port flag") } - reason := resolveReason(cmd) - - log.Debug().Msg("PAM Redis Access: Trying to fetch secrets using logged in details") - loggedInUserDetails, err := util.GetCurrentLoggedInUserDetails(true) - isConnected := util.ValidateInfisicalAPIConnection() - - if isConnected { - log.Debug().Msg("PAM Redis Access: Connected to Infisical instance, checking logged in creds") - } - if err != nil { util.HandleError(err, "Unable to get logged in user details") } - if isConnected && loggedInUserDetails.LoginExpired { - loggedInUserDetails = util.EstablishUserLoginSession() - } - - pam.StartRedisLocalProxy(loggedInUserDetails.UserCredentials.JTWToken, pam.PAMAccessParams{ - ResourceName: resourceName, - AccountName: accountName, - Reason: reason, - }, projectID, durationStr, port) - }, -} - -// ==================== RDP Commands ==================== - -var pamRdpCmd = &cobra.Command{ - Use: "rdp", - Short: "RDP-related PAM commands", - Long: "RDP-related PAM commands for Infisical (Windows Server / Remote Desktop)", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, -} - -var pamRdpAccessCmd = &cobra.Command{ - Use: "access", - Short: "Access PAM Windows/RDP accounts", - Long: "Access a PAM-managed Windows target over RDP. This starts a local loopback proxy that your RDP client connects to; the session tunnels through the Infisical Gateway with credentials injected server-side.", - Example: "infisical pam rdp access --resource windows-prod --account administrator --duration 1h --project-id ", - DisableFlagsInUseLine: true, - Args: cobra.NoArgs, - Run: func(cmd *cobra.Command, args []string) { - util.RequireLogin() - - resourceName, _ := cmd.Flags().GetString("resource") - accountName, _ := cmd.Flags().GetString("account") - - if resourceName == "" || accountName == "" { - util.PrintErrorMessageAndExit("Both --resource and --account flags are required") - } - - projectID, err := cmd.Flags().GetString("project-id") - if err != nil { - util.HandleError(err, "Unable to parse project-id flag") - } - - if projectID == "" { - workspaceFile, err := util.GetWorkSpaceFromFile() - if err != nil { - util.PrintErrorMessageAndExit("Please either run infisical init to connect to a project or pass in project id with --project-id flag") - } - projectID = workspaceFile.WorkspaceId - } - - durationStr, err := cmd.Flags().GetString("duration") - if err != nil { - util.HandleError(err, "Unable to parse duration flag") - } - - _, err = time.ParseDuration(durationStr) - if err != nil { - util.HandleError(err, "Invalid duration format. Use formats like '1h', '30m', '2h30m'") - } - - port, err := cmd.Flags().GetInt("port") - if err != nil { - util.HandleError(err, "Unable to parse port flag") - } - - noLaunch, err := cmd.Flags().GetBool("no-launch") - if err != nil { - util.HandleError(err, "Unable to parse no-launch flag") - } - - reason := resolveReason(cmd) - - log.Debug().Msg("PAM RDP Access: Trying to start session using logged in details") - - loggedInUserDetails, err := util.GetCurrentLoggedInUserDetails(true) isConnected := util.ValidateInfisicalAPIConnection() - - if isConnected { - log.Debug().Msg("PAM RDP Access: Connected to Infisical instance, checking logged in creds") - } - - if err != nil { - util.HandleError(err, "Unable to get logged in user details") - } - if isConnected && loggedInUserDetails.LoginExpired { loggedInUserDetails = util.EstablishUserLoginSession() } - pam.StartRDPLocalProxy(loggedInUserDetails.UserCredentials.JTWToken, pam.PAMAccessParams{ - ResourceName: resourceName, - AccountName: accountName, - Reason: reason, - }, projectID, durationStr, port, noLaunch) + pam.StartPAMAccess(loggedInUserDetails.UserCredentials.JTWToken, path, reason, durationStr, port) }, } func init() { - // Database commands - pamDbCmd.AddCommand(pamDbAccessCmd) - pamDbAccessCmd.Flags().String("resource", "", "Name of the PAM resource to access") - pamDbAccessCmd.Flags().String("account", "", "Name of the account within the resource") - pamDbAccessCmd.Flags().String("duration", "1h", "Duration for database access session (e.g., '1h', '30m', '2h30m')") - pamDbAccessCmd.Flags().Int("port", 0, "Port for the local database proxy server (0 for auto-assign)") - pamDbAccessCmd.Flags().String("project-id", "", "Project ID of the account to access") - pamDbAccessCmd.Flags().String("reason", "", "Reason for accessing the account (stored for audit purposes)") - pamDbAccessCmd.MarkFlagRequired("resource") - pamDbAccessCmd.MarkFlagRequired("account") - - // SSH commands - shared flags helper - addSSHFlags := func(cmd *cobra.Command) { - cmd.Flags().String("resource", "", "Name of the PAM resource to access") - cmd.Flags().String("account", "", "Name of the account within the resource") - cmd.Flags().String("duration", "1h", "Duration for SSH access session (e.g., '1h', '30m', '2h30m')") - cmd.Flags().String("project-id", "", "Project ID of the account to access") - cmd.Flags().String("reason", "", "Reason for accessing the account (stored for audit purposes)") - cmd.MarkFlagRequired("resource") - cmd.MarkFlagRequired("account") - } - - pamSshCmd.AddCommand(pamSshAccessCmd) - addSSHFlags(pamSshAccessCmd) - - pamSshCmd.AddCommand(pamSshExecCmd) - addSSHFlags(pamSshExecCmd) - - pamSshCmd.AddCommand(pamSshProxyCmd) - addSSHFlags(pamSshProxyCmd) - - // Kubernetes commands - pamKubernetesCmd.AddCommand(pamKubernetesAccessCmd) - pamKubernetesAccessCmd.Flags().String("resource", "", "Name of the PAM resource to access") - pamKubernetesAccessCmd.Flags().String("account", "", "Name of the account within the resource") - pamKubernetesAccessCmd.Flags().String("duration", "1h", "Duration for kubernetes access session (e.g., '1h', '30m', '2h30m')") - pamKubernetesAccessCmd.Flags().Int("port", 0, "Port for the local kubernetes proxy server (0 for auto-assign)") - pamKubernetesAccessCmd.Flags().String("project-id", "", "Project ID of the account to access") - pamKubernetesAccessCmd.Flags().String("reason", "", "Reason for accessing the account (stored for audit purposes)") - pamKubernetesAccessCmd.MarkFlagRequired("resource") - pamKubernetesAccessCmd.MarkFlagRequired("account") - - // Redis commands - pamRedisCmd.AddCommand(pamRedisAccessCmd) - pamRedisAccessCmd.Flags().String("resource", "", "Name of the PAM resource to access") - pamRedisAccessCmd.Flags().String("account", "", "Name of the account within the resource") - pamRedisAccessCmd.Flags().String("duration", "1h", "Duration for Redis access session (e.g., '1h', '30m', '2h30m')") - pamRedisAccessCmd.Flags().Int("port", 0, "Port for the local Redis proxy server (0 for auto-assign)") - pamRedisAccessCmd.Flags().String("project-id", "", "Project ID of the account to access") - pamRedisAccessCmd.Flags().String("reason", "", "Reason for accessing the account (stored for audit purposes)") - pamRedisAccessCmd.MarkFlagRequired("resource") - pamRedisAccessCmd.MarkFlagRequired("account") - - // RDP commands - pamRdpCmd.AddCommand(pamRdpAccessCmd) - pamRdpAccessCmd.Flags().String("resource", "", "Name of the PAM resource to access") - pamRdpAccessCmd.Flags().String("account", "", "Name of the account within the resource") - pamRdpAccessCmd.Flags().String("duration", "1h", "Duration for RDP access session (e.g., '1h', '30m', '2h30m')") - pamRdpAccessCmd.Flags().Int("port", 0, "Port for the local RDP proxy server (0 for auto-assign)") - pamRdpAccessCmd.Flags().String("project-id", "", "Project ID of the account to access") - pamRdpAccessCmd.Flags().Bool("no-launch", false, "Do not auto-launch the system RDP client; print connection details only") - pamRdpAccessCmd.Flags().String("reason", "", "Reason for accessing the account (stored for audit purposes)") - pamRdpAccessCmd.MarkFlagRequired("resource") - pamRdpAccessCmd.MarkFlagRequired("account") + pamAccessCmd.Flags().String("reason", "", "Reason for accessing the account (stored for audit purposes)") + pamAccessCmd.Flags().String("duration", "1h", "Duration for access session (e.g., '1h', '30m', '2h30m')") + pamAccessCmd.Flags().Int("port", 0, "Port for the local proxy server (0 for auto-assign)") - pamCmd.AddCommand(pamDbCmd) - pamCmd.AddCommand(pamSshCmd) - pamCmd.AddCommand(pamKubernetesCmd) - pamCmd.AddCommand(pamRedisCmd) - pamCmd.AddCommand(pamRdpCmd) + pamCmd.AddCommand(pamAccessCmd) RootCmd.AddCommand(pamCmd) } diff --git a/packages/pam/local/access.go b/packages/pam/local/access.go new file mode 100644 index 00000000..fb482312 --- /dev/null +++ b/packages/pam/local/access.go @@ -0,0 +1,303 @@ +package pam + +import ( + "context" + "fmt" + "os" + "os/signal" + "strings" + "syscall" + "time" + + "github.com/Infisical/infisical-merge/packages/api" + "github.com/Infisical/infisical-merge/packages/util" + "github.com/go-resty/resty/v2" + "github.com/rs/zerolog/log" +) + +// Account type constants (match API enum) +const ( + AccountTypePostgres = "postgres" + AccountTypeSSH = "ssh" + AccountTypeMySQL = "mysql" + AccountTypeMsSQL = "mssql" + AccountTypeMongoDB = "mongodb" + AccountTypeOracleDB = "oracledb" + AccountTypeRedis = "redis" + AccountTypeKubernetes = "kubernetes" + AccountTypeAwsIam = "aws-iam" + AccountTypeWindows = "windows" + AccountTypeActiveDirectory = "active-directory" +) + +// normalizePath ensures the path has a leading slash for display purposes. +// Both "/folder/account" and "folder/account" are accepted as input. +func normalizePath(path string) string { + if !strings.HasPrefix(path, "/") { + return "/" + path + } + return path +} + +// parsePath extracts the folder and account name from a path like "/folder/account" +func parsePath(path string) (folder, account string) { + // Remove leading slash if present + cleanPath := strings.TrimPrefix(path, "/") + parts := strings.SplitN(cleanPath, "/", 2) + if len(parts) == 2 { + return parts[0], parts[1] + } + // If no slash, treat the whole thing as account name + return "", cleanPath +} + +// StartPAMAccess initiates a PAM session for the account at the given path. +// The account type is determined from the API response and routed to the appropriate handler. +func StartPAMAccess(accessToken, path, reason, durationStr string, port int) { + // Normalize path for display (ensure leading slash) + displayPath := normalizePath(path) + + log.Info().Msgf("Starting PAM access for: %s", strings.TrimPrefix(displayPath, "/")) + log.Info().Msgf("Session duration: %s", durationStr) + + httpClient := resty.New() + httpClient.SetAuthToken(accessToken) + httpClient.SetHeader("User-Agent", api.USER_AGENT) + + pamResponse, err := api.CallPAMAccess(httpClient, api.PAMAccessRequest{ + Path: path, + Duration: durationStr, + Reason: reason, + }) + if err != nil { + util.HandleError(err, "Failed to create PAM session") + return + } + + log.Info().Msgf("Session created with ID: %s", pamResponse.SessionId) + log.Info().Msgf("Account type: %s", pamResponse.AccountType) + + // Route based on account type from API response + switch pamResponse.AccountType { + // Database types - all use the same proxy mechanism with different display configs + case AccountTypePostgres, AccountTypeMySQL, AccountTypeMsSQL, AccountTypeMongoDB, AccountTypeOracleDB: + startDatabaseProxy(httpClient, &pamResponse, displayPath, durationStr, port) + + // Non-database types - not yet implemented + case AccountTypeSSH: + util.PrintErrorMessageAndExit("SSH access not yet supported in the new PAM model") + case AccountTypeRedis: + util.PrintErrorMessageAndExit("Redis access not yet supported in the new PAM model") + case AccountTypeKubernetes: + util.PrintErrorMessageAndExit("Kubernetes access not yet supported in the new PAM model") + case AccountTypeAwsIam: + util.PrintErrorMessageAndExit("AWS IAM access not yet supported in the new PAM model") + case AccountTypeWindows: + util.PrintErrorMessageAndExit("Windows/RDP access not yet supported in the new PAM model") + case AccountTypeActiveDirectory: + util.PrintErrorMessageAndExit("Active Directory access not yet supported in the new PAM model") + default: + util.PrintErrorMessageAndExit(fmt.Sprintf("Unsupported account type: %s", pamResponse.AccountType)) + } +} + +// DatabaseDisplayConfig holds the display configuration for a database type +type DatabaseDisplayConfig struct { + TypeLabel string // e.g., "PostgreSQL", "MySQL", "SQL Server" + DefaultPort int // default port for this database type + ConnectionString func(username, database string, port int) string // builds the connection string + UsageExamples func(username, database string, port int) []string // CLI usage examples +} + +// databaseConfigs maps account types to their display configurations +var databaseConfigs = map[string]DatabaseDisplayConfig{ + AccountTypePostgres: { + TypeLabel: "PostgreSQL", + DefaultPort: 5432, + ConnectionString: func(username, database string, port int) string { + return fmt.Sprintf("postgres://%s@localhost:%d/%s", username, port, database) + }, + UsageExamples: func(username, database string, port int) []string { + return []string{ + fmt.Sprintf("psql -h localhost -p %d -U %s -d %s", port, username, database), + } + }, + }, + AccountTypeMySQL: { + TypeLabel: "MySQL", + DefaultPort: 3306, + ConnectionString: func(username, database string, port int) string { + return fmt.Sprintf("mysql://%s@localhost:%d/%s", username, port, database) + }, + UsageExamples: func(username, database string, port int) []string { + return []string{ + fmt.Sprintf("mysql -h 127.0.0.1 -P %d -u %s %s", port, username, database), + } + }, + }, + AccountTypeMsSQL: { + TypeLabel: "SQL Server", + DefaultPort: 1433, + ConnectionString: func(username, database string, port int) string { + return fmt.Sprintf("sqlserver://%s@localhost:%d?database=%s", username, port, database) + }, + UsageExamples: func(username, database string, port int) []string { + return []string{ + fmt.Sprintf("sqlcmd -S localhost,%d -U %s -d %s", port, username, database), + } + }, + }, + AccountTypeMongoDB: { + TypeLabel: "MongoDB", + DefaultPort: 27017, + ConnectionString: func(username, database string, port int) string { + return fmt.Sprintf("mongodb://localhost:%d/%s", port, database) + }, + UsageExamples: func(username, database string, port int) []string { + return []string{ + fmt.Sprintf("mongosh --host localhost --port %d %s", port, database), + } + }, + }, + AccountTypeOracleDB: { + TypeLabel: "Oracle", + DefaultPort: 1521, + ConnectionString: func(username, database string, port int) string { + return fmt.Sprintf("%s@localhost:%d/%s", username, port, database) + }, + UsageExamples: func(username, database string, port int) []string { + return []string{ + fmt.Sprintf("sqlplus %s@localhost:%d/%s", username, port, database), + } + }, + }, +} + +// startDatabaseProxy starts a local database proxy for any SQL-like database type +func startDatabaseProxy(httpClient *resty.Client, response *api.PAMAccessResponse, path, durationStr string, port int) { + config, ok := databaseConfigs[response.AccountType] + if !ok { + util.PrintErrorMessageAndExit(fmt.Sprintf("No display config for database type: %s", response.AccountType)) + return + } + + duration, err := time.ParseDuration(durationStr) + if err != nil { + util.HandleError(err, "Failed to parse duration") + return + } + + // Get connection details from metadata (validate before starting proxy) + username, ok := response.Metadata["username"] + if !ok { + util.HandleError(fmt.Errorf("PAM response metadata is missing 'username'"), "Failed to start proxy server") + return + } + database, ok := response.Metadata["database"] + if !ok { + util.HandleError(fmt.Errorf("PAM response metadata is missing 'database'"), "Failed to start proxy server") + return + } + + ctx, cancel := context.WithCancel(context.Background()) + + proxy := &DatabaseProxyServer{ + BaseProxyServer: BaseProxyServer{ + httpClient: httpClient, + relayHost: response.RelayHost, + relayClientCert: response.RelayClientCertificate, + relayClientKey: response.RelayClientPrivateKey, + relayServerCertChain: response.RelayServerCertificateChain, + gatewayClientCert: response.GatewayClientCertificate, + gatewayClientKey: response.GatewayClientPrivateKey, + gatewayServerCertChain: response.GatewayServerCertificateChain, + sessionExpiry: time.Now().Add(duration), + sessionId: response.SessionId, + resourceType: response.AccountType, + ctx: ctx, + cancel: cancel, + shutdownCh: make(chan struct{}), + }, + } + + if err := proxy.ValidateResourceTypeSupported(); err != nil { + util.HandleError(err, "Gateway version outdated") + return + } + + err = proxy.Start(port) + if err != nil { + util.HandleError(err, "Failed to start proxy server") + return + } + + // Parse path into folder and account + folder, account := parsePath(path) + + log.Info().Msgf("%s proxy server listening on port %d", config.TypeLabel, proxy.port) + printDatabaseSessionInfo(config, folder, account, duration, username, database, proxy.port) + + // Handle shutdown signals + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + go func() { + sig := <-sigChan + log.Info().Msgf("Received signal %v, initiating graceful shutdown...", sig) + proxy.gracefulShutdown() + }() + + proxy.Run() +} + +// printDatabaseSessionInfo prints the connection info banner for database sessions +func printDatabaseSessionInfo(config DatabaseDisplayConfig, folder, account string, duration time.Duration, username, database string, port int) { + fmt.Printf("\n") + fmt.Printf("**********************************************************************\n") + fmt.Printf(" %s Proxy Session Started! \n", config.TypeLabel) + fmt.Printf("**********************************************************************\n") + fmt.Printf("\n") + if folder != "" { + fmt.Printf(" Folder: %s\n", folder) + } + fmt.Printf(" Account: %s\n", account) + fmt.Printf(" Duration: %s\n", duration.String()) + fmt.Printf("\n") + fmt.Printf("----------------------------------------------------------------------\n") + fmt.Printf(" Connection Details \n") + fmt.Printf("----------------------------------------------------------------------\n") + fmt.Printf("\n") + fmt.Printf(" Host: localhost\n") + fmt.Printf(" Port: %d\n", port) + if username != "" { + fmt.Printf(" Username: %s\n", username) + } + fmt.Printf(" Password: (injected by gateway)\n") + if database != "" { + fmt.Printf(" Database: %s\n", database) + } + fmt.Printf("\n") + fmt.Printf("----------------------------------------------------------------------\n") + fmt.Printf(" How to Connect \n") + fmt.Printf("----------------------------------------------------------------------\n") + fmt.Printf("\n") + fmt.Printf(" Use your preferred database client (CLI, GUI, or IDE) to connect\n") + fmt.Printf(" to localhost:%d. The password will be injected automatically.\n", port) + fmt.Printf("\n") + if config.UsageExamples != nil { + examples := config.UsageExamples(username, database, port) + if len(examples) > 0 { + fmt.Printf(" Example:\n") + for _, ex := range examples { + util.PrintfStderr(" $ %s\n", ex) + } + fmt.Printf("\n") + } + } + fmt.Printf(" Connection string:\n") + connStr := config.ConnectionString(username, database, port) + util.PrintfStderr(" %s\n", connStr) + fmt.Printf("\n") + fmt.Printf("**********************************************************************\n") + fmt.Printf("\n") +} diff --git a/packages/pam/local/base-proxy.go b/packages/pam/local/base-proxy.go index 0ef603cd..0cc52108 100644 --- a/packages/pam/local/base-proxy.go +++ b/packages/pam/local/base-proxy.go @@ -26,6 +26,8 @@ import ( "github.com/rs/zerolog/log" ) +// PAMAccessParams holds the legacy resource-based access parameters. +// Used by the old proxy implementations (ssh, redis, kubernetes, rdp). type PAMAccessParams struct { ResourceName string AccountName string @@ -317,8 +319,6 @@ func (b *BaseProxyServer) WaitForConnectionsWithTimeout(timeout time.Duration) { } } -const reasonRequiredErrorName = "PAM_REASON_REQUIRED" - func PromptForReason(required bool) (string, error) { label := "Reason for access" prompt := promptui.Prompt{ @@ -337,18 +337,18 @@ func PromptForReason(required bool) (string, error) { return strings.TrimSpace(result), nil } -// CallPAMAccessWithMFA attempts to access a PAM account and handles MFA if required -// This is a shared function used by all PAM proxies +const reasonRequiredErrorName = "PAM_REASON_REQUIRED" + +// CallPAMAccessWithMFA attempts to access a PAM account and handles MFA if required. +// This is used by the legacy proxy implementations. func CallPAMAccessWithMFA( httpClient *resty.Client, pamRequest api.PAMAccessRequest, interactive bool, ) (api.PAMAccessResponse, error) { - // Initial request pamResponse, err := api.CallPAMAccess(httpClient, pamRequest) if err != nil { if apiErr, ok := err.(*api.APIError); ok { - // Reason required by account policy if apiErr.Name == reasonRequiredErrorName { if !interactive || !isatty.IsTerminal(os.Stdin.Fd()) { return api.PAMAccessResponse{}, fmt.Errorf( @@ -363,21 +363,17 @@ func CallPAMAccessWithMFA( return CallPAMAccessWithMFA(httpClient, pamRequest, interactive) } - // MFA required if apiErr.Name == "SESSION_MFA_REQUIRED" { - // Extract MFA details from error if details, ok := apiErr.Details.(map[string]interface{}); ok { mfaSessionId, _ := details["mfaSessionId"].(string) mfaMethod, _ := details["mfaMethod"].(string) if mfaSessionId != "" { - // Handle MFA flow err := util.HandleMFASession(httpClient, mfaSessionId, mfaMethod, config.INFISICAL_URL) if err != nil { return api.PAMAccessResponse{}, fmt.Errorf("MFA verification failed: %w", err) } - // Retry request with MFA session ID log.Debug().Msg("Retrying PAM access with MFA session...") pamRequest.MfaSessionId = mfaSessionId pamResponse, err = api.CallPAMAccess(httpClient, pamRequest) @@ -390,7 +386,6 @@ func CallPAMAccessWithMFA( } } } - // Return original error if not MFA/reason-related return api.PAMAccessResponse{}, err } @@ -398,7 +393,7 @@ func CallPAMAccessWithMFA( } // HandleApprovalWorkflow checks if an error is due to an approval policy and handles the approval request flow. -// Returns true if the error was handled (either approval request created or user declined), false otherwise. +// Returns true if the error was handled, false otherwise. func HandleApprovalWorkflow(httpClient *resty.Client, err error, projectID string, accessParams PAMAccessParams, durationStr string) bool { var apiErr *api.APIError if !errors.As(err, &apiErr) || apiErr.ErrorMessage != "A policy is in place for this resource" { @@ -462,3 +457,4 @@ func askForApprovalRequestTrigger() (bool, error) { } return strings.ToLower(result) == "y", nil } + diff --git a/packages/pam/local/database-proxy.go b/packages/pam/local/database-proxy.go index c418795c..872a13ee 100644 --- a/packages/pam/local/database-proxy.go +++ b/packages/pam/local/database-proxy.go @@ -6,21 +6,15 @@ import ( "io" "net" "os" - "os/signal" - "syscall" "time" - "github.com/Infisical/infisical-merge/packages/pam/handlers/oracle" - "github.com/Infisical/infisical-merge/packages/pam/session" - "github.com/Infisical/infisical-merge/packages/util" - "github.com/go-resty/resty/v2" "github.com/rs/zerolog/log" ) type DatabaseProxyServer struct { - BaseProxyServer // Embed common functionality - server net.Listener - port int + BaseProxyServer + server net.Listener + port int } type ALPN string @@ -31,123 +25,6 @@ const ( ALPNInfisicalPAMCapabilities ALPN = "infisical-pam-capabilities" ) -func StartDatabaseLocalProxy(accessToken string, accessParams PAMAccessParams, projectID string, durationStr string, port int) { - log.Info().Msgf("Starting database proxy for account: %s", accessParams.GetDisplayName()) - log.Info().Msgf("Session duration: %s", durationStr) - - httpClient := resty.New() - httpClient.SetAuthToken(accessToken) - httpClient.SetHeader("User-Agent", "infisical-cli") - - pamRequest := accessParams.ToAPIRequest(projectID, durationStr) - - pamResponse, err := CallPAMAccessWithMFA(httpClient, pamRequest, true) - if err != nil { - if HandleApprovalWorkflow(httpClient, err, projectID, accessParams, durationStr) { - return - } - util.HandleError(err, "Failed to access PAM account") - return - } - - log.Info().Msgf("Database session created with ID: %s", pamResponse.SessionId) - - duration, err := time.ParseDuration(durationStr) - if err != nil { - util.HandleError(err, "Failed to parse duration") - return - } - - ctx, cancel := context.WithCancel(context.Background()) - - proxy := &DatabaseProxyServer{ - BaseProxyServer: BaseProxyServer{ - httpClient: httpClient, - relayHost: pamResponse.RelayHost, - relayClientCert: pamResponse.RelayClientCertificate, - relayClientKey: pamResponse.RelayClientPrivateKey, - relayServerCertChain: pamResponse.RelayServerCertificateChain, - gatewayClientCert: pamResponse.GatewayClientCertificate, - gatewayClientKey: pamResponse.GatewayClientPrivateKey, - gatewayServerCertChain: pamResponse.GatewayServerCertificateChain, - sessionExpiry: time.Now().Add(duration), - sessionId: pamResponse.SessionId, - resourceType: pamResponse.ResourceType, - ctx: ctx, - cancel: cancel, - shutdownCh: make(chan struct{}), - }, - } - - if err := proxy.ValidateResourceTypeSupported(); err != nil { - util.HandleError(err, "Gateway version outdated") - return - } - - err = proxy.Start(port) - if err != nil { - util.HandleError(err, "Failed to start proxy server") - return - } - - if port == 0 { - fmt.Printf("Database proxy started for account %s with duration %s on port %d (auto-assigned)\n", accessParams.GetDisplayName(), duration.String(), proxy.port) - } else { - fmt.Printf("Database proxy started for account %s with duration %s on port %d\n", accessParams.GetDisplayName(), duration.String(), proxy.port) - } - - username, ok := pamResponse.Metadata["username"] - if !ok { - util.HandleError(fmt.Errorf("PAM response metadata is missing 'username'"), "Failed to start proxy server") - return - } - database, ok := pamResponse.Metadata["database"] - if !ok { - util.HandleError(fmt.Errorf("PAM response metadata is missing 'database'"), "Failed to start proxy server") - return - } - - log.Info().Msgf("Database proxy server listening on port %d", proxy.port) - fmt.Printf("\n") - fmt.Printf("**********************************************************************\n") - fmt.Printf(" Database Proxy Session Started! \n") - fmt.Printf("----------------------------------------------------------------------\n") - fmt.Printf("Resource: %s\n", accessParams.ResourceName) - fmt.Printf("Account: %s\n", accessParams.AccountName) - fmt.Printf("\n") - fmt.Printf("You can now connect to your database using this connection string:\n") - - switch pamResponse.ResourceType { - case session.ResourceTypePostgres: - util.PrintfStderr("postgres://%s@localhost:%d/%s", username, proxy.port, database) - case session.ResourceTypeMysql: - util.PrintfStderr("mysql://%s@localhost:%d/%s", username, proxy.port, database) - case session.ResourceTypeMssql: - util.PrintfStderr("sqlserver://%s@localhost:%d?database=%s&encrypt=false&trustServerCertificate=true", username, proxy.port, database) - case session.ResourceTypeMongodb: - util.PrintfStderr("mongodb://localhost:%d/%s?serverSelectionTimeoutMS=15000", proxy.port, database) - case session.ResourceTypeOracledb: - util.PrintfStderr("%s/%s@localhost:%d/%s", username, oracle.ProxyPasswordPlaceholder, proxy.port, database) - util.PrintfStderr("\njdbc:oracle:thin:@localhost:%d/%s (user: %s, password: %s)", proxy.port, database, username, oracle.ProxyPasswordPlaceholder) - util.PrintfStderr("\n\nNote: the password shown is a protocol placeholder required by Oracle, not a secret.") - default: - util.PrintfStderr("localhost:%d", proxy.port) - } - util.PrintfStderr("\n**********************************************************************\n") - util.PrintfStderr("\n") - - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - - go func() { - sig := <-sigChan - log.Info().Msgf("Received signal %v, initiating graceful shutdown...", sig) - proxy.gracefulShutdown() - }() - - proxy.Run() -} - func (p *DatabaseProxyServer) Start(port int) error { var err error if port == 0 { @@ -170,21 +47,16 @@ func (p *DatabaseProxyServer) gracefulShutdown() { p.shutdownOnce.Do(func() { log.Info().Msg("Starting graceful shutdown of database proxy...") - // Send session termination notification before cancelling context p.NotifySessionTermination() - // Signal the accept loop to stop close(p.shutdownCh) - // Close the server to stop accepting new connections if p.server != nil { p.server.Close() } - // Cancel context to signal all goroutines to stop p.cancel() - // Wait for connections to close p.WaitForConnectionsWithTimeout(10 * time.Second) log.Info().Msg("Database proxy shutdown complete") @@ -204,7 +76,6 @@ func (p *DatabaseProxyServer) Run() { log.Info().Msg("Shutdown signal received, stopping proxy server") return default: - // Check if session has expired if time.Now().After(p.sessionExpiry) { log.Warn().Msg("Database session expired, shutting down proxy") p.gracefulShutdown() @@ -231,7 +102,6 @@ func (p *DatabaseProxyServer) Run() { } } - // Track active connection p.activeConnections.Add(1) go p.handleConnection(conn) } @@ -274,7 +144,6 @@ func (p *DatabaseProxyServer) handleConnection(clientConn net.Conn) { gatewayErrCh, clientErrCh := p.NewDisconnectChannels() - // Gateway → Client: if this side closes first, the gateway dropped the connection go func() { defer connCancel() _, err := io.Copy(clientConn, gatewayConn) @@ -288,7 +157,6 @@ func (p *DatabaseProxyServer) handleConnection(clientConn net.Conn) { gatewayErrCh <- err }() - // Client → Gateway: if this side closes first, the client disconnected normally go func() { defer connCancel() _, err := io.Copy(gatewayConn, clientConn)