Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions client/trino-cli/src/main/java/io/trino/cli/Console.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.cli;

import com.google.common.base.CharMatcher;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.ByteStreams;
import io.airlift.units.Duration;
Expand Down Expand Up @@ -388,6 +389,17 @@ private static boolean process(
builder = builder.path(query.getSetPath().get());
}

// update authorization user if present
if (query.getSetAuthorizationUser().isPresent()) {
builder = builder.authorizationUser(query.getSetAuthorizationUser());
builder = builder.roles(ImmutableMap.of());
}

if (query.isResetAuthorizationUser()) {
builder = builder.authorizationUser(Optional.empty());
builder = builder.roles(ImmutableMap.of());
}

// update session properties if present
if (!query.getSetSessionProperties().isEmpty() || !query.getResetSessionProperties().isEmpty()) {
Map<String, String> sessionProperties = new HashMap<>(session.getProperties());
Expand Down
10 changes: 10 additions & 0 deletions client/trino-cli/src/main/java/io/trino/cli/Query.java
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@ public Optional<String> getSetPath()
return client.getSetPath();
}

public Optional<String> getSetAuthorizationUser()
{
return client.getSetAuthorizationUser();
}

public boolean isResetAuthorizationUser()
{
return client.isResetAuthorizationUser();
}

