HDFS-17575. SaslDataTransferClient should use SaslParticipant to create messages. (#6954)

This commit is contained in:
Tsz-Wo Nicholas Sze 2024-08-05 10:42:12 -07:00 committed by GitHub
parent 59d5e0bb2e
commit b189ef8197
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 20 additions and 7 deletions

View File

@ -519,25 +519,25 @@ private IOStreamPair doSaslHandshake(InetAddress addr,
// In which case there will be no encrypted secret sent from NN. // In which case there will be no encrypted secret sent from NN.
BlockTokenIdentifier blockTokenIdentifier = BlockTokenIdentifier blockTokenIdentifier =
accessToken.decodeIdentifier(); accessToken.decodeIdentifier();
final byte[] first = sasl.createFirstMessage();
if (blockTokenIdentifier != null) { if (blockTokenIdentifier != null) {
byte[] handshakeSecret = byte[] handshakeSecret =
accessToken.decodeIdentifier().getHandshakeMsg(); accessToken.decodeIdentifier().getHandshakeMsg();
if (handshakeSecret == null || handshakeSecret.length == 0) { if (handshakeSecret == null || handshakeSecret.length == 0) {
LOG.debug("Handshake secret is null, " LOG.debug("Handshake secret is null, "
+ "sending without handshake secret."); + "sending without handshake secret.");
sendSaslMessage(out, new byte[0]); sendSaslMessage(out, first);
} else { } else {
LOG.debug("Sending handshake secret."); LOG.debug("Sending handshake secret.");
BlockTokenIdentifier identifier = new BlockTokenIdentifier(); BlockTokenIdentifier identifier = new BlockTokenIdentifier();
identifier.readFields(new DataInputStream( identifier.readFields(new DataInputStream(
new ByteArrayInputStream(accessToken.getIdentifier()))); new ByteArrayInputStream(accessToken.getIdentifier())));
String bpid = identifier.getBlockPoolId(); String bpid = identifier.getBlockPoolId();
sendSaslMessageHandshakeSecret(out, new byte[0], sendSaslMessageHandshakeSecret(out, first, handshakeSecret, bpid);
handshakeSecret, bpid);
} }
} else { } else {
LOG.debug("Block token id is null, sending without handshake secret."); LOG.debug("Block token id is null, sending without handshake secret.");
sendSaslMessage(out, new byte[0]); sendSaslMessage(out, first);
} }
// step 1 // step 1
@ -565,6 +565,7 @@ private IOStreamPair doSaslHandshake(InetAddress addr,
cipherOptions.add(option); cipherOptions.add(option);
} }
} }
LOG.debug("{}: cipherOptions={}", sasl, cipherOptions);
sendSaslMessageAndNegotiationCipherOptions(out, localResponse, sendSaslMessageAndNegotiationCipherOptions(out, localResponse,
cipherOptions); cipherOptions);

View File

@ -20,6 +20,7 @@
import java.io.DataInputStream; import java.io.DataInputStream;
import java.io.DataOutputStream; import java.io.DataOutputStream;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import javax.security.auth.callback.CallbackHandler; import javax.security.auth.callback.CallbackHandler;
import javax.security.sasl.Sasl; import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient; import javax.security.sasl.SaslClient;
@ -52,6 +53,7 @@ class SaslParticipant {
private static final String SERVER_NAME = "0"; private static final String SERVER_NAME = "0";
private static final String PROTOCOL = "hdfs"; private static final String PROTOCOL = "hdfs";
private static final String[] MECHANISM_ARRAY = {SaslConstants.SASL_MECHANISM}; private static final String[] MECHANISM_ARRAY = {SaslConstants.SASL_MECHANISM};
private static final byte[] EMPTY_BYTE_ARRAY = {};
// One of these will always be null. // One of these will always be null.
private final SaslServer saslServer; private final SaslServer saslServer;
@ -110,7 +112,7 @@ public static SaslParticipant createClientSaslParticipant(String userName,
* @param saslServer to wrap * @param saslServer to wrap
*/ */
private SaslParticipant(SaslServer saslServer) { private SaslParticipant(SaslServer saslServer) {
this.saslServer = saslServer; this.saslServer = Objects.requireNonNull(saslServer, "saslServer == null");
this.saslClient = null; this.saslClient = null;
} }
@ -121,7 +123,12 @@ private SaslParticipant(SaslServer saslServer) {
*/ */
private SaslParticipant(SaslClient saslClient) { private SaslParticipant(SaslClient saslClient) {
this.saslServer = null; this.saslServer = null;
this.saslClient = saslClient; this.saslClient = Objects.requireNonNull(saslClient, "saslClient == null");
}
byte[] createFirstMessage() throws SaslException {
return MECHANISM_ARRAY[0].equals(SaslConstants.SASL_MECHANISM_DEFAULT) ? EMPTY_BYTE_ARRAY
: evaluateChallengeOrResponse(EMPTY_BYTE_ARRAY);
} }
/** /**
@ -228,4 +235,9 @@ public IOStreamPair createStreamPair(DataOutputStream out,
new SaslOutputStream(out, saslServer)); new SaslOutputStream(out, saslServer));
} }
} }
@Override
public String toString() {
return "Sasl" + (saslServer != null? "Server" : "Client");
}
} }

View File

@ -77,7 +77,7 @@ public class TestSaslDataTransfer extends SaslDataTransferTestCase {
public ExpectedException exception = ExpectedException.none(); public ExpectedException exception = ExpectedException.none();
@Rule @Rule
public Timeout timeout = new Timeout(60000); public Timeout timeout = new Timeout(300_000);
@After @After
public void shutdown() { public void shutdown() {