IAM Role Authentication for Postgres RDS using Python and Go

In this article I will be discussing how to connect to Postgres RDS from Python using django and also using Go when using IAM authentication.

AWS supports IAM Roles to authenticate with RDS instead of conventional username password. We can create IAM policies for RDS and attach it to an EC2 instance and the EC2 will be able to connect with RDS using an IAM token instead of the password. In this article we will talk about how to periodically refresh the token when using pgx as the postgres driver in go. Or how to use/refresh IAM tokens when using django and python.

In this blog I will not go into the details of creating and attaching an IAM role/policy, there are a lot of resources available for this.

What is IAM role and IAM token

With IAM database authentication, you use an authentication token when you connect to your DB instance. An authentication token is a string of characters that you use instead of a password. After you generate an authentication token, it's valid for 15 minutes before it expires. If you try to connect using an expired token, the connection request is denied.

Network traffic to and from the database is encrypted using Transport Layer Security (TLS).

Using this authentication we can avoid keeping the password in the code/config. The AWS SDK will programmatically create and sign an authentication token.

How does the IAM token work

The AWS SDK for go in case of Golang or boto3 in case of Python is used to generate a new token every time a db connection needs to be made. Important thing to note is that the token is generated on the client side. No network call is made to fetch the token from AWS. So it is completely okay to generate a new token every time a new connection is made.

How this is achieved is using AWS's Signature v4 signing process . AWS SDK signs the token with the access key ID and the secret access key.

We use this token to connect to Postgres instead of the password.

How can this be used to connect to Postgres in go/python

AWS has quite good documentation for connecting with Postgres/RDS using IAM tokens instead of passwords in go as well as in python. But when I was trying to implement this, the problem I faced was that we used client side database pools. It would be fair to assume that most real-world production systems leverage pooling.

Database Pooling

Database connection pooling is a method used to keep database connections open so they can be reused by others.

Typically, opening a database connection is an expensive operation. You have to open up network sessions, authenticate and so on. Pooling keeps the connections active so that, when a connection is later requested, one of the active ones is used in preference to having to create another one.

Now this adds an additional layer of complexity. The pools are created lazily. That means if a pool has 50 connections, all the connections are not stood up when the app starts, the connections are added to the pool as and when queries/transactions arrive. Now if we were to reuse the IAM token, we would have to keep track of them. So we would need a timer to refresh the token every 15 mins. But we cannot really close the connections within a pool as these are self-managing pools. So we would have to close the pools and re-initialize everything again. You can see where I am going with this. This approach seems very hacky and is not elegant.

Go implementation

Fortunately pgx has recently added a BeforeConnect hook to its API. We can use this hook to modify the db credentials before any connection is made. This way every connection in the pool will have its own IAM token and as established earlier, tokens are created on the client side so there is no network call to "fetch" the token.

But how do we refresh the token after 15 mins? We don't! We can set the connection lifetime to 14 mins, so the connection is closed after 14 mins and the new connection that comes up, is created using a new token. This way we can inject our token for every connection. This approach also works because we are not closing the complete pool, we only close one connection at a time( because the connection lifetime is per connection). And because each connection is established lazily and has its own token, the connections are created and closed at different times, so the pool should always have some connections for use.

// peer contains the connection pool
// IAMRoleAuth is the flag to toggle IAM role based authentication
type peer struct {
    name        string
    dbPool      *pgxpool.Pool
    weight      int
    logger      log.Logger
    mu          sync.Mutex
    IAMRoleAuth bool
}

type DBConfig struct {
    Host                string
    Port                int
    User                string
    Password            string
    SSLMode             string
    Name                string
    MinConn             int
    MaxConn             int
    LifeTime            string
    IdleTime            string
    LogLevel            string
    Region              string
    IAMRoleAuth         bool
}

We create a session struct, and pass this on to create a new IAM token.

import "github.com/aws/aws-sdk-go/aws/session"

opts := session.Options{Config: aws.Config{
        CredentialsChainVerboseErrors: aws.Bool(true),
        Region:                        aws.String("us-east-1"),
        MaxRetries:                    aws.Int(3),
    }}
    sess := session.Must(session.NewSessionWithOptions(opts))

