Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/windows/common/relay.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,9 @@ class DockerIORelayHandle : public OverlappedIOHandle
class MultiHandleWait
{
public:
NON_COPYABLE(MultiHandleWait);
DEFAULT_MOVABLE(MultiHandleWait);

enum Flags
{
None = 0,
Expand Down
86 changes: 48 additions & 38 deletions src/windows/wslasession/WSLASession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ try

auto requestContext = m_dockerClient->PullImage(repo, tag);

relay::MultiHandleWait io;
auto io = CreateIOContext();

std::optional<boost::beast::http::status> pullResult;

Expand Down Expand Up @@ -315,7 +315,6 @@ try

auto onCompleted = [&]() { io.Cancel(); };

io.AddHandle(std::make_unique<relay::EventHandle>(m_sessionTerminatingEvent.get(), [&]() { THROW_HR(E_ABORT); }));
io.AddHandle(std::make_unique<DockerHTTPClient::DockerHttpResponseHandle>(
*requestContext, std::move(onHttpResponse), std::move(onChunk), std::move(onCompleted)));

Expand Down Expand Up @@ -393,7 +392,7 @@ try
ServiceProcessLauncher buildLauncher(buildArgs[0], buildArgs, {}, dockerfileFileHandle ? WSLAProcessFlagsStdin : WSLAProcessFlagsNone);
auto buildProcess = buildLauncher.Launch(*m_virtualMachine);

relay::MultiHandleWait io;
auto io = CreateIOContext();

if (dockerfileFileHandle)
{
Expand Down Expand Up @@ -469,8 +468,6 @@ try

io.AddHandle(std::make_unique<relay::LineBasedReadHandle>(buildProcess.GetStdHandle(2), captureOutput, false));

io.AddHandle(std::make_unique<relay::EventHandle>(m_sessionTerminatingEvent.get(), [&]() { THROW_HR(E_ABORT); }));

io.Run({});

int exitCode = buildProcess.Wait();
Expand Down Expand Up @@ -532,7 +529,7 @@ void WSLASession::ImportImageImpl(DockerHTTPClient::HTTPRequestContext& Request,

THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_dockerClient.has_value());

relay::MultiHandleWait io;
auto io = CreateIOContext();

std::optional<boost::beast::http::status> importResult;

Expand All @@ -559,15 +556,12 @@ void WSLASession::ImportImageImpl(DockerHTTPClient::HTTPRequestContext& Request,
}
};

auto onCompleted = [&]() { io.Cancel(); };

io.AddHandle(std::make_unique<relay::RelayHandle<relay::ReadHandle>>(
common::relay::HandleWrapper{std::move(imageFileHandle)}, common::relay::HandleWrapper{Request.stream.native_handle()}));

io.AddHandle(std::make_unique<relay::EventHandle>(m_sessionTerminatingEvent.get(), [&]() { THROW_HR(E_ABORT); }));

io.AddHandle(std::make_unique<DockerHTTPClient::DockerHttpResponseHandle>(
Request, std::move(onHttpResponse), std::move(onProgress), std::move(onCompleted)));
io.AddHandle(
std::make_unique<DockerHTTPClient::DockerHttpResponseHandle>(Request, std::move(onHttpResponse), std::move(onProgress)),
MultiHandleWait::CancelOnCompleted);

io.Run({});

Expand Down Expand Up @@ -607,29 +601,27 @@ void WSLASession::ExportContainerImpl(std::pair<uint32_t, wil::unique_socket>& S

THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_dockerClient.has_value());

relay::MultiHandleWait io;

auto onCompleted = [&]() {
io.Cancel();
WSL_LOG("OnCompletedCalledForExport", TraceLoggingValue("OnCompletedCalledForExport", "Content"));
};
auto io = CreateIOContext();

std::string errorJson;
auto accumulateError = [&](const gsl::span<char>& buffer) {
// If the export failed, accumulate the error message.
errorJson.append(buffer.data(), buffer.size());
};

if (SocketCodePair.first != 200)
{
io.AddHandle(std::make_unique<relay::ReadHandle>(common::relay::HandleWrapper{std::move(SocketCodePair.second)}, std::move(accumulateError)));
auto accumulateError = [&](const gsl::span<char>& buffer) {
// If the export failed, accumulate the error message.
errorJson.append(buffer.data(), buffer.size());
};

io.AddHandle(
std::make_unique<relay::ReadHandle>(common::relay::HandleWrapper{std::move(SocketCodePair.second)}, std::move(accumulateError)),
MultiHandleWait::CancelOnCompleted);
}
else
{
io.AddHandle(std::make_unique<relay::RelayHandle<relay::HTTPChunkBasedReadHandle>>(
common::relay::HandleWrapper{std::move(SocketCodePair.second)},
common::relay::HandleWrapper{std::move(containerFileHandle), std::move(onCompleted)}));
io.AddHandle(std::make_unique<relay::EventHandle>(m_sessionTerminatingEvent.get(), [&]() { THROW_HR(E_ABORT); }));
io.AddHandle(
std::make_unique<relay::RelayHandle<relay::HTTPChunkBasedReadHandle>>(
common::relay::HandleWrapper{std::move(SocketCodePair.second)}, common::relay::HandleWrapper{std::move(containerFileHandle)}),
MultiHandleWait::CancelOnCompleted);
}

io.Run({});
Expand Down Expand Up @@ -668,25 +660,27 @@ void WSLASession::SaveImageImpl(std::pair<uint32_t, wil::unique_socket>& SocketC

THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_dockerClient.has_value());

relay::MultiHandleWait io;
auto io = CreateIOContext();

auto onCompleted = [&]() { io.Cancel(); };
std::string errorJson;
auto accumulateError = [&](const gsl::span<char>& buffer) {
// If the save failed, accumulate the error message.
errorJson.append(buffer.data(), buffer.size());
};

if (SocketCodePair.first != 200)
{
io.AddHandle(std::make_unique<relay::ReadHandle>(common::relay::HandleWrapper{std::move(SocketCodePair.second)}, std::move(accumulateError)));
auto accumulateError = [&](const gsl::span<char>& buffer) {
// If the save failed, accumulate the error message.
errorJson.append(buffer.data(), buffer.size());
};

io.AddHandle(
std::make_unique<relay::ReadHandle>(common::relay::HandleWrapper{std::move(SocketCodePair.second)}, std::move(accumulateError)),
MultiHandleWait::CancelOnCompleted);
}
else
{
io.AddHandle(std::make_unique<relay::RelayHandle<relay::HTTPChunkBasedReadHandle>>(
common::relay::HandleWrapper{std::move(SocketCodePair.second)},
common::relay::HandleWrapper{std::move(imageFileHandle), std::move(onCompleted)}));
io.AddHandle(std::make_unique<relay::EventHandle>(m_sessionTerminatingEvent.get(), [&]() { THROW_HR(E_ABORT); }));
io.AddHandle(
std::make_unique<relay::RelayHandle<relay::HTTPChunkBasedReadHandle>>(
common::relay::HandleWrapper{std::move(SocketCodePair.second)}, common::relay::HandleWrapper{std::move(imageFileHandle)}),
MultiHandleWait::CancelOnCompleted);
}

io.Run({});
Expand Down Expand Up @@ -1127,6 +1121,22 @@ HRESULT WSLASession::InterfaceSupportsErrorInfo(REFIID riid)
return riid == __uuidof(IWSLASession) ? S_OK : S_FALSE;
}

// TODO consider allowing callers to pass cancellation handles.
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO suggests allowing callers to pass custom cancellation handles. This would be a useful enhancement, but consider whether the current design is flexible enough. For example, if a caller needs different cancellation behavior (e.g., only cancel on session termination but not on client exit), they would have no way to achieve this with the current design. Consider adding an optional parameter to CreateIOContext() to allow customization of which handles to add.

Copilot uses AI. Check for mistakes.
MultiHandleWait WSLASession::CreateIOContext()
{
relay::MultiHandleWait io;

// Cancel with E_ABORT if the session is terminating.
io.AddHandle(std::make_unique<relay::EventHandle>(
m_sessionTerminatingEvent.get(), [this]() { THROW_HR_MSG(E_ABORT, "Session %lu is terminating", m_id); }));

// Cancel with E_ABORT if the client process exits.
io.AddHandle(std::make_unique<relay::EventHandle>(
wslutil::OpenCallingProcess(SYNCHRONIZE), [this]() { THROW_HR_MSG(E_ABORT, "Client process has exited"); }));

return io;
}

void WSLASession::OnContainerDeleted(const WSLAContainerImpl* Container)
{
std::lock_guard lock{m_lock};
Expand Down
2 changes: 2 additions & 0 deletions src/windows/wslasession/WSLASession.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession
IFACEMETHOD(MapVmPort)(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort) override;
IFACEMETHOD(UnmapVmPort)(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort) override;

common::relay::MultiHandleWait CreateIOContext();

private:
ULONG m_id = 0;

Expand Down
Loading