From eb3fdc3a227bbf29eb5d68e233ea0045508d5d9d Mon Sep 17 00:00:00 2001 From: Pascal Nowack Date: Sat, 5 Apr 2025 16:33:57 +0200 Subject: [PATCH] rdpei/server: Add optional threaded handling of messages --- channels/rdpei/server/rdpei_main.c | 287 +++++++++++++++++++++++------ include/freerdp/server/rdpei.h | 17 ++ 2 files changed, 252 insertions(+), 52 deletions(-) diff --git a/channels/rdpei/server/rdpei_main.c b/channels/rdpei/server/rdpei_main.c index 4efc76032..4354edf76 100644 --- a/channels/rdpei/server/rdpei_main.c +++ b/channels/rdpei/server/rdpei_main.c @@ -32,6 +32,12 @@ #include #include +typedef enum +{ + RDPEI_INITIAL, + RDPEI_OPENED, +} eRdpEiChannelState; + enum RdpEiState { STATE_INITIAL, @@ -45,6 +51,12 @@ struct s_rdpei_server_private HANDLE channelHandle; HANDLE eventHandle; + HANDLE stopEvent; + HANDLE thread; + + /* Channel state */ + eRdpEiChannelState channelState; + UINT32 expectedBytes; BOOL waitingHeaders; wStream* inputStream; @@ -58,24 +70,233 @@ struct s_rdpei_server_private enum RdpEiState automataState; }; +static UINT rdpei_server_open_channel(RdpeiServerContext* context) +{ + DWORD error = ERROR_SUCCESS; + DWORD bytesReturned = 0; + PULONG pSessionId = NULL; + BOOL status = TRUE; + + WINPR_ASSERT(context); + + RdpeiServerPrivate* priv = context->priv; + WINPR_ASSERT(priv); + + if (WTSQuerySessionInformationA(context->vcm, WTS_CURRENT_SESSION, WTSSessionId, + (LPSTR*)&pSessionId, &bytesReturned) == FALSE) + { + WLog_ERR(TAG, "WTSQuerySessionInformationA failed!"); + return ERROR_INTERNAL_ERROR; + } + + DWORD sessionId = (DWORD)*pSessionId; + WTSFreeMemory(pSessionId); + + priv->channelHandle = + WTSVirtualChannelOpenEx(sessionId, RDPEI_DVC_CHANNEL_NAME, WTS_CHANNEL_OPTION_DYNAMIC); + if (!priv->channelHandle) + { + error = GetLastError(); + WLog_ERR(TAG, "WTSVirtualChannelOpenEx failed with error %" PRIu32 "!", error); + return error; + } + + const UINT32 channelId = WTSChannelGetIdByHandle(priv->channelHandle); + + IFCALLRET(context->onChannelIdAssigned, status, context, channelId); + if (!status) + { + WLog_ERR(TAG, "context->onChannelIdAssigned failed!"); + return ERROR_INTERNAL_ERROR; + } + + return error; +} + +static UINT rdpei_server_context_poll_int(RdpeiServerContext* context) +{ + RdpeiServerPrivate* priv = NULL; + UINT error = ERROR_INTERNAL_ERROR; + + WINPR_ASSERT(context); + priv = context->priv; + WINPR_ASSERT(priv); + + switch (priv->channelState) + { + case RDPEI_INITIAL: + error = rdpei_server_open_channel(context); + if (error) + WLog_ERR(TAG, "rdpei_server_open_channel failed with error %" PRIu32 "!", error); + else + priv->channelState = RDPEI_OPENED; + break; + case RDPEI_OPENED: + error = rdpei_server_handle_messages(context); + break; + default: + break; + } + + return error; +} + +static HANDLE rdpei_server_get_channel_handle(RdpeiServerContext* context) +{ + RdpeiServerPrivate* priv = NULL; + void* buffer = NULL; + DWORD bytesReturned = 0; + HANDLE channelEvent = NULL; + + WINPR_ASSERT(context); + priv = context->priv; + WINPR_ASSERT(priv); + + if (WTSVirtualChannelQuery(priv->channelHandle, WTSVirtualEventHandle, &buffer, + &bytesReturned) == TRUE) + { + if (bytesReturned == sizeof(HANDLE)) + channelEvent = *(HANDLE*)buffer; + + WTSFreeMemory(buffer); + } + + return channelEvent; +} + +static DWORD WINAPI rdpei_server_thread_func(LPVOID arg) +{ + RdpeiServerContext* context = (RdpeiServerContext*)arg; + RdpeiServerPrivate* priv = NULL; + HANDLE events[2] = { 0 }; + DWORD nCount = 0; + UINT error = CHANNEL_RC_OK; + DWORD status = 0; + + WINPR_ASSERT(context); + priv = context->priv; + WINPR_ASSERT(priv); + + events[nCount++] = priv->stopEvent; + + while ((error == CHANNEL_RC_OK) && (WaitForSingleObject(events[0], 0) != WAIT_OBJECT_0)) + { + switch (priv->channelState) + { + case RDPEI_INITIAL: + error = rdpei_server_context_poll_int(context); + if (error == CHANNEL_RC_OK) + { + events[1] = rdpei_server_get_channel_handle(context); + nCount = 2; + } + break; + case RDPEI_OPENED: + status = WaitForMultipleObjects(nCount, events, FALSE, INFINITE); + switch (status) + { + case WAIT_OBJECT_0: + break; + case WAIT_OBJECT_0 + 1: + case WAIT_TIMEOUT: + error = rdpei_server_context_poll_int(context); + break; + + case WAIT_FAILED: + default: + error = ERROR_INTERNAL_ERROR; + break; + } + break; + default: + break; + } + } + + (void)WTSVirtualChannelClose(priv->channelHandle); + priv->channelHandle = NULL; + + ExitThread(error); + return error; +} + +static UINT rdpei_server_open(RdpeiServerContext* context) +{ + RdpeiServerPrivate* priv = NULL; + + priv = context->priv; + WINPR_ASSERT(priv); + + if (!priv->thread) + { + priv->stopEvent = CreateEvent(NULL, TRUE, FALSE, NULL); + if (!priv->stopEvent) + { + WLog_ERR(TAG, "CreateEvent failed!"); + return ERROR_INTERNAL_ERROR; + } + + priv->thread = CreateThread(NULL, 0, rdpei_server_thread_func, context, 0, NULL); + if (!priv->thread) + { + WLog_ERR(TAG, "CreateThread failed!"); + (void)CloseHandle(priv->stopEvent); + priv->stopEvent = NULL; + return ERROR_INTERNAL_ERROR; + } + } + + return CHANNEL_RC_OK; +} + +static UINT rdpei_server_close(RdpeiServerContext* context) +{ + RdpeiServerPrivate* priv = NULL; + UINT error = CHANNEL_RC_OK; + + priv = context->priv; + WINPR_ASSERT(priv); + + if (priv->thread) + { + (void)SetEvent(priv->stopEvent); + + if (WaitForSingleObject(priv->thread, INFINITE) == WAIT_FAILED) + { + error = GetLastError(); + WLog_ERR(TAG, "WaitForSingleObject failed with error %" PRIu32 "", error); + return error; + } + + (void)CloseHandle(priv->thread); + (void)CloseHandle(priv->stopEvent); + priv->thread = NULL; + priv->stopEvent = NULL; + } + + return error; +} + RdpeiServerContext* rdpei_server_context_new(HANDLE vcm) { RdpeiServerContext* ret = calloc(1, sizeof(*ret)); - RdpeiServerPrivate* priv = NULL; if (!ret) return NULL; - ret->priv = priv = calloc(1, sizeof(*ret->priv)); - if (!priv) + ret->Open = rdpei_server_open; + ret->Close = rdpei_server_close; + + ret->priv = calloc(1, sizeof(*ret->priv)); + if (!ret->priv) goto fail; - priv->inputStream = Stream_New(NULL, 256); - if (!priv->inputStream) + ret->priv->inputStream = Stream_New(NULL, 256); + if (!ret->priv->inputStream) goto fail; - priv->outputStream = Stream_New(NULL, 200); - if (!priv->inputStream) + ret->priv->outputStream = Stream_New(NULL, 200); + if (!ret->priv->inputStream) goto fail; ret->vcm = vcm; @@ -97,60 +318,22 @@ fail: */ UINT rdpei_server_init(RdpeiServerContext* context) { - void* buffer = NULL; - DWORD bytesReturned = 0; RdpeiServerPrivate* priv = context->priv; - UINT32 channelId = 0; - BOOL status = TRUE; - DWORD BytesReturned = 0; - PULONG pSessionId = NULL; - DWORD SessionId = 0; + UINT error = rdpei_server_open_channel(context); + if (error) + return error; - if (WTSQuerySessionInformationA(context->vcm, WTS_CURRENT_SESSION, WTSSessionId, - (LPSTR*)&pSessionId, &BytesReturned) == FALSE) + priv->eventHandle = rdpei_server_get_channel_handle(context); + if (!priv->eventHandle) { - WLog_ERR(TAG, "WTSQuerySessionInformationA failed!"); - return ERROR_INTERNAL_ERROR; - } - - SessionId = (DWORD)*pSessionId; - WTSFreeMemory(pSessionId); - - priv->channelHandle = - WTSVirtualChannelOpenEx(SessionId, RDPEI_DVC_CHANNEL_NAME, WTS_CHANNEL_OPTION_DYNAMIC); - if (!priv->channelHandle) - { - WLog_ERR(TAG, "WTSVirtualChannelOpenEx failed!"); - return CHANNEL_RC_INITIALIZATION_ERROR; - } - - channelId = WTSChannelGetIdByHandle(priv->channelHandle); - - IFCALLRET(context->onChannelIdAssigned, status, context, channelId); - if (!status) - { - WLog_ERR(TAG, "context->onChannelIdAssigned failed!"); + WLog_ERR(TAG, "Failed to get channel handle!"); goto out_close; } - if (!WTSVirtualChannelQuery(priv->channelHandle, WTSVirtualEventHandle, &buffer, - &bytesReturned) || - (bytesReturned != sizeof(HANDLE))) - { - WLog_ERR(TAG, "WTSVirtualChannelQuery failed or invalid returned size(%" PRIu32 ")!", - bytesReturned); - if (buffer) - WTSFreeMemory(buffer); - goto out_close; - } - priv->eventHandle = *(HANDLE*)buffer; - WTSFreeMemory(buffer); - return CHANNEL_RC_OK; out_close: (void)WTSVirtualChannelClose(priv->channelHandle); - priv->channelHandle = NULL; return CHANNEL_RC_INITIALIZATION_ERROR; } @@ -174,7 +357,7 @@ void rdpei_server_context_free(RdpeiServerContext* context) priv = context->priv; if (priv) { - if (priv->channelHandle != INVALID_HANDLE_VALUE) + if (priv->channelHandle && priv->channelHandle != INVALID_HANDLE_VALUE) (void)WTSVirtualChannelClose(priv->channelHandle); Stream_Free(priv->inputStream, TRUE); } diff --git a/include/freerdp/server/rdpei.h b/include/freerdp/server/rdpei.h index 215f7f1e5..812532e62 100644 --- a/include/freerdp/server/rdpei.h +++ b/include/freerdp/server/rdpei.h @@ -34,6 +34,9 @@ extern "C" typedef struct s_rdpei_server_context RdpeiServerContext; typedef struct s_rdpei_server_private RdpeiServerPrivate; + typedef UINT (*psRdpeiServerOpen)(RdpeiServerContext* context); + typedef UINT (*psRdpeiServerClose)(RdpeiServerContext* context); + struct s_rdpei_server_context { HANDLE vcm; @@ -56,6 +59,20 @@ extern "C" * Callback, when the channel got its id assigned. */ BOOL (*onChannelIdAssigned)(RdpeiServerContext* context, UINT32 channelId); + + /*** APIs called by the server. ***/ + + /** + * Open the input channel. + * @since version 3.15.0 + */ + psRdpeiServerOpen Open; + + /** + * Close the input channel. + * @since version 3.15.0 + */ + psRdpeiServerClose Close; }; FREERDP_API void rdpei_server_context_free(RdpeiServerContext* context);