'Derive tenant id for signed in user in OAuth2 based multi-tenant Spring Boot application using Spring Security

I am working on a Spring Boot B2B application, where I would like to onboard a considerable number of tenants. Each of the tenants has its own authentication providers. I have decided to support only OAuth2-based authentication.

As of now, my application has been single-tenant, that's why I did not have the need to derive tenant id for any user. But with multi-tenancy, I need to serve the resources based on the tenant, the user belongs, what role the user has for the given tenant, etc. In other words, each and every flow in my application is going to be dependent on the tenant id information of a user.

In order to achieve the same, I have added tenantId column to the existing User entity as follows:

@Entity
@Table(name = "user")
public class User implements Serializable {

  private static final long serialVersionUID = 1L;

  @Id
  private String id;

  @NotNull
  private String login;

  @Column(name = "first_name", length = 50)
  private String firstName;

  @Column(name = "last_name", length = 50)
  private String lastName;

  @Email
  @Column(length = 254, unique = true)
  private String email;

  ...

  @Column(length = 254, unique = true)
  private String tenantId;

  ...
}

Now, I am deriving and capturing tenant id information for a given user at the time of signup to my application as follows:

  User getUser(AbstractAuthenticationToken authToken) { <---THIS FUNCTION GETS CALLED WHILE USER SIGNS UP FOR THE APPLICATION
    Map<String, Object> attributes;
    User user = null;
    if (authToken instanceof OAuth2AuthenticationToken) {
      attributes = ((OAuth2AuthenticationToken) authToken).getPrincipal().getAttributes();
    } else if (authToken instanceof JwtAuthenticationToken) {
      attributes = ((JwtAuthenticationToken) authToken).getTokenAttributes();
    }  else {
      throw new IllegalArgumentException("AuthenticationToken is not OAuth2 or JWT!");
    }
    if (user == null) {
      user = getUser(attributes);
    }
    return Pair.of(user, attributes);
  }

  private static User getUser(Map<String, Object> details) {
    User user = new User();
    if (details.get("uid") != null) {
       user.setId((String) details.get("uid"));
       user.setLogin((String) details.get("sub"));
    } else if (details.get("user_id") != null) {
       user.setId((String) details.get("user_id"));
    } else {
       user.setId((String) details.get("sub"));
    }
    if (details.get("email") != null) {
       user.setEmail(((String) details.get("email")).toLowerCase());
    } else {
      user.setEmail((String) details.get("sub"));
    }
    user.setTenantId(getTenantIdFromEmail(user.getEmail()));
    ...
  }

  private getTenantIdFromEmail(String email) {
       String subDomain = getSubDomain(email);
       return tenantRepository.findBySubDomainName(subDomain).getTenantId();
  }

Now, whenever I need tenant id information for a signed-in user, I can do like below:

 public String getTenantId() {
    Optional<User> signedInUserOptional =
        userRepository.findByLogin(SecurityUtils.getCurrentUserLogin();
    return signedInUserOptional.isPresent() ? signedInUserOptional.get().getTenantId() : null;
  }

Here, I have two questions as follows:

  1. I am not sure if the code to derive tenant id would work correctly across different Ouath2 authentication providers of multiple tenants, if not, could anyone please help here? (by different authentication providers, I meant, each of the tenants may have their own internal authentication providers)
  2. In almost each application flow (for example, doing any CRUD operation on a resource), I have to derive tenant id information as mentioned above. Is there any way to make this a little efficient i.e. compute it once for a signed-in user and then pass it across the entire application flow for the user session?


Solution 1:[1]

Consider using ThreadLocal variables and set this during successful authentication ex:OAuthFilter in your case.

public class TenantIdContextHolder {

    private static final ThreadLocal<String> contextHolder = new ThreadLocal<>();

    public static void setCurrentTenant(String tenantId) {
        contextHolder.set(tenantId);
    }

    public static String getCurrentTenant() {
        return contextHolder.get();
    }


    public static void clear() {
        contextHolder.remove();
    }
}

Usage

private static User getUser(Map<String, Object> details) {
    User user = new User();
    if (details.get("uid") != null) {
       user.setId((String) details.get("uid"));
       user.setLogin((String) details.get("sub"));
    } else if (details.get("user_id") != null) {
       user.setId((String) details.get("user_id"));
    } else {
       user.setId((String) details.get("sub"));
    }
    if (details.get("email") != null) {
       user.setEmail(((String) details.get("email")).toLowerCase());
    } else {
      user.setEmail((String) details.get("sub"));
    }
    TenantIdContextHolder.setCurrentTenant(getTenantIdFromEmail(user.getEmail()));
    ...
  }

Getting current tenant anytime

String currentTenant = TenantIdContextHolder.getCurrentTenant();

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1