Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 90 additions & 4 deletions codex-rs/app-server-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -868,8 +868,11 @@ mod tests {
use tokio::net::TcpListener;
use tokio::time::Duration;
use tokio::time::timeout;
use tokio_tungstenite::accept_async;
use tokio_tungstenite::accept_hdr_async;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::tungstenite::handshake::server::Request as WebSocketRequest;
use tokio_tungstenite::tungstenite::handshake::server::Response as WebSocketResponse;
use tokio_tungstenite::tungstenite::http::header::AUTHORIZATION;

async fn build_test_config() -> Config {
match ConfigBuilder::default().build().await {
Expand Down Expand Up @@ -908,6 +911,19 @@ mod tests {
}

async fn start_test_remote_server<F, Fut>(handler: F) -> String
where
F: FnOnce(tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>) -> Fut
+ Send
+ 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
start_test_remote_server_with_auth(None, handler).await
}

async fn start_test_remote_server_with_auth<F, Fut>(
expected_auth_token: Option<String>,
handler: F,
) -> String
where
F: FnOnce(tokio_tungstenite::WebSocketStream<tokio::net::TcpStream>) -> Fut
+ Send
Expand All @@ -920,9 +936,23 @@ mod tests {
let addr = listener.local_addr().expect("listener address");
tokio::spawn(async move {
let (stream, _) = listener.accept().await.expect("accept should succeed");
let websocket = accept_async(stream)
.await
.expect("websocket upgrade should succeed");
let websocket = accept_hdr_async(
stream,
move |request: &WebSocketRequest, response: WebSocketResponse| {
let provided_auth_token = request
.headers()
.get(AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.map(str::to_owned);
let expected_auth_token = expected_auth_token
.as_ref()
.map(|token| format!("Bearer {token}"));
assert_eq!(provided_auth_token, expected_auth_token);
Ok(response)
},
)
.await
.expect("websocket upgrade should succeed");
handler(websocket).await;
});
format!("ws://{addr}")
Expand Down Expand Up @@ -1037,6 +1067,7 @@ mod tests {
fn test_remote_connect_args(websocket_url: String) -> RemoteAppServerConnectArgs {
RemoteAppServerConnectArgs {
websocket_url,
auth_token: None,
client_name: "codex-app-server-client-test".to_string(),
client_version: "0.0.0-test".to_string(),
experimental_api: true,
Expand Down Expand Up @@ -1253,6 +1284,7 @@ mod tests {
}),
)
.await;
websocket.close(None).await.expect("close should succeed");
})
.await;
let client = RemoteAppServerClient::connect(test_remote_connect_args(websocket_url))
Expand All @@ -1273,6 +1305,59 @@ mod tests {
client.shutdown().await.expect("shutdown should complete");
}

#[tokio::test]
async fn remote_connect_includes_auth_header_when_configured() {
let auth_token = "remote-bearer-token".to_string();
let websocket_url = start_test_remote_server_with_auth(
Some(auth_token.clone()),
|mut websocket| async move {
expect_remote_initialize(&mut websocket).await;
websocket.close(None).await.expect("close should succeed");
},
)
.await;
let client = RemoteAppServerClient::connect(RemoteAppServerConnectArgs {
auth_token: Some(auth_token),
..test_remote_connect_args(websocket_url)
})
.await
.expect("remote client should connect");

client.shutdown().await.expect("shutdown should complete");
}

#[tokio::test]
async fn remote_connect_rejects_non_loopback_ws_when_auth_configured() {
let result = RemoteAppServerClient::connect(RemoteAppServerConnectArgs {
websocket_url: "ws://example.com:4500".to_string(),
auth_token: Some("remote-bearer-token".to_string()),
..test_remote_connect_args("ws://127.0.0.1:1".to_string())
})
.await;
let err = match result {
Ok(_) => panic!("non-loopback ws should be rejected before connect"),
Err(err) => err,
};
assert_eq!(err.kind(), ErrorKind::InvalidInput);
assert!(
err.to_string()
.contains("remote auth tokens require `wss://` or loopback `ws://` URLs")
);
}

#[test]
fn remote_auth_token_transport_policy_allows_wss_and_loopback_ws() {
assert!(crate::remote::websocket_url_supports_auth_token(
&url::Url::parse("wss://example.com:443").expect("wss URL should parse")
));
assert!(crate::remote::websocket_url_supports_auth_token(
&url::Url::parse("ws://127.0.0.1:4500").expect("loopback ws URL should parse")
));
assert!(!crate::remote::websocket_url_supports_auth_token(
&url::Url::parse("ws://example.com:4500").expect("non-loopback ws URL should parse")
));
}

#[tokio::test]
async fn remote_duplicate_request_id_keeps_original_waiter() {
let (first_request_seen_tx, first_request_seen_rx) = tokio::sync::oneshot::channel();
Expand Down Expand Up @@ -1425,6 +1510,7 @@ mod tests {
.await;
let mut client = RemoteAppServerClient::connect(RemoteAppServerConnectArgs {
websocket_url,
auth_token: None,
client_name: "codex-app-server-client-test".to_string(),
client_version: "0.0.0-test".to_string(),
experimental_api: true,
Expand Down
40 changes: 39 additions & 1 deletion codex-rs/app-server-client/src/remote.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::connect_async;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::tungstenite::http::HeaderValue;
use tokio_tungstenite::tungstenite::http::header::AUTHORIZATION;
use tracing::warn;
use url::Url;

Expand All @@ -57,6 +60,7 @@ const INITIALIZE_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Debug, Clone)]
pub struct RemoteAppServerConnectArgs {
pub websocket_url: String,
pub auth_token: Option<String>,
pub client_name: String,
pub client_version: String,
pub experimental_api: bool,
Expand Down Expand Up @@ -86,6 +90,16 @@ impl RemoteAppServerConnectArgs {
}
}

pub(crate) fn websocket_url_supports_auth_token(url: &Url) -> bool {
match (url.scheme(), url.host()) {
("wss", Some(_)) => true,
("ws", Some(url::Host::Domain(domain))) => domain.eq_ignore_ascii_case("localhost"),
("ws", Some(url::Host::Ipv4(addr))) => addr.is_loopback(),
("ws", Some(url::Host::Ipv6(addr))) => addr.is_loopback(),
_ => false,
}
}

enum RemoteClientCommand {
Request {
request: Box<ClientRequest>,
Expand Down Expand Up @@ -132,7 +146,31 @@ impl RemoteAppServerClient {
format!("invalid websocket URL `{websocket_url}`: {err}"),
)
})?;
let stream = timeout(CONNECT_TIMEOUT, connect_async(url.as_str()))
if args.auth_token.is_some() && !websocket_url_supports_auth_token(&url) {
return Err(IoError::new(
ErrorKind::InvalidInput,
format!(
"remote auth tokens require `wss://` or loopback `ws://` URLs; got `{websocket_url}`"
),
));
}
let mut request = url.as_str().into_client_request().map_err(|err| {
IoError::new(
ErrorKind::InvalidInput,
format!("invalid websocket URL `{websocket_url}`: {err}"),
)
})?;
if let Some(auth_token) = args.auth_token.as_deref() {
let header_value =
HeaderValue::from_str(&format!("Bearer {auth_token}")).map_err(|err| {
IoError::new(
ErrorKind::InvalidInput,
format!("invalid remote authorization header value: {err}"),
)
})?;
request.headers_mut().insert(AUTHORIZATION, header_value);
}
let stream = timeout(CONNECT_TIMEOUT, connect_async(request))
.await
.map_err(|_| {
IoError::new(
Expand Down
Loading
Loading