diff --git a/src/windows/common/relay.hpp b/src/windows/common/relay.hpp index ef11d8721..75fb1a19c 100644 --- a/src/windows/common/relay.hpp +++ b/src/windows/common/relay.hpp @@ -492,6 +492,9 @@ class DockerIORelayHandle : public OverlappedIOHandle class MultiHandleWait { public: + NON_COPYABLE(MultiHandleWait); + DEFAULT_MOVABLE(MultiHandleWait); + enum Flags { None = 0, diff --git a/src/windows/wslasession/WSLASession.cpp b/src/windows/wslasession/WSLASession.cpp index f39df81b4..8cc75d54c 100644 --- a/src/windows/wslasession/WSLASession.cpp +++ b/src/windows/wslasession/WSLASession.cpp @@ -280,7 +280,7 @@ try auto requestContext = m_dockerClient->PullImage(repo, tag); - relay::MultiHandleWait io; + auto io = CreateIOContext(); std::optional pullResult; @@ -315,7 +315,6 @@ try auto onCompleted = [&]() { io.Cancel(); }; - io.AddHandle(std::make_unique(m_sessionTerminatingEvent.get(), [&]() { THROW_HR(E_ABORT); })); io.AddHandle(std::make_unique( *requestContext, std::move(onHttpResponse), std::move(onChunk), std::move(onCompleted))); @@ -393,7 +392,7 @@ try ServiceProcessLauncher buildLauncher(buildArgs[0], buildArgs, {}, dockerfileFileHandle ? WSLAProcessFlagsStdin : WSLAProcessFlagsNone); auto buildProcess = buildLauncher.Launch(*m_virtualMachine); - relay::MultiHandleWait io; + auto io = CreateIOContext(); if (dockerfileFileHandle) { @@ -469,8 +468,6 @@ try io.AddHandle(std::make_unique(buildProcess.GetStdHandle(2), captureOutput, false)); - io.AddHandle(std::make_unique(m_sessionTerminatingEvent.get(), [&]() { THROW_HR(E_ABORT); })); - io.Run({}); int exitCode = buildProcess.Wait(); @@ -532,7 +529,7 @@ void WSLASession::ImportImageImpl(DockerHTTPClient::HTTPRequestContext& Request, THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_dockerClient.has_value()); - relay::MultiHandleWait io; + auto io = CreateIOContext(); std::optional importResult; @@ -559,15 +556,12 @@ void WSLASession::ImportImageImpl(DockerHTTPClient::HTTPRequestContext& Request, } }; - auto onCompleted = [&]() { io.Cancel(); }; - io.AddHandle(std::make_unique>( common::relay::HandleWrapper{std::move(imageFileHandle)}, common::relay::HandleWrapper{Request.stream.native_handle()})); - io.AddHandle(std::make_unique(m_sessionTerminatingEvent.get(), [&]() { THROW_HR(E_ABORT); })); - - io.AddHandle(std::make_unique( - Request, std::move(onHttpResponse), std::move(onProgress), std::move(onCompleted))); + io.AddHandle( + std::make_unique(Request, std::move(onHttpResponse), std::move(onProgress)), + MultiHandleWait::CancelOnCompleted); io.Run({}); @@ -607,29 +601,27 @@ void WSLASession::ExportContainerImpl(std::pair& S THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_dockerClient.has_value()); - relay::MultiHandleWait io; - - auto onCompleted = [&]() { - io.Cancel(); - WSL_LOG("OnCompletedCalledForExport", TraceLoggingValue("OnCompletedCalledForExport", "Content")); - }; + auto io = CreateIOContext(); std::string errorJson; - auto accumulateError = [&](const gsl::span& buffer) { - // If the export failed, accumulate the error message. - errorJson.append(buffer.data(), buffer.size()); - }; if (SocketCodePair.first != 200) { - io.AddHandle(std::make_unique(common::relay::HandleWrapper{std::move(SocketCodePair.second)}, std::move(accumulateError))); + auto accumulateError = [&](const gsl::span& buffer) { + // If the export failed, accumulate the error message. + errorJson.append(buffer.data(), buffer.size()); + }; + + io.AddHandle( + std::make_unique(common::relay::HandleWrapper{std::move(SocketCodePair.second)}, std::move(accumulateError)), + MultiHandleWait::CancelOnCompleted); } else { - io.AddHandle(std::make_unique>( - common::relay::HandleWrapper{std::move(SocketCodePair.second)}, - common::relay::HandleWrapper{std::move(containerFileHandle), std::move(onCompleted)})); - io.AddHandle(std::make_unique(m_sessionTerminatingEvent.get(), [&]() { THROW_HR(E_ABORT); })); + io.AddHandle( + std::make_unique>( + common::relay::HandleWrapper{std::move(SocketCodePair.second)}, common::relay::HandleWrapper{std::move(containerFileHandle)}), + MultiHandleWait::CancelOnCompleted); } io.Run({}); @@ -668,25 +660,27 @@ void WSLASession::SaveImageImpl(std::pair& SocketC THROW_HR_IF(HRESULT_FROM_WIN32(ERROR_INVALID_STATE), !m_dockerClient.has_value()); - relay::MultiHandleWait io; + auto io = CreateIOContext(); - auto onCompleted = [&]() { io.Cancel(); }; std::string errorJson; - auto accumulateError = [&](const gsl::span& buffer) { - // If the save failed, accumulate the error message. - errorJson.append(buffer.data(), buffer.size()); - }; if (SocketCodePair.first != 200) { - io.AddHandle(std::make_unique(common::relay::HandleWrapper{std::move(SocketCodePair.second)}, std::move(accumulateError))); + auto accumulateError = [&](const gsl::span& buffer) { + // If the save failed, accumulate the error message. + errorJson.append(buffer.data(), buffer.size()); + }; + + io.AddHandle( + std::make_unique(common::relay::HandleWrapper{std::move(SocketCodePair.second)}, std::move(accumulateError)), + MultiHandleWait::CancelOnCompleted); } else { - io.AddHandle(std::make_unique>( - common::relay::HandleWrapper{std::move(SocketCodePair.second)}, - common::relay::HandleWrapper{std::move(imageFileHandle), std::move(onCompleted)})); - io.AddHandle(std::make_unique(m_sessionTerminatingEvent.get(), [&]() { THROW_HR(E_ABORT); })); + io.AddHandle( + std::make_unique>( + common::relay::HandleWrapper{std::move(SocketCodePair.second)}, common::relay::HandleWrapper{std::move(imageFileHandle)}), + MultiHandleWait::CancelOnCompleted); } io.Run({}); @@ -1127,6 +1121,22 @@ HRESULT WSLASession::InterfaceSupportsErrorInfo(REFIID riid) return riid == __uuidof(IWSLASession) ? S_OK : S_FALSE; } +// TODO consider allowing callers to pass cancellation handles. +MultiHandleWait WSLASession::CreateIOContext() +{ + relay::MultiHandleWait io; + + // Cancel with E_ABORT if the session is terminating. + io.AddHandle(std::make_unique( + m_sessionTerminatingEvent.get(), [this]() { THROW_HR_MSG(E_ABORT, "Session %lu is terminating", m_id); })); + + // Cancel with E_ABORT if the client process exits. + io.AddHandle(std::make_unique( + wslutil::OpenCallingProcess(SYNCHRONIZE), [this]() { THROW_HR_MSG(E_ABORT, "Client process has exited"); })); + + return io; +} + void WSLASession::OnContainerDeleted(const WSLAContainerImpl* Container) { std::lock_guard lock{m_lock}; diff --git a/src/windows/wslasession/WSLASession.h b/src/windows/wslasession/WSLASession.h index c40476e10..78f75f7b5 100644 --- a/src/windows/wslasession/WSLASession.h +++ b/src/windows/wslasession/WSLASession.h @@ -86,6 +86,8 @@ class DECLSPEC_UUID("4877FEFC-4977-4929-A958-9F36AA1892A4") WSLASession IFACEMETHOD(MapVmPort)(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort) override; IFACEMETHOD(UnmapVmPort)(_In_ int Family, _In_ short WindowsPort, _In_ short LinuxPort) override; + common::relay::MultiHandleWait CreateIOContext(); + private: ULONG m_id = 0;