public Map<String, String> getSetSessionProperties()
{
return client.getSetSessionProperties();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class ClientSession
private final URI server;
private final Optional<String> principal;
private final Optional<String> user;
private final Optional<String> authorizationUser;
private final String source;
private final Optional<String> traceToken;
private final Set<String> clientTags;
Expand Down Expand Up @@ -75,6 +76,7 @@ private ClientSession(
URI server,
Optional<String> principal,
Optional<String> user,
Optional<String> authorizationUser,
String source,
Optional<String> traceToken,
Set<String> clientTags,
Expand All @@ -96,6 +98,7 @@ private ClientSession(
this.server = requireNonNull(server, "server is null");
this.principal = requireNonNull(principal, "principal is null");
this.user = requireNonNull(user, "user is null");
this.authorizationUser = requireNonNull(authorizationUser, "authorizationUser is null");
this.source = source;
this.traceToken = requireNonNull(traceToken, "traceToken is null");
this.clientTags = ImmutableSet.copyOf(requireNonNull(clientTags, "clientTags is null"));
Expand Down Expand Up @@ -158,6 +161,11 @@ public Optional<String> getUser()
return user;
}

public Optional<String> getAuthorizationUser()
{
return authorizationUser;
}

public String getSource()
{
return source;
Expand Down Expand Up @@ -258,6 +266,7 @@ public String toString()
.add("server", server)
.add("principal", principal)
.add("user", user)
.add("authorizationUser", authorizationUser)
.add("clientTags", clientTags)
.add("clientInfo", clientInfo)
.add("catalog", catalog)
Expand All @@ -277,6 +286,7 @@ public static final class Builder
private URI server;
private Optional<String> principal = Optional.empty();
private Optional<String> user = Optional.empty();
private Optional<String> authorizationUser = Optional.empty();
private String source;
private Optional<String> traceToken = Optional.empty();
private Set<String> clientTags = ImmutableSet.of();
Expand All @@ -303,6 +313,7 @@ private Builder(ClientSession clientSession)
server = clientSession.getServer();
principal = clientSession.getPrincipal();
user = clientSession.getUser();
authorizationUser = clientSession.getAuthorizationUser();
source = clientSession.getSource();
traceToken = clientSession.getTraceToken();
clientTags = clientSession.getClientTags();
Expand Down Expand Up @@ -334,6 +345,12 @@ public Builder user(Optional<String> user)
return this;
}

public Builder authorizationUser(Optional<String> authorizationUser)
{
this.authorizationUser = authorizationUser;
return this;
}

public Builder principal(Optional<String> principal)
{
this.principal = principal;
Expand Down Expand Up @@ -448,6 +465,7 @@ public ClientSession build()
server,
principal,
user,
authorizationUser,
source,
traceToken,
clientTags,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public final class ProtocolHeaders

private final String name;
private final String requestUser;
private final String requestOriginalUser;
private final String requestSource;
private final String requestCatalog;
private final String requestSchema;
Expand All @@ -52,6 +53,8 @@ public final class ProtocolHeaders
private final String responseDeallocatedPrepare;
private final String responseStartedTransactionId;
private final String responseClearTransactionId;
private final String responseSetAuthorizationUser;
private final String responseResetAuthorizationUser;

public static ProtocolHeaders createProtocolHeaders(String name)
{
Expand All @@ -69,6 +72,7 @@ private ProtocolHeaders(String name)
this.name = name;
String prefix = "X-" + name + "-";
requestUser = prefix + "User";
requestOriginalUser = prefix + "Original-User";
requestSource = prefix + "Source";
requestCatalog = prefix + "Catalog";
requestSchema = prefix + "Schema";
Expand All @@ -95,6 +99,8 @@ private ProtocolHeaders(String name)
responseDeallocatedPrepare = prefix + "Deallocated-Prepare";
responseStartedTransactionId = prefix + "Started-Transaction-Id";
responseClearTransactionId = prefix + "Clear-Transaction-Id";
responseSetAuthorizationUser = prefix + "Set-Authorization-User";
responseResetAuthorizationUser = prefix + "Reset-Authorization-User";
}

public String getProtocolName()
Expand All @@ -107,6 +113,11 @@ public String requestUser()
return requestUser;
}

public String requestOriginalUser()
{
return requestOriginalUser;
}

public String requestSource()
{
return requestSource;
Expand Down Expand Up @@ -237,6 +248,16 @@ public String responseClearTransactionId()
return responseClearTransactionId;
}

public String responseSetAuthorizationUser()
{
return responseSetAuthorizationUser;
}

public String responseResetAuthorizationUser()
{
return responseResetAuthorizationUser;
}

public static ProtocolHeaders detectProtocol(Optional<String> alternateHeaderName, Set<String> headerNames)
throws ProtocolDetectionException
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ public interface StatementClient

Optional<String> getSetPath();

Optional<String> getSetAuthorizationUser();

boolean isResetAuthorizationUser();

Map<String, String> getSetSessionProperties();

Set<String> getResetSessionProperties();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ class StatementClientV1
private final AtomicReference<String> setCatalog = new AtomicReference<>();
private final AtomicReference<String> setSchema = new AtomicReference<>();
private final AtomicReference<String> setPath = new AtomicReference<>();
private final AtomicReference<String> setAuthorizationUser = new AtomicReference<>();
private final AtomicBoolean resetAuthorizationUser = new AtomicBoolean();
private final Map<String, String> setSessionProperties = new ConcurrentHashMap<>();
private final Set<String> resetSessionProperties = Sets.newConcurrentHashSet();
private final Map<String, ClientSelectedRole> setRoles = new ConcurrentHashMap<>();
Expand All @@ -89,6 +91,7 @@ class StatementClientV1
private final ZoneId timeZone;
private final Duration requestTimeoutNanos;
private final Optional<String> user;
private final Optional<String> originalUser;
private final String clientCapabilities;
private final boolean compressionDisabled;

Expand All @@ -104,7 +107,11 @@ public StatementClientV1(Call.Factory httpCallFactory, ClientSession session, St
this.timeZone = session.getTimeZone();
this.query = query;
this.requestTimeoutNanos = session.getClientRequestTimeout();
this.user = Stream.of(session.getUser(), session.getPrincipal())
this.user = Stream.of(session.getAuthorizationUser(), session.getUser(), session.getPrincipal())
.filter(Optional::isPresent)
.map(Optional::get)
.findFirst();
this.originalUser = Stream.of(session.getUser(), session.getPrincipal())
.filter(Optional::isPresent)
.map(Optional::get)
.findFirst();
Expand Down Expand Up @@ -270,6 +277,18 @@ public Optional<String> getSetPath()
return Optional.ofNullable(setPath.get());
}

@Override
public Optional<String> getSetAuthorizationUser()
{
return Optional.ofNullable(setAuthorizationUser.get());
}

@Override
public boolean isResetAuthorizationUser()
{
return resetAuthorizationUser.get();
}

@Override
public Map<String, String> getSetSessionProperties()
{
Expand Down Expand Up @@ -319,6 +338,7 @@ private Request.Builder prepareRequest(HttpUrl url)
.addHeader(USER_AGENT, USER_AGENT_VALUE)
.url(url);
user.ifPresent(requestUser -> builder.addHeader(TRINO_HEADERS.requestUser(), requestUser));
originalUser.ifPresent(originalUser -> builder.addHeader(TRINO_HEADERS.requestOriginalUser(), originalUser));
if (compressionDisabled) {
builder.header(ACCEPT_ENCODING, "identity");
}
Expand Down Expand Up @@ -399,6 +419,16 @@ private void processResponse(Headers headers, QueryResults results)
setSchema.set(headers.get(TRINO_HEADERS.responseSetSchema()));
setPath.set(headers.get(TRINO_HEADERS.responseSetPath()));

String setAuthorizationUser = headers.get(TRINO_HEADERS.responseSetAuthorizationUser());
if (setAuthorizationUser != null) {
this.setAuthorizationUser.set(setAuthorizationUser);
}

String resetAuthorizationUser = headers.get(TRINO_HEADERS.responseResetAuthorizationUser());
if (resetAuthorizationUser != null) {
this.resetAuthorizationUser.set(Boolean.parseBoolean(resetAuthorizationUser));
}

for (String setSession : headers.values(TRINO_HEADERS.responseSetSession())) {
List<String> keyValue = COLLECTION_HEADER_SPLITTER.splitToList(setSession);
if (keyValue.size() != 2) {
Expand Down
17 changes: 17 additions & 0 deletions client/trino-jdbc/src/main/java/io/trino/jdbc/TrinoConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public class TrinoConnection
private final AtomicReference<String> catalog = new AtomicReference<>();
private final AtomicReference<String> schema = new AtomicReference<>();
private final AtomicReference<String> path = new AtomicReference<>();
private final AtomicReference<String> authorizationUser = new AtomicReference<>();
private final AtomicReference<ZoneId> timeZoneId = new AtomicReference<>();
private final AtomicReference<Locale> locale = new AtomicReference<>();
private final AtomicReference<Integer> networkTimeoutMillis = new AtomicReference<>(Ints.saturatedCast(MINUTES.toMillis(2)));
Expand Down Expand Up @@ -746,6 +747,7 @@ StatementClient startQuery(String sql, Map<String, String> sessionPropertiesOver
.server(httpUri)
.principal(user)
.user(sessionUser.get())
.authorizationUser(Optional.ofNullable(authorizationUser.get()))
.source(source)
.traceToken(Optional.ofNullable(clientInfo.get(TRACE_TOKEN)))
.clientTags(ImmutableSet.copyOf(clientTags))
Expand Down Expand Up @@ -781,6 +783,15 @@ void updateSession(StatementClient client)
client.getSetSchema().ifPresent(schema::set);
client.getSetPath().ifPresent(path::set);

if (client.getSetAuthorizationUser().isPresent()) {
authorizationUser.set(client.getSetAuthorizationUser().get());
roles.clear();
}
if (client.isResetAuthorizationUser()) {
authorizationUser.set(null);
roles.clear();
}

if (client.getStartedTransactionId() != null) {
transactionId.set(client.getStartedTransactionId());
}
Expand Down Expand Up @@ -810,6 +821,12 @@ int activeStatements()
return statements.size();
}

@VisibleForTesting
String getAuthorizationUser()
{
return authorizationUser.get();
}

private void checkOpen()
throws SQLException
{
Expand Down
54 changes: 54 additions & 0 deletions client/trino-jdbc/src/test/java/io/trino/jdbc/TestTrinoDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,47 @@ public void testCustomDnsResolver()
}
}

@Test(timeOut = 10000)
public void testResetSessionAuthorization()
throws Exception
{
try (TrinoConnection connection = createConnection("blackhole", "blackhole").unwrap(TrinoConnection.class);
Statement statement = connection.createStatement()) {
assertEquals(connection.getAuthorizationUser(), null);
assertEquals(getCurrentUser(connection), "test");
statement.execute("SET SESSION AUTHORIZATION john");
assertEquals(connection.getAuthorizationUser(), "john");
assertEquals(getCurrentUser(connection), "john");
statement.execute("SET SESSION AUTHORIZATION bob");
assertEquals(connection.getAuthorizationUser(), "bob");
assertEquals(getCurrentUser(connection), "bob");
statement.execute("RESET SESSION AUTHORIZATION");
assertEquals(connection.getAuthorizationUser(), null);
assertEquals(getCurrentUser(connection), "test");
}
}

@Test(timeOut = 10000)
public void testSetRoleAfterSetSessionAuthorization()
throws Exception
{
try (TrinoConnection connection = createConnection("blackhole", "blackhole").unwrap(TrinoConnection.class);
Statement statement = connection.createStatement()) {
statement.execute("SET SESSION AUTHORIZATION john");
assertEquals(connection.getAuthorizationUser(), "john");
statement.execute("SET ROLE ALL");
assertEquals(connection.getRoles(), ImmutableMap.of("system", new ClientSelectedRole(ClientSelectedRole.Type.ALL, Optional.empty())));
statement.execute("SET SESSION AUTHORIZATION bob");
assertEquals(connection.getAuthorizationUser(), "bob");
assertEquals(connection.getRoles(), ImmutableMap.of());
statement.execute("SET ROLE NONE");
assertEquals(connection.getRoles(), ImmutableMap.of("system", new ClientSelectedRole(ClientSelectedRole.Type.NONE, Optional.empty())));
statement.execute("RESET SESSION AUTHORIZATION");
assertEquals(connection.getAuthorizationUser(), null);
assertEquals(connection.getRoles(), ImmutableMap.of());
}
}

private QueryState getQueryState(String queryId)
throws SQLException
{
Expand Down Expand Up @@ -1166,6 +1207,19 @@ private static Properties toProperties(Map<String, String> map)
return properties;
}

private static String getCurrentUser(Connection connection)
throws SQLException
{
try (Statement statement = connection.createStatement();
ResultSet rs = statement.executeQuery("SELECT current_user")) {
while (rs.next()) {
return rs.getString(1);
}
}

throw new RuntimeException("Failed to get CURRENT_USER");
}

public static class TestingDnsResolver
implements DnsResolver
{
Expand Down
Loading