// getDBPool returns a new pgxpool.Pool instance.
// If peer IAMRoleAuth is true then BeforeConnect method is implemented
// BeforeConnect() is used to inject the authToken before a connection is made.
// Connection `LifeTime` is set to 14 mins, hence the connection will expire automatically and no intervention is needed to close the connection.
func (p *peer) getDBPool(ctx context.Context, cfg DBConfig, sess *session.Session) (*pgxpool.Pool, error) {

    poolCfg, err := pgxpool.ParseConfig(getDBURL(cfg))
    if err != nil {
        p.logger.Err(err).Msgf("unable to parse config for peer: %v cfg: %v", cfg.Name, cfg)
        return nil, err
    }

    if p.IAMRoleAuth {
        poolCfg.BeforeConnect = func(ctx context.Context, config *pgx.ConnConfig) error {
            p.logger.Info().Msg("RDS Credential beforeConnect(), creating new credential")
            newPassword, err := p.getCredential(poolCfg, cfg, sess)
            if err != nil {
                return err
            }
            p.mu.Lock()
            config.Password = newPassword
            p.mu.Unlock()
            return nil
        }
    }

    pool, err := pgxpool.ConnectConfig(ctx, poolCfg)
    if err != nil {
        p.logger.Err(err).Msg("unable to connect to db")
        return nil, err
    }

    return pool, nil
}

// getCredential returns the new password to connect to RDS
func (p *peer) getCredential(poolCfg *pgxpool.Config, cfg DBConfig, sess *session.Session) (string, error) {
    dbEndpoint := fmt.Sprintf("%s:%d", poolCfg.ConnConfig.Host, poolCfg.ConnConfig.Port)
    awsRegion := cfg.Region
    dbUser := poolCfg.ConnConfig.User
    authToken, err := rdsutils.BuildAuthToken(dbEndpoint, awsRegion, dbUser, sess.Config.Credentials)
    if err != nil {
        p.logger.Panic().Err(err).Msg("Error in building auth token to connect with RDS")
        return "", err
    }
    return authToken, nil
}

If you are using sqlx as the DB driver, then things get trickier. As sqlx does not provide the BeforeConnect() hook to inject the token. I have not found any elegant way to do this with sqlx yet.

Django implementation

Django is a very popular web framework in python.

Unlike the go implementation, django does not use database pools, instead django uses persistent connections. This is different from pooling because django maintains one persistent connection per thread. In pooling, the pool takes care of opening and closing connections and every thread that needs a connection requests one from the pool.

django uses postgres database backend which internally uses Psycopg as the database adapter. We can create our custom database backend by extending the postgres backend.

def get_aws_connection_params(params):
    """
    "rds_iam_auth" : Ifs set to `True` IAM authentication is used, else password based authentication is used
    "region_name" : Contains the name of the aws region where DB is present.

    Parameters
    ----------
        params : dict
            The DATABASES dict that is passed from the settings.py
    """
    enabled = params.pop("rds_iam_auth", False)
    if enabled:
        region_name = params.pop("region_name", None)
        rds_client = boto3.client(service_name="rds", region_name=region_name)

        hostname = params.get("host")
        hostname = hostname if hostname else "localhost"

        params["password"] = rds_client.generate_db_auth_token(
            DBHostname=hostname,
            Port=params.get("port", 5432),
            DBUsername=params.get("user", getpass.getuser()),
        )

    return params

class DatabaseWrapper(base.DatabaseWrapper):
    def get_connection_params(self):
        params = super().get_connection_params()
        params.setdefault("port", 5432)

        return get_aws_connection_params(params)

In the settings.py we can use this backend as

DATABASES = {
    "default": {
        "ENGINE": "path.to.backend.aws.postgres",
        "NAME": "postgres",
        "USER": "user",
        "PASSWORD": "welcome",
        "HOST": "localhost",
        "PORT": "5432",
        "OPTIONS": {
            "rds_iam_auth": True,
            "region_name": "us-east-1",
            "sslmode": "verify-full",
            "sslrootcert": "/path/to/rds-ca-2019-root.pem",
        },
    },
}
  • Another thing to note in both the above approaches, is that the hostname needs to be used to connect with the RDS PostgreSQL, if a DNS CNAME is used, the connection will fail with PEM authentication error.. This happens because the IAM token is signed for the CNAME, but RDS expects the hostname. One work-around is to resolve the cname into the hostname in the custom backend, before making the connection. I would advise against this, because CNAMEs are cached at multiple places and have a certain TTL, the connection would fail if the CNAME has not propagated yet.

References

Did you find this article valuable?

Support Mayank Thakur by becoming a sponsor. Any amount is appreciated!